diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index 86c6f8b44869d..fa552857a932c 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -33,7 +33,7 @@ def layout(self): def get_type(self): return NdarrayTypeMetadata( - self.ptr.get_ret_type().ptr_removed(), None, _ti_core.get_external_tensor_needs_grad(self.ptr) + _ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr) ) # AnyArray can take any shape @property diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 3ea61396421b8..f13307e5b506d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1012,6 +1012,12 @@ void export_lang(py::module &m) { return expr.cast()->needs_grad; }); + m.def("get_external_tensor_element_type", [](const Expr &expr) { + TI_ASSERT(expr.is()); + auto external_tensor_expr = expr.cast(); + return external_tensor_expr->dt; + }); + m.def("get_external_tensor_element_shape", [](const Expr &expr) { TI_ASSERT(expr.is()); auto external_tensor_expr = expr.cast();