-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add matmul ONNX op support #1638
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1638 +/- ##
==========================================
+ Coverage 86.35% 86.38% +0.03%
==========================================
Files 693 693
Lines 80271 80473 +202
==========================================
+ Hits 69315 69519 +204
+ Misses 10956 10954 -2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this OP.
I think we have opportunity here to fix the broadcasting issue as well.
Burn supports a partial broadcasting (see my issue regarding this: #1499) where if one of the dimensions is 1, then it's broadcasted. But broadcasting won't work if dimensions do not match (2 vs 1) because Burn cannot dynamically expand static constant type (D). However, there is an easy workaround the user can apply (as I mentioned in #1499) - namely use unsqueeze
operations. But the user needs to know which dimension is greater. So we can do the same here since we are just generating rust code. We just need to know if dimensions do not match, we add unsqueeze
.
I hope it's easy change for you but with this we will be 100% ONNX compatible. The same logic can be used for other ops such as addition.
Oh it wasn't clear even to me that this type of broadcasting was supported lol. Will fix! Thanks for the comments. |
Same! That's why I have added this ticket (#1499) to document it in details. I think many will be confused by it too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for completing this task and making matmul broadcastable in broadest sense.
Also thanks for explaining difference case scenarios offline to me. We definitely need a documentation to explain Burn's broadcasting rules.
* Mul onnx op already supported * Add matmul onnx op checks and tests * Add missing eq derives * Change supscript symbol * Remove dead code * Add support for matmul broadcast * No more broadcasting restrictions * Add results comment for mm, mv and vm
Related Issues/PRs
Progress towards #1544
Changes
Added matmul ONNX op import support and marked mul (already supported).
Testing
Added unit tests for codegen and expected outputs