Skip to content

Commit

Permalink
[lang] Record the element_type of the AnyArray
Browse files Browse the repository at this point in the history
ghstack-source-id: 938770dea60ac879c40fe25d4f67a50b44b03234
Pull Request resolved: #8192
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Jun 19, 2023
1 parent 619b957 commit 3653f49
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def layout(self):

def get_type(self):
return NdarrayTypeMetadata(
self.ptr.get_ret_type().ptr_removed(), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
_ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
) # AnyArray can take any shape

@property
Expand Down
6 changes: 6 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,12 @@ void export_lang(py::module &m) {
return expr.cast<ExternalTensorExpression>()->needs_grad;
});

m.def("get_external_tensor_element_type", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
return external_tensor_expr->dt;
});

m.def("get_external_tensor_element_shape", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
Expand Down

0 comments on commit 3653f49

Please sign in to comment.