Skip to content

Commit

Permalink
[rhi] MetalSurface functions (#8274)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at a89e235</samp>

This pull request adds the ability to create and use a Metal surface for
rendering graphics using the RHI interface. It introduces the
`MetalSurface` class in `metal_device.h` and `metal_device.mm`, and
links the QuartzCore framework in `CMakeLists.txt`.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at a89e235</samp>

* Add the MetalSurface class and the create_surface method to support
rendering on Metal
([link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-5b304e188996abd217fad85fd8b2434e729ad9c089eb9ef48a848c0fcc74d614R249-R292),
[link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-5b304e188996abd217fad85fd8b2434e729ad9c089eb9ef48a848c0fcc74d614L260-R310),
[link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-c350a27659df1ebcf2c854e2c21a80f82de87d3b325a80e00f6a6dd6ca53303bR608-R682),
[link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-c350a27659df1ebcf2c854e2c21a80f82de87d3b325a80e00f6a6dd6ca53303bR707-R711))
* Add the QuartzCore framework and header to use the CAMetalLayer and
CAMetalDrawable classes for the MetalSurface
([link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-464e3a16b2cf83d04b813312367cf6b4d56b4d313af162dde3237c8352374667L21-R21),
[link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-5b304e188996abd217fad85fd8b2434e729ad9c089eb9ef48a848c0fcc74d614L13-R18),
[link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-5b304e188996abd217fad85fd8b2434e729ad9c089eb9ef48a848c0fcc74d614L29-R36))
* Add helper functions to convert RHI enums to Metal enums for buffer
format, image dimension, and image usage
([link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-c350a27659df1ebcf2c854e2c21a80f82de87d3b325a80e00f6a6dd6ca53303bR10-R84))
* Add a TODO comment to the MetalShaderResourceSet class to indicate the
need for raster resources support
([link](https://github.com/taichi-dev/taichi/pull/8274/files?diff=unified&w=0#diff-5b304e188996abd217fad85fd8b2434e729ad9c089eb9ef48a848c0fcc74d614L174-R180))

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Proton <feisuzhu@163.com>
Co-authored-by: Bob Cao <bobcaocheng@gmail.com>
  • Loading branch information
4 people committed Jul 13, 2023
1 parent edf8b9e commit 6994dbd
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 85 deletions.
2 changes: 1 addition & 1 deletion cmake/TaichiCAPITests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ if(TI_WITH_STATIC_C_API)
find_library(LIBZSTD_LIBRARY zstd REQUIRED)

target_link_libraries(${C_STATIC_API_TESTS_NAME} PRIVATE "-framework Cocoa" "-framework IOKit" "-framework CoreFoundation")
target_link_libraries(${C_STATIC_API_TESTS_NAME} PRIVATE "-framework Metal")
target_link_libraries(${C_STATIC_API_TESTS_NAME} PRIVATE "-framework Metal" "-framework QuartzCore")
target_link_libraries(${C_STATIC_API_TESTS_NAME} PRIVATE "${LIBZSTD_LIBRARY}")
target_link_libraries(${C_STATIC_API_TESTS_NAME} PRIVATE ZLIB::ZLIB)
target_link_options(${C_STATIC_API_TESTS_NAME} PRIVATE -Wl,-dead_strip)
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ target_include_directories(${METAL_RHI}
${PROJECT_SOURCE_DIR}/external/glad/include
${PROJECT_SOURCE_DIR}/external/glfw/include
)
target_link_libraries(${METAL_RHI} PRIVATE spirv-cross-msl spirv-cross-core)
target_link_libraries(${METAL_RHI} PRIVATE spirv-cross-msl spirv-cross-core "-framework QuartzCore")
65 changes: 57 additions & 8 deletions taichi/rhi/metal/metal_device.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
#pragma once
#include <memory>
#include "taichi/common/logging.h"
#include "taichi/rhi/device.h"
#include "taichi/rhi/metal/metal_api.h"
#include "taichi/rhi/impl_support.h"
#include "taichi/rhi/metal/metal_api.h"
#include <memory>

// clang-format off
#if defined(__APPLE__) && defined(__OBJC__)
#import <CoreGraphics/CoreGraphics.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
#import <CoreGraphics/CoreGraphics.h>
#import <QuartzCore/QuartzCore.h>
#define DEFINE_METAL_ID_TYPE(x) typedef id<x> x##_id;
#define DEFINE_OBJC_TYPE(x) // Should be defined by included headers
#else
#define DEFINE_METAL_ID_TYPE(x) typedef struct x##_t *x##_id;
#define DEFINE_OBJC_TYPE(x) typedef void x;
#endif
// clang-format on

DEFINE_METAL_ID_TYPE(MTLDevice);
DEFINE_METAL_ID_TYPE(MTLBuffer);
Expand All @@ -26,8 +31,11 @@ DEFINE_METAL_ID_TYPE(MTLCommandQueue);
DEFINE_METAL_ID_TYPE(MTLCommandBuffer);
DEFINE_METAL_ID_TYPE(MTLBlitCommandEncoder);
DEFINE_METAL_ID_TYPE(MTLComputeCommandEncoder);
DEFINE_METAL_ID_TYPE(CAMetalDrawable);
DEFINE_OBJC_TYPE(CAMetalLayer);

#undef DEFINE_METAL_ID_TYPE
#undef DEFINE_OBJC_TYPE

namespace taichi::lang {

Expand Down Expand Up @@ -171,7 +179,7 @@ class MetalShaderResourceSet final : public ShaderResourceSet {

private:
const MetalDevice *device_;
std::vector<MetalShaderResource> resources_;
std::vector<MetalShaderResource> resources_; // TODO: need raster resources
};

class MetalCommandList final : public CommandList {
Expand Down Expand Up @@ -240,6 +248,50 @@ class MetalStream final : public Stream {
bool is_destroyed_{false};
};

class MetalSurface final : public Surface {
public:
MetalSurface(MetalDevice *device, const SurfaceConfig &config);
~MetalSurface() override;

CAMetalLayer *mtl_layer() {
return layer_;
}

StreamSemaphore acquire_next_image() override;
DeviceAllocation get_target_image() override;

void present_image(
const std::vector<StreamSemaphore> &wait_semaphores = {}) override;
std::pair<uint32_t, uint32_t> get_size() override;
int get_image_count() override;
BufferFormat image_format() override;
void resize(uint32_t width, uint32_t height) override;

DeviceAllocation get_depth_data(DeviceAllocation &depth_alloc) override {
TI_NOT_IMPLEMENTED;
}
DeviceAllocation get_image_data() override {
TI_NOT_IMPLEMENTED;
}

private:
void destroy_swap_chain();

SurfaceConfig config_;

BufferFormat image_format_{BufferFormat::unknown};

uint32_t width_{0};
uint32_t height_{0};

MTLTexture_id current_swap_chain_texture_;
std::unordered_map<MTLTexture_id, DeviceAllocation> swapchain_images_;
CAMetalDrawable_id current_drawable_;

MetalDevice *device_{nullptr};
CAMetalLayer *layer_;
};

class MetalDevice final : public GraphicsDevice {
public:
// `mtl_device` should be already retained.
Expand All @@ -256,10 +308,7 @@ class MetalDevice final : public GraphicsDevice {
static MetalDevice *create();
void destroy();

std::unique_ptr<Surface> create_surface(
const SurfaceConfig &config) override {
TI_NOT_IMPLEMENTED;
}
std::unique_ptr<Surface> create_surface(const SurfaceConfig &config) override;

RhiResult allocate_memory(const AllocParams &params,
DeviceAllocation *out_devalloc) override;
Expand Down
227 changes: 152 additions & 75 deletions taichi/rhi/metal/metal_device.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,81 @@
namespace taichi::lang {
namespace metal {

MTLPixelFormat format2mtl(BufferFormat format) {
static const std::map<BufferFormat, MTLPixelFormat> map{
{BufferFormat::unknown, MTLPixelFormatInvalid},
{BufferFormat::r8, MTLPixelFormatR8Unorm},
{BufferFormat::rg8, MTLPixelFormatRG8Unorm},
{BufferFormat::rgba8, MTLPixelFormatRGBA8Unorm},
{BufferFormat::rgba8srgb, MTLPixelFormatRGBA8Unorm_sRGB},
{BufferFormat::bgra8, MTLPixelFormatBGRA8Unorm},
{BufferFormat::bgra8srgb, MTLPixelFormatBGRA8Unorm_sRGB},
{BufferFormat::r8u, MTLPixelFormatR8Uint},
{BufferFormat::rg8u, MTLPixelFormatRG8Uint},
{BufferFormat::rgba8u, MTLPixelFormatRGBA8Uint},
{BufferFormat::r8i, MTLPixelFormatR8Sint},
{BufferFormat::rg8i, MTLPixelFormatRG8Sint},
{BufferFormat::rgba8i, MTLPixelFormatRGBA8Sint},
{BufferFormat::r16, MTLPixelFormatR16Unorm},
{BufferFormat::rg16, MTLPixelFormatRG16Unorm},
{BufferFormat::rgb16, MTLPixelFormatInvalid},
{BufferFormat::rgba16, MTLPixelFormatRGBA16Unorm},
{BufferFormat::r16u, MTLPixelFormatR16Uint},
{BufferFormat::rg16u, MTLPixelFormatRG16Uint},
{BufferFormat::rgb16u, MTLPixelFormatInvalid},
{BufferFormat::rgba16u, MTLPixelFormatRGBA16Uint},
{BufferFormat::r16i, MTLPixelFormatR16Sint},
{BufferFormat::rg16i, MTLPixelFormatRG16Sint},
{BufferFormat::rgb16i, MTLPixelFormatInvalid},
{BufferFormat::rgba16i, MTLPixelFormatRGBA16Sint},
{BufferFormat::r16f, MTLPixelFormatR16Float},
{BufferFormat::rg16f, MTLPixelFormatRG16Float},
{BufferFormat::rgb16f, MTLPixelFormatInvalid},
{BufferFormat::rgba16f, MTLPixelFormatRGBA16Float},
{BufferFormat::r32u, MTLPixelFormatR32Uint},
{BufferFormat::rg32u, MTLPixelFormatRG32Uint},
{BufferFormat::rgb32u, MTLPixelFormatInvalid},
{BufferFormat::rgba32u, MTLPixelFormatRGBA32Uint},
{BufferFormat::r32i, MTLPixelFormatR32Sint},
{BufferFormat::rg32i, MTLPixelFormatRG32Sint},
{BufferFormat::rgb32i, MTLPixelFormatInvalid},
{BufferFormat::rgba32i, MTLPixelFormatRGBA32Sint},
{BufferFormat::r32f, MTLPixelFormatR32Float},
{BufferFormat::rg32f, MTLPixelFormatRG32Float},
{BufferFormat::rgb32f, MTLPixelFormatInvalid},
{BufferFormat::rgba32f, MTLPixelFormatRGBA32Float},
{BufferFormat::depth16, MTLPixelFormatDepth16Unorm},
{BufferFormat::depth24stencil8, MTLPixelFormatInvalid},
{BufferFormat::depth32f, MTLPixelFormatDepth32Float},
};
auto it = map.find(format);
RHI_ASSERT(it != map.end());
return it->second;
}
MTLTextureType dimension2mtl(ImageDimension dimension) {
static const std::map<ImageDimension, MTLTextureType> map = {
{ImageDimension::d1D, MTLTextureType1D},
{ImageDimension::d2D, MTLTextureType2D},
{ImageDimension::d3D, MTLTextureType3D},
};
auto it = map.find(dimension);
RHI_ASSERT(it != map.end());
return it->second;
}
MTLTextureUsage usage2mtl(ImageAllocUsage usage) {
MTLTextureUsage out = 0;
if (usage & ImageAllocUsage::Sampled) {
out |= MTLTextureUsageShaderRead;
}
if (usage & ImageAllocUsage::Storage) {
out |= MTLTextureUsageShaderWrite;
}
if (usage & ImageAllocUsage::Attachment) {
out |= MTLTextureUsageRenderTarget;
}
return out;
}

MetalMemory::MetalMemory(MTLBuffer_id mtl_buffer, bool can_map)
: mtl_buffer_(mtl_buffer), can_map_(can_map) {}
MetalMemory::~MetalMemory() {
Expand Down Expand Up @@ -530,6 +605,78 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) {
return std::make_unique<MetalSampler>(mtl_sampler_state);
}

MetalSurface::MetalSurface(MetalDevice *device, const SurfaceConfig &config)
: config_(config), device_(device) {

width_ = config.width;
height_ = config.height;

image_format_ = BufferFormat::bgra8;

layer_ = [CAMetalLayer layer];
layer_.device = device->mtl_device();
layer_.pixelFormat = format2mtl(image_format_);
layer_.drawableSize = CGSizeMake(width_, height_);
layer_.allowsNextDrawableTimeout = NO;
#if TARGET_OS_OSX
// Older versions may not have this property so check if it exists first.
layer_.displaySyncEnabled = config.vsync;
#endif
}

MetalSurface::~MetalSurface() {
destroy_swap_chain();
[layer_ release];
}

void MetalSurface::destroy_swap_chain() {
for (auto &alloc : swapchain_images_) {
device_->destroy_image(alloc.second);
}
swapchain_images_.clear();
}

StreamSemaphore MetalSurface::acquire_next_image() {
current_drawable_ = [layer_ nextDrawable];
current_swap_chain_texture_ = current_drawable_.texture;

if (swapchain_images_.count(current_swap_chain_texture_) == 0) {
swapchain_images_[current_swap_chain_texture_] =
device_->import_mtl_texture(current_drawable_.texture);
RHI_ASSERT(swapchain_images_.size() <=
50); // In case something goes wrong on Metal side, prevent this
// map of images from growing each frame unbounded.
}
return nullptr;
}

DeviceAllocation MetalSurface::get_target_image() {
return swapchain_images_.at(current_swap_chain_texture_);
}

void MetalSurface::present_image(
const std::vector<StreamSemaphore> &wait_semaphores) {

[current_drawable_ present];

device_->wait_idle();
}

std::pair<uint32_t, uint32_t> MetalSurface::get_size() {
return std::make_pair(width_, height_);
}

int MetalSurface::get_image_count() { return (int)layer_.maximumDrawableCount; }

BufferFormat MetalSurface::image_format() { return image_format_; }

void MetalSurface::resize(uint32_t width, uint32_t height) {
destroy_swap_chain();
width_ = width;
height_ = height;
layer_.drawableSize = CGSizeMake(width_, height_);
}

MetalDevice::MetalDevice(MTLDevice_id mtl_device) : mtl_device_(mtl_device) {
compute_stream_ = std::unique_ptr<MetalStream>(MetalStream::create(*this));

Expand All @@ -554,6 +701,11 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) {
}
}

std::unique_ptr<Surface>
MetalDevice::create_surface(const SurfaceConfig &config) {
return std::make_unique<MetalSurface>(this, config);
}

RhiResult MetalDevice::allocate_memory(const AllocParams &params,
DeviceAllocation *out_devalloc) {
if (params.export_sharing) {
Expand Down Expand Up @@ -599,81 +751,6 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) {
memory_allocs_.release(&get_memory(handle.alloc_id));
}

MTLPixelFormat format2mtl(BufferFormat format) {
static const std::map<BufferFormat, MTLPixelFormat> map{
{BufferFormat::unknown, MTLPixelFormatInvalid},
{BufferFormat::r8, MTLPixelFormatR8Unorm},
{BufferFormat::rg8, MTLPixelFormatRG8Unorm},
{BufferFormat::rgba8, MTLPixelFormatRGBA8Unorm},
{BufferFormat::rgba8srgb, MTLPixelFormatRGBA8Unorm_sRGB},
{BufferFormat::bgra8, MTLPixelFormatBGRA8Unorm},
{BufferFormat::bgra8srgb, MTLPixelFormatBGRA8Unorm_sRGB},
{BufferFormat::r8u, MTLPixelFormatR8Uint},
{BufferFormat::rg8u, MTLPixelFormatRG8Uint},
{BufferFormat::rgba8u, MTLPixelFormatRGBA8Uint},
{BufferFormat::r8i, MTLPixelFormatR8Sint},
{BufferFormat::rg8i, MTLPixelFormatRG8Sint},
{BufferFormat::rgba8i, MTLPixelFormatRGBA8Sint},
{BufferFormat::r16, MTLPixelFormatR16Unorm},
{BufferFormat::rg16, MTLPixelFormatRG16Unorm},
{BufferFormat::rgb16, MTLPixelFormatInvalid},
{BufferFormat::rgba16, MTLPixelFormatRGBA16Unorm},
{BufferFormat::r16u, MTLPixelFormatR16Uint},
{BufferFormat::rg16u, MTLPixelFormatRG16Uint},
{BufferFormat::rgb16u, MTLPixelFormatInvalid},
{BufferFormat::rgba16u, MTLPixelFormatRGBA16Uint},
{BufferFormat::r16i, MTLPixelFormatR16Sint},
{BufferFormat::rg16i, MTLPixelFormatRG16Sint},
{BufferFormat::rgb16i, MTLPixelFormatInvalid},
{BufferFormat::rgba16i, MTLPixelFormatRGBA16Sint},
{BufferFormat::r16f, MTLPixelFormatR16Float},
{BufferFormat::rg16f, MTLPixelFormatRG16Float},
{BufferFormat::rgb16f, MTLPixelFormatInvalid},
{BufferFormat::rgba16f, MTLPixelFormatRGBA16Float},
{BufferFormat::r32u, MTLPixelFormatR32Uint},
{BufferFormat::rg32u, MTLPixelFormatRG32Uint},
{BufferFormat::rgb32u, MTLPixelFormatInvalid},
{BufferFormat::rgba32u, MTLPixelFormatRGBA32Uint},
{BufferFormat::r32i, MTLPixelFormatR32Sint},
{BufferFormat::rg32i, MTLPixelFormatRG32Sint},
{BufferFormat::rgb32i, MTLPixelFormatInvalid},
{BufferFormat::rgba32i, MTLPixelFormatRGBA32Sint},
{BufferFormat::r32f, MTLPixelFormatR32Float},
{BufferFormat::rg32f, MTLPixelFormatRG32Float},
{BufferFormat::rgb32f, MTLPixelFormatInvalid},
{BufferFormat::rgba32f, MTLPixelFormatRGBA32Float},
{BufferFormat::depth16, MTLPixelFormatDepth16Unorm},
{BufferFormat::depth24stencil8, MTLPixelFormatInvalid},
{BufferFormat::depth32f, MTLPixelFormatDepth32Float},
};
auto it = map.find(format);
RHI_ASSERT(it != map.end());
return it->second;
}
MTLTextureType dimension2mtl(ImageDimension dimension) {
static const std::map<ImageDimension, MTLTextureType> map = {
{ImageDimension::d1D, MTLTextureType1D},
{ImageDimension::d2D, MTLTextureType2D},
{ImageDimension::d3D, MTLTextureType3D},
};
auto it = map.find(dimension);
RHI_ASSERT(it != map.end());
return it->second;
}
MTLTextureUsage usage2mtl(ImageAllocUsage usage) {
MTLTextureUsage out = 0;
if (usage & ImageAllocUsage::Sampled) {
out |= MTLTextureUsageShaderRead;
}
if (usage & ImageAllocUsage::Storage) {
out |= MTLTextureUsageShaderWrite;
}
if (usage & ImageAllocUsage::Attachment) {
out |= MTLTextureUsageRenderTarget;
}
return out;
}

DeviceAllocation MetalDevice::create_image(const ImageParams &params) {
if (params.export_sharing) {
RHI_LOG_ERROR("export sharing is not available in metal");
Expand Down

0 comments on commit 6994dbd

Please sign in to comment.