Skip to content

Commit

Permalink
[Lang] support v.x/v.y/v.z/v.w for easy Vector subscript (#1133)
Browse files Browse the repository at this point in the history
* [Lang] support v.x, v.y, v.z for easy Vector subscript

* [skip ci] Proxy.xyzw, add test

* fix test_cli

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
archibate and taichi-gardener authored Jun 4, 2020
1 parent 022d3f4 commit 2ad7d22
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
88 changes: 88 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,62 @@ def subscript(self, *indices):
j = 0
return self(i, j)

@property
def x(self):
if impl.inside_kernel():
return self.subscript(0)
else:
return self[0]

@property
def y(self):
if impl.inside_kernel():
return self.subscript(1)
else:
return self[1]

@property
def z(self):
if impl.inside_kernel():
return self.subscript(2)
else:
return self[2]

@property
def w(self):
if impl.inside_kernel():
return self.subscript(3)
else:
return self[3]

@x.setter
def x(self, value):
if impl.inside_kernel():
self.subscript(0).assign(value)
else:
self[0] = value

@y.setter
def y(self, value):
if impl.inside_kernel():
self.subscript(1).assign(value)
else:
self[1] = value

@z.setter
def z(self, value):
if impl.inside_kernel():
self.subscript(2).assign(value)
else:
self[2] = value

@w.setter
def w(self, value):
if impl.inside_kernel():
self.subscript(3).assign(value)
else:
self[3] = value

class Proxy:
def __init__(self, mat, index):
self.mat = mat
Expand All @@ -242,6 +298,38 @@ def __setitem__(self, key, value):
key = [key]
self.mat(*key)[self.index] = value

@property
def x(self):
return self[0]

@property
def y(self):
return self[1]

@property
def z(self):
return self[2]

@property
def w(self):
return self[3]

@x.setter
def x(self, value):
self[0] = value

@y.setter
def y(self, value):
self[1] = value

@z.setter
def z(self, value):
self[2] = value

@w.setter
def w(self, value):
self[3] = value

# host access
def __getitem__(self, index):
return Matrix.Proxy(self, index)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,23 @@ def func():
func()
assert np.allclose(m.to_numpy()[1, 0, 0, :, :], np.array([[2, 4], [6, 8]]))
assert np.allclose(v.to_numpy()[1, 0, 0, :], np.array([10, 12]))


@ti.host_arch_only
def test_vector_xyzw_accessor():
u = ti.Vector(2, dt=ti.i32, shape=(2, 2, 1))
v = ti.Vector(4, dt=ti.i32, shape=(2, 2, 1))

u[1, 0, 0].y = 3
v[1, 0, 0].w = 4

@ti.kernel
def func():
u[1, 0, 0].x = 8 * u[1, 0, 0].y
v[1, 0, 0].z = 1 - v[1, 0, 0].w
v[1, 0, 0].x = 6

func()
assert u[1, 0, 0].x == 24
assert u[1, 0, 0].y == 3
assert np.allclose(v.to_numpy()[1, 0, 0, :], np.array([6, 0, -3, 4]))

0 comments on commit 2ad7d22

Please sign in to comment.