Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend][amd] Support device print using hostcall #3476

Merged
merged 8 commits into from
Apr 1, 2024

Conversation

antiagainst
Copy link
Collaborator

@antiagainst antiagainst commented Mar 27, 2024

This commit add device printf support in the AMD backend.
It moves the existing NVIDIA lowering logic to the common
conversion library and adds a new method in TargetInfo
for target specific code generation.

This right now only supports the hostcall mode, which
requires PCIe atomics. There is also a buffered mode,
see https://rocm.docs.amd.com/en/docs-5.7.0/release.html#non-hostcall-hip-printf
for details.

Follows implementation in https://reviews.llvm.org/D110448.

@antiagainst antiagainst force-pushed the amd-print branch 5 times, most recently from 40a194d to 8cb5489 Compare March 30, 2024 17:58
@antiagainst antiagainst force-pushed the amd-print branch 3 times, most recently from 04ac6d1 to 28fabee Compare March 31, 2024 18:24
@antiagainst antiagainst force-pushed the amd-print branch 2 times, most recently from ba7c48a to 8241010 Compare March 31, 2024 22:47
@antiagainst antiagainst changed the title [backend] Add printf support to AMD backend [backend][amd] Support device print using hostcall Mar 31, 2024
@antiagainst antiagainst marked this pull request as ready for review March 31, 2024 22:52
@antiagainst
Copy link
Collaborator Author

Note that the commits are structured to make reviewing easier. You can pretty much look at commits one by one; some of them just shuffle code around and have a NFC marker on it. The meaty parts are the last two commits.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks Lei!

message = call(printStrFn, arguments).getResult();

// Emit the intrinsic function call to handle arguments iteratively.
// We can only handle at most 7 values each time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be a problem? For example if I have a 4D tensor, we will print out 3 program_id's, plus 4 tensor indices, plus one value. So the value will not be printed in the same printf statement as the 7 other values. But then will it be interleaved with other threads' printfs? If so that will make it basically useless...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I haven't looked into how print with hostcall is implemented in the driver stack. But these function calls have a isLast parameter chaining them together--only the last print function will set isLast as 1. That makes me think they are "atomic" in a sense. @scxiao do you know if that's indeed the case? Good to stress test and see how it behaves with more arguments.

@zahimoud zahimoud merged commit 38cd5ab into triton-lang:main Apr 1, 2024
5 checks passed
@antiagainst antiagainst deleted the amd-print branch April 1, 2024 20:37
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants