//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------

#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include <Foundation/Foundation.hpp>
#include <Metal/Metal.hpp>
#include <QuartzCore/QuartzCore.hpp>

#define IR_RUNTIME_METALCPP
#define IR_PRIVATE_IMPLEMENTATION
#include <metal_irconverter_runtime/metal_irconverter_runtime.h>

#include <metal_irconverter/metal_irconverter.h>

#include <fstream>
#include <simd/simd.h>
#include <utility>
#include <variant>
#include <vector>

// Include RenderCore after metal-cpp and metal-irconverter to generate
// their implementations inline in this file.
#include "RenderCore.hpp"

#include "MathUtils.hpp"

#define NUM_ELEMS(arr) (sizeof(arr) / sizeof(arr[0]))

static constexpr uint32_t kTextureWidth                  = 1600;
static constexpr uint32_t kTextureHeight                 = 1200;
static constexpr uint64_t kPerFrameBumpAllocatorCapacity = 1024; // 1 KiB

RenderCore::RenderCore(MTL::Device* pDevice, const std::string& shaderSearchPath)
    : _pDevice(pDevice->retain())
    , _frame(0)
    , _animationIndex(0)
    , _aspectRatio(1.f)
{
    _pCommandQueue = _pDevice->newCommandQueue();
    buildRenderPipelines(shaderSearchPath);
    buildComputePipelines(shaderSearchPath);
    buildTextures();
    buildSamplers();
    buildBuffers();

    _semaphore = dispatch_semaphore_create(kMaxFramesInFlight);

    for (size_t i = 0; i < kMaxFramesInFlight; ++i)
    {
        _bufferAllocator[i] = new BumpAllocator(pDevice, kPerFrameBumpAllocatorCapacity, MTL::ResourceStorageModeShared);
    }
}

static void releaseMesh(IndexedMesh* pIndexedMesh)
{
    pIndexedMesh->pVertices->release();
    pIndexedMesh->pIndices->release();
}

RenderCore::~RenderCore()
{
    _pTriangleTexture->release();
    _pSampler->release();

    for (size_t i = 0; i < kMaxFramesInFlight; ++i)
    {
        delete _bufferAllocator[i];
    }

    _pComputeRTPipeline->release();
    _pPresentPipeline->release();
    _pScratch->release();
    releaseMesh(&_screenMesh);
    _pCommandQueue->release();
    _pDevice->release();
    dispatch_release(_semaphore);
}

void RenderCore::buildRenderPipelines(const std::string& shaderSearchPath)
{
    _pPresentPipeline = shader_pipeline::newPresentPipeline(shaderSearchPath, _pDevice);
}

void RenderCore::buildComputePipelines(const std::string& shaderSearchPath)
{
    _pComputeRTPipeline = shader_pipeline::newComputeRTPipeline(shaderSearchPath, _pDevice);
}

void RenderCore::buildTextures()
{
    MTL::TextureDescriptor* pTextureDesc = MTL::TextureDescriptor::alloc()->init();
    pTextureDesc->setWidth(kTextureWidth);
    pTextureDesc->setHeight(kTextureHeight);
    pTextureDesc->setDepth(1);
    pTextureDesc->setPixelFormat(MTL::PixelFormatBGRA8Unorm);
#if IR_VERSION_MAJOR < 3
    pTextureDesc->setTextureType(MTL::TextureType2DArray);
#else
    pTextureDesc->setTextureType(MTL::TextureType2D);
#endif
    pTextureDesc->setArrayLength(1);
    pTextureDesc->setStorageMode(MTL::StorageModePrivate);
    pTextureDesc->setUsage(MTL::ResourceUsageSample | MTL::ResourceUsageRead | MTL::ResourceUsageWrite);

    MTL::Texture* pTexture = _pDevice->newTexture(pTextureDesc);
    _pTriangleTexture              = pTexture;

    pTextureDesc->release();
}

void RenderCore::buildSamplers()
{
    MTL::SamplerDescriptor* pSampDesc = MTL::SamplerDescriptor::alloc()->init()->autorelease();
    pSampDesc->setSupportArgumentBuffers(true);
    pSampDesc->setMagFilter(MTL::SamplerMinMagFilterLinear);
    pSampDesc->setMinFilter(MTL::SamplerMinMagFilterLinear);
    pSampDesc->setRAddressMode(MTL::SamplerAddressModeRepeat);
    pSampDesc->setSAddressMode(MTL::SamplerAddressModeRepeat);
    _pSampler = _pDevice->newSamplerState(pSampDesc);
    assert(_pSampler);
}

void RenderCore::buildBuffers()
{
    using simd::float2;
    using simd::float3;
    
    // Screen mesh helps present the ray traced objects on screen:
    _screenMesh = mesh_utils::newScreenQuad(_pDevice);

    // Some extra scratch memory for argument buffers:

    MTL::HeapDescriptor* pHeapDesc = MTL::HeapDescriptor::alloc()->init()->autorelease();
    pHeapDesc->setSize(4 * 1024); // 4 KiB
    pHeapDesc->setStorageMode(MTL::StorageModeShared);
    pHeapDesc->setHazardTrackingMode(MTL::HazardTrackingModeTracked);
    pHeapDesc->setType(MTL::HeapTypeAutomatic);
    _pScratch = _pDevice->newHeap(pHeapDesc);
    
    // Build scene data
    
    float3 vertexData[] = {
        { -1.0, -1.0, 0.0 },
        { +1.0, -1.0, 0.0 },
        {  0.0, +1.0, 0.0 }
    };
    
    uint16_t indexData[] = {
        0, 1, 2
    };
    
    auto vertexBuffer = NS::TransferPtr(_pDevice->newBuffer(sizeof(vertexData), MTL::ResourceStorageModeShared));
    auto indexBuffer = NS::TransferPtr(_pDevice->newBuffer(sizeof(indexData), MTL::ResourceStorageModeShared));
    
    memcpy(vertexBuffer->contents(), vertexData, sizeof(vertexData));
    memcpy(indexBuffer->contents(), indexData, sizeof(indexData));
    
    // Build primitive acceleration structure
    {
        auto pGeometryDescriptor = NS::TransferPtr(MTL::AccelerationStructureTriangleGeometryDescriptor::alloc()->init());
        pGeometryDescriptor->setVertexBuffer(vertexBuffer.get());
        pGeometryDescriptor->setVertexBufferOffset(0);
        pGeometryDescriptor->setVertexStride(sizeof(float3));
        
        pGeometryDescriptor->setIndexBuffer(indexBuffer.get());
        pGeometryDescriptor->setIndexBufferOffset(0);
        pGeometryDescriptor->setIndexType(MTL::IndexTypeUInt16);
        
        pGeometryDescriptor->setTriangleCount(NUM_ELEMS(indexData)/3);
        
        MTL::AccelerationStructureTriangleGeometryDescriptor* descriptors[] = {
            pGeometryDescriptor.get()
        };
        
        auto geometryDescriptors = NS::TransferPtr((NS::Array *)CFArrayCreate(CFAllocatorGetDefault(),
                                                                              (const void**)&descriptors, 1,
                                                                              &kCFTypeArrayCallBacks));
        
        auto pPrimitiveDescriptors = NS::TransferPtr(MTL::PrimitiveAccelerationStructureDescriptor::alloc()->init());
        pPrimitiveDescriptors->setGeometryDescriptors(geometryDescriptors.get());
        
        auto primitiveSizes = _pDevice->accelerationStructureSizes(pPrimitiveDescriptors.get());
        _primitiveAccelerationStructure = NS::TransferPtr(_pDevice->newAccelerationStructure(primitiveSizes.accelerationStructureSize));
    
    
        auto pScratch = NS::TransferPtr(_pDevice->newBuffer(primitiveSizes.buildScratchBufferSize, MTL::ResourceStorageModePrivate));
        
        assert(_pCommandQueue);
        MTL::CommandBuffer* pCmd = _pCommandQueue->commandBuffer();
        MTL::AccelerationStructureCommandEncoder* pBuildEnc = pCmd->accelerationStructureCommandEncoder();
        pBuildEnc->buildAccelerationStructure(_primitiveAccelerationStructure.get(),
                                              pPrimitiveDescriptors.get(),
                                              pScratch.get(), 0);
        pBuildEnc->endEncoding();
        pCmd->commit();
        pCmd->waitUntilCompleted();
    }
    
    // Build instance acceleration structure
    {
        
        MTL::AccelerationStructure* primitiveStructures[] = {
            _primitiveAccelerationStructure.get()
        };
        
        auto primitiveAccelerationStructures = NS::TransferPtr((NS::Array *)CFArrayCreate(CFAllocatorGetDefault(),
                                                                                          (const void **)primitiveStructures, 1,
                                                                                          &kCFTypeArrayCallBacks));
        
        auto pInstanceAccelDescriptor = NS::TransferPtr(MTL::InstanceAccelerationStructureDescriptor::alloc()->init());
        pInstanceAccelDescriptor->setInstancedAccelerationStructures(primitiveAccelerationStructures.get());
        
        const uint32_t kInstanceCount = 2;
        pInstanceAccelDescriptor->setInstanceCount(kInstanceCount);
        
        MTL::AccelerationStructureInstanceDescriptor instanceDescs[kInstanceCount] = { {
            .accelerationStructureIndex = 0,
            .intersectionFunctionTableOffset = 0,
            .mask = 0xFF,
            .options = MTL::AccelerationStructureInstanceOptionOpaque | MTL::AccelerationStructureInstanceOptionDisableTriangleCulling,
            .transformationMatrix = MTL::PackedFloat4x3({ 0.25, 0, 0 },     /* col 0 */
                                                        { 0, 0.25, 0 },     /* col 1 */
                                                        { 0, 0, 1 },        /* col 2 */
                                                        { -0.55, 0, -1.5 }) /* col 3 */
            },
            {
            .accelerationStructureIndex = 0,
            .intersectionFunctionTableOffset = 0,
            .mask = 0xFF,
            .options = MTL::AccelerationStructureInstanceOptionOpaque | MTL::AccelerationStructureInstanceOptionDisableTriangleCulling,
            .transformationMatrix = MTL::PackedFloat4x3(MTL::PackedFloat3( 0.5, 0, 0 ),     /* col 0 */
                                                        MTL::PackedFloat3( 0, 0.5, 0 ),     /* col 1 */
                                                        MTL::PackedFloat3( 0, 0, 1 ),       /* col 2 */
                                                        MTL::PackedFloat3( 0.25, 0, -1.5 )) /* col 3 */
            }};
        
        auto pInstanceBuffer = NS::TransferPtr(_pDevice->newBuffer(kInstanceCount * sizeof(MTL::AccelerationStructureInstanceDescriptor), MTL::ResourceStorageModeShared));
        memcpy(pInstanceBuffer->contents(), instanceDescs, kInstanceCount * sizeof(MTL::AccelerationStructureInstanceDescriptor));
        
        pInstanceAccelDescriptor->setInstanceDescriptorBuffer(pInstanceBuffer.get());
        
        auto instanceSizes = _pDevice->accelerationStructureSizes(pInstanceAccelDescriptor.get());
        _instanceAccelerationStructure = NS::TransferPtr(_pDevice->newAccelerationStructure(instanceSizes.accelerationStructureSize));
        
        auto pScratch = NS::TransferPtr(_pDevice->newBuffer(instanceSizes.buildScratchBufferSize, MTL::ResourceStorageModePrivate));
        
        assert(_pCommandQueue);
        MTL::CommandBuffer* pCmd = _pCommandQueue->commandBuffer();
        MTL::AccelerationStructureCommandEncoder* pEnc = pCmd->accelerationStructureCommandEncoder();
        pEnc->buildAccelerationStructure(_instanceAccelerationStructure.get(),
                                         pInstanceAccelDescriptor.get(),
                                         pScratch.get(), 0);
        pEnc->endEncoding();
        pCmd->commit();
        pCmd->waitUntilCompleted();
    }
}

void RenderCore::raytrace(MTL::CommandBuffer* pCommandBuffer)
{
    assert(pCommandBuffer);

    MTL::ComputeCommandEncoder* pComputeEncoder = pCommandBuffer->computeCommandEncoder();

    pComputeEncoder->setComputePipelineState(_pComputeRTPipeline);


    // Bind resources (according to root signature):
    MTL::Buffer* pUAVTable = _pScratch->newBuffer(sizeof(IRDescriptorTableEntry), MTL::ResourceStorageModeShared)->autorelease();

    IRDescriptorTableSetTexture((IRDescriptorTableEntry*)pUAVTable->contents(), _pTriangleTexture, 0, 0);

    struct TopLevelAB
    {
        uint64_t accelStructureHeaderAddr;
        uint64_t uavTableAddr;
    };
    
    auto [topLevelABContents, offset]   = _bufferAllocator[_frame]->allocate<TopLevelAB>();
    
    uint32_t instanceContributions[] = { 0, 0 };
    NS::UInteger headerSize = sizeof(IRRaytracingAccelerationStructureGPUHeader) + sizeof(uint32_t) * NUM_ELEMS(instanceContributions);
    MTL::Buffer* pAccelStructureHdrBuffer = _pDevice->newBuffer(headerSize, MTL::ResourceStorageModeShared)->autorelease();
    
    // This sample uses the same header buffer to store the instance contributions to avoid loading
    // from a second buffer, and for simplicity, but this is not a requirement.
    
    IRRaytracingSetAccelerationStructure((uint8_t *)pAccelStructureHdrBuffer->contents(),
                                         _instanceAccelerationStructure->gpuResourceID(),
                                         (uint8_t *)pAccelStructureHdrBuffer->contents() + sizeof(IRRaytracingAccelerationStructureGPUHeader),
                                         instanceContributions, NUM_ELEMS(instanceContributions));
    
    
    topLevelABContents->accelStructureHeaderAddr = pAccelStructureHdrBuffer->gpuAddress();
    topLevelABContents->uavTableAddr    = pUAVTable->gpuAddress();

    MTL::Resource* indirectROResources[] = {
        _primitiveAccelerationStructure.get(),
        _instanceAccelerationStructure.get(),
        pAccelStructureHdrBuffer,
        pUAVTable
    };
    
    pComputeEncoder->useResources(indirectROResources, NUM_ELEMS(indirectROResources), MTL::ResourceUsageRead);
    
    pComputeEncoder->useResource(_pTriangleTexture, MTL::ResourceUsageWrite);
    pComputeEncoder->setBuffer(_bufferAllocator[_frame]->baseBuffer(), offset, kIRArgumentBufferBindPoint);

    // Dispatch threads:
    NS::UInteger threadGroupSize = _pComputeRTPipeline->maxTotalThreadsPerThreadgroup();
    MTL::Size threadgroupSize(threadGroupSize, 1, 1);

    MTL::Size gridSize = MTL::Size(kTextureWidth, kTextureHeight, 1);
    pComputeEncoder->dispatchThreads(gridSize, threadgroupSize);

    pComputeEncoder->endEncoding();
}

void RenderCore::presentTexture(MTL::RenderCommandEncoder* pRenderEnc, MTL::Texture* pTexture)
{
    struct PresentTLAB
    {
        uint64_t srvTable;
        uint64_t smpTable;
    };
    
    auto [srvPtr, srvOff] = _bufferAllocator[_frame]->allocate<IRDescriptorTableEntry>();
    auto [smpPtr, smpOff] = _bufferAllocator[_frame]->allocate<IRDescriptorTableEntry>();
    
    IRDescriptorTableSetTexture(srvPtr, pTexture, 0, 0);
    IRDescriptorTableSetSampler(smpPtr, _pSampler, 0);
    
    auto [presentTlabPtr, presentTlabOff] = _bufferAllocator[_frame]->allocate<PresentTLAB>();
    presentTlabPtr->srvTable = _bufferAllocator[_frame]->baseBuffer()->gpuAddress() + srvOff;
    presentTlabPtr->smpTable = _bufferAllocator[_frame]->baseBuffer()->gpuAddress() + smpOff;
    
    pRenderEnc->useResource(pTexture, MTL::ResourceUsageRead);
    pRenderEnc->setVertexBuffer(_screenMesh.pVertices, 0, kIRVertexBufferBindPoint);
    pRenderEnc->setFragmentBuffer(_bufferAllocator[_frame]->baseBuffer(), presentTlabOff, kIRArgumentBufferBindPoint);
    
    pRenderEnc->drawIndexedPrimitives(MTL::PrimitiveTypeTriangle, _screenMesh.numIndices, _screenMesh.indexType, _screenMesh.pIndices, 0);
}


void RenderCore::updateWorld()
{
}

void RenderCore::draw(MTL::RenderPassDescriptor* pRenderPass, CA::MetalDrawable* pDrawable)
{
    NS::AutoreleasePool* pPool = NS::AutoreleasePool::alloc()->init();

//#define CAPTURE
#ifdef CAPTURE
    MTL::CaptureDescriptor* pCapDesc = MTL::CaptureDescriptor::alloc()->init()->autorelease();
    pCapDesc->setDestination(MTL::CaptureDestinationDeveloperTools);
    pCapDesc->setCaptureObject(_pDevice);
    
    NS::Error* pError = nullptr;
    MTL::CaptureManager* pCapMan = MTL::CaptureManager::sharedCaptureManager();
    if (!pCapMan->startCapture(pCapDesc, &pError))
    {
        printf("%s\n", pError->localizedDescription()->utf8String());
        __builtin_trap();
    }
#endif

    _frame = (_frame + 1) % kMaxFramesInFlight;

    // Wait for the signal to start encoding the next frame.
    dispatch_semaphore_wait(_semaphore, DISPATCH_TIME_FOREVER);
    RenderCore* pRenderCore = this;

    // Reset the bump allocator for this new frame.
    _bufferAllocator[_frame]->reset();

    MTL::CommandBuffer* pCmd = _pCommandQueue->commandBuffer();
    pCmd->addCompletedHandler(^void(MTL::CommandBuffer* pCmd) {
        dispatch_semaphore_signal(pRenderCore->_semaphore);
    });

    // Update scene data to produce any animations
    updateWorld();

    // Update the texture:
    raytrace(pCmd);
    
    MTL::RenderCommandEncoder* pEnc = pCmd->renderCommandEncoder(pRenderPass);
    pEnc->setRenderPipelineState(_pPresentPipeline);
    presentTexture(pEnc, _pTriangleTexture);
    pEnc->endEncoding();
    pCmd->presentDrawable(pDrawable);
    
    pCmd->commit();
    
#ifdef CAPTURE
    pCapMan->stopCapture();
#endif
    
    pPool->release();
}

void RenderCore::resizeDrawable(float width, float height)
{
    _aspectRatio = width / height;
}
