Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul ONNX op support #1638

Merged
merged 9 commits into from
Apr 18, 2024
Merged

Add matmul ONNX op support #1638

merged 9 commits into from
Apr 18, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Apr 15, 2024

Related Issues/PRs

Progress towards #1544

Changes

Added matmul ONNX op import support and marked mul (already supported).

Testing

Added unit tests for codegen and expected outputs

@laggui laggui requested a review from antimora April 15, 2024 16:01
Copy link

codecov bot commented Apr 15, 2024

Codecov Report

Attention: Patch coverage is 96.13527% with 8 lines in your changes are missing coverage. Please review.

Project coverage is 86.38%. Comparing base (2a721a9) to head (1174d64).

Files Patch % Lines
crates/burn-import/src/burn/node/matmul.rs 94.77% 7 Missing ⚠️
crates/burn-import/src/onnx/dim_inference.rs 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1638      +/-   ##
==========================================
+ Coverage   86.35%   86.38%   +0.03%     
==========================================
  Files         693      693              
  Lines       80271    80473     +202     
==========================================
+ Hits        69315    69519     +204     
+ Misses      10956    10954       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this OP.

I think we have opportunity here to fix the broadcasting issue as well.

Burn supports a partial broadcasting (see my issue regarding this: #1499) where if one of the dimensions is 1, then it's broadcasted. But broadcasting won't work if dimensions do not match (2 vs 1) because Burn cannot dynamically expand static constant type (D). However, there is an easy workaround the user can apply (as I mentioned in #1499) - namely use unsqueeze operations. But the user needs to know which dimension is greater. So we can do the same here since we are just generating rust code. We just need to know if dimensions do not match, we add unsqueeze.

I hope it's easy change for you but with this we will be 100% ONNX compatible. The same logic can be used for other ops such as addition.

@laggui
Copy link
Member Author

laggui commented Apr 15, 2024

Oh it wasn't clear even to me that this type of broadcasting was supported lol.

Will fix! Thanks for the comments.

@antimora
Copy link
Collaborator

Oh it wasn't clear even to me that this type of broadcasting was supported lol.

Will fix! Thanks for the comments.

Same! That's why I have added this ticket (#1499) to document it in details. I think many will be confused by it too.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for completing this task and making matmul broadcastable in broadest sense.

Also thanks for explaining difference case scenarios offline to me. We definitely need a documentation to explain Burn's broadcasting rules.

@laggui laggui merged commit 7705fd9 into main Apr 18, 2024
15 checks passed
@laggui laggui deleted the feat/onnx/matmul branch April 18, 2024 13:20
syl20bnr pushed a commit that referenced this pull request Apr 26, 2024
* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants