Skip to content

Commit

Permalink
Merge pull request #118 from crud89/execute-indirect
Browse files Browse the repository at this point in the history
Implement indirect draw and dispatch.
  • Loading branch information
crud89 authored Jun 9, 2024
2 parents b816a8b + 212a92e commit 818ec4f
Show file tree
Hide file tree
Showing 37 changed files with 1,710 additions and 161 deletions.
38 changes: 37 additions & 1 deletion src/Backends/DirectX12/include/litefx/backends/dx12.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,9 +1000,12 @@ namespace LiteFX::Rendering::Backends {
public:
using base_type = CommandBuffer<DirectX12CommandBuffer, IDirectX12Buffer, IDirectX12VertexBuffer, IDirectX12IndexBuffer, IDirectX12Image, DirectX12Barrier, DirectX12PipelineState, DirectX12BottomLevelAccelerationStructure, DirectX12TopLevelAccelerationStructure>;
using base_type::dispatch;
using base_type::dispatchIndirect;
using base_type::dispatchMesh;
using base_type::draw;
using base_type::drawIndirect;
using base_type::drawIndexed;
using base_type::drawIndexedIndirect;
using base_type::barrier;
using base_type::transfer;
using base_type::generateMipMaps;
Expand Down Expand Up @@ -1143,13 +1146,37 @@ namespace LiteFX::Rendering::Backends {
void dispatch(const Vector3u& threadCount) const noexcept override;

/// <inheritdoc />
void dispatchMesh (const Vector3u& threadCount) const noexcept override;
void dispatchIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset = 0) const noexcept override;

/// <inheritdoc />
void dispatchIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset = 0, UInt64 countOffset = 0, UInt32 maxBatches = std::numeric_limits<UInt32>::max()) const noexcept;

/// <inheritdoc />
void dispatchMesh(const Vector3u& threadCount) const noexcept override;

/// <inheritdoc />
void dispatchMeshIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset = 0) const noexcept override;

/// <inheritdoc />
void dispatchMeshIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset = 0, UInt64 countOffset = 0, UInt32 maxBatches = std::numeric_limits<UInt32>::max()) const noexcept override;

/// <inheritdoc />
void draw(UInt32 vertices, UInt32 instances = 1, UInt32 firstVertex = 0, UInt32 firstInstance = 0) const noexcept override;

/// <inheritdoc />
void drawIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset = 0) const noexcept override;

/// <inheritdoc />
void drawIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset = 0, UInt64 countOffset = 0, UInt32 maxBatches = std::numeric_limits<UInt32>::max()) const noexcept override;

/// <inheritdoc />
void drawIndexed(UInt32 indices, UInt32 instances = 1, UInt32 firstIndex = 0, Int32 vertexOffset = 0, UInt32 firstInstance = 0) const noexcept override;

/// <inheritdoc />
void drawIndexedIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset = 0) const noexcept override;

/// <inheritdoc />
void drawIndexedIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset = 0, UInt64 countOffset = 0, UInt32 maxBatches = std::numeric_limits<UInt32>::max()) const noexcept override;

/// <inheritdoc />
void pushConstants(const DirectX12PushConstantsLayout& layout, const void* const memory) const noexcept override;
Expand Down Expand Up @@ -1984,6 +2011,15 @@ namespace LiteFX::Rendering::Backends {
/// <seealso cref="DirectX12Texture::generateMipMaps" />
virtual DirectX12ComputePipeline& blitPipeline() const noexcept;

/// <summary>
/// Returns the command signatures for indirect dispatch and draw calls.
/// </summary>
/// <param name="dispatchSignature">The command signature used to execute indirect dispatches.</param>
/// <param name="dispatchMeshSignature">The command signature used to execute indirect mesh shader dispatches.</param>
/// <param name="drawSignature">The command signature used to execute indirect non-indexed draw calls.</param>
/// <param name="drawIndexedSignature">The command signature used to execute indirect indexed draw calls.</param>
virtual void indirectDrawSignatures(ComPtr<ID3D12CommandSignature>& dispatchSignature, ComPtr<ID3D12CommandSignature>& dispatchMeshSignature, ComPtr<ID3D12CommandSignature>& drawSignature, ComPtr<ID3D12CommandSignature>& drawIndexedSignature) const noexcept;

// GraphicsDevice interface.
public:
/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ namespace LiteFX::Rendering::Backends {
/// <param name="parent">The parent pipeline layout builder.</param>
/// <param name="space">The space the descriptor set is bound to.</param>
/// <param name="stages">The shader stages, the descriptor set is accessible from.</param>
/// <param name="maxUnboundedArraySize">Ignored for DirectX 12, but required for compatibility.</param>
constexpr inline explicit DirectX12DescriptorSetLayoutBuilder(DirectX12PipelineLayoutBuilder& parent, UInt32 space = 0, ShaderStage stages = ShaderStage::Any, UInt32 maxUnboundedArraySize = 0);
constexpr inline explicit DirectX12DescriptorSetLayoutBuilder(DirectX12PipelineLayoutBuilder& parent, UInt32 space = 0, ShaderStage stages = ShaderStage::Any);
DirectX12DescriptorSetLayoutBuilder(const DirectX12DescriptorSetLayoutBuilder&) = delete;
DirectX12DescriptorSetLayoutBuilder(DirectX12DescriptorSetLayoutBuilder&&) = delete;
constexpr inline virtual ~DirectX12DescriptorSetLayoutBuilder() noexcept;
Expand Down Expand Up @@ -236,8 +235,7 @@ namespace LiteFX::Rendering::Backends {
/// </summary>
/// <param name="space">The space, the descriptor set is bound to.</param>
/// <param name="stages">The stages, the descriptor set will be accessible from.</param>
/// <param name="maxUnboundedArraySize">Unused for this backend.</param>
constexpr inline DirectX12DescriptorSetLayoutBuilder descriptorSet(UInt32 space = 0, ShaderStage stages = ShaderStage::Any, UInt32 maxUnboundedArraySize = 0);
constexpr inline DirectX12DescriptorSetLayoutBuilder descriptorSet(UInt32 space = 0, ShaderStage stages = ShaderStage::Any);

/// <summary>
/// Builds a new push constants layout for the pipeline layout.
Expand Down
6 changes: 3 additions & 3 deletions src/Backends/DirectX12/src/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ class DirectX12BottomLevelAccelerationStructure::DirectX12BottomLevelAcceleratio
};

// Transition the buffer into UAV state. We create manual barriers here, as the special access flag is only required in this specific situation.
CD3DX12_BUFFER_BARRIER preBarrier[2] = {
CD3DX12_BUFFER_BARRIER preBarrier[1] = {
CD3DX12_BUFFER_BARRIER(afterCopy ? D3D12_BARRIER_SYNC_COPY_RAYTRACING_ACCELERATION_STRUCTURE : D3D12_BARRIER_SYNC_BUILD_RAYTRACING_ACCELERATION_STRUCTURE, D3D12_BARRIER_SYNC_EMIT_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO, D3D12_BARRIER_ACCESS_RAYTRACING_ACCELERATION_STRUCTURE_WRITE, D3D12_BARRIER_ACCESS_RAYTRACING_ACCELERATION_STRUCTURE_READ, std::as_const(*m_buffer).handle().Get()),
CD3DX12_BUFFER_BARRIER(D3D12_BARRIER_SYNC_NONE, D3D12_BARRIER_SYNC_EMIT_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO, D3D12_BARRIER_ACCESS_NO_ACCESS, D3D12_BARRIER_ACCESS_UNORDERED_ACCESS, std::as_const(*m_postBuildBuffer).handle().Get()),
//CD3DX12_BUFFER_BARRIER(D3D12_BARRIER_SYNC_NONE, D3D12_BARRIER_SYNC_EMIT_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO, D3D12_BARRIER_ACCESS_NO_ACCESS, D3D12_BARRIER_ACCESS_UNORDERED_ACCESS, std::as_const(*m_postBuildBuffer).handle().Get()),
};
auto preBarrierGroup = CD3DX12_BARRIER_GROUP(2, preBarrier);
auto preBarrierGroup = CD3DX12_BARRIER_GROUP(1, preBarrier);
commandBuffer.handle()->Barrier(1, &preBarrierGroup);

// Emit the
Expand Down
22 changes: 5 additions & 17 deletions src/Backends/DirectX12/src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,10 @@ void DirectX12Buffer::map(const void* const data, size_t size, UInt32 element)
if (element >= m_impl->m_elements) [[unlikely]]
throw ArgumentOutOfRangeException("element", 0u, m_impl->m_elements, element, "The element {0} is out of range. The buffer only contains {1} elements.", element, m_impl->m_elements);

size_t alignedSize = size;
size_t alignment = this->elementAlignment();

if (alignment > 0)
alignedSize = (size + alignment - 1) & ~(alignment - 1);

D3D12_RANGE mappedRange = {};
D3D12_RANGE mappedRange = { };
char* buffer;
raiseIfFailed(this->handle()->Map(0, &mappedRange, reinterpret_cast<void**>(&buffer)), "Unable to map buffer memory.");
auto result = ::memcpy_s(reinterpret_cast<void*>(buffer + (element * alignedSize)), alignedSize, data, size);
auto result = ::memcpy_s(reinterpret_cast<void*>(buffer + (element * this->alignedElementSize())), this->size(), data, size);
this->handle()->Unmap(0, nullptr);

if (result != 0) [[unlikely]]
Expand All @@ -117,18 +111,12 @@ void DirectX12Buffer::map(void* data, size_t size, UInt32 element, bool write)
if (element >= m_impl->m_elements) [[unlikely]]
throw ArgumentOutOfRangeException("element", 0u, m_impl->m_elements, element, "The element {0} is out of range. The buffer only contains {1} elements.", element, m_impl->m_elements);

size_t alignedSize = size;
size_t alignment = this->elementAlignment();

if (alignment > 0)
alignedSize = (size + alignment - 1) & ~(alignment - 1);

D3D12_RANGE mappedRange = {};
D3D12_RANGE mappedRange = { };
char* buffer;
raiseIfFailed(this->handle()->Map(0, &mappedRange, reinterpret_cast<void**>(&buffer)), "Unable to map buffer memory.");
auto result = write ?
::memcpy_s(reinterpret_cast<void*>(buffer + (element * alignedSize)), alignedSize, data, size) :
::memcpy_s(data, size, reinterpret_cast<void*>(buffer + (element * alignedSize)), alignedSize);
::memcpy_s(reinterpret_cast<void*>(buffer + (element * this->alignedElementSize())), this->size(), data, size) :
::memcpy_s(data, size, reinterpret_cast<void*>(buffer + (element * this->alignedElementSize())), size);

this->handle()->Unmap(0, nullptr);

Expand Down
46 changes: 45 additions & 1 deletion src/Backends/DirectX12/src/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DirectX12CommandBuffer::DirectX12CommandBufferImpl : public Implement<Dire
const DirectX12Queue& m_queue;
Array<SharedPtr<const IStateResource>> m_sharedResources;
const DirectX12PipelineState* m_lastPipeline = nullptr;
ComPtr<ID3D12CommandSignature> m_dispatchSignature, m_drawSignature, m_drawIndexedSignature, m_dispatchMeshSignature;

public:
DirectX12CommandBufferImpl(DirectX12CommandBuffer* parent, const DirectX12Queue& queue) :
Expand All @@ -26,6 +27,9 @@ class DirectX12CommandBuffer::DirectX12CommandBufferImpl : public Implement<Dire
public:
ComPtr<ID3D12GraphicsCommandList7> initialize(bool begin, bool primary)
{
// Store the command signatures for indirect drawing.
m_queue.device().indirectDrawSignatures(m_dispatchSignature, m_dispatchMeshSignature, m_drawSignature, m_drawIndexedSignature);

// Create a command allocator.
D3D12_COMMAND_LIST_TYPE type;

Expand Down Expand Up @@ -275,7 +279,7 @@ void DirectX12CommandBuffer::generateMipMaps(IDirectX12Image& image) noexcept
this->bind(*samplerBindings, pipeline);

// Transition the texture into a read/write state.
DirectX12Barrier startBarrier(PipelineStage::None, PipelineStage::Compute);
DirectX12Barrier startBarrier(PipelineStage::All, PipelineStage::Compute);
startBarrier.transition(image, ResourceAccess::None, ResourceAccess::ShaderReadWrite, ImageLayout::Undefined, ImageLayout::ReadWrite);
this->barrier(startBarrier);
auto resource = resourceBindings.begin();
Expand Down Expand Up @@ -496,21 +500,61 @@ void DirectX12CommandBuffer::dispatch(const Vector3u& threadCount) const noexcep
this->handle()->Dispatch(threadCount.x(), threadCount.y(), threadCount.z());
}

void DirectX12CommandBuffer::dispatchIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_dispatchSignature.Get(), batchCount, batchBuffer.handle().Get(), offset, nullptr, 0);
}

void DirectX12CommandBuffer::dispatchIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset, UInt64 countOffset, UInt32 maxBatches) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_dispatchSignature.Get(), std::min(maxBatches, static_cast<UInt32>(batchBuffer.alignedElementSize() / sizeof(IndirectDispatchBatch))), batchBuffer.handle().Get(), offset, countBuffer.handle().Get(), countOffset);
}

void DirectX12CommandBuffer::dispatchMesh(const Vector3u& threadCount) const noexcept
{
this->handle()->DispatchMesh(threadCount.x(), threadCount.y(), threadCount.z());
}

void DirectX12CommandBuffer::dispatchMeshIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_dispatchMeshSignature.Get(), batchCount, batchBuffer.handle().Get(), offset, nullptr, 0);
}

void DirectX12CommandBuffer::dispatchMeshIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset, UInt64 countOffset, UInt32 maxBatches) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_dispatchMeshSignature.Get(), std::min(maxBatches, static_cast<UInt32>(batchBuffer.alignedElementSize() / sizeof(IndirectDispatchBatch))), batchBuffer.handle().Get(), offset, countBuffer.handle().Get(), countOffset);
}

void DirectX12CommandBuffer::draw(UInt32 vertices, UInt32 instances, UInt32 firstVertex, UInt32 firstInstance) const noexcept
{
this->handle()->DrawInstanced(vertices, instances, firstVertex, firstInstance);
}

void DirectX12CommandBuffer::drawIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_drawSignature.Get(), batchCount, batchBuffer.handle().Get(), offset, nullptr, 0);
}

void DirectX12CommandBuffer::drawIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset, UInt64 countOffset, UInt32 maxBatches) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_drawSignature.Get(), std::min(maxBatches, static_cast<UInt32>(batchBuffer.alignedElementSize() / sizeof(IndirectBatch))), batchBuffer.handle().Get(), offset, countBuffer.handle().Get(), countOffset);
}

void DirectX12CommandBuffer::drawIndexed(UInt32 indices, UInt32 instances, UInt32 firstIndex, Int32 vertexOffset, UInt32 firstInstance) const noexcept
{
this->handle()->DrawIndexedInstanced(indices, instances, firstIndex, vertexOffset, firstInstance);
}

void DirectX12CommandBuffer::drawIndexedIndirect(const IDirectX12Buffer& batchBuffer, UInt32 batchCount, UInt64 offset) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_drawIndexedSignature.Get(), batchCount, batchBuffer.handle().Get(), offset, nullptr, 0);
}

void DirectX12CommandBuffer::drawIndexedIndirect(const IDirectX12Buffer& batchBuffer, const IDirectX12Buffer& countBuffer, UInt64 offset, UInt64 countOffset, UInt32 maxBatches) const noexcept
{
this->handle()->ExecuteIndirect(m_impl->m_drawIndexedSignature.Get(), std::min(maxBatches, static_cast<UInt32>(batchBuffer.alignedElementSize() / sizeof(IndirectIndexedBatch))), batchBuffer.handle().Get(), offset, countBuffer.handle().Get(), countOffset);
}

void DirectX12CommandBuffer::pushConstants(const DirectX12PushConstantsLayout& layout, const void* const memory) const noexcept
{
std::ranges::for_each(layout.ranges(), [this, &layout, &memory](const DirectX12PushConstantsRange* range) { this->handle()->SetGraphicsRoot32BitConstants(range->rootParameterIndex(), range->size() / 4, reinterpret_cast<const char* const>(memory) + range->offset(), 0); });
Expand Down
11 changes: 6 additions & 5 deletions src/Backends/DirectX12/src/descriptor_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ void DirectX12DescriptorSet::update(UInt32 binding, const IDirectX12Buffer& buff
}
case DescriptorType::RWStructuredBuffer:
{
// TODO: Support counter in AppendStructuredBuffer.
for (UInt32 i(0); i < elementCount; ++i)
{
D3D12_UNORDERED_ACCESS_VIEW_DESC bufferView = {
Expand All @@ -160,11 +159,12 @@ void DirectX12DescriptorSet::update(UInt32 binding, const IDirectX12Buffer& buff
{
for (UInt32 i(0); i < elementCount; ++i)
{
// NOTE: One takes 4 byte size (sizeof(DWORD)) in DXGI_FORMAT_R32_TYPELESS format, which is required for raw buffers.
D3D12_SHADER_RESOURCE_VIEW_DESC bufferView = {
.Format = DXGI_FORMAT_R32_TYPELESS,
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
.Buffer = { .FirstElement = ((bufferElement + i) * buffer.alignedElementSize()) / 4, .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / 4), .StructureByteStride = 0, .Flags = D3D12_BUFFER_SRV_FLAG_RAW }
.Buffer = { .FirstElement = (bufferElement + i) * sizeof(DWORD), .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / sizeof(DWORD)), .StructureByteStride = 0, .Flags = D3D12_BUFFER_SRV_FLAG_RAW }
};

m_impl->m_layout.device().handle()->CreateShaderResourceView(buffer.handle().Get(), &bufferView, descriptorHandle);
Expand All @@ -177,10 +177,11 @@ void DirectX12DescriptorSet::update(UInt32 binding, const IDirectX12Buffer& buff
{
for (UInt32 i(0); i < elementCount; ++i)
{
// NOTE: Individual fields in a buffer are always required to be 4 bytes wide, while alignment between elements is 16 bytes (D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT).
D3D12_UNORDERED_ACCESS_VIEW_DESC bufferView = {
.Format = DXGI_FORMAT_R32_TYPELESS,
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
.Buffer = { .FirstElement = ((bufferElement + i) * buffer.alignedElementSize()) / 4, .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / 4), .StructureByteStride = 0, .CounterOffsetInBytes = 0, .Flags = D3D12_BUFFER_UAV_FLAG_RAW }
.Buffer = { .FirstElement = (bufferElement + i) * sizeof(DWORD), .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / sizeof(DWORD)), .StructureByteStride = 0, .CounterOffsetInBytes = 0, .Flags = D3D12_BUFFER_UAV_FLAG_RAW }
};

m_impl->m_layout.device().handle()->CreateUnorderedAccessView(buffer.handle().Get(), nullptr, &bufferView, descriptorHandle);
Expand All @@ -197,7 +198,7 @@ void DirectX12DescriptorSet::update(UInt32 binding, const IDirectX12Buffer& buff
.Format = DXGI_FORMAT_R32_TYPELESS, // TODO: Actually set the proper texel format.
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
.Buffer = { .FirstElement = ((bufferElement + i) * buffer.alignedElementSize()) / 4, .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / 4), .StructureByteStride = 0, .Flags = D3D12_BUFFER_SRV_FLAG_RAW }
.Buffer = { .FirstElement = (bufferElement + i) * sizeof(DWORD), .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / sizeof(DWORD)), .StructureByteStride = 0 }
};

m_impl->m_layout.device().handle()->CreateShaderResourceView(buffer.handle().Get(), &bufferView, descriptorHandle);
Expand All @@ -213,7 +214,7 @@ void DirectX12DescriptorSet::update(UInt32 binding, const IDirectX12Buffer& buff
D3D12_UNORDERED_ACCESS_VIEW_DESC bufferView = {
.Format = DXGI_FORMAT_R32_TYPELESS, // TODO: Actually set the proper texel format.
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
.Buffer = { .FirstElement = ((bufferElement + i) * buffer.alignedElementSize()) / 4, .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / 4), .StructureByteStride = 0, .CounterOffsetInBytes = 0, .Flags = D3D12_BUFFER_UAV_FLAG_RAW }
.Buffer = { .FirstElement = (bufferElement + i) * sizeof(DWORD), .NumElements = static_cast<UInt32>(buffer.alignedElementSize() / sizeof(DWORD)), .StructureByteStride = 0, .CounterOffsetInBytes = 0 }
};

m_impl->m_layout.device().handle()->CreateUnorderedAccessView(buffer.handle().Get(), nullptr, &bufferView, descriptorHandle);
Expand Down
Loading

0 comments on commit 818ec4f

Please sign in to comment.