Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a matmul test from int8, bf16 #2718

Closed

Conversation

karupayun
Copy link
Collaborator

In this PR we are adding a matmul test from int8, bf16.

  • First I included two new params:
    • acc_dtype: So users of the test class can specify the type used internally in the dot, and not the one set by default given the two types. There are several restrictions for these types anyway.
    • output_dtype: The return type of the matmul. I included a few tests in the case of making a dot with two float16.
  • I had to modify test_matmul to use a small range of values to prevent numerical issues. In the case of testing with two float16 and acc_dtype float16, since I can't force torch to use float16 internally (it uses float32), I was having precision issues when comparing the results with triton. Anyway, the goal of this test shouldn't be testing precision.
  • I also needed to include torch.int8 in the possible datatypes.
  • I clean/refactor files a bit.
  • Finally I tried to simplify a bit the logic of matmul because after adding these two parameters it was a bit hard to follow why we needed every part of the code, so I included a supported_acc_dtypes for the allowed types in relation of the types of the operands a and b.

@karupayun karupayun requested a review from ptillet as a code owner November 29, 2023 16:33
@karupayun karupayun force-pushed the feature/matmul-int8-bf16 branch from cbc7049 to 0bd49d4 Compare November 29, 2023 16:36
@ptillet
Copy link
Collaborator

ptillet commented Nov 29, 2023

I generally approve of this PR, but could we split it into two PRs -- one that refactors without functionally changing anything, and a separate one that adds the new modes?

PS: I think we only need to narrow down the range for the case of FP16 accumulation

karupayun added a commit to karupayun/triton that referenced this pull request Dec 5, 2023
In this PR we are simplifying matmul test without changing the logic.

It's the first PR from splitting triton-lang#2718.
Follow ups will be adding the `output_dtype` parameter and adding the bf16,
int8 matmul test.

Basically:
- I clean/refactor files a bit.
- rename dot_out_dtype to acc_dtype because it was confusing for me
- I added `supported_acc_dtypes` for the allowed types in relation of the
types of the operands a and b.
In this PR we are adding a matmul test from `int8`, `bf16`.

- First I included two new params:
  - `acc_dtype`: So users of the test class can specify the type used internally in the dot, and not
the one set by default given the two types. There are several restrictions for these types anyway.
  - `output_dtype`: The return type of the matmul. I included a few tests in the case of making a dot
with two float16.
- I had to modify test_matmul to use a small range of values to prevent numerical issues. In the case
of testing with two `float16` and `acc_dtype` `float16`, since I can't force torch to use `float16`
internally (it uses `float32`), I was having precision issues when comparing the results with triton.
Anyway, the goal of this test shouldn't be testing precision.
- I also needed to include `torch.int8` in the possible datatypes.
- I clean/refactor files a bit.
- Finally I tried to simplify a bit the logic of matmul because after adding these two parameters it
was a bit hard to follow why we needed every part of the code, so I included a `supported_acc_dtypes`
for the allowed types in relation of the types of the operands a and b.
@karupayun karupayun force-pushed the feature/matmul-int8-bf16 branch from 0bd49d4 to 87be2ae Compare December 5, 2023 18:17
@karupayun
Copy link
Collaborator Author

I generally approve of this PR, but could we split it into two PRs -- one that refactors without functionally changing anything, and a separate one that adds the new modes?

Sorry for taking so long to reply. I will split it into 3 PRs, the first without any real change (#2760), second to add acc_dtype and return_dtype set by users and last add the int8, bfloat16 test.

PS: I think we only need to narrow down the range for the case of FP16 accumulation

I had a discussion about this topic with @gflegar in openxla#6 (comment). I don't have a strong opinion, I'm ok with doing what you suggest but it's making the test a bit more complex when we shouldn't be testing precision here. What do you think?

karupayun added a commit to karupayun/triton that referenced this pull request Dec 5, 2023
In this PR we are simplifying matmul test without changing the logic.

It's the first PR from splitting triton-lang#2718.
Follow ups will be adding the `output_dtype` parameter and adding the bf16,
int8 matmul test.

Basically:
- I clean/refactor files a bit.
- rename dot_out_dtype to acc_dtype because it was confusing for me
- I added `supported_acc_dtypes` for the allowed types in relation of the
types of the operands a and b.
ptillet pushed a commit that referenced this pull request Dec 6, 2023
In this PR we are simplifying matmul test without changing the behavior.

It's the first PR from splitting
#2718. Follow ups will be adding
the `output_dtype` parameter and adding the bf16, int8 matmul test.

Basically:
- I clean/refactor files a bit.
- rename `dot_out_dtype` to `acc_dtype` because it was confusing for me
- I added `supported_acc_dtypes` for the allowed types in relation of
the types of the operands a and b.
@karupayun
Copy link
Collaborator Author

This PR was divided between #2768, #2769 and #2760. All of them are already merged.

@karupayun karupayun closed this Dec 13, 2023
feihugis pushed a commit to feihugis/triton that referenced this pull request Feb 13, 2024
In this PR we are simplifying matmul test without changing the behavior.

It's the first PR from splitting
triton-lang#2718. Follow ups will be adding
the `output_dtype` parameter and adding the bf16, int8 matmul test.

Basically:
- I clean/refactor files a bit.
- rename `dot_out_dtype` to `acc_dtype` because it was confusing for me
- I added `supported_acc_dtypes` for the allowed types in relation of
the types of the operands a and b.
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
In this PR we are simplifying matmul test without changing the behavior.

It's the first PR from splitting
triton-lang#2718. Follow ups will be adding
the `output_dtype` parameter and adding the bf16, int8 matmul test.

Basically:
- I clean/refactor files a bit.
- rename `dot_out_dtype` to `acc_dtype` because it was confusing for me
- I added `supported_acc_dtypes` for the allowed types in relation of
the types of the operands a and b.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants