Skip to content

Commit

Permalink
[vulkan] Device API explicit semaphores (taichi-dev#4852)
Browse files Browse the repository at this point in the history
* Device API explicit semaphores

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Destroy the semaphore before the context

* Fix type warnings

* fix nits

* return nullptr for devices that don't need semaphores

* test out no semaphores between same queue

* Use native command list instead of emulated for dx11

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove the in-queue semaphore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use flush instead of sync in places

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix possible null semaphore

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and k-ye committed May 5, 2022
1 parent 6e11eaf commit 6221388
Show file tree
Hide file tree
Showing 24 changed files with 375 additions and 261 deletions.
2 changes: 1 addition & 1 deletion python/taichi/ui/staging_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def copy_image_u8_to_u8(src: ti.template(), dst: ti.template(),
num_components: ti.template()):
for i, j in src:
for k in ti.static(range(num_components)):
dst[i, j][k] = src[i, j][k]
dst[i, j][k] = ti.cast(src[i, j][k], ti.u8)
if num_components < 4:
# alpha channel
dst[i, j][3] = u8(255)
Expand Down
3 changes: 3 additions & 0 deletions taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ class TargetDevice : public Device {
Stream *get_compute_stream() override {
TI_NOT_IMPLEMENTED;
}
void wait_idle() override {
TI_NOT_IMPLEMENTED;
}
};

} // namespace aot
Expand Down
11 changes: 9 additions & 2 deletions taichi/backends/cpu/cpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ class CpuStream : public Stream {
~CpuStream() override{};

std::unique_ptr<CommandList> new_command_list() override{TI_NOT_IMPLEMENTED};
void submit(CommandList *cmdlist) override{TI_NOT_IMPLEMENTED};
void submit_synced(CommandList *cmdlist) override{TI_NOT_IMPLEMENTED};
StreamSemaphore submit(CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores =
{}) override{TI_NOT_IMPLEMENTED};
StreamSemaphore submit_synced(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores = {}) override{
TI_NOT_IMPLEMENTED};

void command_sync() override{TI_NOT_IMPLEMENTED};
};
Expand Down Expand Up @@ -111,6 +116,8 @@ class CpuDevice : public LlvmDevice {

Stream *get_compute_stream() override{TI_NOT_IMPLEMENTED};

void wait_idle() override{TI_NOT_IMPLEMENTED};

private:
std::vector<AllocInfo> allocations_;
std::unordered_map<int, std::unique_ptr<VirtualMemoryAllocator>>
Expand Down
11 changes: 9 additions & 2 deletions taichi/backends/cuda/cuda_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ class CudaStream : public Stream {
~CudaStream() override{};

std::unique_ptr<CommandList> new_command_list() override{TI_NOT_IMPLEMENTED};
void submit(CommandList *cmdlist) override{TI_NOT_IMPLEMENTED};
void submit_synced(CommandList *cmdlist) override{TI_NOT_IMPLEMENTED};
StreamSemaphore submit(CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores =
{}) override{TI_NOT_IMPLEMENTED};
StreamSemaphore submit_synced(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores = {}) override{
TI_NOT_IMPLEMENTED};

void command_sync() override{TI_NOT_IMPLEMENTED};
};
Expand Down Expand Up @@ -123,6 +128,8 @@ class CudaDevice : public LlvmDevice {

Stream *get_compute_stream() override{TI_NOT_IMPLEMENTED};

void wait_idle() override{TI_NOT_IMPLEMENTED};

private:
std::vector<AllocInfo> allocations_;
void validate_device_alloc(const DeviceAllocation alloc) {
Expand Down
26 changes: 22 additions & 4 deletions taichi/backends/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,26 @@ inline bool operator&(AllocUsage a, AllocUsage b) {
return static_cast<int>(a) & static_cast<int>(b);
}

class StreamSemaphoreObject {
public:
virtual ~StreamSemaphoreObject() {
}
};

using StreamSemaphore = std::shared_ptr<StreamSemaphoreObject>;

class Stream {
public:
virtual ~Stream(){};
virtual ~Stream() {
}

virtual std::unique_ptr<CommandList> new_command_list() = 0;
virtual void submit(CommandList *cmdlist) = 0;
virtual void submit_synced(CommandList *cmdlist) = 0;
virtual StreamSemaphore submit(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores = {}) = 0;
virtual StreamSemaphore submit_synced(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores = {}) = 0;

virtual void command_sync() = 0;
};
Expand Down Expand Up @@ -457,6 +470,9 @@ class Device {
// Each thraed will acquire its own stream
virtual Stream *get_compute_stream() = 0;

// Wait for all tasks to complete (task from all streams)
virtual void wait_idle() = 0;

// Mapping can fail and will return nullptr
virtual void *map_range(DevicePtr ptr, uint64_t size) = 0;
virtual void *map(DeviceAllocation alloc) = 0;
Expand Down Expand Up @@ -498,8 +514,10 @@ class Surface {
virtual ~Surface() {
}

virtual StreamSemaphore acquire_next_image() = 0;
virtual DeviceAllocation get_target_image() = 0;
virtual void present_image() = 0;
virtual void present_image(
const std::vector<StreamSemaphore> &wait_semaphores = {}) = 0;
virtual std::pair<uint32_t, uint32_t> get_size() = 0;
virtual int get_image_count() = 0;
virtual BufferFormat image_format() = 0;
Expand Down
176 changes: 78 additions & 98 deletions taichi/backends/dx/dx_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,40 +79,54 @@ Dx11ResourceBinder::~Dx11ResourceBinder() {
}

Dx11CommandList::Dx11CommandList(Dx11Device *ti_device) : device_(ti_device) {
HRESULT hr;
hr = device_->d3d11_device()->CreateDeferredContext(0,
&d3d11_deferred_context_);
check_dx_error(hr, "create deferred context");
}

Dx11CommandList::~Dx11CommandList() {
for (ID3D11Buffer *cb : used_spv_workgroup_cb) {
cb->Release();
}
if (d3d11_command_list_) {
d3d11_command_list_->Release();
}
d3d11_deferred_context_->Release();
}

void Dx11CommandList::bind_pipeline(Pipeline *p) {
Dx11Pipeline *pipeline = static_cast<Dx11Pipeline *>(p);
std::unique_ptr<CmdBindPipeline> cmd =
std::make_unique<CmdBindPipeline>(this);
cmd->compute_shader_ = pipeline->get_program();
recorded_commands_.push_back(std::move(cmd));
d3d11_deferred_context_->CSSetShader(pipeline->get_program(), nullptr, 0);
}

void Dx11CommandList::bind_resources(ResourceBinder *binder_) {
Dx11ResourceBinder *binder = static_cast<Dx11ResourceBinder *>(binder_);

// UAV
for (auto &[binding, alloc_id] : binder->uav_binding_to_alloc_id()) {
std::unique_ptr<CmdBindUAVBufferToIndex> cmd =
std::make_unique<CmdBindUAVBufferToIndex>(this);
ID3D11UnorderedAccessView *uav = device_->alloc_id_to_uav(alloc_id);
cmd->binding = binding;
cmd->uav = uav;
recorded_commands_.push_back(std::move(cmd));
d3d11_deferred_context_->CSSetUnorderedAccessViews(binding, 1, &uav,
nullptr);
}

// CBV
for (auto &[binding, alloc_id] : binder->cb_binding_to_alloc_id()) {
std::unique_ptr<CmdBindConstantBufferToIndex> cmd =
std::make_unique<CmdBindConstantBufferToIndex>(this);
cmd->binding = binding;
cmd->cb_buffer = device_->create_or_get_cb_buffer(alloc_id);
cmd->buffer = device_->alloc_id_to_buffer(alloc_id);
recorded_commands_.push_back(std::move(cmd));
auto cb_buffer = device_->create_or_get_cb_buffer(alloc_id);
auto buffer = device_->alloc_id_to_buffer(alloc_id);

D3D11_BUFFER_DESC desc;
buffer->GetDesc(&desc);
D3D11_BOX box{};
box.left = 0;
box.right = desc.ByteWidth;
box.top = 0;
box.bottom = 1; // 1 past the end!
box.front = 0;
box.back = 1;
d3d11_deferred_context_->CopySubresourceRegion(cb_buffer, 0, 0, 0, 0,
buffer, 0, &box);
d3d11_deferred_context_->CSSetConstantBuffers(binding, 1, &cb_buffer);

cb_slot_watermark_ = std::max(cb_slot_watermark_, int(binding));
}
Expand Down Expand Up @@ -140,68 +154,26 @@ void Dx11CommandList::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) {
}

void Dx11CommandList::buffer_fill(DevicePtr ptr, size_t size, uint32_t data) {
std::unique_ptr<Dx11CommandList::CmdBufferFill> cmd =
std::make_unique<CmdBufferFill>(this);
ID3D11Buffer *buf = device_->alloc_id_to_buffer(ptr.alloc_id);
ID3D11UnorderedAccessView *uav = device_->alloc_id_to_uav(ptr.alloc_id);
cmd->uav = uav;
D3D11_BUFFER_DESC desc;
buf->GetDesc(&desc);
cmd->size = desc.ByteWidth;
recorded_commands_.push_back(std::move(cmd));
}

void Dx11CommandList::CmdBufferFill::execute() {
ID3D11DeviceContext *context = cmdlist_->device_->d3d11_context();
const UINT values[4] = {data, data, data, data};
context->ClearUnorderedAccessViewUint(uav, values);
}

void Dx11CommandList::CmdBindPipeline::execute() {
ID3D11DeviceContext *context = cmdlist_->device_->d3d11_context();
context->CSSetShader(compute_shader_, nullptr, 0);
}

void Dx11CommandList::CmdBindUAVBufferToIndex::execute() {
cmdlist_->device_->d3d11_context()->CSSetUnorderedAccessViews(binding, 1,
&uav, nullptr);
}

void Dx11CommandList::CmdBindConstantBufferToIndex::execute() {
D3D11_BUFFER_DESC desc;
buffer->GetDesc(&desc);
D3D11_BOX box{};
box.left = 0;
box.right = desc.ByteWidth;
box.top = 0;
box.bottom = 1; // 1 past the end!
box.front = 0;
box.back = 1;
cmdlist_->device_->d3d11_context()->CopySubresourceRegion(cb_buffer, 0, 0, 0,
0, buffer, 0, &box);
cmdlist_->device_->d3d11_context()->CSSetConstantBuffers(binding, 1,
&cb_buffer);
}

void Dx11CommandList::CmdDispatch::execute() {
cmdlist_->device_->set_spirv_cross_numworkgroups(x, y, z,
spirv_cross_num_wg_cb_slot_);
cmdlist_->device_->d3d11_context()->Dispatch(x, y, z);
d3d11_deferred_context_->ClearUnorderedAccessViewUint(uav, values);
}

void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) {
std::unique_ptr<CmdDispatch> cmd = std::make_unique<CmdDispatch>(this);
cmd->x = x;
cmd->y = y;
cmd->z = z;

// Set SPIRV_Cross_NumWorkgroups's CB slot based on the watermark
cmd->spirv_cross_num_wg_cb_slot_ = cb_slot_watermark_ + 1;
auto cb_slot = cb_slot_watermark_ + 1;
auto spirv_cross_numworkgroups_cb =
device_->set_spirv_cross_numworkgroups(x, y, z, cb_slot);
d3d11_deferred_context_->CSSetConstantBuffers(cb_slot, 1,
&spirv_cross_numworkgroups_cb);
used_spv_workgroup_cb.push_back(spirv_cross_numworkgroups_cb);

// Reset watermark
cb_slot_watermark_ = -1;

recorded_commands_.push_back(std::move(cmd));
d3d11_deferred_context_->Dispatch(x, y, z);
}

void Dx11CommandList::begin_renderpass(int x0,
Expand Down Expand Up @@ -260,19 +232,14 @@ void Dx11CommandList::image_to_buffer(DevicePtr dst_buf,
}

void Dx11CommandList::run_commands() {
for (const auto &cmd : recorded_commands_) {
cmd->execute();
if (!d3d11_command_list_) {
HRESULT hr;
hr =
d3d11_deferred_context_->FinishCommandList(FALSE, &d3d11_command_list_);
check_dx_error(hr, "error finishing command list");
}
}

int Dx11CommandList::cb_count() {
int ret = 0;
for (const auto &cmd : recorded_commands_) {
if (dynamic_cast<CmdBindConstantBufferToIndex *>(cmd.get()) != nullptr) {
ret++;
}
}
return ret;
device_->d3d11_context()->ExecuteCommandList(d3d11_command_list_, TRUE);
}

namespace {
Expand Down Expand Up @@ -739,6 +706,9 @@ void Dx11Device::image_to_buffer(DevicePtr dst_buf,
TI_NOT_IMPLEMENTED;
}

void Dx11Device::wait_idle() {
}

ID3D11Buffer *Dx11Device::alloc_id_to_buffer(uint32_t alloc_id) {
return alloc_id_to_buffer_.at(alloc_id);
}
Expand Down Expand Up @@ -766,33 +736,35 @@ ID3D11Buffer *Dx11Device::create_or_get_cb_buffer(uint32_t alloc_id) {
return cb_buf;
}

void Dx11Device::set_spirv_cross_numworkgroups(uint32_t x,
uint32_t y,
uint32_t z,
int cb_slot) {
if (spirv_cross_numworkgroups_ == nullptr) {
ID3D11Buffer *temp;
create_raw_buffer(device_, 16, nullptr, &temp);
create_cpu_accessible_buffer_copy(device_, temp,
&spirv_cross_numworkgroups_);
temp->Release();
}
if (spirv_cross_numworkgroups_cb_ == nullptr) {
create_constant_buffer_copy(device_, spirv_cross_numworkgroups_,
&spirv_cross_numworkgroups_cb_);
}
ID3D11Buffer *Dx11Device::set_spirv_cross_numworkgroups(uint32_t x,
uint32_t y,
uint32_t z,
int cb_slot) {
ID3D11Buffer *spirv_cross_numworkgroups;
ID3D11Buffer *temp;
create_raw_buffer(device_, 16, nullptr, &temp);
create_cpu_accessible_buffer_copy(device_, temp, &spirv_cross_numworkgroups);
temp->Release();

ID3D11Buffer *spirv_cross_numworkgroups_cb;
create_constant_buffer_copy(device_, spirv_cross_numworkgroups,
&spirv_cross_numworkgroups_cb);

D3D11_MAPPED_SUBRESOURCE mapped;
context_->Map(spirv_cross_numworkgroups_, 0, D3D11_MAP_WRITE, 0, &mapped);
d3d11_context()->Map(spirv_cross_numworkgroups, 0, D3D11_MAP_WRITE, 0,
&mapped);
uint32_t *u = reinterpret_cast<uint32_t *>(mapped.pData);
u[0] = x;
u[1] = y;
u[2] = z;
context_->Unmap(spirv_cross_numworkgroups_, 0);
d3d11_context()->Unmap(spirv_cross_numworkgroups, 0);

d3d11_context()->CopyResource(spirv_cross_numworkgroups_cb,
spirv_cross_numworkgroups);

context_->CopyResource(spirv_cross_numworkgroups_cb_,
spirv_cross_numworkgroups_);
context_->CSSetConstantBuffers(cb_slot, 1, &spirv_cross_numworkgroups_cb_);
spirv_cross_numworkgroups->Release();

return spirv_cross_numworkgroups_cb;
}

Dx11Stream::Dx11Stream(Dx11Device *device_) : device_(device_) {
Expand All @@ -805,15 +777,23 @@ std::unique_ptr<CommandList> Dx11Stream::new_command_list() {
return std::make_unique<Dx11CommandList>(device_);
}

void Dx11Stream::submit(CommandList *cmdlist) {
StreamSemaphore Dx11Stream::submit(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores) {
Dx11CommandList *dx_cmd_list = static_cast<Dx11CommandList *>(cmdlist);
dx_cmd_list->run_commands();

return nullptr;
}

// No difference for DX11
void Dx11Stream::submit_synced(CommandList *cmdlist) {
StreamSemaphore Dx11Stream::submit_synced(
CommandList *cmdlist,
const std::vector<StreamSemaphore> &wait_semaphores) {
Dx11CommandList *dx_cmd_list = static_cast<Dx11CommandList *>(cmdlist);
dx_cmd_list->run_commands();

return nullptr;
}

void Dx11Stream::command_sync() {
Expand Down
Loading

0 comments on commit 6221388

Please sign in to comment.