-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the usage of the repeat function for embedding #2590
Fix the usage of the repeat function for embedding #2590
Conversation
Can you add this as a test as well to show what is being fixed and ensure it doesn't break again? You can add it to |
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #2590 +/- ##
============================================
+ Coverage 72.08% 73.69% +1.60%
- Complexity 5126 6952 +1826
============================================
Files 473 684 +211
Lines 21970 30333 +8363
Branches 2351 3138 +787
============================================
+ Hits 15838 22354 +6516
- Misses 4925 6453 +1528
- Partials 1207 1526 +319
☔ View full report in Codecov by Sentry. |
Description
embeddedTokens = embedding.embed(manager, { 1, 2 }) ;
gets the following error.
Exception in thread "main" java.lang.IllegalArgumentException: The desired shape has too many dimensions
at ai.djl.pytorch.engine.PtNDArray.repeatsToMatchShape(PtNDArray.java:1330)
at ai.djl.pytorch.engine.PtNDArray.repeat(PtNDArray.java:1323)
at ai.djl.pytorch.engine.PtNDArray.repeat(PtNDArray.java:39)
at ai.djl.nn.core.ConstantEmbedding.embed(ConstantEmbedding.java:108)
To fix this error, simplify to repeat the length of the input and reshape.