Skip to content

Commit

Permalink
Update on "[lang] Record the element_type of the AnyArray"
Browse files Browse the repository at this point in the history
<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 0176872</samp>

### Summary
📝🔬🆕

<!--
1.  📝 - This emoji can represent the change of modifying the `decl_ndarray_arg` function to pass the `element_type` to the `AnyArray` constructor, since this is a code change that involves writing or editing some code.
2.  🔬 - This emoji can represent the change of adding `element_type` argument to `AnyArray` constructor and using it in `get_type` and `grad` methods, since this is a code change that involves improving the type handling and gradient computation of arbitrary arrays, which are related to scientific or mathematical operations.
3.  🆕 - This emoji can represent the change of introducing a new argument to the `AnyArray` constructor, since this is a code change that involves adding a new feature or functionality to the existing code.
-->
Improved type handling and gradient computation of `AnyArray` arguments in Taichi kernels and functions. Added `element_type` parameter to `AnyArray` constructor and `decl_ndarray_arg` function.

> _To declare an array argument_
> _We pass the element type along_
> _This helps `AnyArray`_
> _To know what to say_
> _When it calls `get_type` or `grad`_

### Walkthrough
*  Add `element_type` argument to `AnyArray` constructor and store it as an attribute ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L18-R20), [link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-575efc738df7b1202370c2531ec82232dc7f287b2bec4999af03ef40da4f5deeL111-R111))
*  Return stored `element_type` in `AnyArray.get_type` method instead of inferring from pointer ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L36-R37))
*  Pass stored `element_type` to `AnyArray` constructor in `AnyArray.grad` method to preserve type information in gradient array ([link](https://github.com/taichi-dev/taichi/pull/8192/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L43-R44))






[ghstack-poisoned]
  • Loading branch information
lin-hitonami committed Jun 16, 2023
2 parents 0176872 + ad5b75e commit b49adb0
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,20 +1677,22 @@ std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
expanded_exprs.push_back(expr);
} else {
// Expand TensorType expr
// clang-format off
/*
Before:
TensorType<4 x i32> index = Expr;
After:
TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>)
After:
TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>)
i32 ind0 = IndexExpression(id_expr, 0)
i32 ind1 = IndexExpression(id_expr, 1)
i32 ind2 = IndexExpression(id_expr, 2)
i32 ind3 = IndexExpression(id_expr, 3)
i32 ind1 = IndexExpression(id_expr, 1)
i32 ind2 = IndexExpression(id_expr, 2)
i32 ind3 = IndexExpression(id_expr, 3)
return {ind0, ind1, ind2, ind3}
return {ind0, ind1, ind2, ind3}
*/
*/
// clang-format on
auto tensor_type = expr->ret_type.ptr_removed()->cast<TensorType>();

Expr id_expr;
Expand Down

0 comments on commit b49adb0

Please sign in to comment.