From 5ab41a4c52ccb74960af3618cda25733948c8128 Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Tue, 28 Jun 2022 02:20:15 +0800 Subject: [PATCH 1/3] Separated api spec json parsing and codegen --- misc/generate_c_api.py | 476 ++++++++++------------------------------- misc/taichi_json.py | 270 +++++++++++++++++++++++ 2 files changed, 387 insertions(+), 359 deletions(-) create mode 100644 misc/taichi_json.py diff --git a/misc/generate_c_api.py b/misc/generate_c_api.py index 263fedad8162a..51fabf6e56d82 100644 --- a/misc/generate_c_api.py +++ b/misc/generate_c_api.py @@ -1,402 +1,160 @@ -import json -import re +from taichi_json import EntryBase, BuiltInType, Alias, Handle, Definition, \ + Handle, Enumeration, BitField, Field, Structure, Union, Function, Module #from os import system -class Name: - def __init__(self, name: str): - assert re.match('^[@a-z0-9_]+$', name) - self._segs = name.split("_") +def get_type_name(x: EntryBase): + ty = type(x) + if ty in [BuiltInType]: + return x.type_name + elif ty in [Alias, Handle, Enumeration, Structure, Union]: + return x.name.upper_camel_case + elif ty in [BitField]: + return x.name.extend('flag_bits').upper_camel_case + else: + raise RuntimeError(f"'{x.id}' is not a type") - @property - def snake_case(self) -> str: - return '_'.join(self._segs) - @property - def screaming_snake_case(self) -> str: - return '_'.join(x.upper() for x in self._segs) +def get_field(x: Field): + # `count` is an integer so it's a static array. + is_dyn_array = x.count and not isinstance(x.count, int) - @property - def upper_camel_case(self) -> str: - return ''.join(x.title() for x in self._segs) + is_ptr = x.by_ref or x.by_mut or is_dyn_array + const_q = "const" if not x.by_mut else "" + type_name = get_type_name(x.type) - def __repr__(self) -> str: - return self.snake_case + if is_ptr: + return f"{const_q} {type_name}* {x.name}" + elif x.count: + return f"{type_name} {x.name}[{x.count}]" + else: + return f"{type_name} {x.name}" -class DeclarationRegistry: - current = None +def get_declr(x: EntryBase): + ty = type(x) + if ty is BuiltInType: + return "" - def __init__(self): - # "xxx.yyy" -> Xxx(yyy) Look-up table. - self._inner = {} - self._imported = {} + elif ty is Alias: + return f"typedef {get_type_name(x.alias_of)} {get_type_name(x)};" - def resolve(self, id: str): - if id in self._inner: - return self._inner[id] - elif id in self._imported: - return self._imported[id] - else: - return None + elif ty is Definition: + return f"#define {x.name.screaming_snake_case} {x.value}" - def register(self, x): - self._inner[x.id] = x + elif ty is Handle: + return f"typedef struct {get_type_name(x)}_t* {get_type_name(x)};" - def import_declrs(self, other): - for x in other._inner.values(): - self._imported[x.id] = x - - def __iter__(self): - return iter(self._inner) - - @staticmethod - def set_current(declr_reg): - DeclarationRegistry.current = declr_reg - - -class Alias: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"alias.{self.name}" - self.alias_of = j["alias_of"] - - @property - def type_name(self) -> str: - return "Ti" + self.name.upper_camel_case - - def declr(self): - return f"typedef {self.alias_of} {self.type_name};" - - -class Definition: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"definition.{self.name}" - self.value = j["value"] - - def declr(self): - return f"#define TI_{self.name.screaming_snake_case} {self.value}" - - -class Handle: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"handle.{self.name}" - self.is_dispatchable = j["is_dispatchable"] - - @property - def type_name(self) -> str: - return "Ti" + self.name.upper_camel_case - - def declr(self): - return f"typedef struct {self.type_name}_t* {self.type_name};" - - -class Enumeration: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"enumeration.{self.name}" - - cases = {} - if "inc_cases" in j: - inc_file_name = "taichi/inc/" + j["inc_cases"] + ".inc.h" - with open(inc_file_name) as f: - for line in f.readlines(): - m = re.match(r"\w+\((\w+)\).*", line) - if m: - case_name = self.get_case_name(Name(m[1])) - cases[case_name] = len(cases) - else: - for name, value in j["cases"].items(): - cases[self.get_case_name(Name(name))] = value - self.cases = cases - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def get_case_name(self, case_name: Name): - return "TI_" + self.name.screaming_snake_case + "_" + case_name.screaming_snake_case - - def declr(self): - out = ["typedef enum " + self.type_name + " {"] - for name, value in self.cases.items(): - out += [f" {name} = {value},"] - out += [f" {self.get_case_name(Name('max_enum'))} = 0xffffffff,"] - out += ["} " + self.type_name + ";"] + elif ty is Enumeration: + out = ["typedef enum " + get_type_name(x) + " {"] + for name, value in x.cases.items(): + out += [f" {name.screaming_snake_case} = {value},"] + out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0xffffffff,"] + out += ["} " + get_type_name(x) + ";"] return '\n'.join(out) - -class BitField: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"bit_field.{self.name}" - - bits = {} - if "inc_bits" in j: - inc_file_name = "taichi/inc/" + j["inc_bits"] + ".inc.h" - with open(inc_file_name) as f: - for line in enumerate(f.readlines()): - m = re.match(r"\w+\((\w+)\).*", line) - if m: - bit_name = self.get_flag_name(Name(m[1])) - bits[bit_name] = len(bits) - else: - for name, value in j["bits"].items(): - bits[self.get_flag_name(Name(name))] = value - self.bits = bits - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case + "FlagBits" - - @property - def field_type_name(self): - return "Ti" + self.name.upper_camel_case + "Flags" - - def get_flag_name(self, flag_name: Name): - return "TI_" + self.name.screaming_snake_case + "_" + flag_name.screaming_snake_case + "_BIT" - - def declr(self): - out = ["typedef enum " + self.type_name + " {"] - for name, value in self.bits.items(): - out += [f" {name} = 1 << {value},"] - out += ["} " + self.type_name + ";"] - out += [f"typedef TiFlags {self.field_type_name};"] + elif ty is BitField: + out = ["typedef enum " + get_type_name(x) + " {"] + for name, value in x.bits.items(): + out += [f" {name.extend('bit').screaming_snake_case} = 1 << {value},"] + out += ["} " + get_type_name(x) + ";"] + out += [ + f"typedef TiFlags {x.name.extend('flags').upper_camel_case};"] return '\n'.join(out) - -class CType: - def __init__(self, name: str): - self.name = name - - @property - def type_name(self): - return self.name - - -class Field: - def __init__(self, j): - ty = DeclarationRegistry.current.resolve(j["type"]) - if ty != None: - # The type has been registered. - self.type = ty - if "name" in j: - self.name = Name(j["name"]) - else: - self.name = ty.name - else: - # The type is not (yet) registered, treat it as a untracked C type. - self.type = CType(j["type"]) - self.name = Name(j["name"]) - self.count = j["count"] if "count" in j else None - self.by_mut = j["by_mut"] if "by_mut" in j else None - self.by_ref = j["by_ref"] if "by_ref" in j else None - - def declr(self): - # `count` is an integer so it's a static array. - is_dyn_array = self.count and not isinstance(self.count, int) - - is_ptr = self.by_ref or self.by_mut or is_dyn_array - const_q = "const" if not self.by_mut else "" - - if is_ptr: - return f"{const_q} {self.type.type_name}* {self.name}" - elif self.count: - return f"{self.type.type_name} {self.name}[{self.count}]" - else: - return f"{self.type.type_name} {self.name}" - - -class Structure: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"structure.{self.name}" - self.fields = [] - if "fields" in j: - for x in j["fields"]: - self.fields += [Field(x)] - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def declr(self): - out = ["typedef struct " + self.type_name + " {"] - for x in self.fields: - out += [f" {x.declr()};"] - out += ["} " + self.type_name + ";"] + elif ty is Structure: + out = ["typedef struct " + get_type_name(x) + " {"] + for field in x.fields: + out += [f" {get_field(field)};"] + out += ["} " + get_type_name(x) + ";"] return '\n'.join(out) - -class Union: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"union.{self.name}" - self.variants = [] - if "variants" in j: - for x in j["variants"]: - self.variants += [Field(x)] - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def declr(self): - out = ["typedef union " + self.type_name + " {"] - for x in self.variants: - out += [f" {x.declr()};"] - out += ["} " + self.type_name + ";"] + elif ty is Union: + out = ["typedef union " + get_type_name(x) + " {"] + for variant in x.variants: + out += [f" {get_field(variant)};"] + out += ["} " + get_type_name(x) + ";"] return '\n'.join(out) - -class Function: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"function.{self.name}" - self.version = 1 - self.is_extension = False - self.return_value_type = None - self.params = [] - - if "version" in j: - self.version = j["version"] - - if "is_extension" in j: - self.is_extension = j["is_extension"] - - if "parameters" in j: - for x in j["parameters"]: - field = Field(x) - if field.name.snake_case == "@return": - self.return_value_type = field.type - else: - self.params += [field] - - @property - def func_name(self): - name = "ti_" + self.name.snake_case - if self.is_extension: - name += "_ext" - if self.version > 1: - name += f"_{self.version}" - return name - - def declr(self): - return_value_type = "void" if self.return_value_type == None else self.return_value_type.type_name + elif ty is Function: + return_value_type = "void" if x.return_value_type == None else get_type_name( + x.return_value_type) out = [ "TI_DLL_EXPORT " + return_value_type + " TI_API_CALL " + - self.func_name + "(" + x.name.snake_case + "(" ] - out += [',\n'.join(f" {param.declr()}" for param in self.params)] + if x.params: + out += [',\n'.join(f" {get_field(param)}" for param in x.params)] out += [");"] return '\n'.join(out) + else: + raise RuntimeError(f"'{x.id}' doesn't need declaration") -class Module: - all_modules = {} - - def __init__(self, j): - self.is_built_in = False - self.declr_reg = DeclarationRegistry() - self.required_modules = [] - - DeclarationRegistry.set_current(self.declr_reg) - - if "is_built_in" in j: - self.is_built_in = True - # Built-in headers are hand-written so we can return right away. - return - - if "required_modules" in j: - for x in j["required_modules"]: - assert x in Module.all_modules - module = Module.all_modules[x] - self.declr_reg.import_declrs(module.declr_reg) - self.required_modules += [x] - - if "declarations" in j: - for k in j["declarations"]: - ty = k["type"] - - if ty == "alias": - self.declr_reg.register(Alias(k)) - elif ty == "definition": - self.declr_reg.register(Definition(k)) - elif ty == "handle": - self.declr_reg.register(Handle(k)) - elif ty == "enumeration": - self.declr_reg.register(Enumeration(k)) - elif ty == "bit_field": - self.declr_reg.register(BitField(k)) - elif ty == "structure": - self.declr_reg.register(Structure(k)) - elif ty == "union": - self.declr_reg.register(Union(k)) - elif ty == "function": - self.declr_reg.register(Function(k)) - else: - print(f"ignored unrecognized type declaration '{k}'") - - DeclarationRegistry.set_current(None) - - def declr(self): - out = ["#pragma once"] - - for x in self.required_modules: - out += [f"#include <{x}>"] - out += [ - "", - "#ifdef __cplusplus", - 'extern "C" {', - "#endif // __cplusplus", - "", - ] +def print_module_header(module): + out = ["#pragma once"] - for x in self.declr_reg: - out += [ - "", - f"// {x}", - self.declr_reg.resolve(x).declr(), - ] + for x in module.required_modules: + out += [f"#include <{x}>"] + out += [ + "", + "#ifdef __cplusplus", + 'extern "C" {', + "#endif // __cplusplus", + "", + ] + + for x in module.declr_reg: out += [ "", - "#ifdef __cplusplus", - '} // extern "C"', - "#endif // __cplusplus", - "", + f"// {x}", + get_declr(module.declr_reg.resolve(x)), ] - return '\n'.join(out) - - @staticmethod - def generate_header(j): - module_name = j["name"] - module = Module(j) - Module.all_modules[module_name] = module + out += [ + "", + "#ifdef __cplusplus", + '} // extern "C"', + "#endif // __cplusplus", + "", + ] - if module.is_built_in: - return + return '\n'.join(out) - print(f"processing module '{module_name}'") - path = f"c_api/include/{module_name}" - with open(path, "w") as f: - f.write(module.declr()) - #system(f"clang-format {path} -i") +def generate_module_header(module): + if module.is_built_in: + return + print(f"processing module '{module.name}'") + path = f"c_api/include/{module.name}" + with open(path, "w") as f: + f.write(print_module_header(module)) -if __name__ == "__main__": - j = None - with open("c_api/taichi.json") as f: - j = json.load(f) + #system(f"clang-format {path} -i") - version = j["version"] - print("taichi c-api version is:", version) - for module in j["modules"]: - Module.generate_header(module) +if __name__ == "__main__": + builtin_tys = { + BuiltInType("uint64_t", "uint64_t"), + BuiltInType("int64_t", "int64_t"), + BuiltInType("uint32_t", "uint32_t"), + BuiltInType("int32_t", "int32_t"), + BuiltInType("float", "float"), + BuiltInType("const char*", "const char*"), + BuiltInType("const char**", "const char**"), + BuiltInType("void*", "void*"), + BuiltInType("const void*", "const void*"), + BuiltInType("VkInstance", "VkInstance"), + BuiltInType("VkPhysicalDevice", "VkPhysicalDevice"), + BuiltInType("VkDevice", "VkDevice"), + BuiltInType("VkQueue", "VkQueue"), + BuiltInType("VkBuffer", "VkBuffer"), + BuiltInType("VkBufferUsageFlags", "VkBufferUsageFlags"), + } + + for module in Module.load_all(builtin_tys): + generate_module_header(module) diff --git a/misc/taichi_json.py b/misc/taichi_json.py new file mode 100644 index 0000000000000..ee1798a224ecf --- /dev/null +++ b/misc/taichi_json.py @@ -0,0 +1,270 @@ +import json +import re + + +class Name: + def __init__(self, name: str, prefix=[], suffix=[]): + assert re.match('^[@a-z0-9_]+$', name) + self._segs = name.split("_") + self._prefix = prefix + self._suffix = suffix + + def extend(self, subname): + if isinstance(subname, str): + subname = Name(subname) + assert isinstance(subname, Name) + assert len(subname._prefix) == 0 and len(subname._suffix) == 0 + return Name('_'.join(self._segs + subname._segs), self._prefix, self._suffix) + + @property + def segs(self): + return self._prefix + self._segs + self._suffix + + @property + def snake_case(self) -> str: + return '_'.join(self.segs) + + @property + def screaming_snake_case(self) -> str: + return '_'.join(x.upper() for x in self.segs) + + @property + def upper_camel_case(self) -> str: + return ''.join(x.title() for x in self.segs) + + def __repr__(self) -> str: + return '_'.join(self._segs) + + +class DeclarationRegistry: + current = None + + def __init__(self, builtin_tys={}): + # "xxx.yyy" -> Xxx(yyy) Look-up table. + self._inner = {} + self._imported = {} + self._builtin_tys = dict((x.id, x) for x in builtin_tys) + + def resolve(self, id: str): + if id in self._builtin_tys: + return self._builtin_tys[id] + elif id in self._inner: + return self._inner[id] + elif id in self._imported: + return self._imported[id] + else: + return None + + def register(self, x): + self._inner[x.id] = x + + def import_declrs(self, other): + for x in other._inner.values(): + self._imported[x.id] = x + + def __iter__(self): + return iter(self._inner) + + @staticmethod + def set_current(declr_reg): + DeclarationRegistry.current = declr_reg + + +def load_inc_enums(name, inc_file_name): + path = "taichi/inc/" + inc_file_name + ".inc.h" + cases = {} + with open(path) as f: + for line in f.readlines(): + m = re.match(r"\w+\((\w+)\).*", line) + if m: + case_name = name.extend(m[1]) + cases[case_name] = len(cases) + return cases + + +class EntryBase: + def __init__(self, j, clazz: str): + assert "name" in j + self.vendor = None + self.is_extension = False + + prefix = [] + suffix = [] + if "vendor" in j: + vendor = j["vendor"] + prefix += ["tix"] + suffix += [vendor] + self.vendor = vendor + self.is_extension = True + elif "is_extension" in j: + prefix += ["ti"] + suffix += ["ext"] + self.is_extension = True + else: + prefix += ["ti"] + + if "version" in j: + version = int(j["version"]) + if version > 1: + suffix += [str(version)] + self.version = version + + self.name = Name(j["name"], prefix, suffix) + self.id = f"{clazz}.{self.name}" + + +class BuiltInType(EntryBase): + def __init__(self, id, type_name): + self.name = "value" + self.id = id + self.type_name = type_name + + +class Alias(EntryBase): + def __init__(self, j): + super().__init__(j, "alias") + self.alias_of = DeclarationRegistry.current.resolve(j["alias_of"]) + + +class Definition(EntryBase): + def __init__(self, j): + super().__init__(j, "definition") + self.value = j["value"] + + +class Handle(EntryBase): + def __init__(self, j): + super().__init__(j, "handle") + self.is_dispatchable = j["is_dispatchable"] + + +class Enumeration(EntryBase): + def __init__(self, j): + super().__init__(j, "enumeration") + if "inc_cases" in j: + self.cases = load_inc_enums(self.name, j["inc_cases"]) + else: + self.cases = dict((self.name.extend(name), value) + for name, value in j["cases"].items()) + + +class BitField(EntryBase): + def __init__(self, j): + super().__init__(j, "bit_field") + if "inc_cases" in j: + self.bits = load_inc_enums(self.name, j["inc_bits"]) + else: + self.bits = dict((self.name.extend(name), value) + for name, value in j["bits"].items()) + + +class Field: + def __init__(self, j): + ty = DeclarationRegistry.current.resolve(j["type"]) + assert ty != None, f"unknown type '{j['type']}'" + # The type has been registered. + self.type = ty + self.name = Name(j["name"]) if "name" in j else ty.name + self.count = j["count"] if "count" in j else None + self.by_mut = bool(j["by_mut"]) if "by_mut" in j else False + self.by_ref = bool(j["by_ref"]) if "by_ref" in j else False + + +class Structure(EntryBase): + def __init__(self, j): + super().__init__(j, "structure") + self.fields = [] + if "fields" in j: + for x in j["fields"]: + self.fields += [Field(x)] + + +class Union(EntryBase): + def __init__(self, j): + super().__init__(j, "union") + self.variants = [] + if "variants" in j: + for x in j["variants"]: + self.variants += [Field(x)] + + +class Function(EntryBase): + def __init__(self, j): + super().__init__(j, "function") + self.return_value_type = None + self.params = [] + + if "parameters" in j: + for x in j["parameters"]: + field = Field(x) + if field.name.snake_case == "@return": + self.return_value_type = field.type + else: + self.params += [field] + + +class Module: + all_modules = {} + + def __init__(self, j, builtin_tys): + self.name = j["name"] + self.is_built_in = False + self.declr_reg = DeclarationRegistry(builtin_tys) + self.required_modules = [] + + DeclarationRegistry.set_current(self.declr_reg) + + if "is_built_in" in j: + self.is_built_in = True + # Built-in headers are hand-written so we can return right away. + return + + if "required_modules" in j: + for x in j["required_modules"]: + assert x in Module.all_modules + module = Module.all_modules[x] + self.declr_reg.import_declrs(module.declr_reg) + self.required_modules += [x] + + if "declarations" in j: + for k in j["declarations"]: + try: + ty = k["type"] + + if ty == "alias": + self.declr_reg.register(Alias(k)) + elif ty == "definition": + self.declr_reg.register(Definition(k)) + elif ty == "handle": + self.declr_reg.register(Handle(k)) + elif ty == "enumeration": + self.declr_reg.register(Enumeration(k)) + elif ty == "bit_field": + self.declr_reg.register(BitField(k)) + elif ty == "structure": + self.declr_reg.register(Structure(k)) + elif ty == "union": + self.declr_reg.register(Union(k)) + elif ty == "function": + self.declr_reg.register(Function(k)) + else: + print(f"ignored unrecognized type declaration '{k}'") + except: + print("failed to generate declaration for:", k) + + DeclarationRegistry.set_current(None) + + @staticmethod + def load_all(builtin_tys): + j = None + with open("c_api/taichi.json") as f: + j = json.load(f) + + version = j["version"] + print("taichi c-api version is:", version) + + for k in j["modules"]: + module = Module(k, builtin_tys) + Module.all_modules[module.name] = module + + return list(Module.all_modules.values()) From 543c3e1a8609b7ab53403a1780458af73860cc3b Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Tue, 28 Jun 2022 02:20:45 +0800 Subject: [PATCH 2/3] Adapted unity C# binding generator --- misc/generate_unity_language_binding.py | 586 ++++++------------------ 1 file changed, 149 insertions(+), 437 deletions(-) diff --git a/misc/generate_unity_language_binding.py b/misc/generate_unity_language_binding.py index 39beb7624e15b..670acda21addb 100644 --- a/misc/generate_unity_language_binding.py +++ b/misc/generate_unity_language_binding.py @@ -1,534 +1,246 @@ -import json +from taichi_json import EntryBase, BuiltInType, Alias, Handle, Definition, \ + Handle, Enumeration, BitField, Field, Structure, Union, Function, Module import re -#from os import system -TYPE_MAP = { - "void": "void", - "int32_t": "int", - "uint32_t": "uint", - "int64_t": "long", - "uint64_t": "ulong", - "float": "float", - "const char*": "string", - "void*": "IntPtr", -} +def get_type_name(x: EntryBase): + ty = type(x) + if ty in [BuiltInType]: + return x.type_name + elif ty in [Alias, Handle, Enumeration, Structure, Union]: + return x.name.upper_camel_case + elif ty in [BitField]: + return x.name.extend('flag_bits').upper_camel_case + else: + raise RuntimeError(f"'{x.id}' is not a type") -class InternalAlias: - def __init__(self, name: str): - self.name = name +def get_c_function_param(x: Field): + # `count` is an integer so it's a static array. + is_dyn_array = x.count and not isinstance(x.count, int) - @property - def type_name(self): - self.name + if x.by_ref or x.by_mut: + return f"IntPtr {x.name}" + elif is_dyn_array: + return f"[MarshalAs(UnmanagedType.LPArray)] {get_type_name(x.type)}[] {x.name}" + elif x.count: + return f"{get_type_name(x.type)}[{x.count}] {x.name}" + else: + return f"{get_type_name(x.type)} {x.name}" -class Name: - def __init__(self, name: str): - assert re.match('^[@a-z0-9_]+$', name) - self._segs = name.split("_") +def get_function_param(x: Field): + # `count` is an integer so it's a static array. + is_dyn_array = x.count and not isinstance(x.count, int) - @property - def snake_case(self) -> str: - return '_'.join(self._segs) + if is_dyn_array: + return f"{get_type_name(x.type)}[] {x.name}" + elif x.count: + return f"{get_type_name(x.type)}[{x.count}] {x.name}" + else: + return f"{get_type_name(x.type)} {x.name}" - @property - def screaming_snake_case(self) -> str: - return '_'.join(x.upper() for x in self._segs) - @property - def upper_camel_case(self) -> str: - return ''.join(x.title() for x in self._segs) +def get_struct_field(x: Field): + # `count` is an integer so it's a static array. + is_dyn_array = x.count and not isinstance(x.count, int) - def __repr__(self) -> str: - return self.snake_case + is_ptr = x.by_ref or x.by_mut or is_dyn_array + out = "" + if is_ptr: + out += f"IntPtr {x.name}" + elif x.count: + out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] " + out += f"public {get_type_name(x.type)}[] {x.name}" + else: + out += f"public {get_type_name(x.type)} {x.name}" + return out -class DeclarationRegistry: - current = None - def __init__(self): - # "xxx.yyy" -> Xxx(yyy) Look-up table. - self._inner = {} - self._imported = {} - - def resolve(self, id: str): - if id in self._inner: - return self._inner[id] - elif id in self._imported: - return self._imported[id] - else: - return None - - def register(self, x): - self._inner[x.id] = x - - def import_declrs(self, other): - for x in other._inner.values(): - self._imported[x.id] = x - - def __iter__(self): - return iter(self._inner) - - @staticmethod - def set_current(declr_reg): - DeclarationRegistry.current = declr_reg +def get_union_variant(x: Field): + # `count` is an integer so it's a static array. + is_dyn_array = x.count and not isinstance(x.count, int) + is_ptr = x.by_ref or x.by_mut or is_dyn_array -class Alias: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"alias.{self.name}" - self.alias_of = j["alias_of"] - TYPE_MAP[self.type_name] = TYPE_MAP[self.alias_of] + out = "[FieldOffset(0)] " + if is_ptr: + out += f"IntPtr {x.name}" + elif x.count: + out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={x.count})] " + out += f"public {get_type_name(x.type)}[] {x.name}" + else: + out += f"public {get_type_name(x.type)} {x.name}" + return out - @property - def type_name(self) -> str: - return TYPE_MAP[self.alias_of] - def declr(self): - return f"// using Ti{self.name.upper_camel_case} = {TYPE_MAP[self.alias_of]};" +def get_declr(x: EntryBase): + ty = type(x) + if ty is BuiltInType: + return "" + elif ty is Alias: + return f"// using {get_type_name(x)} = {get_type_name(x.alias_of)};" -class Definition: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"definition.{self.name}" - self.value = j["value"] - - def declr(self): + elif ty is Definition: out = [ "static partial class Def {", - f"public const uint {self.name.screaming_snake_case} = {self.value};", + f"public const uint {x.name.screaming_snake_case} = {x.value};", "}" ] return '\n'.join(out) - -class Handle: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"handle.{self.name}" - self.is_dispatchable = j["is_dispatchable"] - - @property - def type_name(self) -> str: - return "Ti" + self.name.upper_camel_case - - def declr(self): + elif ty is Handle: out = [ "[StructLayout(LayoutKind.Sequential)]", - "public struct " + self.type_name + " {", + "public struct " + get_type_name(x) + " {", " public IntPtr Inner;", "}", ] return '\n'.join(out) - -class Enumeration: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"enumeration.{self.name}" - - cases = {} - if "inc_cases" in j: - inc_file_name = "taichi/inc/" + j["inc_cases"] + ".inc.h" - with open(inc_file_name) as f: - for line in f.readlines(): - m = re.match(r"\w+\((\w+)\).*", line) - if m: - case_name = self.get_case_name(Name(m[1])) - cases[case_name] = len(cases) - else: - for name, value in j["cases"].items(): - cases[self.get_case_name(Name(name))] = value - self.cases = cases - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def get_case_name(self, case_name: Name): - return "TI_" + self.name.screaming_snake_case + "_" + case_name.screaming_snake_case - - def declr(self): - out = ["public enum " + self.type_name + " {"] - for name, value in self.cases.items(): - out += [f" {name} = {value},"] - out += [f" {self.get_case_name(Name('max_enum'))} = 0x7fffffff,"] + elif ty is Enumeration: + out = ["public enum " + get_type_name(x) + " {"] + for name, value in x.cases.items(): + out += [f" {name.screaming_snake_case} = {value},"] + out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0x7fffffff,"] out += ["}"] return '\n'.join(out) - -class BitField: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"bit_field.{self.name}" - - bits = {} - if "inc_bits" in j: - inc_file_name = "taichi/inc/" + j["inc_bits"] + ".inc.h" - with open(inc_file_name) as f: - for line in enumerate(f.readlines()): - m = re.match(r"\w+\((\w+)\).*", line) - if m: - bit_name = self.get_flag_name(Name(m[1])) - bits[bit_name] = len(bits) - else: - for name, value in j["bits"].items(): - bits[self.get_flag_name(Name(name))] = value - TYPE_MAP[self.type_name] = "uint" - self.bits = bits - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case + "FlagBits" - - @property - def field_type_name(self): - return self.name.upper_camel_case + "Flags" - - def get_flag_name(self, flag_name: Name): - return "TI_" + self.name.screaming_snake_case + "_" + flag_name.screaming_snake_case + "_BIT" - - def declr(self): - out = ["[Flags]", "public enum " + self.type_name + " {"] - for name, value in self.bits.items(): - out += [f" {name} = 1 << {value},"] + elif ty is BitField: + out = ["[Flags]", "public enum " + get_type_name(x) + " {"] + for name, value in x.bits.items(): + out += [f" {name.extend('bit').screaming_snake_case} = 1 << {value},"] out += ["};"] return '\n'.join(out) - -class CType: - def __init__(self, name: str): - self.name = TYPE_MAP[name] - - @property - def type_name(self): - return self.name - - -class Field: - def __init__(self, j): - ty = DeclarationRegistry.current.resolve(j["type"]) - if ty != None: - # The type has been registered. - self.type = ty - if "name" in j: - self.name = Name(j["name"]) - else: - self.name = ty.name - else: - # The type is not (yet) registered, treat it as a untracked C type. - self.type = CType(j["type"]) - self.name = Name(j["name"]) - self.count = j["count"] if "count" in j else None - self.by_mut = j["by_mut"] if "by_mut" in j else None - self.by_ref = j["by_ref"] if "by_ref" in j else None - - def declr_c_function_param(self): - # `count` is an integer so it's a static array. - is_dyn_array = self.count and not isinstance(self.count, int) - - if self.by_ref or self.by_mut: - return f"IntPtr {self.name}" - elif is_dyn_array: - return f"[MarshalAs(UnmanagedType.LPArray)] {self.type.type_name}[] {self.name}" - elif self.count: - return f"{self.type.type_name}[{self.count}] {self.name}" - else: - return f"{self.type.type_name} {self.name}" - - def declr_function_param(self): - # `count` is an integer so it's a static array. - is_dyn_array = self.count and not isinstance(self.count, int) - - is_ptr = self.by_ref or self.by_mut or is_dyn_array - - if is_dyn_array: - return f"{self.type.type_name}[] {self.name}" - elif self.count: - return f"{self.type.type_name}[{self.count}] {self.name}" - else: - return f"{self.type.type_name} {self.name}" - - def declr_struct_field(self): - # `count` is an integer so it's a static array. - is_dyn_array = self.count and not isinstance(self.count, int) - - is_ptr = self.by_ref or self.by_mut or is_dyn_array - - out = "" - if is_ptr: - out += f"IntPtr {self.name}" - elif self.count: - out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={self.count})] " - out += f"public {self.type.type_name}[] {self.name}" - else: - out += f"public {self.type.type_name} {self.name}" - return out - - def declr_union_variant(self): - # `count` is an integer so it's a static array. - is_dyn_array = self.count and not isinstance(self.count, int) - - is_ptr = self.by_ref or self.by_mut or is_dyn_array - - out = "[FieldOffset(0)] " - if is_ptr: - out += f"IntPtr {self.name}" - elif self.count: - out += f"[MarshalAs(UnmanagedType.ByValArray, SizeConst={self.count})] " - out += f"public {self.type.type_name}[] {self.name}" - else: - out += f"public {self.type.type_name} {self.name}" - return out - - -class Structure: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"structure.{self.name}" - self.fields = [] - if "fields" in j: - for x in j["fields"]: - self.fields += [Field(x)] - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def declr(self): + elif ty is Structure: out = [ "[StructLayout(LayoutKind.Sequential)]", - "public struct " + self.type_name + " {", + "public struct " + get_type_name(x) + " {", ] - for x in self.fields: - out += [f" {x.declr_struct_field()};"] + for field in x.fields: + out += [f" {get_struct_field(field)};"] out += ["}"] return '\n'.join(out) - -class Union: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"union.{self.name}" - self.variants = [] - if "variants" in j: - for x in j["variants"]: - self.variants += [Field(x)] - - @property - def type_name(self): - return "Ti" + self.name.upper_camel_case - - def declr(self): + elif ty is Union: out = [ "[StructLayout(LayoutKind.Explicit)]", - "public struct " + self.type_name + " {", + "public struct " + get_type_name(x) + " {", ] - for x in self.variants: - out += [f" {x.declr_union_variant()};"] + for variant in x.variants: + out += [f" {get_union_variant(variant)};"] out += ["}"] return '\n'.join(out) - -class Function: - def __init__(self, j): - self.name = Name(j["name"]) - self.id = f"function.{self.name}" - self.version = 1 - self.is_extension = False - self.return_value_type = None - self.params = [] - - if "version" in j: - self.version = j["version"] - - if "is_extension" in j: - self.is_extension = j["is_extension"] - - if "parameters" in j: - for x in j["parameters"]: - field = Field(x) - if field.name.snake_case == "@return": - self.return_value_type = field.type - else: - self.params += [field] - - @property - def c_func_name(self): - name = "ti_" + self.name.snake_case - if self.is_extension: - name += "_ext" - if self.version > 1: - name += f"_{self.version}" - return name - - @property - def func_name(self): - name = self.name.upper_camel_case - if self.is_extension: - name += "_ext" - if self.version > 1: - name += f"_{self.version}" - return name - - def declr(self): - return_value_type = "void" if self.return_value_type == None else self.return_value_type.type_name + elif ty is Function: + return_value_type = "void" if x.return_value_type == None else get_type_name( + x.return_value_type) out = [ "static partial class Ffi {", "#if (UNITY_IOS || UNITY_TVOS || UNITY_WEBGL) && !UNITY_EDITOR", ' [DllImport ("__Internal")]', "#else", - ' [DllImport("taichi_c_api")]', + ' [DllImport("taichi_unity")]' if x.vendor == "unity" else ' [DllImport("taichi_c_api")]', "#endif", "private static extern " + return_value_type + " " + - self.c_func_name + "(", - ',\n'.join(f" {param.declr_c_function_param()}" - for param in self.params), + x.name.snake_case + "(", + ',\n'.join(f" {get_c_function_param(param)}" + for param in x.params), ");", - "public static " + return_value_type + " " + self.func_name + "(", - ',\n'.join(f" {param.declr_function_param()}" - for param in self.params), + "public static " + return_value_type + + " " + x.name.upper_camel_case + "(", + ',\n'.join(f" {get_function_param(param)}" + for param in x.params), ") {", ] - for param in self.params: + for param in x.params: if (isinstance(param.type, Structure) or isinstance(param.type, Union)) and not param.count: out += [ - f" IntPtr hglobal_{param.name} = Marshal.AllocHGlobal(Marshal.SizeOf(typeof({param.type.type_name})));", + f" IntPtr hglobal_{param.name} = Marshal.AllocHGlobal(Marshal.SizeOf(typeof({get_type_name(param.type)})));", f" Marshal.StructureToPtr({param.name}, hglobal_{param.name}, false);" ] - if self.return_value_type: - out += [f" var rv = {self.c_func_name}("] + if x.return_value_type: + out += [f" var rv = {x.name.snake_case}("] else: - out += [f" {self.c_func_name}("] - for i, param in enumerate(self.params): + out += [f" {x.name.snake_case}("] + for i, param in enumerate(x.params): if (isinstance(param.type, Structure) or isinstance(param.type, Union)) and not param.count: out += [ - f" hglobal_{param.name}{','if i + 1 != len(self.params) else ''}" + f" hglobal_{param.name}{','if i + 1 != len(x.params) else ''}" ] else: out += [ - f" {param.name}{','if i + 1 != len(self.params) else ''}" + f" {param.name}{','if i + 1 != len(x.params) else ''}" ] out += [" );"] - for param in self.params: + for param in x.params: if (isinstance(param.type, Structure) or isinstance(param.type, Union)) and not param.count: out += [f" Marshal.FreeHGlobal(hglobal_{param.name});"] - if self.return_value_type: + if x.return_value_type: out += [f" return rv;"] out += ["}", "}"] return '\n'.join(out) + else: + raise RuntimeError(f"'{x.id}' doesn't need declaration") -class Module: - current = None - all_modules = {} - - def __init__(self, j): - Module.current = self - self.name = Name(j["name"][len("taichi/"):-len(".h")]) - self.is_built_in = False - self.declr_reg = DeclarationRegistry() - self.required_modules = [] - - DeclarationRegistry.set_current(self.declr_reg) - - if "is_built_in" in j: - self.is_built_in = True - # Built-in headers are hand-written so we can return right away. - return - - if "required_modules" in j: - for x in j["required_modules"]: - assert x in Module.all_modules - module = Module.all_modules[x] - self.declr_reg.import_declrs(module.declr_reg) - self.required_modules += [x] - - if "declarations" in j: - for k in j["declarations"]: - ty = k["type"] - - try: - if ty == "alias": - self.declr_reg.register(Alias(k)) - elif ty == "definition": - self.declr_reg.register(Definition(k)) - elif ty == "handle": - self.declr_reg.register(Handle(k)) - elif ty == "enumeration": - self.declr_reg.register(Enumeration(k)) - elif ty == "bit_field": - self.declr_reg.register(BitField(k)) - elif ty == "structure": - self.declr_reg.register(Structure(k)) - elif ty == "union": - self.declr_reg.register(Union(k)) - elif ty == "function": - self.declr_reg.register(Function(k)) - else: - print(f"ignored unrecognized type declaration '{k}'") - except KeyError as k: - print(f"Ignored declaration '{x}' with hidden type: {k}") - - DeclarationRegistry.set_current(None) - - def declr(self): - out = [ - "using System;", - "using System.Runtime.InteropServices;", - "using System.Collections.Generic;", - "", - "namespace Taichi {", - ] - for x in self.declr_reg: - out += [ - "", - f"// {x}", - self.declr_reg.resolve(x).declr(), - ] +def print_module_header(module): + out = [ + "using System;", + "using System.Runtime.InteropServices;", + "using System.Collections.Generic;", + "", + "namespace Taichi {", + ] + for x in module.declr_reg: out += [ "", - "} // namespace Taichi", - "", + f"// {x}", + get_declr(module.declr_reg.resolve(x)), ] - return '\n'.join(out) + out += [ + "", + "} // namespace Taichi", + "", + ] - @staticmethod - def generate_header(j): - module_name = j["name"] - module = Module(j) - Module.all_modules[module_name] = module + return '\n'.join(out) - if module.is_built_in: - return - print(f"processing module '{module_name}'") - assert re.match("taichi/\w+.h", module_name) - module_name = module_name[len("taichi/"):-len(".h")] - path = f"c_api/unity/{module_name}.cs" - with open(path, "w") as f: - f.write(module.declr()) +def generate_module_header(module): + if module.is_built_in: + return - #system(f"clang-format {path} -i") + print(f"processing module '{module.name}'") + assert re.match("taichi/\w+.h", module.name) + module_name = module.name[len("taichi/"):-len(".h")] + path = f"c_api/unity/{module_name}.cs" + with open(path, "w") as f: + f.write(print_module_header(module)) if __name__ == "__main__": - j = None - with open("c_api/taichi.json") as f: - j = json.load(f) - - version = j["version"] - print("taichi c-api version is:", version) - - for module in j["modules"]: - Module.generate_header(module) + builtin_tys = { + BuiltInType("void", "void"), + BuiltInType("int32_t", "int"), + BuiltInType("uint32_t", "uint"), + BuiltInType("int64_t", "long"), + BuiltInType("uint64_t", "ulong"), + BuiltInType("float", "float"), + BuiltInType("const char*", "string"), + BuiltInType("void*", "IntPtr"), + BuiltInType("const void*", "IntPtr"), + BuiltInType("alias.bool", "uint"), + } + + for module in Module.load_all(builtin_tys): + generate_module_header(module) From 8ef0b83594d7fb7fda0ade4c9194eb21fedc96f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jun 2022 18:41:57 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- misc/generate_c_api.py | 16 ++++++++++------ misc/generate_unity_language_binding.py | 24 +++++++++++++++--------- misc/taichi_json.py | 3 ++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/misc/generate_c_api.py b/misc/generate_c_api.py index 51fabf6e56d82..f57dbf1e6deea 100644 --- a/misc/generate_c_api.py +++ b/misc/generate_c_api.py @@ -1,5 +1,6 @@ -from taichi_json import EntryBase, BuiltInType, Alias, Handle, Definition, \ - Handle, Enumeration, BitField, Field, Structure, Union, Function, Module +from taichi_json import (Alias, BitField, BuiltInType, Definition, EntryBase, + Enumeration, Field, Function, Handle, Module, + Structure, Union) #from os import system @@ -50,17 +51,20 @@ def get_declr(x: EntryBase): out = ["typedef enum " + get_type_name(x) + " {"] for name, value in x.cases.items(): out += [f" {name.screaming_snake_case} = {value},"] - out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0xffffffff,"] + out += [ + f" {x.name.extend('max_enum').screaming_snake_case} = 0xffffffff," + ] out += ["} " + get_type_name(x) + ";"] return '\n'.join(out) elif ty is BitField: out = ["typedef enum " + get_type_name(x) + " {"] for name, value in x.bits.items(): - out += [f" {name.extend('bit').screaming_snake_case} = 1 << {value},"] + out += [ + f" {name.extend('bit').screaming_snake_case} = 1 << {value}," + ] out += ["} " + get_type_name(x) + ";"] - out += [ - f"typedef TiFlags {x.name.extend('flags').upper_camel_case};"] + out += [f"typedef TiFlags {x.name.extend('flags').upper_camel_case};"] return '\n'.join(out) elif ty is Structure: diff --git a/misc/generate_unity_language_binding.py b/misc/generate_unity_language_binding.py index 670acda21addb..b48a0d2c9c428 100644 --- a/misc/generate_unity_language_binding.py +++ b/misc/generate_unity_language_binding.py @@ -1,7 +1,9 @@ -from taichi_json import EntryBase, BuiltInType, Alias, Handle, Definition, \ - Handle, Enumeration, BitField, Field, Structure, Union, Function, Module import re +from taichi_json import (Alias, BitField, BuiltInType, Definition, EntryBase, + Enumeration, Field, Function, Handle, Module, + Structure, Union) + def get_type_name(x: EntryBase): ty = type(x) @@ -104,14 +106,18 @@ def get_declr(x: EntryBase): out = ["public enum " + get_type_name(x) + " {"] for name, value in x.cases.items(): out += [f" {name.screaming_snake_case} = {value},"] - out += [f" {x.name.extend('max_enum').screaming_snake_case} = 0x7fffffff,"] + out += [ + f" {x.name.extend('max_enum').screaming_snake_case} = 0x7fffffff," + ] out += ["}"] return '\n'.join(out) elif ty is BitField: out = ["[Flags]", "public enum " + get_type_name(x) + " {"] for name, value in x.bits.items(): - out += [f" {name.extend('bit').screaming_snake_case} = 1 << {value},"] + out += [ + f" {name.extend('bit').screaming_snake_case} = 1 << {value}," + ] out += ["};"] return '\n'.join(out) @@ -143,17 +149,17 @@ def get_declr(x: EntryBase): "#if (UNITY_IOS || UNITY_TVOS || UNITY_WEBGL) && !UNITY_EDITOR", ' [DllImport ("__Internal")]', "#else", - ' [DllImport("taichi_unity")]' if x.vendor == "unity" else ' [DllImport("taichi_c_api")]', + ' [DllImport("taichi_unity")]' + if x.vendor == "unity" else ' [DllImport("taichi_c_api")]', "#endif", "private static extern " + return_value_type + " " + x.name.snake_case + "(", ',\n'.join(f" {get_c_function_param(param)}" for param in x.params), ");", - "public static " + return_value_type + - " " + x.name.upper_camel_case + "(", - ',\n'.join(f" {get_function_param(param)}" - for param in x.params), + "public static " + return_value_type + " " + + x.name.upper_camel_case + "(", + ',\n'.join(f" {get_function_param(param)}" for param in x.params), ") {", ] for param in x.params: diff --git a/misc/taichi_json.py b/misc/taichi_json.py index ee1798a224ecf..c1c034c63a1db 100644 --- a/misc/taichi_json.py +++ b/misc/taichi_json.py @@ -14,7 +14,8 @@ def extend(self, subname): subname = Name(subname) assert isinstance(subname, Name) assert len(subname._prefix) == 0 and len(subname._suffix) == 0 - return Name('_'.join(self._segs + subname._segs), self._prefix, self._suffix) + return Name('_'.join(self._segs + subname._segs), self._prefix, + self._suffix) @property def segs(self):