//
//  Heise Vision Pro
//  Copyright © 2024 Gero Gerber. All rights reserved.
//

import RealityKit
import RealityKitContent
import SwiftUI

struct VertexPositionColor {
    var position: SIMD3<Float> = .zero
    var color: SIMD4<Float> = .zero
}

extension VertexPositionColor {
    static var vertexAttributes: [LowLevelMesh.Attribute] = [
        .init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
        .init(semantic: .color, format: .float4, offset: MemoryLayout<Self>.offset(of: \.color)!)
    ]

    static var vertexLayouts: [LowLevelMesh.Layout] = [
        .init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
    ]

    static var descriptor: LowLevelMesh.Descriptor {
        var desc = LowLevelMesh.Descriptor()
        desc.vertexAttributes = VertexPositionColor.vertexAttributes
        desc.vertexLayouts = VertexPositionColor.vertexLayouts
        desc.indexType = .uint32
        return desc
    }
}

struct MeshParams {
    let meshResolution: Int
    let meshSize: Float
    let timerPhase: Float
    let color: SIMD4<Float>
}

struct LowLevelMeshView: View {
    @State private var showWireframe = true
    @State private var vertexColor = Color(.sRGB, red: 0.98, green: 0.9, blue: 0.2)
    @State private var timerPhase: Float = 0
    @State private var useShader = false

    private let mesh: LowLevelMesh
    private let root = Entity()
    private let meshSize: Float = 0.3
    private let timer = Timer.publish(every: 1.0 / 60.0, on: .main, in: .common).autoconnect()

    private static let meshResolution: Int = 100

    let device: MTLDevice
    let commandQueue: MTLCommandQueue
    let computePipeline: MTLComputePipelineState

    init() {
        var desc = VertexPositionColor.descriptor
        desc.vertexCapacity = Self.meshResolution * Self.meshResolution
        desc.indexCapacity = Self.indicesCount

        mesh = try! LowLevelMesh(descriptor: desc)

        self.device = MTLCreateSystemDefaultDevice()!
        self.commandQueue = device.makeCommandQueue()!

        let library = device.makeDefaultLibrary()!
        let updateFunction = library.makeFunction(name: "updateMeshKernel")!
        self.computePipeline = try! device.makeComputePipelineState(function: updateFunction)
    }

    var body: some View {
        VStack {
            RealityView { content in
                await root.addChild(createMeshEntity())
                content.add(root)
            }
            .frame(depth: 0.5)
            .onChange(of: showWireframe) { _, _ in
                updateMaterial()
            }
            .onChange(of: vertexColor) { _, _ in
                updateMesh()
            }
            .onReceive(timer) { _ in
                timerPhase += 0.1
                updateMesh()
            }
            Group {
                ColorPicker("Color", selection: $vertexColor)
                Toggle(isOn: $showWireframe) {
                    Text("Wireframe")
                }
                Toggle(isOn: $useShader) {
                    Text("Use Compute Shader")
                }
                .padding(.bottom)
            }
            .padding(.horizontal, 100)
        }
    }

    private static var indicesCount: Int {
        (meshResolution - 1) * (meshResolution - 1) * 6
    }

    private func updateMesh() {
        if useShader {
            updateMeshGPU()
        } else {
            updateMeshCPU()
        }
    }

    private func updateMeshCPU() {
        mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
            let vertices = rawBytes.bindMemory(to: VertexPositionColor.self)

            for y in 0 ..< Self.meshResolution {
                for x in 0 ..< Self.meshResolution {
                    let index = y * Self.meshResolution + x
                    let xPos = Float(x) / Float(Self.meshResolution - 1) * meshSize - meshSize / 2
                    let yPos = Float(y) / Float(Self.meshResolution - 1) * meshSize - meshSize / 2

                    let position = SIMD3<Float>(xPos + sin(Float(y) * 0.1 + timerPhase) / 100 * 0.5,
                                                yPos + sin(Float(x) * 0.1 + timerPhase) / 100 * 0.5,
                                                0)

                    vertices[index] = VertexPositionColor(position: position, color: vertexColor.asFloat4)
                }
            }
        }

        mesh.withUnsafeMutableIndices { rawIndices in
            let indices = rawIndices.bindMemory(to: UInt32.self)
            var index = 0

            for y in 0 ..< (Self.meshResolution - 1) {
                for x in 0 ..< (Self.meshResolution - 1) {
                    let topLeft = UInt32(y * Self.meshResolution + x)
                    let topRight = topLeft + 1
                    let bottomLeft = UInt32((y + 1) * Self.meshResolution + x)
                    let bottomRight = bottomLeft + 1

                    indices[index] = topLeft
                    indices[index + 2] = bottomLeft
                    indices[index + 1] = topRight

                    indices[index + 3] = topRight
                    indices[index + 5] = bottomLeft
                    indices[index + 4] = bottomRight

                    index += 6
                }
            }
        }

        replaceMesh()
    }

    private func updateMeshGPU() {
        guard let commandBuffer = commandQueue.makeCommandBuffer(),
              let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }

        let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)

        computeEncoder.setComputePipelineState(computePipeline)
        computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)

        var params = MeshParams(meshResolution: Self.meshResolution,
                                meshSize: meshSize,
                                timerPhase: timerPhase,
                                color: vertexColor.asFloat4)
        computeEncoder.setBytes(&params, length: MemoryLayout<MeshParams>.size, index: 1)

        let threadSize = MTLSize(width: Self.meshResolution, height: 1, depth: 1)

        computeEncoder.dispatchThreadgroups(threadSize, threadsPerThreadgroup: threadSize)

        computeEncoder.endEncoding()
        commandBuffer.commit()

        replaceMesh()
    }

    private func replaceMesh() {
        let meshBounds = BoundingBox(min: [-meshSize / 2.0, -meshSize / 2.0, 0], max: [meshSize / 2.0, meshSize / 2.0, 0])

        mesh.parts.replaceAll([
            LowLevelMesh.Part(
                indexCount: Self.indicesCount,
                topology: .triangle,
                bounds: meshBounds
            )
        ])
    }

    private func updateMaterial() {
        if let entity = root.children.first, var material = entity.components[ModelComponent.self]?.materials.first as? ShaderGraphMaterial {
            material.triangleFillMode = showWireframe ? .lines : .fill
            entity.components[ModelComponent.self]?.materials = [material]
        }
    }

    private func createMeshEntity() async -> Entity {
        updateMesh()

        var material = try! await ShaderGraphMaterial(named: "/Root/VertexColorMaterial", from: "HelperScene", in: realityKitContentBundle)
        material.triangleFillMode = showWireframe ? .lines : .fill

        let resource = try! await MeshResource(from: mesh)
        let modelComponent = ModelComponent(mesh: resource, materials: [material])

        let entity = Entity()
        entity.name = "Mesh"
        entity.components.set(modelComponent)
        return entity
    }
}
