Skip to content

Commit

Permalink
[bug] Fix argpack with struct and vector (#8403)
Browse files Browse the repository at this point in the history
Issue: #8385 

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 33db838</samp>

Added support for struct arguments in kernels using `ArgPack` types and
fixed a bug with nested struct offsets. Added tests for `ArgPack` with
struct and vector types.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 33db838</samp>

* Import `Struct` class from `taichi.lang.struct` to support struct
arguments
([link](https://github.com/taichi-dev/taichi/pull/8403/files?diff=unified&w=0#diff-ca7ad70b7884646820303ff5b04a8fd2aa7396cf4b1c7a5947a2bdf651d37a48L12-R12))
* Update `isinstance` check to include `Struct` as a possible data type
([link](https://github.com/taichi-dev/taichi/pull/8403/files?diff=unified&w=0#diff-ca7ad70b7884646820303ff5b04a8fd2aa7396cf4b1c7a5947a2bdf651d37a48L343-R343))
* Fix offset bug for nested structs by moving iterator back one position
([link](https://github.com/taichi-dev/taichi/pull/8403/files?diff=unified&w=0#diff-aea65318c1059505b73cd497b15f153d5439631cfad19f0f55061f5efa8f4272L209-R209))
* Add three tests for passing struct and vector types as `ArgPack`
elements to kernels in `test_argpack.py`
([link](https://github.com/taichi-dev/taichi/pull/8403/files?diff=unified&w=0#diff-b84304006655e02b7d4e51469ac99f189f031d811e50a32477fb6b6c19ab4fd6R27-R53))
  • Loading branch information
lin-hitonami authored Nov 3, 2023
1 parent fc4c8f1 commit ff251e1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
TaichiSyntaxError,
)
from taichi.lang.matrix import MatrixType
from taichi.lang.struct import StructType
from taichi.lang.struct import StructType, Struct
from taichi.lang.util import cook_dtype
from taichi.types import (
ndarray_type,
Expand Down Expand Up @@ -340,7 +340,7 @@ def __call__(self, *args, **kwargs):

# If dtype is CompoundType and data is a scalar, it cannot be
# casted in the self.cast call later. We need an initialization here.
if isinstance(dtype, CompoundType) and not isinstance(data, (dict, ArgPack)):
if isinstance(dtype, CompoundType) and not isinstance(data, (dict, ArgPack, Struct)):
data = dtype(data)

d[name] = data
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ size_t ArgPackType::get_element_offset(const std::vector<int> &indices) const {
type_now = tensor_type->get_element_type();
} else if (auto struct_type = type_now->cast<StructType>()) {
std::vector<int> indices_for_struct;
indices_for_struct.assign(++it, indices.end());
indices_for_struct.assign(it, indices.end());
offset += struct_type->get_element_offset(indices_for_struct);
return offset;
} else {
Expand Down
27 changes: 27 additions & 0 deletions tests/python/test_argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@ def foo(pack: pack_type) -> ti.f32:
assert foo(pack2) == test_utils.approx(2 + 2.1, rel=1e-3)


@test_utils.test()
def test_argpack_with_struct():
struct_type = ti.types.struct(a=ti.i32, c=ti.f32)
pack_type = ti.types.argpack(d=ti.f32, element=struct_type)

@ti.kernel
def foo(pack: pack_type) -> ti.f32:
tmp = pack.element.a + pack.element.c
return tmp + pack.d

pack = pack_type(d=2.1, element=struct_type(a=2, c=2.1))
assert foo(pack) == test_utils.approx(2 + 2.1 + 2.1, rel=1e-3)


@test_utils.test()
def test_argpack_with_vector():
pack_type = ti.types.argpack(a=ti.i32, b=ti.types.vector(3, ti.f32), c=ti.f32)
pack = pack_type(a=1, b=[1.0, 2.0, 3.0], c=2.1)

@ti.kernel
def foo(pack: pack_type) -> ti.f32:
tmp = pack.a * pack.c
return tmp + pack.b[1]

assert foo(pack) == test_utils.approx(1 * 2.1 + 2.0, rel=1e-3)


@test_utils.test()
def test_argpack_multiple():
arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4))
Expand Down

0 comments on commit ff251e1

Please sign in to comment.