diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index ca585296a7c09..fa552857a932c 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -15,9 +15,8 @@ class AnyArray: layout (Layout): Memory layout. """ - def __init__(self, ptr, element_type): + def __init__(self, ptr): assert ptr.is_external_tensor_expr() - self.element_type = element_type self.ptr = ptr self.ptr.type_check(impl.get_runtime().prog.config()) @@ -34,14 +33,14 @@ def layout(self): def get_type(self): return NdarrayTypeMetadata( - self.element_type, None, _ti_core.get_external_tensor_needs_grad(self.ptr) + _ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr) ) # AnyArray can take any shape @property @taichi_scope def grad(self): """Returns the gradient of this array.""" - return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr), self.element_type) + return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr)) @property @taichi_scope diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 04ecbd83ee9d4..94f237622a725 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -108,7 +108,7 @@ def decl_sparse_matrix(dtype, name): def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary): arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad) - return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary), element_type) + return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary)) def decl_texture_arg(num_dimensions, name): diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 3ea61396421b8..f13307e5b506d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1012,6 +1012,12 @@ void export_lang(py::module &m) { return expr.cast()->needs_grad; }); + m.def("get_external_tensor_element_type", [](const Expr &expr) { + TI_ASSERT(expr.is()); + auto external_tensor_expr = expr.cast(); + return external_tensor_expr->dt; + }); + m.def("get_external_tensor_element_shape", [](const Expr &expr) { TI_ASSERT(expr.is()); auto external_tensor_expr = expr.cast();