From 0aec8cad5aa5623a58217b2495c7e5d72205f582 Mon Sep 17 00:00:00 2001 From: Abd-El-Aziz Zayed <48853777+AzizZayed@users.noreply.github.com> Date: Thu, 10 Jun 2021 12:46:16 -0700 Subject: [PATCH] [tensoflow] Add truncated normal operation (#1005) --- .../java/ai/djl/ndarray/BaseNDManager.java | 6 ++ .../main/java/ai/djl/ndarray/NDManager.java | 59 +++++++++++++++++++ .../ai/djl/tensorflow/engine/TfNDManager.java | 25 ++++++++ 3 files changed, 90 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 64b52355710..2fcd9d86423 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -153,6 +153,12 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public NDArray randomMultinomial(int n, NDArray pValues) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index dae10ec8c99..826ee66d17f 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -1232,6 +1232,65 @@ default NDArray randomNormal( return newSubManager(device).randomNormal(loc, scale, shape, dataType); } + /** + * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation + * 1, discarding and re-drawing any samples that are more than two standard deviations from the + * mean. + * + *
Samples are distributed according to a normal distribution parametrized by mean = 0 and + * standard deviation = 1. + * + * @param shape the output {@link Shape} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal(Shape shape) { + return truncatedNormal(0f, 1f, shape, DataType.FLOAT32); + } + + /** + * Draws random samples from a normal (Gaussian) distribution with mean 0 and standard deviation + * 1, discarding and re-drawing any samples that are more than two standard deviations from the + * mean. + * + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal(Shape shape, DataType dataType) { + return truncatedNormal(0.0f, 1.0f, shape, dataType); + } + + /** + * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any + * samples that are more than two standard deviations from the mean. + * + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType); + + /** + * Draws random samples from a normal (Gaussian) distribution, discarding and re-drawing any + * samples that are more than two standard deviations from the mean. + * + * @param loc the mean (centre) of the distribution + * @param scale the standard deviation (spread or "width") of the distribution + * @param shape the output {@link Shape} + * @param dataType the {@link DataType} of the {@link NDArray} + * @param device the {@link Device} of the {@link NDArray} + * @return the drawn samples {@link NDArray} + */ + default NDArray truncatedNormal( + float loc, float scale, Shape shape, DataType dataType, Device device) { + if (device == null || device.equals(getDevice())) { + return truncatedNormal(loc, scale, shape, dataType); + } + return newSubManager(device).truncatedNormal(loc, scale, shape, dataType); + } + /** * Draw samples from a multinomial distribution. * diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index 71e731f47c9..a200673b8c8 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -242,6 +242,31 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy } } + /** {@inheritDoc} */ + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + if (DataType.STRING.equals(dataType)) { + throw new IllegalArgumentException("String data type is not supported!"); + } + NDArray axes = create(shape.getShape()); + TfOpExecutor opBuilder = + opExecutor("TruncatedNormal").addInput(axes).addParam("dtype", dataType); + Integer seed = getEngine().getSeed(); + if (seed != null) { + // seed1 is graph-level seed + // set it to default graph seed used by tensorflow + // https://github.com/tensorflow/tensorflow/blob/85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/framework/random_seed.py#L31 + opBuilder.addParam("seed", 87654321); + opBuilder.addParam("seed2", seed); + } + try (NDArray array = opBuilder.buildSingletonOrThrow(); + NDArray temp = array.mul(scale)) { + return temp.add(loc); + } finally { + axes.close(); + } + } + /** {@inheritDoc} */ @Override public TfNDManager newSubManager(Device device) {