Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][refactor] Convert loop_var into LoopIndexStmt #953

Merged
merged 22 commits into from
May 14, 2020

Conversation

xumingkuan
Copy link
Contributor

Related PR = #932 (comment)

There are more places than I thought that need to be refactored...

[Click here for the format server]

@xumingkuan
Copy link
Contributor Author

The offload pass and all backends still need to be modified.

If we put this pass before demote_dense_struct_fors, it also needs to be modified.

I think ideally we should put this pass right after lower (lower_ast). WDYT?

@@ -217,19 +217,20 @@ class OffloadedStmt : public Stmt {

class LoopIndexStmt : public Stmt {
public:
Stmt *loop;
int index;
bool is_struct_for;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the field is_struct_for would be unnecessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please feel free to remove that.

@yuanming-hu
Copy link
Member

The offload pass and all backends still need to be modified.

If we put this pass before demote_dense_struct_fors, it also needs to be modified.

I think ideally we should put this pass right after lower (lower_ast). WDYT?

You mean moving offload to right after lower? I'm afraid that will lead to some issues in autodiff.

@xumingkuan
Copy link
Contributor Author

The offload pass and all backends still need to be modified.
If we put this pass before demote_dense_struct_fors, it also needs to be modified.
I think ideally we should put this pass right after lower (lower_ast). WDYT?

You mean moving offload to right after lower? I'm afraid that will lead to some issues in autodiff.

Oh, I mean moving irpass::convert_into_loop_index(ir); to right after lower.

@yuanming-hu
Copy link
Member

Oh, I see. That sounds good. Actually, let's make it part of lower.

void lower(IRNode *root) {
auto offsets = LowerAST::run(root);
FixStructForOffsets::run(root, offsets);
}

@xumingkuan
Copy link
Contributor Author

Why is this?

taichi/taichi/ir/ir.cpp

Lines 243 to 246 in e0ef399

void Stmt::replace_with(Stmt *new_stmt) {
auto root = get_ir_root();
irpass::replace_all_usages_with(root, this, new_stmt);
// Note: the current structure should have been destroyed now..

I think I need

irpass::replace_all_usages_with(body.get(), old_loop_vars[i], alloca);

to not throw exceptions. I'm changing the above line to irpass::replace_statements_with(...) and trying to add an argument for irpass::replace_statements_with so as to not throw exceptions, but I'm getting Windows fatal exception: access violation.

@yuanming-hu
Copy link
Member

I have no idea on this. get_ir_root will follow the parent field of blocks, which might have been nullptr - could this be the cause?

@xumingkuan
Copy link
Contributor Author

I have no idea on this. get_ir_root will follow the parent field of blocks, which might have been nullptr - could this be the cause?

Very probably...

@xumingkuan
Copy link
Contributor Author

I wonder in which case this is false:

if (loop_var->ret_type.data_type == DataType::i32) {

Since there are no loop_vars after this PR, I'm afraid that the information on data type will be lost.

@yuanming-hu
Copy link
Member

Thanks for pointing that out. Shall we simply add a rule to enforce the data type of loop variables to be i32? In the future maybe i64 should also be allowed. Maybe the data type should be a member of the for statements so that LoopVarStmt can infer from it.

@xumingkuan
Copy link
Contributor Author

Thanks for pointing that out. Shall we simply add a rule to enforce the data type of loop variables to be i32? In the future maybe i64 should also be allowed. Maybe the data type should be a member of the for statements so that LoopVarStmt can infer from it.

I think in LLVM and CUDA backends, the loop variables are i32 by default. So I'm going to make them i32 in all backends by default for now.

@yuanming-hu
Copy link
Member

Thanks for pointing that out. Shall we simply add a rule to enforce the data type of loop variables to be i32? In the future maybe i64 should also be allowed. Maybe the data type should be a member of the for statements so that LoopVarStmt can infer from it.

I think in LLVM and CUDA backends, the loop variables are i32 by default. So I'm going to make them i32 in all backends by default for now.

Sounds good. i32 should work for a while.

Comment on lines -20 to 21
TI_ASSERT(total_bits <= 31);
TI_ASSERT(total_bits <= 30);

auto upper_bound = 1 << total_bits;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 << 31 looks dangerous. So I suppose not supporting tensors of size 2^31 for now.

@xumingkuan xumingkuan marked this pull request as ready for review May 14, 2020 00:13
@xumingkuan
Copy link
Contributor Author

TODO: Remove RangeForStmt::loop_var and StructForStmt::loop_vars. Maybe in the next PR.

Copy link
Member

@yuanming-hu yuanming-hu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thank you so much for refactoring this. The LLVM backend modifications look good to me. (I didn't look at the OpenGL and Metal backends.)

Btw, would it be possible to remove RangeForStmt::loop_var and StructForStmt::loop_vars in the future? How easy/hard is that? Sorry I missed your last message. Great work!

@xumingkuan
Copy link
Contributor Author

Looks great! Thank you so much for refactoring this. The LLVM backend modifications look good to me. (I didn't look at the OpenGL and Metal backends.)

Btw, would it be possible to remove RangeForStmt::loop_var and StructForStmt::loop_vars in the future? How easy/hard is that?

After lower_ast, it's easy -- I've already set them to nullptrs. I think only StructForStmt::loop_vars->size() contains some information now, but it should be the same as StructForStmt::snode->num_active_indices.

But during lower_ast, they are used temporarily, and maybe we need an std::unordered_map.

I wonder if it's feasible to create LoopIndexStmts directly in LowerAST::visit(FrontendForStmt *stmt). In this way, we can even obviate the need of eliminating identical LoopIndexStmts in simplify. (Because we don't eliminate them but we do eliminate identical local loads now, the performance is impaired.)
benchmark20200513

@yuanming-hu
Copy link
Member

I wonder if it's feasible to create LoopIndexStmts directly in LowerAST::visit(FrontendForStmt *stmt). In this way, we can even obviate the need of eliminating identical LoopIndexStmts in simplify. (Because we don't eliminate them but we do eliminate identical local loads now, the performance is impaired.)

My feeling is that this is doable, but probably something in transformer.py needs to be changed. You can actually directly insert LoopIndexStmt in transformer.py and fix the loop pointers later in lower_ast.

@xumingkuan
Copy link
Contributor Author

I wonder if it's feasible to create LoopIndexStmts directly in LowerAST::visit(FrontendForStmt *stmt). In this way, we can even obviate the need of eliminating identical LoopIndexStmts in simplify. (Because we don't eliminate them but we do eliminate identical local loads now, the performance is impaired.)

My feeling is that this is doable, but probably something in transformer.py needs to be changed. You can actually directly insert LoopIndexStmt in transformer.py and fix the loop pointers later in lower_ast.

I would prefer not touching transformer.py, as there are too many kinds of "frontend for"s... As the allocas are inserted in LowerAST::visit(FrontendForStmt *stmt), I think maybe I can replace them with LoopIndexStmts.

Currently, for this IR,

kernel {
  $0 : for @tmp1, @tmp2, @tmp3 where S2place_i32 active {
    $1 = alloca @tmp7
    @tmp7 = @tmp1
  }
}

after LowerAST::run, we have

kernel {
  <i32 x1> $99 = alloca
  <i32 x1> $100 = alloca
  <i32 x1> $101 = alloca
  $102 : for $99, $100, $101 where S1dense active, step 1 {
    $103 = alloca
    $104 = local load [ [$99[0]]]
    $105 : local store [$103 <- $104]
  }
}

Shall we modify

void flatten(FlattenContext *ctx) override {
ctx->push_back(std::make_unique<LocalLoadStmt>(
LocalAddress(ctx->current_block->lookup_var(id), 0)));
stmt = ctx->back_stmt();
}

and Stmt *Block::lookup_var(const Identifier &ident) to treat loop vars specifically?

@yuanming-hu
Copy link
Member

Sounds good. Feel free to pick the most convenient solution for now (keep in mind the presence of #925)...

@xumingkuan
Copy link
Contributor Author

Feel free to pick the most convenient solution for now

Shall I do it in this PR or the next?

keep in mind the presence of #925

I think even if lower_ast is in python, dealing with for loops in lower_ast will still be easier than in transformer.

@xumingkuan xumingkuan changed the title [IR][Refactor] Convert loop_var into LoopIndexStmt [IR][refactor] Convert loop_var into LoopIndexStmt May 14, 2020
@yuanming-hu
Copy link
Member

Feel free to pick the most convenient solution for now

Shall I do it in this PR or the next?

Next PR sounds good.

keep in mind the presence of #925

I think even if lower_ast is in python, dealing with for loops in lower_ast will still be easier than in transformer.

That's true. Testing if a variable is a loop variable is easier in AST than in transformer.py.

Copy link
Collaborator

@archibate archibate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenGL is OK, thank you!

@archibate
Copy link
Collaborator

codegen_kernel_statements                    194 ->   225   +16.0%
codegen_kernel_statements                     47 ->    55   +17.0%
codegen_kernel_statements                     53 ->    58    +9.4%
codegen_kernel_statements                     26 ->     0  -100.0%
codegen_kernel_statements                     57 ->   150  +163.2%
codegen_kernel_statements                     15 ->    16    +6.7%
codegen_kernel_statements                    135 ->     0  -100.0%
codegen_kernel_statements                     72 ->   146  +102.8%
codegen_kernel_statements                     16 ->    27   +68.8%
codegen_kernel_statements                     21 ->    40   +90.5%
codegen_kernel_statements                     17 ->    19   +11.8%
codegen_kernel_statements                     37 ->     0  -100.0%
codegen_kernel_statements                     37 ->    38    +2.7%
codegen_kernel_statements                      0 ->    72    +inf%
codegen_kernel_statements                     18 ->    20   +11.1%
codegen_kernel_statements                     24 ->    10   -58.3%
codegen_kernel_statements                     53 ->    28   -47.2%
codegen_kernel_statements                    135 ->   150   +11.1%
codegen_kernel_statements                     74 ->    62   -16.2%
codegen_kernel_statements                     40 ->    39    -2.5%
codegen_kernel_statements                     53 ->    17   -67.9%
codegen_kernel_statements                     24 ->    12   -50.0%
codegen_kernel_statements                     22 ->   136  +518.2%
codegen_kernel_statements                     17 ->    19   +11.8%
codegen_kernel_statements                     16 ->    12   -25.0%
codegen_kernel_statements                     50 ->    51    +2.0%
codegen_kernel_statements                     32 ->    16   -50.0%
codegen_kernel_statements                      0 ->    11    +inf%
codegen_kernel_statements                     16 ->    18   +12.5%
codegen_kernel_statements                     18 ->    53  +194.4%
codegen_kernel_statements                     20 ->    34   +70.0%
codegen_kernel_statements                     27 ->    18   -33.3%
codegen_kernel_statements                     34 ->    36    +5.9%
codegen_kernel_statements                     27 ->    29    +7.4%
codegen_kernel_statements                     12 ->    23   +91.7%
codegen_kernel_statements                     14 ->    85  +507.1%
codegen_kernel_statements                     37 ->    38    +2.7%
codegen_kernel_statements                    158 ->     0  -100.0%
codegen_kernel_statements                     37 ->    38    +2.7%
codegen_kernel_statements                     60 ->    66   +10.0%
codegen_kernel_statements                     15 ->     0  -100.0%
codegen_kernel_statements                     42 ->    44    +4.8%
codegen_kernel_statements                     16 ->    18   +12.5%
codegen_kernel_statements                      3 ->    52 +1633.3%
codegen_kernel_statements                     12 ->    65  +441.7%
codegen_kernel_statements                   2789 ->    11   -99.6%
codegen_kernel_statements                     50 ->    37   -26.0%
codegen_kernel_statements                   2085 ->  2084    -0.0%
codegen_kernel_statements                      0 ->    22    +inf%
codegen_kernel_statements                     69 ->     3   -95.7%
codegen_kernel_statements                     12 ->    51  +325.0%
codegen_kernel_statements                   2788 ->  2787    -0.0%
codegen_kernel_statements                     71 ->    83   +16.9%
codegen_kernel_statements                     38 ->    41    +7.9%
codegen_kernel_statements                     16 ->    18   +12.5%
codegen_kernel_statements                     28 ->   158  +464.3%
codegen_kernel_statements                     17 ->    18    +5.9%
codegen_kernel_statements                      0 ->  2788    +inf%
codegen_kernel_statements                     38 ->    53   +39.5%
codegen_kernel_statements                     52 ->     0  -100.0%
codegen_kernel_statements                     23 ->    24    +4.3%
codegen_kernel_statements                     76 ->    70    -7.9%
codegen_kernel_statements                     17 ->    18    +5.9%
codegen_kernel_statements                     10 ->    32  +220.0%
codegen_kernel_statements                     17 ->    24   +41.2%
codegen_kernel_statements                     48 ->    16   -66.7%
codegen_kernel_statements                     68 ->    37   -45.6%
codegen_kernel_statements                     38 ->    37    -2.6%
codegen_kernel_statements                      4 ->    50 +1150.0%
codegen_kernel_statements                    150 ->   148    -1.3%
codegen_kernel_statements                     20 ->    19    -5.0%
codegen_kernel_statements                   2727 ->  2726    -0.0%
codegen_kernel_statements                      4 ->    16  +300.0%
codegen_kernel_statements                     18 ->    17    -5.6%
codegen_kernel_statements                     64 ->    65    +1.6%
codegen_kernel_statements                     16 ->   101  +531.2%
codegen_kernel_statements                     17 ->    23   +35.3%
codegen_kernel_statements                     18 ->    24   +33.3%
codegen_kernel_statements                     34 ->    35    +2.9%
codegen_kernel_statements                     28 ->    22   -21.4%
codegen_kernel_statements                     20 ->    12   -40.0%
codegen_kernel_statements                     16 ->    52  +225.0%
codegen_kernel_statements                     10 ->    72  +620.0%
codegen_kernel_statements                     55 ->     0  -100.0%
codegen_kernel_statements                     20 ->    21    +5.0%
codegen_kernel_statements                     76 ->    69    -9.2%
codegen_kernel_statements                     28 ->    29    +3.6%
codegen_kernel_statements                     34 ->     0  -100.0%
codegen_kernel_statements                     18 ->    19    +5.6%
codegen_kernel_statements                     38 ->    10   -73.7%
codegen_kernel_statements                     70 ->    56   -20.0%
codegen_kernel_statements                     38 ->    36    -5.3%
codegen_kernel_statements                     34 ->    48   +41.2%
codegen_kernel_statements                     61 ->   158  +159.0%
codegen_kernel_statements                     16 ->    18   +12.5%
codegen_kernel_statements                     56 ->     4   -92.9%
codegen_kernel_statements                     28 ->    10   -64.3%
codegen_kernel_statements                      0 ->    52    +inf%
codegen_kernel_statements                     20 ->    34   +70.0%
codegen_kernel_statements                     29 ->    33   +13.8%
codegen_kernel_statements                      0 ->    12    +inf%
codegen_kernel_statements                     32 ->    56   +75.0%
codegen_kernel_statements                     45 ->    52   +15.6%
codegen_kernel_statements                    100 ->   117   +17.0%
codegen_kernel_statements                      4 ->    22  +450.0%
codegen_kernel_statements                    148 ->    38   -74.3%
codegen_kernel_statements                     12 ->    51  +325.0%
codegen_kernel_statements                     22 ->    28   +27.3%
codegen_kernel_statements                     20 ->    26   +30.0%
codegen_kernel_statements                     23 ->    27   +17.4%
codegen_kernel_statements                     24 ->     0  -100.0%
codegen_kernel_statements                     16 ->    17    +6.2%
codegen_kernel_statements                     41 ->    43    +4.9%
codegen_kernel_statements                     62 ->    27   -56.5%
codegen_kernel_statements                     26 ->     4   -84.6%
codegen_kernel_statements                      0 ->    53    +inf%
codegen_kernel_statements                     14 ->    24   +71.4%
codegen_kernel_statements                     26 ->    14   -46.2%
codegen_kernel_statements                     69 ->   116   +68.1%
codegen_kernel_statements                     19 ->    57  +200.0%
codegen_kernel_statements                     51 ->    62   +21.6%
codegen_kernel_statements                    112 ->    78   -30.4%
codegen_kernel_statements                     28 ->    48   +71.4%
codegen_kernel_statements                     36 ->     0  -100.0%
codegen_kernel_statements                     63 ->    64    +1.6%
codegen_kernel_statements                    159 ->   261   +64.2%
codegen_kernel_statements                     62 ->    18   -71.0%
codegen_kernel_statements                     62 ->    28   -54.8%
codegen_kernel_statements                     62 ->    61    -1.6%
codegen_kernel_statements                     24 ->    26    +8.3%
codegen_kernel_statements                     18 ->    19    +5.6%
codegen_kernel_statements                      0 ->    28    +inf%
codegen_kernel_statements                     18 ->    17    -5.6%
codegen_kernel_statements                     40 ->    45   +12.5%
codegen_kernel_statements                    186 ->   225   +21.0%
codegen_kernel_statements                    158 ->     0  -100.0%
codegen_kernel_statements                     52 ->    70   +34.6%
codegen_kernel_statements                     10 ->    34  +240.0%
codegen_kernel_statements                     18 ->    17    -5.6%
codegen_kernel_statements                     37 ->    38    +2.7%
codegen_kernel_statements                     41 ->    43    +4.9%
codegen_kernel_statements                     14 ->    13    -7.1%
codegen_kernel_statements                      0 ->    15    +inf%
codegen_kernel_statements                     56 ->    15   -73.2%
codegen_kernel_statements                     11 ->    27  +145.5%
codegen_kernel_statements                      0 ->   148    +inf%
codegen_kernel_statements                     16 ->    46  +187.5%
codegen_kernel_statements                    112 ->    78   -30.4%
codegen_kernel_statements                     11 ->    18   +63.6%
codegen_kernel_statements                     23 ->    65  +182.6%
codegen_kernel_statements                     90 ->    32   -64.4%
codegen_kernel_statements                     72 ->     0  -100.0%
codegen_kernel_statements                     16 ->    18   +12.5%
codegen_kernel_statements                     49 ->    29   -40.8%
codegen_kernel_statements                     52 ->     4   -92.3%
codegen_kernel_statements                     18 ->    17    -5.6%
codegen_kernel_statements                    136 ->    16   -88.2%
codegen_kernel_statements                      0 ->    26    +inf%
codegen_kernel_statements                     14 ->    16   +14.3%
codegen_kernel_statements                     21 ->    26   +23.8%
codegen_kernel_statements                     18 ->    20   +11.1%
codegen_kernel_statements                     53 ->    54    +1.9%
codegen_kernel_statements                    146 ->     0  -100.0%

Copy link
Member

@k-ye k-ye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I can confirm that the change works for Metal as well, thx!

@@ -72,8 +72,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
std::vector<OffloadedTask> offloaded_tasks;
BasicBlock *func_body_bb;

std::unordered_map<OffloadedStmt *, std::vector<llvm::Value *>>
offloaded_loop_vars_llvm;
std::unordered_map<Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it possible to use const Stmt* here? i don't think the key is modified?

@xumingkuan
Copy link
Contributor Author

codegen_kernel_statements                    194 ->   225   +16.0%
codegen_kernel_statements                     47 ->    55   +17.0%
codegen_kernel_statements                     53 ->    58    +9.4%
codegen_kernel_statements                     26 ->     0  -100.0%
...
codegen_kernel_statements                      0 ->    26    +inf%
codegen_kernel_statements                     14 ->    16   +14.3%
codegen_kernel_statements                     21 ->    26   +23.8%
codegen_kernel_statements                     18 ->    20   +11.1%
codegen_kernel_statements                     53 ->    54    +1.9%
codegen_kernel_statements                    146 ->     0  -100.0%

I think maybe we're encountering this issue #656 (comment), that is, some tests are mismatched, so there are some -100.0% and +inf%.

@xumingkuan xumingkuan merged commit c6086e0 into taichi-dev:master May 14, 2020
@xumingkuan xumingkuan deleted the loop-index branch May 27, 2020 20:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants