Skip to content

Commit

Permalink
Update on "[Lang] Support ndarray argument for real function"
Browse files Browse the repository at this point in the history
<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at a936066</samp>

### Summary
🧮🧪🎁

<!--
1.  🧮 for adding the binding for the `insert_ndarray_param` method.
2.  🧪 for modifying and adding the tests for the functionality.
3.  🎁 for enabling the feature of passing `ti.ndarray` as arguments.
-->
This pull request adds support for passing `ti.ndarray` as arguments to `ti.func` and `ti.experimental.real_func`, which allows users to write more flexible and generic functions that can operate on different types of arrays. It also updates and adds some tests to verify the correctness and error handling of this feature.

> _`AnyArray` is the key to unleash the power_
> _Pass it to the `func` and `real_func` in the hour_
> _Invoke the `insert_ndarray_param` to bind the data_
> _Break the limits of the backends and the taichi_

### Walkthrough
*  Import `AnyArray` class to support unified indexing and arithmetic operations for `ti.Matrix` and `ti.ndarray` ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1R15))
*  Modify `Func.__call__` to return `arg_features` of arguments, which are tuples of element type, dimension, needs_grad, and boundary attributes ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1L229-R235))
*  Add branch to `Func.func_call_rvalue` to handle `AnyArray` arguments with `ndarray_type.NdarrayType` annotations, and append pointer to `AnyArray` data to `non_template_args` ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1R261-R266))
*  Modify `Func.do_compile` to take `arg_features` as input and pass them to `_get_tree_and_ctx`, which generates AST and context for function ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1L277-R287))
*  Add branch to `TaichiCallableTemplateMapper.extract_arg` to return `arg_features` for `AnyArray` arguments ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-a157043b38542c8145447ff342fda65fe4d54fb777fe514daa70007e83e20dc1R415-R418))
*  Add binding for `Function.insert_ndarray_param` to Python interface, which inserts pointer to `AnyArray` data to function parameters and sets flag ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-af631a0c71978fe591e17005f01f7c06bc30ae36c65df306bbb3b08ade770167R720))
*  Modify and add tests for passing `ti.ndarray` to `ti.func` and `ti.experimental.real_func` in `test_ndarray.py` ([link](https://github.com/taichi-dev/taichi/pull/8188/files?diff=unified&w=0#diff-ca3c8d1edb25b6a7f4affbb79b2e3e74f73b3757e5d465258ce42ea9eb09fbc0L1008-R1051))






[ghstack-poisoned]
  • Loading branch information
lin-hitonami committed Jun 16, 2023
2 parents a9344ef + fe51245 commit bcd32b1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
7 changes: 3 additions & 4 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ class AnyArray:
layout (Layout): Memory layout.
"""

def __init__(self, ptr, element_type):
def __init__(self, ptr):
assert ptr.is_external_tensor_expr()
self.element_type = element_type
self.ptr = ptr
self.ptr.type_check(impl.get_runtime().prog.config())

Expand All @@ -34,14 +33,14 @@ def layout(self):

def get_type(self):
return NdarrayTypeMetadata(
self.element_type, None, _ti_core.get_external_tensor_needs_grad(self.ptr)
_ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
) # AnyArray can take any shape

@property
@taichi_scope
def grad(self):
"""Returns the gradient of this array."""
return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr), self.element_type)
return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr))

@property
@taichi_scope
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def decl_sparse_matrix(dtype, name):

def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary), element_type)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))


def decl_texture_arg(num_dimensions, name):
Expand Down
6 changes: 6 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,12 @@ void export_lang(py::module &m) {
return expr.cast<ExternalTensorExpression>()->needs_grad;
});

m.def("get_external_tensor_element_type", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
return external_tensor_expr->dt;
});

m.def("get_external_tensor_element_shape", [](const Expr &expr) {
TI_ASSERT(expr.is<ExternalTensorExpression>());
auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
Expand Down

0 comments on commit bcd32b1

Please sign in to comment.