Skip to content

Commit

Permalink
Update on "[lang] Record the element_type of the AnyArray"
Browse files Browse the repository at this point in the history
<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 0176872</samp>

### Summary
📝🔬🆕

<!--
1.  📝 - This emoji can represent the change of modifying the `decl_ndarray_arg` function to pass the `element_type` to the `AnyArray` constructor, since this is a code change that involves writing or editing some code.
2.  🔬 - This emoji can represent the change of adding `element_type` argument to `AnyArray` constructor and using it in `get_type` and `grad` methods, since this is a code change that involves improving the type handling and gradient computation of arbitrary arrays, which are related to scientific or mathematical operations.
3.  🆕 - This emoji can represent the change of introducing a new argument to the `AnyArray` constructor, since this is a code change that involves adding a new feature or functionality to the existing code.
-->
Improved type handling and gradient computation of `AnyArray` arguments in Taichi kernels and functions. Added `element_type` parameter to `AnyArray` constructor and `decl_ndarray_arg` function.

> _To declare an array argument_
> _We pass the element type along_
> _This helps `AnyArray`_
> _To know what to say_
> _When it calls `get_type` or `grad`_

### Walkthrough
*  Add `element_type` argument to `AnyArray` constructor and store it as an attribute ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L18-R20), [link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-575efc738df7b1202370c2531ec82232dc7f287b2bec4999af03ef40da4f5deeL111-R111))
*  Return stored `element_type` in `AnyArray.get_type` method instead of inferring from pointer ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L36-R37))
*  Pass stored `element_type` to `AnyArray` constructor in `AnyArray.grad` method to preserve type information in gradient array ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L43-R44))






[ghstack-poisoned]
  • Loading branch information
lin-hitonami committed Jun 16, 2023
2 parents b49adb0 + 6a10318 commit d183bb1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
7 changes: 3 additions & 4 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ class AnyArray:
layout (Layout): Memory layout.
"""

def __init__(self, ptr, element_type):
def __init__(self, ptr):
assert ptr.is_external_tensor_expr()
self.element_type = element_type
self.ptr = ptr
self.ptr.type_check(impl.get_runtime().prog.config())

Expand All @@ -34,14 +33,14 @@ def layout(self):

def get_type(self):
return NdarrayTypeMetadata(
self.element_type, None, _ti_core.get_external_tensor_needs_grad(self.ptr)
_ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
) # AnyArray can take any shape

@property
@taichi_scope
def grad(self):
"""Returns the gradient of this array."""
return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr), self.element_type)
return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr))

@property
@taichi_scope
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def decl_sparse_matrix(dtype, name):

def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary), element_type)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))


def decl_texture_arg(num_dimensions, name):
Expand Down
6 changes: 6 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,12 @@ void export_lang(py::module &m) {
return expr.cast<ExternalTensorExpression>()->needs_grad;
});

m.def("get_external_tensor_element_type", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
return external_tensor_expr->dt;
});

m.def("get_external_tensor_element_shape", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
Expand Down

0 comments on commit d183bb1

Please sign in to comment.