Skip to content

Commit

Permalink
[aot] Added third-party render thread task injection for Unity (taich…
Browse files Browse the repository at this point in the history
…i-dev#7151)

Issue: #

### Brief Summary

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 555de3c commit 1ea2f18
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 12 deletions.
7 changes: 7 additions & 0 deletions c_api/include/taichi/taichi_unity.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ extern "C" {
// Handle `TixNativeBufferUnity`
typedef struct TixNativeBufferUnity_t *TixNativeBufferUnity;

// Callback `TixAsyncTaskUnity`
typedef void(TI_API_CALL *TixAsyncTaskUnity)(void *user_data);

// Function `tix_import_native_runtime_unity`
TI_DLL_EXPORT TiRuntime TI_API_CALL tix_import_native_runtime_unity();

// Function `tix_enqueue_task_async_unity`
TI_DLL_EXPORT void TI_API_CALL
tix_enqueue_task_async_unity(void *user_data, TixAsyncTaskUnity async_task);

// Function `tix_launch_kernel_async_unity`
TI_DLL_EXPORT void TI_API_CALL
tix_launch_kernel_async_unity(TiRuntime runtime,
Expand Down
25 changes: 25 additions & 0 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,17 @@
"type": "handle",
"is_dispatchable": false
},
{
"name": "async_task",
"vendor": "unity",
"type": "callback",
"parameters": [
{
"name": "user_data",
"type": "void*"
}
]
},
{
"name": "import_native_runtime",
"vendor": "unity",
Expand All @@ -1519,6 +1530,20 @@
}
]
},
{
"name": "enqueue_task_async",
"vendor": "unity",
"type": "function",
"parameters": [
{
"name": "user_data",
"type": "void*"
},
{
"type": "callback.async_task"
}
]
},
{
"name": "launch_kernel_async",
"vendor": "unity",
Expand Down
30 changes: 23 additions & 7 deletions misc/generate_c_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import re
from os import system

from taichi_json import (Alias, BitField, BuiltInType, Definition, EntryBase,
Enumeration, Field, Function, Handle, Module,
Structure, Union)
from taichi_json import (Alias, BitField, BuiltInType, Callback, Definition,
EntryBase, Enumeration, Field, Function, Handle,
Module, Structure, Union)


def get_type_name(x: EntryBase):
ty = type(x)
if ty in [BuiltInType]:
return x.type_name
elif ty in [Alias, Handle, Enumeration, Structure, Union]:
elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]:
return x.name.upper_camel_case
elif ty in [BitField]:
return x.name.extend('flags').upper_camel_case
Expand Down Expand Up @@ -110,6 +110,21 @@ def get_declr(module: Module, x: EntryBase, with_docs=False):
out += [f" {get_field(variant)};"]
out += ["} " + get_type_name(x) + ";"]

elif ty is Callback:
return_value_type = "void" if x.return_value_type == None else get_type_name(
x.return_value_type)
out += [
f"typedef {return_value_type} (TI_API_CALL *{get_type_name(x)})("
]
if x.params:
for i, param in enumerate(x.params):
if i != 0:
out[-1] += ","
if with_docs:
out += get_api_field_ref(module, x, param.name)
out += [f" {get_field(param)}"]
out += [");"]

elif ty is Function:
return_value_type = "void" if x.return_value_type == None else get_type_name(
x.return_value_type)
Expand Down Expand Up @@ -143,7 +158,8 @@ def get_human_readable_name(x: EntryBase):
elif ty is Definition:
return f"{x.name.screaming_snake_case}"

elif isinstance(x, (Handle, Enumeration, BitField, Structure, Union)):
elif isinstance(
x, (Handle, Enumeration, BitField, Structure, Union, Callback)):
return f"{get_type_name(x)}"

elif ty is Function:
Expand All @@ -162,7 +178,7 @@ def get_title(x: EntryBase):
extra += " (Device Command)"

if isinstance(x, (Alias, Definition, Handle, Enumeration, BitField,
Structure, Union, Function)):
Structure, Union, Callback, Function)):
return f"{type(x).__name__} `{get_human_readable_name(x)}`" + extra
else:
raise RuntimeError(f"'{x.id}' doesn't need title")
Expand Down Expand Up @@ -237,7 +253,7 @@ def get_human_readable_field_name(x: EntryBase, field_name: str):
if str(field.name) == field_name:
out = str(field.name)
break
elif isinstance(x, Function):
elif isinstance(x, (Callback, Function)):
for field in x.params:
if str(field.name) == field_name:
out = str(field.name)
Expand Down
4 changes: 3 additions & 1 deletion misc/generate_c_api_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def print_module_doc(module: Module):
else:
print(f"WARNING: `{x}` is not documented")
out += [""]
out += [""]

if out[-1]:
out += [""]

return '\n'.join(out)

Expand Down
17 changes: 13 additions & 4 deletions misc/generate_unity_language_binding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re

from taichi_json import (Alias, BitField, BuiltInType, Definition, EntryBase,
Enumeration, Field, Function, Handle, Module,
Structure, Union)
from taichi_json import (Alias, BitField, BuiltInType, Callback, Definition,
EntryBase, Enumeration, Field, Function, Handle,
Module, Structure, Union)

RESERVED_WORD_TRANSFORM = {
'event': 'event_',
Expand All @@ -14,7 +14,7 @@ def get_type_name(x: EntryBase):
ty = type(x)
if ty in [BuiltInType]:
return x.type_name
elif ty in [Alias, Handle, Enumeration, Structure, Union]:
elif ty in [Alias, Handle, Enumeration, Structure, Union, Callback]:
return x.name.upper_camel_case
elif ty in [BitField]:
return x.name.extend('flag_bits').upper_camel_case
Expand Down Expand Up @@ -154,6 +154,15 @@ def get_declr(x: EntryBase):
out += ["}"]
return '\n'.join(out)

elif ty is Callback:
out = [
"[StructLayout(LayoutKind.Sequential)]",
"public struct " + get_type_name(x) + " {",
" public IntPtr Inner;",
"}",
]
return '\n'.join(out)

elif ty is Function:

out = []
Expand Down
17 changes: 17 additions & 0 deletions misc/taichi_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,21 @@ def __init__(self, j):
self.variants += [Field(x)]


class Callback(EntryBase):
def __init__(self, j):
super().__init__(j, "callback")
self.return_value_type = None
self.params = []

if "parameters" in j:
for x in j["parameters"]:
field = Field(x)
if field.name.snake_case == "@return":
self.return_value_type = field.type
else:
self.params += [field]


class Function(EntryBase):
def __init__(self, j):
super().__init__(j, "function")
Expand Down Expand Up @@ -374,6 +389,8 @@ def __init__(self, version: Version, j: dict,
self.declr_reg.register(Structure(k))
elif ty == "union":
self.declr_reg.register(Union(k))
elif ty == "callback":
self.declr_reg.register(Callback(k))
elif ty == "function":
self.declr_reg.register(Function(k))
else:
Expand Down

0 comments on commit 1ea2f18

Please sign in to comment.