From 88be974a08dc0a1bea0e5d24b3e814c1494668d7 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Sun, 10 May 2020 23:14:06 +0800 Subject: [PATCH 01/18] [skip ci] add test --- python/taichi/lang/expr.py | 60 ++++++++--------- python/taichi/lang/matrix.py | 106 ++++++++++-------------------- python/taichi/lang/ops.py | 71 +++++++++++++++----- tests/python/test_element_wise.py | 29 +++++++- 4 files changed, 145 insertions(+), 121 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index ee9129d16a40c..6da9e879e9c98 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -46,51 +46,49 @@ def stack_info(): # remove the confusing last line return '\n'.join(raw.split('\n')[:-3]) + '\n' + def __neg__(self): + import taichi as ti + return ti.neg(self) + def __add__(self, other): - other = Expr(other) - return Expr(taichi_lang_core.expr_add(self.ptr, other.ptr), - tb=self.stack_info()) + import taichi as ti + return ti.add(self, other) __radd__ = __add__ - def __neg__(self): - return Expr(taichi_lang_core.expr_neg(self.ptr), tb=self.stack_info()) - def __sub__(self, other): - other = Expr(other) - return Expr(taichi_lang_core.expr_sub(self.ptr, other.ptr), - tb=self.stack_info()) + import taichi as ti + return ti.sub(self, other) def __rsub__(self, other): - other = Expr(other) - return Expr(taichi_lang_core.expr_sub(other.ptr, self.ptr)) + import taichi as ti + return ti.sub(other, self) def __mul__(self, other): - if is_taichi_class(other) and hasattr(other, '__rmul__'): - return other.__rmul__(self) - else: - other = Expr(other) - return Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr)) + import taichi as ti + return ti.mul(self, other) __rmul__ = __mul__ def __truediv__(self, other): - return Expr(taichi_lang_core.expr_truediv(self.ptr, Expr(other).ptr)) + import taichi as ti + return ti.truediv(self, other) def __rtruediv__(self, other): - return Expr(taichi_lang_core.expr_truediv(Expr(other).ptr, self.ptr)) + import taichi as ti + return ti.truediv(other, self) def __floordiv__(self, other): - return Expr(taichi_lang_core.expr_floordiv(self.ptr, Expr(other).ptr)) + import taichi as ti + return ti.floordiv(self, other) def __rfloordiv__(self, other): - return Expr(taichi_lang_core.expr_floordiv(Expr(other).ptr, self.ptr)) + import taichi as ti + return ti.floordiv(other, self) def __mod__(self, other): - other = Expr(other) - quotient = Expr(taichi_lang_core.expr_floordiv(self.ptr, other.ptr)) - multiply = Expr(taichi_lang_core.expr_mul(other.ptr, quotient.ptr)) - return Expr(taichi_lang_core.expr_sub(self.ptr, multiply.ptr)) + import taichi as ti + return ti.mod(self, other) def __iadd__(self, other): self.atomic_add(other) @@ -99,17 +97,16 @@ def __isub__(self, other): self.atomic_sub(other) def __imul__(self, other): - self.assign(Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr))) + import taichi as ti + self.assign(ti.mul(self, other)) def __itruediv__(self, other): - self.assign( - Expr(taichi_lang_core.expr_truediv(self.ptr, - Expr(other).ptr))) + import taichi as ti + self.assign(ti.truediv(self, other)) def __ifloordiv__(self, other): - self.assign( - Expr(taichi_lang_core.expr_floordiv(self.ptr, - Expr(other).ptr))) + import taichi as ti + self.assign(ti.floordiv(self, other)) def __iand__(self, other): self.atomic_and(other) @@ -120,6 +117,7 @@ def __ior__(self, other): def __ixor__(self, other): self.atomic_xor(other) + # TODO: ti.cmp_le def __le__(self, other): other = Expr(other) return Expr(taichi_lang_core.expr_cmp_le(self.ptr, other.ptr)) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 7f7393ee9d596..e097190d39574 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -186,94 +186,55 @@ def __rpow__(self, other): ret(i, j).assign(other(i, j)**self(i, j)) return ret - @broadcast_if_scalar - def __div__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) / other(i, j)) - return ret - - @broadcast_if_scalar - def __rtruediv__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(other(i, j) / self(i, j)) - return ret - def broadcast(self, scalar): ret = Matrix(self.n, self.m, empty=True) for i in range(self.n * self.m): ret.entries[i] = scalar return ret - @broadcast_if_scalar - def __truediv__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) / other(i, j)) - return ret + def __neg__(self): + import taichi as ti + return ti.neg(self) - @broadcast_if_scalar - def __floordiv__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) // other(i, j)) - return ret + def __add__(self, other): + import taichi as ti + return ti.add(self, other) + + __radd__ = __add__ + + def __sub__(self, other): + import taichi as ti + return ti.sub(self, other) + + def __rsub__(self, other): + import taichi as ti + return ti.sub(other, self) - @broadcast_if_scalar def __mul__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) * other(i, j)) - return ret + import taichi as ti + return ti.mul(self, other) __rmul__ = __mul__ - @broadcast_if_scalar - def __add__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) + other(i, j)) - return ret + def __truediv__(self, other): + import taichi as ti + return ti.truediv(self, other) - __radd__ = __add__ + def __rtruediv__(self, other): + import taichi as ti + return ti.truediv(other, self) - @broadcast_if_scalar - def __sub__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j) - other(i, j)) - return ret + def __floordiv__(self, other): + import taichi as ti + return ti.floordiv(self, other) - def __neg__(self): - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(-self(i, j)) - return ret + def __rfloordiv__(self, other): + import taichi as ti + return ti.floordiv(other, self) - @broadcast_if_scalar - def __rsub__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(other(i, j) - self(i, j)) - return ret + def __mod__(self, other): + import taichi as ti + return ti.mod(self, other) def linearize_entry_id(self, *args): assert 1 <= len(args) <= 2 @@ -524,6 +485,7 @@ def diag(dim, val): def loop_range(self): return self.entries[0] + # TODO @broadcast_if_scalar def augassign(self, other, op): if not isinstance(other, Matrix): diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 9bf64010532f1..ef5929ece6312 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -20,10 +20,12 @@ def stack_info(): def unary(foo): import taichi as ti + imp_foo = lambda x: foo(Expr(x)) + @functools.wraps(foo) def wrapped(a): if ti.is_taichi_class(a): - return a.element_wise_unary(foo) + return a.element_wise_unary(imp_foo) else: return foo(Expr(a)) @@ -37,15 +39,17 @@ def wrapped(a): def binary(foo): import taichi as ti + imp_foo = lambda x, y: foo(Expr(x), Expr(y)) + rev_foo = lambda x, y: foo(Expr(y), Expr(x)) + @functools.wraps(foo) def wrapped(a, b): if ti.is_taichi_class(a): - return a.element_wise_binary(foo, b) + return a.element_wise_binary(imp_foo, b) elif ti.is_taichi_class(b): - rev_foo = lambda x, y: foo(y, x) return b.element_wise_binary(rev_foo, a) else: - return foo(Expr(a), Expr(b)) + return imp_foo(a, b) binary_ops.append(wrapped) return wrapped @@ -84,6 +88,11 @@ def sqr(obj): return obj * obj +@unary +def neg(expr): + return Expr(taichi_lang_core.expr_neg(expr.ptr), tb=stack_info()) + + @unary def sin(expr): return Expr(taichi_lang_core.expr_sin(expr.ptr), tb=stack_info()) @@ -121,32 +130,32 @@ def ceil(expr): @unary def inv(expr): - return Expr(taichi_lang_core.expr_inv(expr.ptr)) + return Expr(taichi_lang_core.expr_inv(expr.ptr), tb=stack_info()) @unary def tan(expr): - return Expr(taichi_lang_core.expr_tan(expr.ptr)) + return Expr(taichi_lang_core.expr_tan(expr.ptr), tb=stack_info()) @unary def tanh(expr): - return Expr(taichi_lang_core.expr_tanh(expr.ptr)) + return Expr(taichi_lang_core.expr_tanh(expr.ptr), tb=stack_info()) @unary def exp(expr): - return Expr(taichi_lang_core.expr_exp(expr.ptr)) + return Expr(taichi_lang_core.expr_exp(expr.ptr), tb=stack_info()) @unary def log(expr): - return Expr(taichi_lang_core.expr_log(expr.ptr)) + return Expr(taichi_lang_core.expr_log(expr.ptr), tb=stack_info()) @unary def abs(expr): - return Expr(taichi_lang_core.expr_abs(expr.ptr)) + return Expr(taichi_lang_core.expr_abs(expr.ptr), tb=stack_info()) def random(dt=None): @@ -156,29 +165,61 @@ def random(dt=None): return Expr(taichi_lang_core.make_rand_expr(dt)) +@binary +def add(a, b): + return Expr(taichi_lang_core.expr_add(a.ptr, b.ptr), tb=stack_info()) + + +@binary +def sub(a, b): + return Expr(taichi_lang_core.expr_sub(a.ptr, b.ptr), tb=stack_info()) + + +@binary +def mul(a, b): + return Expr(taichi_lang_core.expr_mul(a.ptr, b.ptr), tb=stack_info()) + + +@binary +def mod(a, b): + quotient = Expr(taichi_lang_core.expr_floordiv(a.ptr, b.ptr)) + multiply = Expr(taichi_lang_core.expr_mul(b.ptr, quotient.ptr)) + return Expr(taichi_lang_core.expr_sub(a.ptr, multiply.ptr)) + + +@binary +def floordiv(a, b): + return Expr(taichi_lang_core.expr_floordiv(a.ptr, b.ptr), tb=stack_info()) + + +@binary +def truediv(a, b): + return Expr(taichi_lang_core.expr_truediv(a.ptr, b.ptr), tb=stack_info()) + + @binary def max(a, b): - return Expr(taichi_lang_core.expr_max(a.ptr, b.ptr)) + return Expr(taichi_lang_core.expr_max(a.ptr, b.ptr), tb=stack_info()) @binary def min(a, b): - return Expr(taichi_lang_core.expr_min(a.ptr, b.ptr)) + return Expr(taichi_lang_core.expr_min(a.ptr, b.ptr), tb=stack_info()) @binary def atan2(a, b): - return Expr(taichi_lang_core.expr_atan2(a.ptr, b.ptr)) + return Expr(taichi_lang_core.expr_atan2(a.ptr, b.ptr), tb=stack_info()) @binary def raw_div(a, b): - return Expr(taichi_lang_core.expr_div(a.ptr, b.ptr)) + return Expr(taichi_lang_core.expr_div(a.ptr, b.ptr), tb=stack_info()) @binary def raw_mod(a, b): - return Expr(taichi_lang_core.expr_mod(a.ptr, b.ptr)) + return Expr(taichi_lang_core.expr_mod(a.ptr, b.ptr), tb=stack_info()) def ti_max(*args): diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 762bbf36322e0..cc7c15b37adbe 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -1,6 +1,7 @@ import taichi as ti from taichi import approx from random import random, randint, seed +import operator as ops import math @@ -12,7 +13,7 @@ def rand(dtype): if ti.core.is_integral(dtype): return randint(1, 5) else: - return float(randint(1, 5)) / 5 + return float(randint(1, 5)) / 5 + 0.01 # prevent floordiv step @ti.host_arch_only @@ -82,7 +83,28 @@ def func(): assert c[None][i, j] == approx(expected) -def test_matrix_element_wise_binary(): +def stest_matrix_element_wise_unary_infix(): + seed(5156) + for n, m in [(5, 4), (3, 1)]: + _test_matrix_element_wise_unary(ti.f32, n, m, ops.neg, ops.neg) + + +def test_matrix_element_wise_binary_infix(): + seed(4399) + for n, m in [(5, 4), (3, 1)]: + _test_matrix_element_wise_binary(ti.f32, n, m, ops.add, ops.add) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.sub, ops.sub) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.mul, ops.mul) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.mod, ops.mod) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.truediv, ops.truediv) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.floordiv, ops.floordiv) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.add, ops.add) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.sub, ops.sub) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.mul, ops.mul) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.mod, ops.mod) + + +def stest_matrix_element_wise_binary(): seed(666) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.f32, n, m, ti.atan2, math.atan2) @@ -93,9 +115,10 @@ def test_matrix_element_wise_binary(): _test_matrix_element_wise_binary(ti.f32, n, m, pow, pow) _test_matrix_element_wise_binary(ti.i32, n, m, pow, pow) _test_matrix_element_wise_binary(ti.i32, n, m, ti.raw_mod, _c_mod) + # TODO: add ti.raw_div -def test_matrix_element_wise_unary(): +def stest_matrix_element_wise_unary(): seed(233) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ti.sin, math.sin) From e15eba82d3e2213bcde56e2491e11ae4e3e27386 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Sun, 10 May 2020 23:29:27 +0800 Subject: [PATCH 02/18] fix linalg (hopefully don't harm performance) --- python/taichi/lang/matrix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index e097190d39574..a694e57e5601c 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -137,8 +137,11 @@ def assign(self, other): if not isinstance(other, Matrix): other = Matrix(other) assert other.n == self.n and other.m == self.m + temp = Matrix(self.n, self.m) for i in range(self.n * self.m): - self.entries[i].assign(other.entries[i]) + temp.entries[i].assign(other.entries[i]) + for i in range(self.n * self.m): + self.entries[i].assign(temp.entries[i]) def element_wise_binary(self, foo, other): ret = Matrix(self.n, self.m) From 02d6e8d92bb6ec2f296a89a502c51e68973cd944 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Sun, 10 May 2020 23:47:14 +0800 Subject: [PATCH 03/18] better fix --- python/taichi/lang/expr.py | 8 ++++---- python/taichi/lang/matrix.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 6da9e879e9c98..6549b522afe53 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -50,6 +50,10 @@ def __neg__(self): import taichi as ti return ti.neg(self) + def __abs__(self): + import taichi as ti + return ti.abs(self) + def __add__(self, other): import taichi as ti return ti.add(self, other) @@ -341,10 +345,6 @@ def __pow__(self, power, modulo=None): else: return ret - def __abs__(self): - import taichi as ti - return ti.abs(self) - def __ti_int__(self): import taichi as ti return ti.cast(self, ti.get_runtime().default_ip) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index a694e57e5601c..8e1662e9b01a9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -137,11 +137,8 @@ def assign(self, other): if not isinstance(other, Matrix): other = Matrix(other) assert other.n == self.n and other.m == self.m - temp = Matrix(self.n, self.m) for i in range(self.n * self.m): - temp.entries[i].assign(other.entries[i]) - for i in range(self.n * self.m): - self.entries[i].assign(temp.entries[i]) + self.entries[i].assign(other.entries[i]) def element_wise_binary(self, foo, other): ret = Matrix(self.n, self.m) @@ -199,6 +196,10 @@ def __neg__(self): import taichi as ti return ti.neg(self) + def __abs__(self): + import taichi as ti + return ti.abs(self) + def __add__(self, other): import taichi as ti return ti.add(self, other) @@ -343,7 +344,7 @@ def abs(self): def trace(self): assert self.n == self.m - sum = self(0, 0) + sum = expr.Expr(self(0, 0)) for i in range(1, self.n): sum = sum + self(i, i) return sum @@ -354,8 +355,9 @@ def inverse(self): return Matrix([1 / self(0, 0)]) elif self.n == 2: inv_det = impl.expr_init(1.0 / self.determinant(self)) + # Dis: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323 return inv_det * Matrix([[self(1, 1), -self(0, 1)], - [-self(1, 0), self(0, 0)]]) + [-self(1, 0), self(0, 0)]]).variable() elif self.n == 3: n = 3 import taichi as ti From e3f8b477d6bccb6f5a03d71f21e08a6d47fe7d86 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Sun, 10 May 2020 23:53:56 +0800 Subject: [PATCH 04/18] [skip ci] fix typo --- tests/python/test_element_wise.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index cc7c15b37adbe..83a916de43732 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -83,7 +83,7 @@ def func(): assert c[None][i, j] == approx(expected) -def stest_matrix_element_wise_unary_infix(): +def test_matrix_element_wise_unary_infix(): seed(5156) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ops.neg, ops.neg) @@ -95,16 +95,19 @@ def test_matrix_element_wise_binary_infix(): _test_matrix_element_wise_binary(ti.f32, n, m, ops.add, ops.add) _test_matrix_element_wise_binary(ti.f32, n, m, ops.sub, ops.sub) _test_matrix_element_wise_binary(ti.f32, n, m, ops.mul, ops.mul) - _test_matrix_element_wise_binary(ti.i32, n, m, ops.mod, ops.mod) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.mod, ops.mod) + _test_matrix_element_wise_binary(ti.f32, n, m, ops.pow, ops.pow) _test_matrix_element_wise_binary(ti.f32, n, m, ops.truediv, ops.truediv) _test_matrix_element_wise_binary(ti.f32, n, m, ops.floordiv, ops.floordiv) _test_matrix_element_wise_binary(ti.i32, n, m, ops.add, ops.add) _test_matrix_element_wise_binary(ti.i32, n, m, ops.sub, ops.sub) _test_matrix_element_wise_binary(ti.i32, n, m, ops.mul, ops.mul) _test_matrix_element_wise_binary(ti.i32, n, m, ops.mod, ops.mod) + _test_matrix_element_wise_binary(ti.i32, n, m, ops.pow, ops.pow) + # TODO: add pow(f32, i32) -def stest_matrix_element_wise_binary(): +def test_matrix_element_wise_binary(): seed(666) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.f32, n, m, ti.atan2, math.atan2) @@ -118,7 +121,7 @@ def stest_matrix_element_wise_binary(): # TODO: add ti.raw_div -def stest_matrix_element_wise_unary(): +def test_matrix_element_wise_unary(): seed(233) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ti.sin, math.sin) From 3898cdf209c2cb53f979ea5a25723e2f7e892e13 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Sun, 10 May 2020 23:57:41 +0800 Subject: [PATCH 05/18] fix asin dom err --- tests/python/test_element_wise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 83a916de43732..58fc5419cd8d5 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -13,7 +13,7 @@ def rand(dtype): if ti.core.is_integral(dtype): return randint(1, 5) else: - return float(randint(1, 5)) / 5 + 0.01 # prevent floordiv step + return float(randint(1, 5)) / 5 - 0.01 # prevent floordiv step @ti.host_arch_only From f0bfd3843c3b8a5615175041b4ca7433d0d8a888 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 00:15:41 +0800 Subject: [PATCH 06/18] enhanced pow --- python/taichi/lang/expr.py | 37 ++++++++---------------------------- python/taichi/lang/matrix.py | 26 ++++++++----------------- python/taichi/lang/ops.py | 32 +++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 47 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 6549b522afe53..8a0f4a49fed57 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -94,6 +94,14 @@ def __mod__(self, other): import taichi as ti return ti.mod(self, other) + def __pow__(self, other, modulo=None): + import taichi as ti + return ti.pow(self, other) + + def __rpow__(self, other, modulo=None): + import taichi as ti + return ti.pow(other, self) + def __iadd__(self, other): self.atomic_add(other) @@ -316,35 +324,6 @@ def fill(self, val): from .meta import fill_tensor fill_tensor(self, val) - def __rpow__(self, power, modulo=None): - # Python will try Matrix.__pow__ first so we don't have to worry whether `power` is `Matrix` - return Expr(power).__pow__(self, modulo) - - def __pow__(self, power, modulo=None): - import taichi as ti - if ti.is_taichi_class(power): - return power.element_wise_binary(lambda x, y: pow(y, x), self) - if not isinstance(power, int) or abs(power) > 100: - return Expr(taichi_lang_core.expr_pow(self.ptr, Expr(power).ptr)) - if power == 0: - return Expr(1) - negative = power < 0 - power = abs(power) - tmp = self - ret = None - while power: - if power & 1: - if ret is None: - ret = tmp - else: - ret = ti.expr_init(ret * tmp) - tmp = ti.expr_init(tmp * tmp) - power >>= 1 - if negative: - return 1 / ret - else: - return ret - def __ti_int__(self): import taichi as ti return ti.cast(self, ti.get_runtime().default_ip) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 8e1662e9b01a9..9ef3d34351688 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -168,24 +168,6 @@ def __matmul__(self, other): ret(i, j).assign(ret(i, j) + self(i, k) * other(k, j)) return ret - @broadcast_if_scalar - def __pow__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(self(i, j)**other(i, j)) - return ret - - @broadcast_if_scalar - def __rpow__(self, other): - assert self.n == other.n and self.m == other.m - ret = Matrix(self.n, self.m) - for i in range(self.n): - for j in range(self.m): - ret(i, j).assign(other(i, j)**self(i, j)) - return ret - def broadcast(self, scalar): ret = Matrix(self.n, self.m, empty=True) for i in range(self.n * self.m): @@ -240,6 +222,14 @@ def __mod__(self, other): import taichi as ti return ti.mod(self, other) + def __pow__(self, other, modulo=None): + import taichi as ti + return ti.pow(self, other) + + def __rpow__(self, other, modulo=None): + import taichi as ti + return ti.pow(other, self) + def linearize_entry_id(self, *args): assert 1 <= len(args) <= 2 if len(args) == 1 and isinstance(args[0], (list, tuple)): diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index ef5929ece6312..83333ef3c16cb 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -187,6 +187,38 @@ def mod(a, b): return Expr(taichi_lang_core.expr_sub(a.ptr, multiply.ptr)) +@binary +def raw_pow(a, b): + return Expr(taichi_lang_core.expr_pow(a.ptr, b.ptr), tb=stack_info()) + + +def pow(self, power): + import taichi as ti + if not isinstance(power, int) or abs(power) > 100: + return raw_pow(self, power) + if power == 0: + return self * 0 + Expr(1) # TODO: rid hack, use {Expr,Matrix}.dup().fill(1) + negative = power < 0 + power = abs(power) + tmp = self + ret = None + while power: + if power & 1: + if ret is None: + ret = tmp + else: + ret = ti.expr_init(ret * tmp) + tmp = ti.expr_init(tmp * tmp) + power >>= 1 + if negative: + return 1 / ret + else: + return ret + + +# NEXT: add matpow(self, power) + + @binary def floordiv(a, b): return Expr(taichi_lang_core.expr_floordiv(a.ptr, b.ptr), tb=stack_info()) From afdd65ce84fedd830d507bd5234ec07bce140b5f Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 00:18:38 +0800 Subject: [PATCH 07/18] [skip ci] reduce pow optimization from 100 to 50 --- python/taichi/lang/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 83333ef3c16cb..cb6d32232436b 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -194,7 +194,7 @@ def raw_pow(a, b): def pow(self, power): import taichi as ti - if not isinstance(power, int) or abs(power) > 100: + if not isinstance(power, int) or abs(power) > 50: return raw_pow(self, power) if power == 0: return self * 0 + Expr(1) # TODO: rid hack, use {Expr,Matrix}.dup().fill(1) From 8696e82642867d786a1c7ccd22bc74891897d186 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 00:22:39 +0800 Subject: [PATCH 08/18] [skip ci] balance test load pressure --- tests/python/test_element_wise.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 58fc5419cd8d5..1be29e072293a 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -87,9 +87,10 @@ def test_matrix_element_wise_unary_infix(): seed(5156) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ops.neg, ops.neg) + _test_matrix_element_wise_unary(ti.i32, n, m, ops.neg, ops.neg) -def test_matrix_element_wise_binary_infix(): +def test_matrix_element_wise_binary_infix_f32(): seed(4399) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.f32, n, m, ops.add, ops.add) @@ -99,6 +100,11 @@ def test_matrix_element_wise_binary_infix(): _test_matrix_element_wise_binary(ti.f32, n, m, ops.pow, ops.pow) _test_matrix_element_wise_binary(ti.f32, n, m, ops.truediv, ops.truediv) _test_matrix_element_wise_binary(ti.f32, n, m, ops.floordiv, ops.floordiv) + + +def test_matrix_element_wise_binary_infix_i32(): + seed(6174) + for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.i32, n, m, ops.add, ops.add) _test_matrix_element_wise_binary(ti.i32, n, m, ops.sub, ops.sub) _test_matrix_element_wise_binary(ti.i32, n, m, ops.mul, ops.mul) @@ -107,21 +113,26 @@ def test_matrix_element_wise_binary_infix(): # TODO: add pow(f32, i32) -def test_matrix_element_wise_binary(): +def test_matrix_element_wise_binary_f32(): seed(666) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.f32, n, m, ti.atan2, math.atan2) _test_matrix_element_wise_binary(ti.f32, n, m, ti.min, min) - _test_matrix_element_wise_binary(ti.i32, n, m, ti.min, min) _test_matrix_element_wise_binary(ti.f32, n, m, ti.max, max) - _test_matrix_element_wise_binary(ti.i32, n, m, ti.max, max) _test_matrix_element_wise_binary(ti.f32, n, m, pow, pow) + + +def test_matrix_element_wise_binary_i32(): + seed(985) + for n, m in [(5, 4), (3, 1)]: + _test_matrix_element_wise_binary(ti.i32, n, m, ti.min, min) + _test_matrix_element_wise_binary(ti.i32, n, m, ti.max, max) _test_matrix_element_wise_binary(ti.i32, n, m, pow, pow) _test_matrix_element_wise_binary(ti.i32, n, m, ti.raw_mod, _c_mod) # TODO: add ti.raw_div -def test_matrix_element_wise_unary(): +def test_matrix_element_wise_unary_1(): seed(233) for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ti.sin, math.sin) @@ -131,6 +142,11 @@ def test_matrix_element_wise_unary(): _test_matrix_element_wise_unary(ti.f32, n, m, ti.acos, math.acos) _test_matrix_element_wise_unary(ti.f32, n, m, ti.tanh, math.tanh) _test_matrix_element_wise_unary(ti.f32, n, m, ti.sqrt, math.sqrt) + + +def test_matrix_element_wise_unary_2(): + seed(211) + for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_unary(ti.f32, n, m, ti.exp, math.exp) _test_matrix_element_wise_unary(ti.f32, n, m, ti.log, math.log) _test_matrix_element_wise_unary(ti.f32, n, m, ti.ceil, math.ceil) From 61c6825873e2b54bc7cc5a456ef9fb8c1ebf6f1e Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 00:34:38 +0800 Subject: [PATCH 09/18] share common_ops between Expr and Matrix --- python/taichi/lang/common_ops.py | 56 ++++++++++++++++++++++++++++++ python/taichi/lang/expr.py | 3 +- python/taichi/lang/matrix.py | 59 ++------------------------------ 3 files changed, 60 insertions(+), 58 deletions(-) create mode 100644 python/taichi/lang/common_ops.py diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py new file mode 100644 index 0000000000000..821cf4e219eb9 --- /dev/null +++ b/python/taichi/lang/common_ops.py @@ -0,0 +1,56 @@ +class TaichiOperations: + def __neg__(self): + import taichi as ti + return ti.neg(self) + + def __abs__(self): + import taichi as ti + return ti.abs(self) + + def __add__(self, other): + import taichi as ti + return ti.add(self, other) + + __radd__ = __add__ + + def __sub__(self, other): + import taichi as ti + return ti.sub(self, other) + + def __rsub__(self, other): + import taichi as ti + return ti.sub(other, self) + + def __mul__(self, other): + import taichi as ti + return ti.mul(self, other) + + __rmul__ = __mul__ + + def __truediv__(self, other): + import taichi as ti + return ti.truediv(self, other) + + def __rtruediv__(self, other): + import taichi as ti + return ti.truediv(other, self) + + def __floordiv__(self, other): + import taichi as ti + return ti.floordiv(self, other) + + def __rfloordiv__(self, other): + import taichi as ti + return ti.floordiv(other, self) + + def __mod__(self, other): + import taichi as ti + return ti.mod(self, other) + + def __pow__(self, other, modulo=None): + import taichi as ti + return ti.pow(self, other) + + def __rpow__(self, other, modulo=None): + import taichi as ti + return ti.pow(other, self) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 8a0f4a49fed57..072fc3dffd18b 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -1,10 +1,11 @@ from .core import taichi_lang_core from .util import * +from .common_ops import TaichiOperations import traceback # Scalar, basic data type -class Expr: +class Expr(TaichiOperations): materialize_layout_callback = None layout_materialized = False diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 9ef3d34351688..3451f047252a5 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -4,6 +4,7 @@ import numbers import numpy as np from .util import to_numpy_type, to_pytorch_type +from .common_ops import TaichiOperations def broadcast_if_scalar(func): @@ -15,7 +16,7 @@ def broadcasted(self, other, *args, **kwargs): return broadcasted -class Matrix: +class Matrix(TaichiOperations): is_taichi_class = True def __init__(self, @@ -174,62 +175,6 @@ def broadcast(self, scalar): ret.entries[i] = scalar return ret - def __neg__(self): - import taichi as ti - return ti.neg(self) - - def __abs__(self): - import taichi as ti - return ti.abs(self) - - def __add__(self, other): - import taichi as ti - return ti.add(self, other) - - __radd__ = __add__ - - def __sub__(self, other): - import taichi as ti - return ti.sub(self, other) - - def __rsub__(self, other): - import taichi as ti - return ti.sub(other, self) - - def __mul__(self, other): - import taichi as ti - return ti.mul(self, other) - - __rmul__ = __mul__ - - def __truediv__(self, other): - import taichi as ti - return ti.truediv(self, other) - - def __rtruediv__(self, other): - import taichi as ti - return ti.truediv(other, self) - - def __floordiv__(self, other): - import taichi as ti - return ti.floordiv(self, other) - - def __rfloordiv__(self, other): - import taichi as ti - return ti.floordiv(other, self) - - def __mod__(self, other): - import taichi as ti - return ti.mod(self, other) - - def __pow__(self, other, modulo=None): - import taichi as ti - return ti.pow(self, other) - - def __rpow__(self, other, modulo=None): - import taichi as ti - return ti.pow(other, self) - def linearize_entry_id(self, *args): assert 1 <= len(args) <= 2 if len(args) == 1 and isinstance(args[0], (list, tuple)): From bc6a35a93f4732923b0b6ec2aa5f7bd7321deaaf Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 00:36:44 +0800 Subject: [PATCH 10/18] [skip ci] add comments about the l-value problem --- python/taichi/lang/matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 3451f047252a5..f7155c4428e4a 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -43,12 +43,14 @@ def __init__(self, assert row.n == rows[ 0].n, "input vectors must be the same shape" self.m = rows[0].n + # l-value copy: self.entries = [row(i) for row in rows for i in range(row.n)] elif isinstance(rows[0], list): for row in rows: assert len(row) == len( rows[0]), "input lists must be the same shape" self.m = len(rows[0]) + # l-value copy: self.entries = [x for row in rows for x in row] else: raise Exception( From a92e37e0be66d3f0a31a10ded039277534c092ca Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 12:16:06 +0800 Subject: [PATCH 11/18] [skip ci] fix opengl pow test by adding fast_pow for rhs is int --- taichi/backends/opengl/codegen_opengl.cpp | 16 ++++++++ taichi/backends/opengl/opengl_kernel_util.h | 1 + .../backends/opengl/shaders/fast_pow.glsl.h | 38 +++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 taichi/backends/opengl/shaders/fast_pow.glsl.h diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index dc651e11db66a..8e977f87b05a8 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -169,6 +169,12 @@ class KernelGen : public IRVisitor { ); } // }}} + if (used.fast_pow) { + kernel_header += ( +#include "taichi/backends/opengl/shaders/fast_pow.glsl.h" + ); + } + line_appender_header_.append_raw(kernel_header); int threads_per_group = opengl_get_threads_per_group(); @@ -407,6 +413,16 @@ class KernelGen : public IRVisitor { emit("{} {} = atan({}, {});", dt_name, bin_name, lhs_name, rhs_name); } return; + } else if (bin->op_type == BinaryOpType::pow + && is_integral(bin->rhs->element_type())) { + // The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains 124 + // So that we have to use some hack to make it percise. + // Dis: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902 + emit("{} {} = {}(fast_pow_{}({}, {}));", dt_name, bin_name, dt_name, + data_type_short_name(bin->lhs->element_type()), + lhs_name, rhs_name); + used.fast_pow = true; + return; } const auto binop = binary_op_type_symbol(bin->op_type); if (is_opengl_binary_op_infix(bin->op_type)) { diff --git a/taichi/backends/opengl/opengl_kernel_util.h b/taichi/backends/opengl/opengl_kernel_util.h index b9388014b8b76..ee9f18e746438 100644 --- a/taichi/backends/opengl/opengl_kernel_util.h +++ b/taichi/backends/opengl/opengl_kernel_util.h @@ -20,6 +20,7 @@ struct UsedFeature { bool simulated_atomic_float{false}; bool int64{false}; bool global_temp{false}; + bool fast_pow{false}; }; struct StructCompiledResult { diff --git a/taichi/backends/opengl/shaders/fast_pow.glsl.h b/taichi/backends/opengl/shaders/fast_pow.glsl.h new file mode 100644 index 0000000000000..051ad939ce6eb --- /dev/null +++ b/taichi/backends/opengl/shaders/fast_pow.glsl.h @@ -0,0 +1,38 @@ +// vim: ft=glsl +// clang-format off +#include "taichi/util/macros.h" +STR( +int fast_pow_i32(int x, int y) +{ + if (y > 512) + return int(pow(x, y)); + + bool neg = y < 0; + y = abs(y); + int ret = 1; + while (y != 0) { + if ((y & 1) != 0) + ret *= x; + x *= x; + y >>= 1; + } + return neg ? 1 / ret : ret; +} + +float fast_pow_f32(float x, int y) +{ + if (y > 512) + return pow(x, y); + + bool neg = y < 0; + y = abs(y); + float ret = 1.0; + while (y != 0) { + if ((y & 1) != 0) + ret *= x; + x *= x; + y >>= 1; + } + return neg ? 1.0 / ret : ret; +} +) From 52c3903bc235a1affab0f4c799a9c7ff3beffdde Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 12:49:19 +0800 Subject: [PATCH 12/18] [skip ci] add comment --- python/taichi/lang/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index cb6d32232436b..7c8c2d41819c7 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -192,6 +192,7 @@ def raw_pow(a, b): return Expr(taichi_lang_core.expr_pow(a.ptr, b.ptr), tb=stack_info()) +# TODO: move this to a C++ pass (#944) def pow(self, power): import taichi as ti if not isinstance(power, int) or abs(power) > 50: From 073cef122c96fa8fc83228f25175b370685394a8 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 22:26:27 +0800 Subject: [PATCH 13/18] [skip ci] nit comment --- python/taichi/lang/ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 7c8c2d41819c7..d55786a74696b 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -198,7 +198,9 @@ def pow(self, power): if not isinstance(power, int) or abs(power) > 50: return raw_pow(self, power) if power == 0: - return self * 0 + Expr(1) # TODO: rid hack, use {Expr,Matrix}.dup().fill(1) + # TODO: remove the hack, use {Expr,Matrix}.dup().fill(1) + # also note that this can be solved by #940 + return self * 0 + Expr(1) negative = power < 0 power = abs(power) tmp = self From ede36659ab83d54d5841d543f732f2a9014c1c09 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 22:27:12 +0800 Subject: [PATCH 14/18] delete --- python/taichi/lang/expr.py | 56 -------------------------------------- 1 file changed, 56 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 072fc3dffd18b..0255e3ba11f84 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -47,62 +47,6 @@ def stack_info(): # remove the confusing last line return '\n'.join(raw.split('\n')[:-3]) + '\n' - def __neg__(self): - import taichi as ti - return ti.neg(self) - - def __abs__(self): - import taichi as ti - return ti.abs(self) - - def __add__(self, other): - import taichi as ti - return ti.add(self, other) - - __radd__ = __add__ - - def __sub__(self, other): - import taichi as ti - return ti.sub(self, other) - - def __rsub__(self, other): - import taichi as ti - return ti.sub(other, self) - - def __mul__(self, other): - import taichi as ti - return ti.mul(self, other) - - __rmul__ = __mul__ - - def __truediv__(self, other): - import taichi as ti - return ti.truediv(self, other) - - def __rtruediv__(self, other): - import taichi as ti - return ti.truediv(other, self) - - def __floordiv__(self, other): - import taichi as ti - return ti.floordiv(self, other) - - def __rfloordiv__(self, other): - import taichi as ti - return ti.floordiv(other, self) - - def __mod__(self, other): - import taichi as ti - return ti.mod(self, other) - - def __pow__(self, other, modulo=None): - import taichi as ti - return ti.pow(self, other) - - def __rpow__(self, other, modulo=None): - import taichi as ti - return ti.pow(other, self) - def __iadd__(self, other): self.atomic_add(other) From c06a13f83367e429fa9ca4122fc57037318e0d01 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Mon, 11 May 2020 23:23:46 +0800 Subject: [PATCH 15/18] add comment --- python/taichi/lang/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 0255e3ba11f84..64a13cac23e7a 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -74,7 +74,7 @@ def __ior__(self, other): def __ixor__(self, other): self.atomic_xor(other) - # TODO: ti.cmp_le + # TODO: move to ops.py: ti.cmp_le def __le__(self, other): other = Expr(other) return Expr(taichi_lang_core.expr_cmp_le(self.ptr, other.ptr)) From db165fa9be8eee363cf48954406ca280fa53e865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E4=BA=8E=E6=96=8C?= <1931127624@qq.com> Date: Tue, 12 May 2020 00:57:01 +0800 Subject: [PATCH 16/18] [skip ci] Apply suggestions from code review Co-authored-by: Yuanming Hu --- python/taichi/lang/common_ops.py | 8 ++++++-- python/taichi/lang/matrix.py | 2 +- taichi/backends/opengl/codegen_opengl.cpp | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index 821cf4e219eb9..c8743cfe62254 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -11,7 +11,9 @@ def __add__(self, other): import taichi as ti return ti.add(self, other) - __radd__ = __add__ + def __radd__(self, other): + import taichi as ti + return ti.add(other, self) def __sub__(self, other): import taichi as ti @@ -25,7 +27,9 @@ def __mul__(self, other): import taichi as ti return ti.mul(self, other) - __rmul__ = __mul__ + def __rmul__(self, other): + import taichi as ti + return ti.mul(other, self) def __truediv__(self, other): import taichi as ti diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index f7155c4428e4a..43d1e3d14ef6a 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -292,7 +292,7 @@ def inverse(self): return Matrix([1 / self(0, 0)]) elif self.n == 2: inv_det = impl.expr_init(1.0 / self.determinant(self)) - # Dis: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323 + # Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323 return inv_det * Matrix([[self(1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]]).variable() elif self.n == 3: diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 8e977f87b05a8..0c24622ec5a0d 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -417,7 +417,7 @@ class KernelGen : public IRVisitor { && is_integral(bin->rhs->element_type())) { // The GLSL `pow` is not so percise for `int`... e.g.: `pow(5, 3)` obtains 124 // So that we have to use some hack to make it percise. - // Dis: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902 + // Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626354902 emit("{} {} = {}(fast_pow_{}({}, {}));", dt_name, bin_name, dt_name, data_type_short_name(bin->lhs->element_type()), lhs_name, rhs_name); From 58daa5d422c3245f6fda040b1b630a3eaf2b3626 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Tue, 12 May 2020 09:38:15 +0800 Subject: [PATCH 17/18] nit comment --- tests/python/test_element_wise.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 1be29e072293a..57dcd88978994 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -13,7 +13,9 @@ def rand(dtype): if ti.core.is_integral(dtype): return randint(1, 5) else: - return float(randint(1, 5)) / 5 - 0.01 # prevent floordiv step + # Prevent integer operands in pow and floordiv in GLSL + # Discussion: https://github.com/taichi-dev/taichi/pull/943#discussion_r423177941 + return float(randint(1, 5)) / 5 - 0.01 @ti.host_arch_only @@ -119,7 +121,7 @@ def test_matrix_element_wise_binary_f32(): _test_matrix_element_wise_binary(ti.f32, n, m, ti.atan2, math.atan2) _test_matrix_element_wise_binary(ti.f32, n, m, ti.min, min) _test_matrix_element_wise_binary(ti.f32, n, m, ti.max, max) - _test_matrix_element_wise_binary(ti.f32, n, m, pow, pow) + _test_matrix_element_wise_binary(ti.f32, n, m, ti.pow, pow) def test_matrix_element_wise_binary_i32(): @@ -127,7 +129,7 @@ def test_matrix_element_wise_binary_i32(): for n, m in [(5, 4), (3, 1)]: _test_matrix_element_wise_binary(ti.i32, n, m, ti.min, min) _test_matrix_element_wise_binary(ti.i32, n, m, ti.max, max) - _test_matrix_element_wise_binary(ti.i32, n, m, pow, pow) + _test_matrix_element_wise_binary(ti.i32, n, m, ti.pow, pow) _test_matrix_element_wise_binary(ti.i32, n, m, ti.raw_mod, _c_mod) # TODO: add ti.raw_div From 78ce3db51dea2f26074ff95a5eadcd3d95aafb79 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Tue, 12 May 2020 10:22:02 +0800 Subject: [PATCH 18/18] fix abs(power) --- python/taichi/lang/ops.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index d55786a74696b..bb712cbf1113b 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -195,14 +195,19 @@ def raw_pow(a, b): # TODO: move this to a C++ pass (#944) def pow(self, power): import taichi as ti - if not isinstance(power, int) or abs(power) > 50: + if not isinstance(power, int): return raw_pow(self, power) if power == 0: # TODO: remove the hack, use {Expr,Matrix}.dup().fill(1) # also note that this can be solved by #940 return self * 0 + Expr(1) + negative = power < 0 - power = abs(power) + # Why not simply use `power = abs(power)`? + # Because `abs` is overrided by the `ti.abs` above. + if negative: + power = -power + tmp = self ret = None while power: @@ -213,6 +218,7 @@ def pow(self, power): ret = ti.expr_init(ret * tmp) tmp = ti.expr_init(tmp * tmp) power >>= 1 + if negative: return 1 / ret else: