Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use full-vectorized load instructions for load vectorization #445

Merged
merged 6 commits into from
Jan 18, 2024

Conversation

htyu
Copy link

@htyu htyu commented Jan 9, 2024

I'm not quite sure why the existing code for load vectorization is using segmented short-vectorized loads instead of using a full 128-bit load. Using multiple copies of shorter load seems to create a dependency on the LLVM backend (esp. the load and store vectorizer) for full vectorization. This might be fragile as I saw in some cases the vector combine pass and the jump threading pass screwed it up and resulted in non-ideal vectorization.

@zhanglx13
Copy link

zhanglx13 commented Jan 9, 2024

@htyu Thanks for looking into this issue. However, I don't think this is the right way to solve the problem.

  • maxWordWidth refers to the size of a word, rather than the vector size, which is indicated by vec. So I don't think we should change it to 128 here.
  • That part of code is common on NV and AMD GPUs. Changing it will also break NV path tests.
  • We are aware of the facts that in some cases, the LLVM backend cannot (or it has a good reason not to) vectorize global load. For now we need to massage the address computation carefully (like the one here) so that the backend LLVM can do vector global load for us. And if we really want vector global load but the LLVM backend won't do that, the right way should be to generate inline assembly directly as the NV path.

If there is an issue with global load vectorization in your customized kernel, we are happy to help.

@htyu
Copy link
Author

htyu commented Jan 9, 2024

Thanks for the comments.

  • We are aware of the facts that in some cases, the LLVM backend cannot (or it has a good reason not to) vectorize global load. For now we need to massage the address computation carefully

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

Inline assembly should work. But I'm not sure it has side effects on other LLVM optimizations. What problem do you see with the long load?

@zhanglx13
Copy link

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

I expect a sequence of shorter loads. We haven't tried full vectorized global load at llvm level since we are trying to reused as much code from NV path as possible.
I agree that inline assembly can have issues related to mem sync. So I'm happy to see that we can avoid using inline assembly to solve the issue.

There are a lot of failed tests. Can you make them pass first?

@htyu
Copy link
Author

htyu commented Jan 9, 2024

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

I expect a sequence of shorter loads. We haven't tried full vectorized global load at llvm level since we are trying to reused as much code from NV path as possible. I agree that inline assembly can have issues related to mem sync. So I'm happy to see that we can avoid using inline assembly to solve the issue.

There are a lot of failed tests. Can you make them pass first?

Sure, I'll work on clearing the test failures and making sure it's not affecting the NV path.

P.S., the problem I was seeing is that the VectorCombine pass converted the original four i32 loads into four 2xi16 loads. Then the jump threading pass threaded the four 2xi16 loads by getting rid of the redundant mask checks for the first three loads. During the threading, the first three loads were further decomposed into six i16 loads which were vectorized later. The fourth 2x116 load was excluded from the vectorization because it's already vectorized. An 8xi16 load in the first place appeared to be immune from all those issues.

@bertmaher
Copy link

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

@htyu
Copy link
Author

htyu commented Jan 9, 2024

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is exactly the issue I'm fixing here.

@zhanglx13
Copy link

CC +@scxiao

@codego7250
Copy link

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

@htyu
Copy link
Author

htyu commented Jan 10, 2024

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness.

One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 .

Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>

@codego7250
Copy link

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness.

One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 .

Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>

This looks good. And it may have case for the unit case in terms of the predicate etc. Let's make sure that works for all.

@htyu
Copy link
Author

htyu commented Jan 16, 2024

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness.
One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 .
Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>

This looks good. And it may have case for the unit case in terms of the predicate etc. Let's make sure that works for all.

Thanks. But on the second thought, I'm inclined to generating a full vectorized load when possible. This should make it more immune to the llvm uncertainty. It also reduces the size of LLVM IR to improve compile time. Please check my latest version and see if it looks good.

@zhanglx13
Copy link

@htyu Thanks for fixing this issue.
And I tested on MI250 and MI300 that we don't need this trick anymore: https://github.com/ROCmSoftwarePlatform/triton/blob/e7033218d6a0f0f1129aa3adc1bfbbe57c84fd20/python/tutorials/03-matrix-multiplication.py#L253-L257

cc+ @scxiao

@zhanglx13 zhanglx13 self-requested a review January 18, 2024 15:32
@htyu
Copy link
Author

htyu commented Jan 18, 2024

@htyu Thanks for fixing this issue. And I tested on MI250 and MI300 that we don't need this trick anymore:

https://github.com/ROCmSoftwarePlatform/triton/blob/e7033218d6a0f0f1129aa3adc1bfbbe57c84fd20/python/tutorials/03-matrix-multiplication.py#L253-L257

cc+ @scxiao

Thanks for giving a try!

BTW, how should I land this patch?

@zhanglx13
Copy link

@htyu I'll land it.
One more thing, do you think this method also works for nv path? If so, it'll be better if we can merge the two paths.

@zhanglx13 zhanglx13 merged commit 315528f into ROCm:triton-mlir Jan 18, 2024
2 checks passed
@htyu
Copy link
Author

htyu commented Jan 18, 2024

@htyu I'll land it. One more thing, do you think this method also works for nv path? If so, it'll be better if we can merge the two paths.

I'll need to take a deeper look. NV loads come with those cache flags and I'm not sure how to express them on LLVM dialect. But yeah, I'm in general not in favor of using asm volatiles. It'd be great to get rid of them.

@zhanglx13
Copy link

@htyu Sounds good. Keep us posted !

jtang10 pushed a commit that referenced this pull request Jan 30, 2024
* Stablize load vectorization

* fix test failures

* Shared one mask check when decomposing a load

* Revert "fix test failures"

This reverts commit 75a461a.

* Emit vectorized loads

* Fix test failures due to using vectorized load
@zhanglx13
Copy link

@htyu Since we are moving our dev work upstream and closing the perf gap between this fork and upstream, could you please upstream this PR?

@htyu
Copy link
Author

htyu commented Apr 8, 2024

@htyu Since we are moving our dev work upstream and closing the perf gap between this fork and upstream, could you please upstream this PR?

Sure, will do.

Do you need me to upstream other PRs I made in this repo?

@zhanglx13
Copy link

Yes, that would be great. Thank you very much ~
I think we can hold dot-slicing related PRs for the moment, since they are still experimental for now.

zahimoud pushed a commit to triton-lang/triton that referenced this pull request Apr 9, 2024
…orization (#3609)

Current implementation for load vectorization uses segmented
short-vectorized loads instead of a full 128-bit load. Using multiple
copies of shorter load creates a dependency on the LLVM backend (esp.
the load and store vectorizer) for full vectorization. This could be
fragile as I saw in some cases the vector combine pass and the jump
threading pass screwed it up and resulted in non-ideal vectorization

This is a backport of ROCm#445
@htyu
Copy link
Author

htyu commented Apr 9, 2024

Upstreaming PR: triton-lang#3609

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants