Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensoflow] Add truncated-normal operation #1005

Merged
merged 1 commit into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray randomMultinomial(int n, NDArray pValues) {
Expand Down
59 changes: 59 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,65 @@ default NDArray randomNormal(
return newSubManager(device).randomNormal(loc, scale, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* <p>Samples are distributed according to a normal distribution parametrized by mean = 0 and
* standard deviation = 1.
*
* @param shape the output {@link Shape}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape) {
return truncatedNormal(0f, 1f, shape, DataType.FLOAT32);
}

/**
* Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation
* 1, discarding and re-drawing any samples that are more than two standard deviations from the
* mean.
*
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(Shape shape, DataType dataType) {
return truncatedNormal(0.0f, 1.0f, shape, dataType);
}

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType);

/**
* Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any
* samples that are more than two standard deviations from the mean.
*
* @param loc the mean (centre) of the distribution
* @param scale the standard deviation (spread or "width") of the distribution
* @param shape the output {@link Shape}
* @param dataType the {@link DataType} of the {@link NDArray}
* @param device the {@link Device} of the {@link NDArray}
* @return the drawn samples {@link NDArray}
*/
default NDArray truncatedNormal(
float loc, float scale, Shape shape, DataType dataType, Device device) {
if (device == null || device.equals(getDevice())) {
return truncatedNormal(loc, scale, shape, dataType);
}
return newSubManager(device).truncatedNormal(loc, scale, shape, dataType);
}

/**
* Draw samples from a multinomial distribution.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,31 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
}
}

/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
if (DataType.STRING.equals(dataType)) {
throw new IllegalArgumentException("String data type is not supported!");
}
NDArray axes = create(shape.getShape());
TfOpExecutor opBuilder =
opExecutor("TruncatedNormal").addInput(axes).addParam("dtype", dataType);
Integer seed = getEngine().getSeed();
if (seed != null) {
// seed1 is graph-level seed
// set it to default graph seed used by tensorflow
// https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/framework/random_seed.py#L31
opBuilder.addParam("seed", 87654321);
opBuilder.addParam("seed2", seed);
}
try (NDArray array = opBuilder.buildSingletonOrThrow();
NDArray temp = array.mul(scale)) {
return temp.add(loc);
} finally {
axes.close();
}
}

/** {@inheritDoc} */
@Override
public TfNDManager newSubManager(Device device) {
Expand Down