Skip to content

Commit

Permalink
[aot] Guard C-API interfaces with try-catch (#6060)
Browse files Browse the repository at this point in the history
  • Loading branch information
PENGUINLIONG authored Sep 15, 2022
1 parent ebe258e commit a9f2905
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 23 deletions.
1 change: 1 addition & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ typedef enum TiError {
TI_ERROR_ARGUMENT_OUT_OF_RANGE = -6,
TI_ERROR_ARGUMENT_NOT_FOUND = -7,
TI_ERROR_INVALID_INTEROP = -8,
TI_ERROR_INVALID_STATE = -9,
TI_ERROR_MAX_ENUM = 0xffffffff,
} TiError;

Expand Down
113 changes: 96 additions & 17 deletions c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const char *describe_error(TiError error) {
return "argument not found";
case TI_ERROR_INVALID_INTEROP:
return "invalid interop";
case TI_ERROR_INVALID_STATE:
return "invalid state";
default:
return "unknown error";
}
Expand Down Expand Up @@ -104,6 +106,8 @@ Runtime &Event::runtime() {
// -----------------------------------------------------------------------------

TiError ti_get_last_error(uint64_t message_size, char *message) {
TiError out = TI_ERROR_INVALID_STATE;
TI_CAPI_TRY_CATCH_BEGIN();
// Emit message only if the output buffer is property provided.
if (message_size > 0 && message != nullptr) {
size_t n = thread_error_cache.message.size();
Expand All @@ -113,11 +117,14 @@ TiError ti_get_last_error(uint64_t message_size, char *message) {
std::memcpy(message, thread_error_cache.message.data(), n);
message[n] = '\0';
}
return thread_error_cache.error;
out = thread_error_cache.error;
TI_CAPI_TRY_CATCH_END();
return out;
}
// C-API errors MUST be set via this interface. No matter from internal or
// external procedures.
void ti_set_last_error(TiError error, const char *message) {
TI_CAPI_TRY_CATCH_BEGIN();
if (error < TI_ERROR_SUCCESS) {
TI_WARN("C-API error: ({}) {}", describe_error(error), message);
if (message != nullptr) {
Expand All @@ -130,48 +137,61 @@ void ti_set_last_error(TiError error, const char *message) {
thread_error_cache.error = TI_ERROR_SUCCESS;
thread_error_cache.message.clear();
}
TI_CAPI_TRY_CATCH_END();
}

TiRuntime ti_create_runtime(TiArch arch) {
TiRuntime out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
switch (arch) {
#ifdef TI_WITH_VULKAN
case TI_ARCH_VULKAN: {
return (TiRuntime)(static_cast<Runtime *>(new VulkanRuntimeOwned));
out = (TiRuntime)(static_cast<Runtime *>(new VulkanRuntimeOwned));
break;
}
#endif // TI_WITH_VULKAN
#ifdef TI_WITH_OPENGL
case TI_ARCH_OPENGL: {
return (TiRuntime)(static_cast<Runtime *>(new OpenglRuntime));
out = (TiRuntime)(static_cast<Runtime *>(new OpenglRuntime));
break;
}
#endif // TI_WITH_OPENGL
#ifdef TI_WITH_LLVM
case TI_ARCH_X64: {
return (TiRuntime)(static_cast<Runtime *>(
out = (TiRuntime)(static_cast<Runtime *>(
new capi::LlvmRuntime(taichi::Arch::x64)));
break;
}
case TI_ARCH_ARM64: {
return (TiRuntime)(static_cast<Runtime *>(
out = (TiRuntime)(static_cast<Runtime *>(
new capi::LlvmRuntime(taichi::Arch::arm64)));
break;
}
case TI_ARCH_CUDA: {
return (TiRuntime)(static_cast<Runtime *>(
out = (TiRuntime)(static_cast<Runtime *>(
new capi::LlvmRuntime(taichi::Arch::cuda)));
break;
}
#endif // TI_WITH_LLVM
default: {
TI_CAPI_NOT_SUPPORTED(arch);
return TI_NULL_HANDLE;
}
}
return TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_runtime(TiRuntime runtime) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
delete (Runtime *)runtime;
TI_CAPI_TRY_CATCH_END();
}

TiMemory ti_allocate_memory(TiRuntime runtime,
const TiMemoryAllocateInfo *create_info) {
TiMemory out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(create_info);

Expand All @@ -196,35 +216,46 @@ TiMemory ti_allocate_memory(TiRuntime runtime,
params.export_sharing = create_info->export_sharing;
params.usage = usage;

TiMemory devmem = ((Runtime *)runtime)->allocate_memory(params);
return devmem;
out = ((Runtime *)runtime)->allocate_memory(params);
TI_CAPI_TRY_CATCH_END();
return out;
}

void ti_free_memory(TiRuntime runtime, TiMemory devmem) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(devmem);

Runtime *runtime2 = (Runtime *)runtime;
runtime2->free_memory(devmem);
TI_CAPI_TRY_CATCH_END();
}

void *ti_map_memory(TiRuntime runtime, TiMemory devmem) {
void *out = nullptr;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(devmem);

Runtime *runtime2 = (Runtime *)runtime;
return runtime2->get().map(devmem2devalloc(*runtime2, devmem));
out = runtime2->get().map(devmem2devalloc(*runtime2, devmem));
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_unmap_memory(TiRuntime runtime, TiMemory devmem) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(devmem);

Runtime *runtime2 = (Runtime *)runtime;
runtime2->get().unmap(devmem2devalloc(*runtime2, devmem));
TI_CAPI_TRY_CATCH_END();
}

TiImage ti_allocate_image(TiRuntime runtime,
const TiImageAllocateInfo *allocate_info) {
TiImage out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(allocate_info);

Expand Down Expand Up @@ -275,43 +306,58 @@ TiImage ti_allocate_image(TiRuntime runtime,
params.export_sharing = false;
params.usage = usage;

TiImage devimg = ((Runtime *)runtime)->allocate_image(params);
return devimg;
out = ((Runtime *)runtime)->allocate_image(params);
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_free_image(TiRuntime runtime, TiImage image) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(image);

((Runtime *)runtime)->free_image(image);
TI_CAPI_TRY_CATCH_END();
}

TiSampler ti_create_sampler(TiRuntime runtime,
const TiSamplerCreateInfo *create_info) {
TiSampler out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_NOT_SUPPORTED(ti_create_sampler);
return TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_sampler(TiRuntime runtime, TiSampler sampler) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_NOT_SUPPORTED(ti_destroy_sampler);
TI_CAPI_TRY_CATCH_END();
}

TiEvent ti_create_event(TiRuntime runtime) {
TiEvent out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);

Runtime *runtime2 = (Runtime *)runtime;
std::unique_ptr<taichi::lang::DeviceEvent> event =
runtime2->get().create_event();
Event *event2 = new Event(*runtime2, std::move(event));
return (TiEvent)event2;
out = (TiEvent)event2;
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_event(TiEvent event) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(event);

delete (Event *)event;
TI_CAPI_TRY_CATCH_END();
}

void ti_copy_memory_device_to_device(TiRuntime runtime,
const TiMemorySlice *dst_memory,
const TiMemorySlice *src_memory) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(dst_memory);
TI_CAPI_ARGUMENT_NULL(dst_memory->memory);
Expand All @@ -325,11 +371,13 @@ void ti_copy_memory_device_to_device(TiRuntime runtime,
auto src = devmem2devalloc(*runtime2, src_memory->memory)
.get_ptr(src_memory->offset);
runtime2->buffer_copy(dst, src, dst_memory->size);
TI_CAPI_TRY_CATCH_END();
}

void ti_copy_texture_device_to_device(TiRuntime runtime,
const TiImageSlice *dst_texture,
const TiImageSlice *src_texture) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(dst_texture);
TI_CAPI_ARGUMENT_NULL(dst_texture->image);
Expand All @@ -353,10 +401,12 @@ void ti_copy_texture_device_to_device(TiRuntime runtime,
params.height = dst_texture->extent.height;
params.depth = dst_texture->extent.depth;
runtime2->copy_image(dst, src, params);
TI_CAPI_TRY_CATCH_END();
}
void ti_transition_texture(TiRuntime runtime,
TiImage texture,
TiImageLayout layout) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(texture);

Expand All @@ -376,9 +426,12 @@ void ti_transition_texture(TiRuntime runtime,
}

runtime2->transition_image(image, layout2);
TI_CAPI_TRY_CATCH_END();
}

TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
TiAotModule out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(module_path);

Expand All @@ -388,15 +441,21 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
ti_set_last_error(TI_ERROR_CORRUPTED_DATA, module_path);
return TI_NULL_HANDLE;
}
return aot_module;
out = aot_module;
TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_aot_module(TiAotModule aot_module) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(aot_module);

delete (AotModule *)aot_module;
TI_CAPI_TRY_CATCH_END();
}

TiKernel ti_get_aot_module_kernel(TiAotModule aot_module, const char *name) {
TiKernel out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(aot_module);
TI_CAPI_ARGUMENT_NULL_RV(name);

Expand All @@ -408,11 +467,15 @@ TiKernel ti_get_aot_module_kernel(TiAotModule aot_module, const char *name) {
return TI_NULL_HANDLE;
}

return (TiKernel)kernel;
out = (TiKernel)kernel;
TI_CAPI_TRY_CATCH_END();
return out;
}

TiComputeGraph ti_get_aot_module_compute_graph(TiAotModule aot_module,
const char *name) {
TiComputeGraph out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(aot_module);
TI_CAPI_ARGUMENT_NULL_RV(name);

Expand All @@ -424,13 +487,16 @@ TiComputeGraph ti_get_aot_module_compute_graph(TiAotModule aot_module,
return TI_NULL_HANDLE;
}

return (TiComputeGraph)cgraph;
out = (TiComputeGraph)cgraph;
TI_CAPI_TRY_CATCH_END();
return out;
}

void ti_launch_kernel(TiRuntime runtime,
TiKernel kernel,
uint32_t arg_count,
const TiArgument *args) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(kernel);
if (arg_count > 0) {
Expand Down Expand Up @@ -482,12 +548,14 @@ void ti_launch_kernel(TiRuntime runtime,
}
}
((taichi::lang::aot::Kernel *)kernel)->launch(&runtime_context);
TI_CAPI_TRY_CATCH_END();
}

void ti_launch_compute_graph(TiRuntime runtime,
TiComputeGraph compute_graph,
uint32_t arg_count,
const TiNamedArgument *args) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(compute_graph);
if (arg_count > 0) {
Expand Down Expand Up @@ -615,36 +683,47 @@ void ti_launch_compute_graph(TiRuntime runtime,
}
}
((taichi::lang::aot::CompiledGraph *)compute_graph)->run(arg_map);
TI_CAPI_TRY_CATCH_END();
}

void ti_signal_event(TiRuntime runtime, TiEvent event) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(event);

((Runtime *)runtime)->signal_event(&((Event *)event)->get());
TI_CAPI_TRY_CATCH_END();
}

void ti_reset_event(TiRuntime runtime, TiEvent event) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(event);

((Runtime *)runtime)->reset_event(&((Event *)event)->get());
TI_CAPI_TRY_CATCH_END();
}

void ti_wait_event(TiRuntime runtime, TiEvent event) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);
TI_CAPI_ARGUMENT_NULL(event);

((Runtime *)runtime)->wait_event(&((Event *)event)->get());
TI_CAPI_TRY_CATCH_END();
}

void ti_submit(TiRuntime runtime) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);

((Runtime *)runtime)->submit();
TI_CAPI_TRY_CATCH_END();
}
void ti_wait(TiRuntime runtime) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);

((Runtime *)runtime)->wait();
TI_CAPI_TRY_CATCH_END();
}
Loading

0 comments on commit a9f2905

Please sign in to comment.