-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Truncated-Normal Operation #1015
Conversation
@@ -156,7 +158,20 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy | |||
/** {@inheritDoc} */ | |||
@Override | |||
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { | |||
throw new UnsupportedOperationException("Not supported!"); | |||
Random random = new Random(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use RandomUtils.
@@ -13,11 +13,13 @@ | |||
package ai.djl.ndarray; | |||
|
|||
import ai.djl.Device; | |||
import ai.djl.ndarray.internal.NDArrayEx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused imports
Codecov Report
@@ Coverage Diff @@
## master #1015 +/- ##
============================================
+ Coverage 69.97% 70.03% +0.06%
- Complexity 5106 5111 +5
============================================
Files 504 504
Lines 22601 22616 +15
Branches 2367 2369 +2
============================================
+ Hits 15815 15840 +25
+ Misses 5519 5513 -6
+ Partials 1267 1263 -4
Continue to review full report at Codecov.
|
dist[i] = sample; | ||
} | ||
|
||
return create(dist).addi(loc).muli(scale).reshape(shape).toType(dataType, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AzizZayed I think you need to multiply before you add. Otherwise, it will scale the addition too.
Truncated-Normal Operation Truncated-Normal Operation (deepjavalibrary#1015) * [tensoflow] Add truncated normal operation * Add truncated-normal to all engines that do not support it * Truncated-Normal Operation
* [tensoflow] Add truncated normal operation * Add truncated-normal to all engines that do not support it Truncated-Normal Operation Truncated-Normal Operation (#1015) * [tensoflow] Add truncated normal operation * Add truncated-normal to all engines that do not support it * Truncated-Normal Operation
Description
The same as #1005 but for other engines.