Skip to content

Commit

Permalink
[bug] Fix optimization of exp with negative exponent (#8398)
Browse files Browse the repository at this point in the history
Issue: fixes #8269 

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 7a8b410</samp>

Fix the return type of exponent statements that are simplified by
`alg_simp` and add a test case for this optimization. This improves the
correctness and performance of code generation for exponent operations.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 7a8b410</samp>

* Simplify negative exponents by using reciprocal function
([link](https://github.com/taichi-dev/taichi/pull/8398/files?diff=unified&w=0#diff-77d8ca8e4dc6081988bd6dddb74bb9a5485af28ce3e0b43bc06d123256695513R282),
[link](https://github.com/taichi-dev/taichi/pull/8398/files?diff=unified&w=0#diff-77d8ca8e4dc6081988bd6dddb74bb9a5485af28ce3e0b43bc06d123256695513R288))
- Set return type of new exponent statement to match original right-hand
side
([link](https://github.com/taichi-dev/taichi/pull/8398/files?diff=unified&w=0#diff-77d8ca8e4dc6081988bd6dddb74bb9a5485af28ce3e0b43bc06d123256695513R282))
- Set return type of result statement to match original exponent
statement
([link](https://github.com/taichi-dev/taichi/pull/8398/files?diff=unified&w=0#diff-77d8ca8e4dc6081988bd6dddb74bb9a5485af28ce3e0b43bc06d123256695513R288))
* Add test case for negative exponent simplification in
`test_optimization.py`
([link](https://github.com/taichi-dev/taichi/pull/8398/files?diff=unified&w=0#diff-b8b031f0789413acece482512df4af5b8419a2a2dea3624b26114bbb9b57d334L156-R169))

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Oct 31, 2023
1 parent ae83197 commit 1ae0e46
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,13 @@ class AlgSimp : public BasicStmtVisitor {

cast_to_result_type(one, stmt);
auto new_exponent = Stmt::make<UnaryOpStmt>(UnaryOpType::neg, stmt->rhs);
new_exponent->ret_type = stmt->rhs->ret_type;
auto a_to_n = Stmt::make<BinaryOpStmt>(BinaryOpType::pow, stmt->lhs,
new_exponent.get());
a_to_n->ret_type = stmt->ret_type;
auto result =
Stmt::make<BinaryOpStmt>(BinaryOpType::div, one, a_to_n.get());
result->ret_type = stmt->ret_type;
stmt->replace_usages_with(result.get());
modifier.insert_before(stmt, std::move(new_exponent));
modifier.insert_before(stmt, std::move(a_to_n));
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,19 @@ def my_cast(x: ti.f32) -> ti.u32:
return ti.cast(y, ti.u32)

assert my_cast(-1) == 4294967295


@test_utils.test()
def test_negative_exp():
@ti.dataclass
class Particle:
epsilon: ti.f32

@ti.kernel
def test() -> ti.f32:
p1 = Particle()
p1.epsilon = 1.0
e = p1.epsilon
return e**-1

assert test() == 1.0

0 comments on commit 1ae0e46

Please sign in to comment.