/*
 * Copyright (c) 2021-2022 Apple Inc. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#pragma once

#import "API.h"
#import "Adapter.h"
#import "HardwareCapabilities.h"
#import "Queue.h"
#import "WebGPU.h"
#import "WebGPUExt.h"
#import <CoreVideo/CVMetalTextureCache.h>
#import <CoreVideo/CoreVideo.h>
#import <IOSurface/IOSurfaceRef.h>
#import <Metal/Metal.h>
#import <simd/matrix_types.h>
#import <wtf/CompletionHandler.h>
#import <wtf/FastMalloc.h>
#import <wtf/Function.h>
#import <wtf/Ref.h>
#import <wtf/RetainReleaseSwift.h>
#import <wtf/TZoneMalloc.h>
#import <wtf/ThreadSafeWeakPtr.h>
#import <wtf/Vector.h>
#import <wtf/text/WTFString.h>

struct WGPUDeviceImpl {
};

namespace WGSL {
struct PipelineLayout;
}

namespace WebGPU {

class BindGroup;
class BindGroupLayout;
class Buffer;
class CommandEncoder;
class ComputePipeline;
class ExternalTexture;
class Instance;
class PipelineLayout;
class PresentationContext;
class QuerySet;
class RenderBundleEncoder;
class RenderPipeline;
class Sampler;
class ShaderModule;
class Texture;
class XRBinding;
class XRSubImage;
class XRProjectionLayer;
class XRView;

#if ENABLE(WEBGPU_BY_DEFAULT)
using GPUShaderValidation = MTLShaderValidation;
#else
using GPUShaderValidation = uint32_t;
#endif

// https://gpuweb.github.io/gpuweb/#gpudevice
class Device : public WGPUDeviceImpl, public ThreadSafeRefCountedAndCanMakeThreadSafeWeakPtr<Device> {
    WTF_MAKE_TZONE_ALLOCATED(Device);
public:
    static Ref<Device> create(id<MTLDevice>, String&& deviceLabel, HardwareCapabilities&&, Adapter&);
    static Ref<Device> createInvalid(Adapter& adapter)
    {
        return adoptRef(*new Device(adapter));
    }

    ~Device();

    Ref<BindGroup> createBindGroup(const WGPUBindGroupDescriptor&);
    Ref<BindGroupLayout> createBindGroupLayout(const WGPUBindGroupLayoutDescriptor&, bool isGeneratedLayout = false);
    Ref<XRBinding> createXRBinding();
    Ref<XRSubImage> createXRSubImage();
    Ref<XRView> createXRView();
    Ref<Buffer> createBuffer(const WGPUBufferDescriptor&);
    Ref<CommandEncoder> createCommandEncoder(const WGPUCommandEncoderDescriptor&);
    std::pair<Ref<ComputePipeline>, NSString*> createComputePipeline(const WGPUComputePipelineDescriptor&, bool isAsync = false);
    void createComputePipelineAsync(const WGPUComputePipelineDescriptor&, CompletionHandler<void(WGPUCreatePipelineAsyncStatus, Ref<ComputePipeline>&&, String&& message)>&& callback);
    Ref<ExternalTexture> createExternalTexture(const WGPUExternalTextureDescriptor&);
    void updateExternalTexture(ExternalTexture&) const;
    Ref<PipelineLayout> createPipelineLayout(const WGPUPipelineLayoutDescriptor&, bool isAutogenerated = false);
    Ref<QuerySet> createQuerySet(const WGPUQuerySetDescriptor&);
    Ref<RenderBundleEncoder> createRenderBundleEncoder(const WGPURenderBundleEncoderDescriptor&);
    Ref<PipelineLayout> extracted(const Vector<Vector<WGPUBindGroupLayoutEntry>> &bindGroupEntries);

    std::pair<Ref<RenderPipeline>, NSString*> createRenderPipeline(const WGPURenderPipelineDescriptor&, bool isAsync = false);
    void createRenderPipelineAsync(const WGPURenderPipelineDescriptor&, CompletionHandler<void(WGPUCreatePipelineAsyncStatus, Ref<RenderPipeline>&&, String&& message)>&& callback);
    Ref<Sampler> createSampler(const WGPUSamplerDescriptor&);
    Ref<ShaderModule> createShaderModule(const WGPUShaderModuleDescriptor&);
    Ref<PresentationContext> createSwapChain(PresentationContext&, const WGPUSwapChainDescriptor&);
    Ref<Texture> createTexture(const WGPUTextureDescriptor&);
    void destroy();
    size_t enumerateFeatures(WGPUFeatureName* features);
    bool getLimits(WGPUSupportedLimits&);
    Queue& getQueue() const { return m_defaultQueue; }
    Ref<Queue> protectedQueue() const { return m_defaultQueue; }
    bool hasFeature(WGPUFeatureName) const;
    bool popErrorScope(CompletionHandler<void(WGPUErrorType, String&&)>&& callback);
    void pushErrorScope(WGPUErrorFilter);
    void setDeviceLostCallback(Function<void(WGPUDeviceLostReason, String&&)>&&);
    void setUncapturedErrorCallback(Function<void(WGPUErrorType, String&&)>&&);
    void setLabel(String&&);

    bool isValid() const { return m_device; }
    bool isLost() const { return m_isLost; }
    const WGPULimits& limits() const { return m_capabilities.limits; }
    const WGPULimits limitsCopy() const { return m_capabilities.limits; }
    const Vector<WGPUFeatureName>& features() const { return m_capabilities.features; }
    const HardwareCapabilities::BaseCapabilities& baseCapabilities() const { return m_capabilities.baseCapabilities; }

    id<MTLDevice> device() const { return m_device; }
    void generateAValidationError(NSString * message);
    void generateAValidationError(String&& message);
    void generateAnOutOfMemoryError(String&& message);
    void generateAnInternalError(String&& message);

    RefPtr<Instance> instance() const { return m_instance.get(); }
#if CPU(X86_64)
    bool hasUnifiedMemory() const { return false; }
#else
    bool hasUnifiedMemory() const { return m_device.hasUnifiedMemory; }
#endif

    uint32_t maxBuffersPlusVertexBuffersForVertexStage() const
    {
        ASSERT(m_capabilities.limits.maxBindGroupsPlusVertexBuffers > 0);
        return m_capabilities.limits.maxBindGroupsPlusVertexBuffers;
    }

    uint32_t maxBuffersForFragmentStage() const { return m_capabilities.limits.maxBindGroups; }

    uint32_t maxBuffersForComputeStage() const { return m_capabilities.limits.maxBindGroups; }
    uint32_t vertexBufferIndexForBindGroup(uint32_t groupIndex) const
    {
        ASSERT(maxBuffersPlusVertexBuffersForVertexStage() > 0);
        return WGSL::vertexBufferIndexForBindGroup(groupIndex, maxBuffersPlusVertexBuffersForVertexStage() - 1);
    }

    id<MTLBuffer> newBufferWithBytes(const void*, size_t, MTLResourceOptions) const;
    id<MTLBuffer> newBufferWithBytesNoCopy(void*, size_t, MTLResourceOptions) const;
    id<MTLTexture> newTextureWithDescriptor(MTLTextureDescriptor *, IOSurfaceRef = nullptr, NSUInteger plane = 0) const;

    static bool isStencilOnlyFormat(MTLPixelFormat);
    bool shouldStopCaptureAfterSubmit();
    id<MTLBuffer> placeholderBuffer() const { return m_placeholderBuffer; }

    id<MTLTexture> placeholderTexture(WGPUTextureFormat) const;
    bool isDestroyed() const;
    NSString *errorValidatingTextureCreation(const WGPUTextureDescriptor&, const Vector<WGPUTextureFormat>& viewFormats);
    id<MTLBuffer> dispatchCallBuffer();
    id<MTLComputePipelineState> dispatchCallPipelineState(id<MTLFunction>);
    id<MTLRenderPipelineState> indexBufferClampPipeline(MTLIndexType, NSUInteger rasterSampleCount);
    id<MTLRenderPipelineState> indexedIndirectBufferClampPipeline(NSUInteger rasterSampleCount);
    id<MTLRenderPipelineState> indirectBufferClampPipeline(NSUInteger rasterSampleCount);
    id<MTLRenderPipelineState> icbCommandClampPipeline(MTLIndexType, NSUInteger rasterSampleCount);
    id<MTLFunction> icbCommandClampFunction(MTLIndexType);
    id<MTLRenderPipelineState> copyIndexIndirectArgsPipeline(NSUInteger rasterSampleCount);
    id<MTLBuffer> safeCreateBuffer(NSUInteger length, MTLStorageMode, MTLCPUCacheMode = MTLCPUCacheModeDefaultCache, MTLHazardTrackingMode = MTLHazardTrackingModeDefault) const;
    id<MTLBuffer> safeCreateBuffer(NSUInteger) const;
    void loseTheDevice(WGPUDeviceLostReason);
    int bufferIndexForICBContainer() const;
    void setOwnerWithIdentity(id<MTLResource>) const;
    struct ExternalTextureData {
        id<MTLTexture> texture0 { nil };
        id<MTLTexture> texture1 { nil };
        simd::float3x2 uvRemappingMatrix;
        simd::float4x3 colorSpaceConversionMatrix;
    };
    ExternalTextureData createExternalTextureFromPixelBuffer(CVPixelBufferRef, WGPUColorSpace) const;
    RefPtr<XRSubImage> getXRViewSubImage(XRProjectionLayer&);
    const std::optional<const MachSendRight> webProcessID() const;
#if CPU(X86_64)
    bool isIntel() const { return [m_device.name localizedCaseInsensitiveContainsString:@"intel"]; }
#else
    constexpr bool isIntel() const { return false; }
#endif
    void pauseErrorReporting(bool pauseReporting);
    bool enableEncoderTimestamps() const;
    id<MTLCounterSampleBuffer> timestampsBuffer(id<MTLCommandBuffer>, size_t);
    void resolveTimestampsForBuffer(id<MTLCommandBuffer>);
    id<MTLSharedEvent> resolveTimestampsSharedEvent();
    uint32_t maxVerticesPerDrawCall() const { return m_maxVerticesPerDrawCall; }

private:
    Device(id<MTLDevice>, id<MTLCommandQueue> defaultQueue, HardwareCapabilities&&, Adapter&);
    Device(Adapter&);

    struct ErrorScope;
    ErrorScope* currentErrorScope(WGPUErrorFilter);
    std::optional<WGPUErrorType> validatePopErrorScope() const;
    bool validateCreateIOSurfaceBackedTexture(const WGPUTextureDescriptor&, const Vector<WGPUTextureFormat>& viewFormats, IOSurfaceRef backing);

    bool validateRenderPipeline(const WGPURenderPipelineDescriptor&);

    void makeInvalid();
    NSString* addPipelineLayouts(Vector<Vector<WGPUBindGroupLayoutEntry>>&, const std::optional<WGSL::PipelineLayout>&);
    Ref<PipelineLayout> generatePipelineLayout(const Vector<Vector<WGPUBindGroupLayoutEntry>> &bindGroupEntries);

    void captureFrameIfNeeded() const;
    GPUShaderValidation shaderValidationState() const;

    struct Error {
        WGPUErrorType type;
        String message;
    };
    struct ErrorScope {
        std::optional<Error> error;
        const WGPUErrorFilter filter;
    };

    id<MTLDevice> m_device { nil };
    const Ref<Queue> m_defaultQueue;

    Function<void(WGPUErrorType, String&&)> m_uncapturedErrorCallback;
    Vector<ErrorScope> m_errorScopeStack;
    RefPtr<XRSubImage> m_xrSubImage;

    Function<void(WGPUDeviceLostReason, String&&)> m_deviceLostCallback;
    bool m_isLost { false };
    bool m_destroyed { false };
    id<NSObject> m_deviceObserver { nil };

    HardwareCapabilities m_capabilities { };

    id<MTLBuffer> m_placeholderBuffer { nil };
    id<MTLTexture> m_placeholderTexture { nil };
    id<MTLTexture> m_placeholderDepthStencilTexture { nil };
    id<MTLBuffer> m_dispatchCallBuffer { nil };
    id<MTLComputePipelineState> m_dispatchCallPipelineState { nil };

    id<MTLRenderPipelineState> m_indexBufferClampUintPSO { nil };
    id<MTLRenderPipelineState> m_indexBufferClampUshortPSO { nil };
    id<MTLRenderPipelineState> m_indexBufferClampUintPSOMS { nil };
    id<MTLRenderPipelineState> m_indexBufferClampUshortPSOMS { nil };

    id<MTLRenderPipelineState> m_indexedIndirectBufferClampPSO { nil };
    id<MTLRenderPipelineState> m_indexedIndirectBufferClampPSOMS { nil };

    id<MTLRenderPipelineState> m_indirectBufferClampPSO { nil };
    id<MTLRenderPipelineState> m_indirectBufferClampPSOMS { nil };

    id<MTLRenderPipelineState> m_icbCommandClampUintPSO { nil };
    id<MTLRenderPipelineState> m_icbCommandClampUshortPSO { nil };
    id<MTLRenderPipelineState> m_icbCommandClampUintPSOMS { nil };
    id<MTLRenderPipelineState> m_icbCommandClampUshortPSOMS { nil };

    id<MTLRenderPipelineState> m_copyIndexedIndirectArgsPSO { nil };
    id<MTLRenderPipelineState> m_copyIndexedIndirectArgsPSOMS { nil };

    const Ref<Adapter> m_adapter;
    const ThreadSafeWeakPtr<Instance> m_instance;
#if HAVE(COREVIDEO_METAL_SUPPORT)
    RetainPtr<CVMetalTextureCacheRef> m_coreVideoTextureCache;
#endif
    NSMapTable<id<MTLCommandBuffer>, id<MTLCounterSampleBuffer>>* m_sampleCounterBuffers;
    NSMapTable<id<MTLCommandBuffer>, NSMutableArray<id<MTLBuffer>>*>* m_resolvedSampleCounterBuffers;
    id<MTLSharedEvent> m_resolveTimestampsSharedEvent { nil };
    bool m_supressAllErrors { false };
    const uint32_t m_maxVerticesPerDrawCall { 0 };
} SWIFT_SHARED_REFERENCE(refDevice, derefDevice);

} // namespace WebGPU

inline void refDevice(WebGPU::Device* obj)
{
    WTF::ref(obj);
}

inline void derefDevice(WebGPU::Device* obj)
{
    WTF::deref(obj);
}
