Skip to content

Commit

Permalink
[tensoflow] Add truncated normal operation (deepjavalibrary#1005)
Browse files Browse the repository at this point in the history
  • Loading branch information
AzizZayed authored Jun 10, 2021
1 parent d8e7e1d commit 0aec8ca
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray randomMultinomial(int n, NDArray pValues) {
Expand Down
59 changes: 59 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,65 @@ default NDArray randomNormal(
return newSubManager(device).randomNormal(loc, scale, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* <p>Samples are distributed according to a normal distribution parametrized by mean = 0 and
* standard deviation = 1.
*
* @param shape the output {@link Shape}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape) {
return truncatedNormal(0f, 1f, shape, DataType.FLOAT32);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape, DataType dataType) {
return truncatedNormal(0.0f, 1.0f, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType);

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @param device the {@link Device} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(
float loc, float scale, Shape shape, DataType dataType, Device device) {
if (device == null || device.equals(getDevice())) {
return truncatedNormal(loc, scale, shape, dataType);
}
return newSubManager(device).truncatedNormal(loc, scale, shape, dataType);
}

/**
* Draw samples from a multinomial distribution.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,31 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
}
}

/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
if (DataType.STRING.equals(dataType)) {
throw new IllegalArgumentException("String data type is not supported!");
}
NDArray axes = create(shape.getShape());
TfOpExecutor opBuilder =
opExecutor("TruncatedNormal").addInput(axes).addParam("dtype", dataType);
Integer seed = getEngine().getSeed();
if (seed != null) {
// seed1 is graph-level seed
// set it to default graph seed used by tensorflow
// https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/framework/random_seed.py#L31
opBuilder.addParam("seed", 87654321);
opBuilder.addParam("seed2", seed);
}
try (NDArray array = opBuilder.buildSingletonOrThrow();
NDArray temp = array.mul(scale)) {
return temp.add(loc);
} finally {
axes.close();
}
}

/** {@inheritDoc} */
@Override
public TfNDManager newSubManager(Device device) {
Expand Down

0 comments on commit 0aec8ca

Please sign in to comment.