From d517dfb5fd5a683b64175bfadaa566329ae4cf2c Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 23 Mar 2021 17:39:32 +0800 Subject: [PATCH] [refactor] Cleanup python imports --- python/taichi/core/__init__.py | 1 + python/taichi/core/primitive_types.py | 60 ++++++++ python/taichi/lang/__init__.py | 6 +- python/taichi/lang/expr.py | 5 +- python/taichi/lang/impl.py | 126 ++++++++-------- .../taichi/lang/{kernel.py => kernel_impl.py} | 16 +- python/taichi/lang/matrix.py | 48 +++--- python/taichi/lang/ops.py | 20 +-- python/taichi/lang/snode.py | 3 +- python/taichi/lang/transformer.py | 13 +- python/taichi/lang/util.py | 138 +++++++----------- tests/python/test_kernel_template_mapper.py | 9 +- 12 files changed, 237 insertions(+), 208 deletions(-) create mode 100644 python/taichi/core/primitive_types.py rename python/taichi/lang/{kernel.py => kernel_impl.py} (97%) diff --git a/python/taichi/core/__init__.py b/python/taichi/core/__init__.py index f2bd9164d180e..0f8b19ce421cd 100644 --- a/python/taichi/core/__init__.py +++ b/python/taichi/core/__init__.py @@ -2,6 +2,7 @@ from taichi.core.settings import * from taichi.core.record import * from taichi.core.logging import * +from taichi.core.primitive_types import * ti_core.build = build ti_core.load_module = load_module diff --git a/python/taichi/core/primitive_types.py b/python/taichi/core/primitive_types.py new file mode 100644 index 0000000000000..57923abced992 --- /dev/null +++ b/python/taichi/core/primitive_types.py @@ -0,0 +1,60 @@ +from taichi.core.util import ti_core + +# Real types + +float32 = ti_core.DataType_f32 +f32 = float32 +float64 = ti_core.DataType_f64 +f64 = float64 + +real_types = [f32, f64, float] +real_type_ids = [id(t) for t in real_types] + +# Integer types + +int8 = ti_core.DataType_i8 +i8 = int8 +int16 = ti_core.DataType_i16 +i16 = int16 +int32 = ti_core.DataType_i32 +i32 = int32 +int64 = ti_core.DataType_i64 +i64 = int64 + +uint8 = ti_core.DataType_u8 +u8 = uint8 +uint16 = ti_core.DataType_u16 +u16 = uint16 +uint32 = ti_core.DataType_u32 +u32 = uint32 +uint64 = ti_core.DataType_u64 +u64 = uint64 + +integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] +integer_type_ids = [id(t) for t in integer_types] + +types = real_types + integer_types +type_ids = [id(t) for t in types] + +__all__ = [ + 'float32', + 'f32', + 'float64', + 'f64', + 'int8', + 'i8', + 'int16', + 'i16', + 'int32', + 'i32', + 'int64', + 'i64', + 'uint8', + 'u8', + 'uint16', + 'u16', + 'uint32', + 'u32', + 'uint64', + 'u64', +] diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 424866d1f338c..07510569d00a8 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -3,13 +3,17 @@ from copy import deepcopy as _deepcopy from taichi.lang.impl import * +from taichi.lang.kernel_arguments import ext_arr, template +from taichi.lang.kernel_impl import (KernelArgError, KernelDefError, + data_oriented, func, kernel, pyfunc) from taichi.lang.matrix import Matrix, Vector from taichi.lang.ndrange import GroupedNDRange, ndrange +from taichi.lang.ops import * from taichi.lang.quant_impl import quant from taichi.lang.runtime_ops import async_flush, sync from taichi.lang.transformer import TaichiSyntaxError from taichi.lang.type_factory_impl import type_factory -from taichi.lang.util import deprecated +from taichi.lang.util import * core = taichi_lang_core diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 1c6221a103fac..5718cc9c37cf8 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -1,8 +1,9 @@ from taichi.lang import impl from taichi.lang.common_ops import TaichiOperations from taichi.lang.core import taichi_lang_core -from taichi.lang.util import (deprecated, is_taichi_class, python_scope, - to_numpy_type, to_pytorch_type) +from taichi.lang.util import (is_taichi_class, python_scope, to_numpy_type, + to_pytorch_type) +from taichi.misc.util import deprecated import taichi as ti diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index bcd7b61320d6e..23ab07ec7088a 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -1,17 +1,22 @@ import numbers -from .core import taichi_lang_core -from .expr import Expr -from .snode import SNode -from .util import * -from .exception import TaichiSyntaxError +import numpy as np +from taichi.core import util as cutil +from taichi.lang import ops as ops_mod +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.expr import Expr, make_expr_group +from taichi.lang.snode import SNode +from taichi.lang.util import (cook_dtype, is_taichi_class, python_scope, + taichi_scope) +from taichi.misc.util import deprecated, get_traceback + +import taichi as ti @taichi_scope def expr_init(rhs): - import taichi as ti if rhs is None: - return Expr(taichi_lang_core.expr_alloca()) + return Expr(cutil.ti_core.expr_alloca()) if is_taichi_class(rhs): return rhs.variable() else: @@ -21,19 +26,18 @@ def expr_init(rhs): return tuple(expr_init(e) for e in rhs) elif isinstance(rhs, dict): return dict((key, expr_init(val)) for key, val in rhs.items()) - elif isinstance(rhs, taichi_lang_core.DataType): + elif isinstance(rhs, cutil.ti_core.DataType): return rhs elif isinstance(rhs, ti.ndrange): return rhs elif hasattr(rhs, '_data_oriented'): return rhs else: - return Expr(taichi_lang_core.expr_var(Expr(rhs).ptr)) + return Expr(cutil.ti_core.expr_var(Expr(rhs).ptr)) @taichi_scope def expr_init_list(xs, expected): - import taichi as ti if not isinstance(xs, (list, tuple, ti.Matrix)): raise TypeError(f'Cannot unpack type: {type(xs)}') if isinstance(xs, ti.Matrix): @@ -55,7 +59,6 @@ def expr_init_list(xs, expected): @taichi_scope def expr_init_func( rhs): # temporary solution to allow passing in fields as arguments - import taichi as ti if isinstance(rhs, Expr) and rhs.ptr.is_global_var(): return rhs if isinstance(rhs, ti.Matrix) and rhs.is_global(): @@ -72,7 +75,7 @@ def begin_frontend_struct_for(group, loop_range): f'({group.size()} != {len(loop_range.shape)}). Maybe you wanted to ' 'use "for I in ti.grouped(x)" to group all indices into a single vector I?' ) - taichi_lang_core.begin_frontend_struct_for(group, loop_range.ptr) + cutil.ti_core.begin_frontend_struct_for(group, loop_range.ptr) def begin_frontend_if(cond): @@ -83,7 +86,7 @@ def begin_frontend_if(cond): ' if all(x == y):\n' 'or\n' ' if any(x != y):\n') - taichi_lang_core.begin_frontend_if(Expr(cond).ptr) + cutil.ti_core.begin_frontend_if(Expr(cond).ptr) def wrap_scalar(x): @@ -95,7 +98,6 @@ def wrap_scalar(x): @taichi_scope def subscript(value, *indices): - import numpy as np _taichi_skip_traceback = 1 if isinstance(value, np.ndarray): return value.__getitem__(*indices) @@ -134,7 +136,7 @@ def subscript(value, *indices): raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) - return Expr(taichi_lang_core.subscript(value.ptr, indices_expr_group)) + return Expr(cutil.ti_core.subscript(value.ptr, indices_expr_group)) else: return value[indices] @@ -162,7 +164,7 @@ def chain_compare(comparators, ops): now = lhs != rhs else: assert False, f'Unknown operator {ops[i]}' - ret = logical_and(ret, now) + ret = ops_mod.logical_and(ret, now) return ret @@ -198,8 +200,8 @@ def __init__(self, kernels=None): self.inside_kernel = False self.global_vars = [] self.print_preprocessed = False - self.default_fp = f32 - self.default_ip = i32 + self.default_fp = ti.f32 + self.default_ip = ti.i32 self.target_tape = None self.inside_complex_kernel = False self.kernels = kernels or [] @@ -208,18 +210,18 @@ def get_num_compiled_functions(self): return len(self.compiled_functions) + len(self.compiled_grad_functions) def set_default_fp(self, fp): - assert fp in [f32, f64] + assert fp in [ti.f32, ti.f64] self.default_fp = fp default_cfg().default_fp = self.default_fp def set_default_ip(self, ip): - assert ip in [i32, i64] + assert ip in [ti.i32, ti.i64] self.default_ip = ip default_cfg().default_ip = self.default_ip def create_program(self): if self.prog is None: - self.prog = taichi_lang_core.Program() + self.prog = cutil.ti_core.Program() def materialize(self): if self.materialized: @@ -232,9 +234,8 @@ def layout(): for func in self.layout_functions: func() - import taichi as ti ti.trace('Materializing layout...') - taichi_lang_core.layout(layout) + cutil.ti_core.layout(layout) self.materialized = True not_placed = [] for var in self.global_vars: @@ -250,8 +251,8 @@ def layout(): f'{bar}Please consider specifying a shape for them. E.g.,' + '\n\n x = ti.field(float, shape=(2, 3))') - for func in self.materialize_callbacks: - func() + for callback in self.materialize_callbacks: + callback() self.materialize_callbacks = [] def print_snode_tree(self): @@ -294,7 +295,6 @@ def _clamp_unsigned_to_range(npty, val): # to deal with: |val| does't fall into the valid range of either # the signed or the unsigned type. return val - import taichi as ti new_val = val - cap ti.warn( f'Constant {val} has exceeded the range of {iif.bits} int, clamped to {new_val}' @@ -304,27 +304,26 @@ def _clamp_unsigned_to_range(npty, val): @taichi_scope def make_constant_expr(val): - import numpy as np _taichi_skip_traceback = 1 if isinstance(val, (int, np.integer)): - if pytaichi.default_ip in {i32, u32}: + if pytaichi.default_ip in {ti.i32, ti.u32}: # It is not always correct to do such clamp without the type info on # the LHS, but at least this makes assigning constant to unsigned # int work. See https://github.com/taichi-dev/taichi/issues/2060 return Expr( - taichi_lang_core.make_const_expr_i32( + cutil.ti_core.make_const_expr_i32( _clamp_unsigned_to_range(np.int32, val))) - elif pytaichi.default_ip in {i64, u64}: + elif pytaichi.default_ip in {ti.i64, ti.u64}: return Expr( - taichi_lang_core.make_const_expr_i64( + cutil.ti_core.make_const_expr_i64( _clamp_unsigned_to_range(np.int64, val))) else: assert False elif isinstance(val, (float, np.floating, np.ndarray)): - if pytaichi.default_fp == f32: - return Expr(taichi_lang_core.make_const_expr_f32(val)) - elif pytaichi.default_fp == f64: - return Expr(taichi_lang_core.make_const_expr_f64(val)) + if pytaichi.default_fp == ti.f32: + return Expr(cutil.ti_core.make_const_expr_f32(val)) + elif pytaichi.default_fp == ti.f64: + return Expr(cutil.ti_core.make_const_expr_f64(val)) else: assert False else: @@ -338,7 +337,7 @@ def reset(): pytaichi = PyTaichi(old_kernels) for k in old_kernels: k.reset() - taichi_lang_core.reset_default_compile_config() + cutil.ti_core.reset_default_compile_config() @taichi_scope @@ -368,9 +367,8 @@ def __init__(self): pass def __getattribute__(self, item): - import taichi as ti - ti.get_runtime().create_program() - root = SNode(ti.get_runtime().prog.get_root()) + get_runtime().create_program() + root = SNode(get_runtime().prog.get_root()) return getattr(root, item) def __repr__(self): @@ -417,16 +415,16 @@ def field(dtype, shape=None, offset=None, needs_grad=False): del _taichi_skip_traceback # primal - x = Expr(taichi_lang_core.make_id_expr("")) + x = Expr(cutil.ti_core.make_id_expr("")) x.declaration_tb = get_traceback(stacklevel=2) - x.ptr = taichi_lang_core.global_new(x.ptr, dtype) + x.ptr = cutil.ti_core.global_new(x.ptr, dtype) x.ptr.set_is_primal(True) pytaichi.global_vars.append(x) - if taichi_lang_core.needs_grad(dtype): + if cutil.ti_core.needs_grad(dtype): # adjoint - x_grad = Expr(taichi_lang_core.make_id_expr("")) - x_grad.ptr = taichi_lang_core.global_new(x_grad.ptr, dtype) + x_grad = Expr(cutil.ti_core.make_id_expr("")) + x_grad.ptr = cutil.ti_core.global_new(x_grad.ptr, dtype) x_grad.ptr.set_is_primal(False) x.set_grad(x_grad) @@ -450,7 +448,7 @@ def __init__(self, soa=False): @python_scope def layout(func): assert not pytaichi.materialized, "All layout must be specified before the first kernel launch / data access." - warning( + ti.warning( f"@ti.layout will be deprecated in the future, use ti.root directly to specify data layout anytime before the data structure materializes.", PendingDeprecationWarning, stacklevel=3) @@ -488,7 +486,8 @@ def vars2entries(vars): def add_separators(vars): for i, var in enumerate(vars): - if i: yield sep + if i: + yield sep yield var yield end @@ -509,16 +508,15 @@ def fused_string(entries): entries = vars2entries(vars) entries = fused_string(entries) contentries = [entry2content(entry) for entry in entries] - taichi_lang_core.create_print(contentries) + cutil.ti_core.create_print(contentries) @taichi_scope def ti_assert(cond, msg, extra_args): - # Mostly a wrapper to help us convert from ti.Expr (defined in Python) to - # taichi_lang_core.Expr (defined in C++) - import taichi as ti - taichi_lang_core.create_assert_stmt( - ti.Expr(cond).ptr, msg, [ti.Expr(x).ptr for x in extra_args]) + # Mostly a wrapper to help us convert from Expr (defined in Python) to + # cutil.ti_core.Expr (defined in C++) + cutil.ti_core.create_assert_stmt( + Expr(cond).ptr, msg, [Expr(x).ptr for x in extra_args]) @taichi_scope @@ -552,16 +550,16 @@ def one(x): @taichi_scope def get_external_tensor_dim(var): - return taichi_lang_core.get_external_tensor_dim(var) + return cutil.ti_core.get_external_tensor_dim(var) @taichi_scope def get_external_tensor_shape_along_axis(var, i): - return taichi_lang_core.get_external_tensor_shape_along_axis(var, i) + return cutil.ti_core.get_external_tensor_shape_along_axis(var, i) def indices(*x): - return [taichi_lang_core.Index(i) for i in x] + return [cutil.ti_core.Index(i) for i in x] index = indices @@ -572,12 +570,12 @@ def static(x, *xs): if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x) return [static(x)] + [static(x) for x in xs] import types - import taichi as ti + if isinstance(x, (bool, int, float, range, list, tuple, enumerate, ti.ndrange, ti.GroupedNDRange, zip, filter, map)) or x is None: return x - elif isinstance(x, (ti.Expr, ti.Matrix)) and x.is_global(): + elif isinstance(x, (Expr, ti.Matrix)) and x.is_global(): return x elif isinstance(x, (types.FunctionType, types.MethodType)): return x @@ -589,7 +587,6 @@ def static(x, *xs): @taichi_scope def grouped(x): - import taichi as ti if isinstance(x, ti.ndrange): return x.grouped() else: @@ -597,21 +594,16 @@ def grouped(x): def stop_grad(x): - taichi_lang_core.stop_grad(x.snode.ptr) + cutil.ti_core.stop_grad(x.snode.ptr) def current_cfg(): - return taichi_lang_core.current_compile_config() + return cutil.ti_core.current_compile_config() def default_cfg(): - return taichi_lang_core.default_compile_config() - - -from .kernel import * -from .ops import * -from .kernel_arguments import * + return cutil.ti_core.default_compile_config() def call_internal(name): - taichi_lang_core.create_internal_func_stmt(name) + cutil.ti_core.create_internal_func_stmt(name) diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel_impl.py similarity index 97% rename from python/taichi/lang/kernel.py rename to python/taichi/lang/kernel_impl.py index 60d0b5c0cf8b5..224fafaa3ce57 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel_impl.py @@ -5,6 +5,7 @@ import re import numpy as np +from taichi.core import primitive_types from taichi.lang import impl, util from taichi.lang.ast_checker import KernelSimplicityASTChecker from taichi.lang.core import taichi_lang_core @@ -12,6 +13,7 @@ from taichi.lang.kernel_arguments import ext_arr, template from taichi.lang.shell import _shell_pop_print, oinspect from taichi.lang.transformer import ASTTransformer +from taichi.misc.util import obsolete import taichi as ti @@ -142,7 +144,7 @@ def extract_arguments(self): if i == 0 and self.classfunc: annotation = template() else: - if id(annotation) in util.type_ids: + if id(annotation) in primitive_types.type_ids: ti.warning( 'Data type annotations are unnecessary for Taichi' ' functions, consider removing it', @@ -284,7 +286,7 @@ def extract_arguments(self): else: if isinstance(annotation, (template, ext_arr)): pass - elif id(annotation) in util.type_ids: + elif id(annotation) in primitive_types.type_ids: pass else: _taichi_skip_traceback = 1 @@ -389,11 +391,11 @@ def func__(*args): continue provided = type(v) # Note: do not use sth like "needed == f32". That would be slow. - if id(needed) in util.real_type_ids: + if id(needed) in primitive_types.real_type_ids: if not isinstance(v, (float, int)): raise KernelArgError(i, needed.to_string(), provided) launch_ctx.set_arg_float(actual_argument_slot, float(v)) - elif id(needed) in util.integer_type_ids: + elif id(needed) in primitive_types.integer_type_ids: if not isinstance(v, int): raise KernelArgError(i, needed.to_string(), provided) launch_ctx.set_arg_int(actual_argument_slot, int(v)) @@ -468,7 +470,7 @@ def call_back(): ti.sync() if has_ret: - if id(ret_dt) in util.integer_type_ids: + if id(ret_dt) in primitive_types.integer_type_ids: ret = t_kernel.get_ret_int(0) else: ret = t_kernel.get_ret_float(0) @@ -588,8 +590,8 @@ def kernel(func): return _kernel_impl(func, level_of_class_stackframe=3) -classfunc = util.obsolete('@ti.classfunc', '@ti.func directly') -classkernel = util.obsolete('@ti.classkernel', '@ti.kernel directly') +classfunc = obsolete('@ti.classfunc', '@ti.func directly') +classkernel = obsolete('@ti.classkernel', '@ti.kernel directly') class _BoundedDifferentiableMethod: diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 7bf2a2e1f33a4..795478c4a9a2f 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,13 +1,17 @@ -from . import expr -from . import impl import copy import numbers -import numpy as np -from .util import taichi_scope, python_scope, deprecated, to_numpy_type, to_pytorch_type, in_python_scope, is_taichi_class, warning -from .common_ops import TaichiOperations -from .exception import TaichiSyntaxError from collections.abc import Iterable +import numpy as np +from taichi.lang import expr, impl +from taichi.lang import ops as ops_mod +from taichi.lang import kernel_impl as kern_mod +from taichi.lang.common_ops import TaichiOperations +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.util import (in_python_scope, is_taichi_class, python_scope, + taichi_scope, to_numpy_type, to_pytorch_type) +from taichi.misc.util import deprecated, warning + class Matrix(TaichiOperations): is_taichi_class = True @@ -440,7 +444,7 @@ def cast(self, dtype): _taichi_skip_traceback = 1 ret = self.copy() for i in range(len(self.entries)): - ret.entries[i] = impl.cast(ret.entries[i], dtype) + ret.entries[i] = ops_mod.cast(ret.entries[i], dtype) return ret def trace(self): @@ -504,7 +508,7 @@ def E(x, y): inversed = deprecated('a.inversed()', 'a.inverse()')(inverse) - @impl.pyfunc + @kern_mod.pyfunc def normalized(self, eps=0): impl.static( impl.static_assert(self.m == 1, @@ -521,7 +525,7 @@ def transposed(a): def T(self): return self.transpose() - @impl.pyfunc + @kern_mod.pyfunc def transpose(self): ret = Matrix([[self[i, j] for i in range(self.n)] for j in range(self.m)]) @@ -606,25 +610,25 @@ def sum(self): ret = ret + self.entries[i] return ret - @impl.pyfunc + @kern_mod.pyfunc def norm(self, eps=0): - return impl.sqrt(self.norm_sqr() + eps) + return ops_mod.sqrt(self.norm_sqr() + eps) - @impl.pyfunc + @kern_mod.pyfunc def norm_inv(self, eps=0): - return impl.rsqrt(self.norm_sqr() + eps) + return ops_mod.rsqrt(self.norm_sqr() + eps) - @impl.pyfunc + @kern_mod.pyfunc def norm_sqr(self): return (self**2).sum() - @impl.pyfunc + @kern_mod.pyfunc def max(self): - return impl.ti_max(*self.entries) + return ops_mod.ti_max(*self.entries) - @impl.pyfunc + @kern_mod.pyfunc def min(self): - return impl.ti_min(*self.entries) + return ops_mod.ti_min(*self.entries) def any(self): import taichi as ti @@ -950,7 +954,7 @@ def __hash__(self): # using matrices as template arguments. return id(self) - @impl.pyfunc + @kern_mod.pyfunc def dot(self, other): impl.static( impl.static_assert(self.m == 1, "lhs for dot is not a vector")) @@ -958,7 +962,7 @@ def dot(self, other): impl.static_assert(other.m == 1, "rhs for dot is not a vector")) return (self * other).sum() - @impl.pyfunc + @kern_mod.pyfunc def _cross3d(self, other): ret = Matrix([ self[1] * other[2] - self[2] * other[1], @@ -967,7 +971,7 @@ def _cross3d(self, other): ]) return ret - @impl.pyfunc + @kern_mod.pyfunc def _cross2d(self, other): ret = self[0] * other[1] - self[1] * other[0] return ret @@ -984,7 +988,7 @@ def cross(self, other): "Cross product is only supported between pairs of 2D/3D vectors" ) - @impl.pyfunc + @kern_mod.pyfunc def outer_product(self, other): impl.static( impl.static_assert(self.m == 1, diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index c56ac28b5385d..cf50b70a4b9de 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -5,9 +5,9 @@ import operator as ops import traceback +from taichi.lang import impl from taichi.lang.exception import TaichiSyntaxError from taichi.lang.expr import Expr, make_expr_group -from taichi.lang.impl import expr_init from taichi.lang.util import (cook_dtype, is_taichi_class, taichi_lang_core, taichi_scope) @@ -271,7 +271,7 @@ def logical_not(a): def random(dtype=float): dtype = cook_dtype(dtype) x = Expr(ti_core.make_rand_expr(dtype)) - return expr_init(x) + return impl.expr_init(x) # NEXT: add matpow(self, power) @@ -437,43 +437,43 @@ def py_select(cond, a, b): @writeback_binary def atomic_add(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_add(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_sub(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_sub(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_min(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_min(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_max(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_max(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_and(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_bit_and(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_or(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_bit_or(a.ptr, b.ptr), tb=stack_info())) @writeback_binary def atomic_xor(a, b): - return expr_init( + return impl.expr_init( Expr(ti_core.expr_atomic_bit_xor(a.ptr, b.ptr), tb=stack_info())) @@ -514,7 +514,7 @@ def ti_all(a): def append(l, indices, val): - a = expr_init( + a = impl.expr_init( ti_core.insert_append(l.snode.ptr, make_expr_group(indices), Expr(val).ptr)) return a diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index aacb97fa1184e..3441b991cd1c0 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -7,7 +7,8 @@ from taichi.core import util as cutil from taichi.lang import impl from taichi.lang.expr import Expr -from taichi.lang.util import deprecated, is_taichi_class +from taichi.lang.util import is_taichi_class +from taichi.misc.util import deprecated class SNode: diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index baa51a6f75a25..62860f99743c6 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -503,7 +503,7 @@ def visit_struct_for(self, node, is_grouped): if 1: ___loop_var = 0 {} = ti.lang.expr.make_var_vector(size=len(___loop_var.loop_range().shape)) - ___expr_group = ti.make_expr_group({}) + ___expr_group = ti.lang.expr.make_expr_group({}) ti.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range()) ti.core.end_frontend_range_for() '''.format(vars, vars) @@ -516,7 +516,7 @@ def visit_struct_for(self, node, is_grouped): if 1: {} ___loop_var = 0 - ___expr_group = ti.make_expr_group({}) + ___expr_group = ti.lang.expr.make_expr_group({}) ti.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range()) ti.core.end_frontend_range_for() '''.format(var_decl, vars) @@ -678,7 +678,8 @@ def visit_FunctionDef(self, node): # Treat return type if node.returns is not None: - ret_init = self.parse_stmt('ti.decl_scalar_ret(0)') + ret_init = self.parse_stmt( + 'ti.lang.kernel_arguments.decl_scalar_ret(0)') ret_init.value.args[0] = node.returns self.returns = node.returns arg_decls.append(ret_init) @@ -691,7 +692,8 @@ def visit_FunctionDef(self, node): continue import taichi as ti if isinstance(self.func.arguments[i], ti.ext_arr): - arg_init = self.parse_stmt('x = ti.decl_ext_arr_arg(0, 0)') + arg_init = self.parse_stmt( + 'x = ti.lang.kernel_arguments.decl_ext_arr_arg(0, 0)') arg_init.targets[0].id = arg.arg self.create_variable(arg.arg) array_dt = self.arg_features[i][0] @@ -704,7 +706,8 @@ def visit_FunctionDef(self, node): "{}".format(array_dim)) arg_decls.append(arg_init) else: - arg_init = self.parse_stmt('x = ti.decl_scalar_arg(0)') + arg_init = self.parse_stmt( + 'x = ti.lang.kernel_arguments.decl_scalar_arg(0)') arg_init.targets[0].id = arg.arg dt = arg.annotation arg_init.value.args[0] = dt diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index 51b47fce620db..006ac06c04f3a 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -1,8 +1,12 @@ -from .core import taichi_lang_core -from taichi.misc.util import warning, deprecated, obsolete, get_traceback -import numpy as np +import functools import os +import numpy as np +from taichi.lang import impl +from taichi.lang.core import taichi_lang_core + +import taichi as ti + _has_pytorch = False _env_torch = os.environ.get('TI_ENABLE_TORCH', '1') @@ -28,89 +32,51 @@ def is_taichi_class(rhs): return taichi_class -# Real types - -float32 = taichi_lang_core.DataType_f32 -f32 = float32 -float64 = taichi_lang_core.DataType_f64 -f64 = float64 - -real_types = [f32, f64, float] -real_type_ids = [id(t) for t in real_types] - -# Integer types - -int8 = taichi_lang_core.DataType_i8 -i8 = int8 -int16 = taichi_lang_core.DataType_i16 -i16 = int16 -int32 = taichi_lang_core.DataType_i32 -i32 = int32 -int64 = taichi_lang_core.DataType_i64 -i64 = int64 - -uint8 = taichi_lang_core.DataType_u8 -u8 = uint8 -uint16 = taichi_lang_core.DataType_u16 -u16 = uint16 -uint32 = taichi_lang_core.DataType_u32 -u32 = uint32 -uint64 = taichi_lang_core.DataType_u64 -u64 = uint64 - -integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] -integer_type_ids = [id(t) for t in integer_types] - -types = real_types + integer_types -type_ids = [id(t) for t in types] - - def to_numpy_type(dt): - if dt == f32: + if dt == ti.f32: return np.float32 - elif dt == f64: + elif dt == ti.f64: return np.float64 - elif dt == i32: + elif dt == ti.i32: return np.int32 - elif dt == i64: + elif dt == ti.i64: return np.int64 - elif dt == i8: + elif dt == ti.i8: return np.int8 - elif dt == i16: + elif dt == ti.i16: return np.int16 - elif dt == u8: + elif dt == ti.u8: return np.uint8 - elif dt == u16: + elif dt == ti.u16: return np.uint16 - elif dt == u32: + elif dt == ti.u32: return np.uint32 - elif dt == u64: + elif dt == ti.u64: return np.uint64 else: assert False def to_pytorch_type(dt): - import torch - if dt == f32: + if dt == ti.f32: return torch.float32 - elif dt == f64: + elif dt == ti.f64: return torch.float64 - elif dt == i32: + elif dt == ti.i32: return torch.int32 - elif dt == i64: + elif dt == ti.i64: return torch.int64 - elif dt == i8: + elif dt == ti.i8: return torch.int8 - elif dt == i16: + elif dt == ti.i16: return torch.int16 - elif dt == u8: + elif dt == ti.u8: return torch.uint8 - elif dt == u16: + elif dt == ti.u16: return torch.uint16 - elif dt == u32: + elif dt == ti.u32: return torch.uint32 - elif dt == u64: + elif dt == ti.u64: return torch.uint64 else: assert False @@ -121,68 +87,66 @@ def to_taichi_type(dt): return dt if dt == np.float32: - return f32 + return ti.f32 elif dt == np.float64: - return f64 + return ti.f64 elif dt == np.int32: - return i32 + return ti.i32 elif dt == np.int64: - return i64 + return ti.i64 elif dt == np.int8: - return i8 + return ti.i8 elif dt == np.int16: - return i16 + return ti.i16 elif dt == np.uint8: - return u8 + return ti.u8 elif dt == np.uint16: - return u16 + return ti.u16 elif dt == np.uint32: - return u32 + return ti.u32 elif dt == np.uint64: - return u64 + return ti.u64 if has_pytorch(): if dt == torch.float32: - return f32 + return ti.f32 elif dt == torch.float64: - return f64 + return ti.f64 elif dt == torch.int32: - return i32 + return ti.i32 elif dt == torch.int64: - return i64 + return ti.i64 elif dt == torch.int8: - return i8 + return ti.i8 elif dt == torch.int16: - return i16 + return ti.i16 elif dt == torch.uint8: - return u8 + return ti.u8 elif dt == torch.uint16: - return u16 + return ti.u16 elif dt == torch.uint32: - return u32 + return ti.u32 elif dt == torch.uint64: - return u64 + return ti.u64 raise AssertionError("Unknown type {}".format(dt)) def cook_dtype(dtype): - from .impl import get_runtime _taichi_skip_traceback = 1 if isinstance(dtype, taichi_lang_core.DataType): return dtype elif isinstance(dtype, taichi_lang_core.Type): return taichi_lang_core.DataType(dtype) elif dtype is float: - return get_runtime().default_fp + return impl.get_runtime().default_fp elif dtype is int: - return get_runtime().default_ip + return impl.get_runtime().default_ip else: raise ValueError(f'Invalid data type {dtype}') def in_taichi_scope(): - from . import impl return impl.inside_kernel() @@ -191,8 +155,6 @@ def in_python_scope(): def taichi_scope(func): - import functools - @functools.wraps(func) def wrapped(*args, **kwargs): _taichi_skip_traceback = 1 @@ -204,8 +166,6 @@ def wrapped(*args, **kwargs): def python_scope(func): - import functools - @functools.wraps(func) def wrapped(*args, **kwargs): _taichi_skip_traceback = 1 diff --git a/tests/python/test_kernel_template_mapper.py b/tests/python/test_kernel_template_mapper.py index 4e4177ec08dd0..3b68c36e84f00 100644 --- a/tests/python/test_kernel_template_mapper.py +++ b/tests/python/test_kernel_template_mapper.py @@ -1,3 +1,4 @@ +from taichi.lang.kernel_impl import KernelTemplateMapper import taichi as ti @@ -8,7 +9,7 @@ def test_kernel_template_mapper(): ti.root.place(x, y) - mapper = ti.KernelTemplateMapper( + mapper = KernelTemplateMapper( (ti.template(), ti.template(), ti.template()), template_slot_locations=(0, 1, 2)) assert mapper.lookup((0, 0, 0))[0] == 0 @@ -17,14 +18,14 @@ def test_kernel_template_mapper(): assert mapper.lookup((0, 0, 1))[0] == 2 assert mapper.lookup((0, 1, 0))[0] == 1 - mapper = ti.KernelTemplateMapper((ti.i32, ti.i32, ti.i32), ()) + mapper = KernelTemplateMapper((ti.i32, ti.i32, ti.i32), ()) assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 0, 1))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 - mapper = ti.KernelTemplateMapper((ti.i32, ti.template(), ti.i32), (1, )) + mapper = KernelTemplateMapper((ti.i32, ti.template(), ti.i32), (1, )) assert mapper.lookup((0, x, 0))[0] == 0 assert mapper.lookup((0, y, 0))[0] == 1 assert mapper.lookup((0, x, 1))[0] == 0 @@ -41,7 +42,7 @@ def test_kernel_template_mapper_numpy(): import numpy as np - mapper = ti.KernelTemplateMapper(annotations, (0, 1, 2)) + mapper = KernelTemplateMapper(annotations, (0, 1, 2)) assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 3), dtype=np.float32)))[0] == 0 assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 4),