Skip to content

Commit

Permalink
Truncated-Normal Operation
Browse files Browse the repository at this point in the history
  • Loading branch information
AzizZayed committed Jun 12, 2021
1 parent fccdc01 commit 2b1c526
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import ai.djl.util.RandomUtils;
import java.nio.Buffer;
import java.nio.file.Path;
import java.util.UUID;
Expand Down Expand Up @@ -157,24 +157,19 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
float leftClip = loc - 2 * scale;
float rightClip = loc + 2 * scale;
int sampleSize = (int) shape.size();

Shape sampleShape = new Shape();
double[] dist = new double[sampleSize];

for (int i = 0; i < sampleSize; i++) {

NDArray sample = randomNormal(loc, scale, sampleShape, DataType.FLOAT64);
while (sample.getDouble() < leftClip || sample.getDouble() > rightClip) {
sample = randomNormal(loc, scale, sampleShape, DataType.FLOAT64);
double sample = RandomUtils.nextGaussian();
while (sample < -2 || sample > 2) {
sample = RandomUtils.nextGaussian();
}

dist[i] = sample.getDouble();
dist[i] = sample;
}

return create(dist).reshape(shape).toType(dataType, false);
return create(dist).addi(loc).muli(scale).reshape(shape).toType(dataType, false);
}

/** {@inheritDoc} */
Expand Down

0 comments on commit 2b1c526

Please sign in to comment.