-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a matmul test from int8, bf16 #2718
Conversation
cbc7049
to
0bd49d4
Compare
I generally approve of this PR, but could we split it into two PRs -- one that refactors without functionally changing anything, and a separate one that adds the new modes? PS: I think we only need to narrow down the range for the case of FP16 accumulation |
In this PR we are simplifying matmul test without changing the logic. It's the first PR from splitting triton-lang#2718. Follow ups will be adding the `output_dtype` parameter and adding the bf16, int8 matmul test. Basically: - I clean/refactor files a bit. - rename dot_out_dtype to acc_dtype because it was confusing for me - I added `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
In this PR we are adding a matmul test from `int8`, `bf16`. - First I included two new params: - `acc_dtype`: So users of the test class can specify the type used internally in the dot, and not the one set by default given the two types. There are several restrictions for these types anyway. - `output_dtype`: The return type of the matmul. I included a few tests in the case of making a dot with two float16. - I had to modify test_matmul to use a small range of values to prevent numerical issues. In the case of testing with two `float16` and `acc_dtype` `float16`, since I can't force torch to use `float16` internally (it uses `float32`), I was having precision issues when comparing the results with triton. Anyway, the goal of this test shouldn't be testing precision. - I also needed to include `torch.int8` in the possible datatypes. - I clean/refactor files a bit. - Finally I tried to simplify a bit the logic of matmul because after adding these two parameters it was a bit hard to follow why we needed every part of the code, so I included a `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
0bd49d4
to
87be2ae
Compare
Sorry for taking so long to reply. I will split it into 3 PRs, the first without any real change (#2760), second to add acc_dtype and return_dtype set by users and last add the int8, bfloat16 test.
I had a discussion about this topic with @gflegar in openxla#6 (comment). I don't have a strong opinion, I'm ok with doing what you suggest but it's making the test a bit more complex when we shouldn't be testing precision here. What do you think? |
In this PR we are simplifying matmul test without changing the logic. It's the first PR from splitting triton-lang#2718. Follow ups will be adding the `output_dtype` parameter and adding the bf16, int8 matmul test. Basically: - I clean/refactor files a bit. - rename dot_out_dtype to acc_dtype because it was confusing for me - I added `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
In this PR we are simplifying matmul test without changing the behavior. It's the first PR from splitting #2718. Follow ups will be adding the `output_dtype` parameter and adding the bf16, int8 matmul test. Basically: - I clean/refactor files a bit. - rename `dot_out_dtype` to `acc_dtype` because it was confusing for me - I added `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
In this PR we are simplifying matmul test without changing the behavior. It's the first PR from splitting triton-lang#2718. Follow ups will be adding the `output_dtype` parameter and adding the bf16, int8 matmul test. Basically: - I clean/refactor files a bit. - rename `dot_out_dtype` to `acc_dtype` because it was confusing for me - I added `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
In this PR we are simplifying matmul test without changing the behavior. It's the first PR from splitting triton-lang#2718. Follow ups will be adding the `output_dtype` parameter and adding the bf16, int8 matmul test. Basically: - I clean/refactor files a bit. - rename `dot_out_dtype` to `acc_dtype` because it was confusing for me - I added `supported_acc_dtypes` for the allowed types in relation of the types of the operands a and b.
In this PR we are adding a matmul test from
int8
,bf16
.acc_dtype
: So users of the test class can specify the type used internally in the dot, and not the one set by default given the two types. There are several restrictions for these types anyway.output_dtype
: The return type of the matmul. I included a few tests in the case of making a dot with two float16.float16
andacc_dtype
float16
, since I can't force torch to usefloat16
internally (it usesfloat32
), I was having precision issues when comparing the results with triton. Anyway, the goal of this test shouldn't be testing precision.torch.int8
in the possible datatypes.supported_acc_dtypes
for the allowed types in relation of the types of the operands a and b.