diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index 7d3c2954cde..5b6a13ecb01 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -43,61 +43,22 @@ public class Parameter implements AutoCloseable { private String id; private String name; private Block block; - private ParameterType type; - private DataType mandatoryDataType; + private Type type; private Initializer initializer; private NDArray array; private boolean requiresGrad; private SparseFormat gradientFormat; - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - */ - public Parameter(String name, Block block, ParameterType type) { - this(name, block, type, true, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requiresGrad whether this {@code Parameter} needs to compute gradients - */ - public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) { - this(name, block, type, requiresGrad, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requireGrad whether this {@code Parameter} needs to compute gradients - * @param gradientFormat the {@link SparseFormat} of the gradient array - */ - public Parameter( - String name, - Block block, - ParameterType type, - boolean requireGrad, - SparseFormat gradientFormat) { + Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); - this.name = name; - this.block = block; - this.type = type; - this.requiresGrad = requireGrad; - this.initializer = type.getInitializer(); - this.gradientFormat = gradientFormat; + this.name = builder.name; + this.block = builder.block; + this.type = builder.type; + this.array = builder.array; + this.requiresGrad = builder.requiresGrad; + this.initializer = + (builder.initializer != null) ? builder.initializer : type.getInitializer(); + this.gradientFormat = builder.gradientFormat; } /** @@ -123,7 +84,7 @@ public String getName() { * * @return the type of this {@code Parameter} */ - public ParameterType getType() { + public Type getType() { return type; } @@ -158,15 +119,6 @@ public boolean requireGradient() { return requiresGrad; } - /** - * Sets the mandatory data type for this {@code Parameter}. - * - * @param mandatoryDataType the mandatory data type for this {@code Parameter} - */ - public void setMandatoryDataType(DataType mandatoryDataType) { - this.mandatoryDataType = mandatoryDataType; - } - /** * Checks if this {@code Parameter} is initialized. * @@ -201,11 +153,7 @@ public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes Objects.requireNonNull(initializer, "No initializer has been set"); if (!isInitialized()) { Shape shape = block.getParameterShape(name, inputShapes); - array = - initializer.initialize( - manager, - shape, - mandatoryDataType == null ? dataType : mandatoryDataType); + array = initializer.initialize(manager, shape, dataType); array.setName(name); } @@ -276,4 +224,139 @@ public void close() { array = null; } } + + /** + * Creates a builder to build a {@code Parameter}. + * + *
The methods start with {@code set} are required fields, and {@code opt} for optional
+ * fields.
+ *
+ * @return a new builder
+ */
+ public static Parameter.Builder builder() {
+ return new Parameter.Builder();
+ }
+
+ /** Enumerates the types of {@link Parameter}. */
+ public enum Type {
+ WEIGHT(null),
+ BIAS(Initializer.ZEROS),
+ GAMMA(Initializer.ONES),
+ BETA(Initializer.ZEROS),
+ RUNNING_MEAN(Initializer.ZEROS),
+ RUNNING_VAR(Initializer.ONES),
+ OTHER(null);
+
+ private final transient Initializer initializer;
+
+ Type(Initializer initializer) {
+ this.initializer = initializer;
+ }
+
+ /**
+ * Gets the {@link Initializer} of this {@code ParameterType}.
+ *
+ * @return the {@link Initializer} of this {@code ParameterType}
+ */
+ public Initializer getInitializer() {
+ return initializer;
+ }
+ }
+
+ /** A Builder to construct a {@code Parameter}. */
+ public static final class Builder {
+ String name;
+ Block block;
+ Type type;
+ Initializer initializer;
+ NDArray array;
+ boolean requiresGrad = true;
+ SparseFormat gradientFormat;
+
+ /**
+ * Sets the name of the {@code Parameter}.
+ *
+ * @param name the name of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setName(String name) {
+ this.name = name;
+ return this;
+ }
+
+ /**
+ * Sets the block of the {@code Parameter}.
+ *
+ * @param block the block of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setBlock(Block block) {
+ this.block = block;
+ return this;
+ }
+
+ /**
+ * Sets the {@code Type} of the {@code Parameter}.
+ *
+ * @param type the {@code Type} of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setType(Type type) {
+ this.type = type;
+ return this;
+ }
+
+ /**
+ * Sets the Initializer of the {@code Parameter}.
+ *
+ * @param initializer the Initializer of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optInitializer(Initializer initializer) {
+ this.initializer = initializer;
+ return this;
+ }
+
+ /**
+ * Sets the array of the {@code Parameter}.
+ *
+ * @param array the array of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optArray(NDArray array) {
+ this.array = array;
+ return this;
+ }
+
+ /**
+ * Sets if the {@code Parameter} requires gradient.
+ *
+ * @param requiresGrad if the {@code Parameter} requires gradient
+ * @return this {@code Parameter}
+ */
+ public Builder optRequiresGrad(boolean requiresGrad) {
+ this.requiresGrad = requiresGrad;
+ return this;
+ }
+
+ /**
+ * Sets the {@link SparseFormat} of the {@code Parameter}.
+ *
+ * @param gradientFormat the {@link SparseFormat} of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optGradientFormat(SparseFormat gradientFormat) {
+ this.gradientFormat = gradientFormat;
+ return this;
+ }
+
+ /**
+ * Builds a {@code Parameter} instance.
+ *
+ * @return the {@code Parameter} instance
+ */
+ public Parameter build() {
+ return new Parameter(this);
+ }
+ }
}
diff --git a/api/src/main/java/ai/djl/nn/ParameterType.java b/api/src/main/java/ai/djl/nn/ParameterType.java
deleted file mode 100644
index dde3cb23ab8..00000000000
--- a/api/src/main/java/ai/djl/nn/ParameterType.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
- * with the License. A copy of the License is located at
- *
- * http://aws.amazon.com/apache2.0/
- *
- * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
- * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
- * and limitations under the License.
- */
-package ai.djl.nn;
-
-import ai.djl.training.initializer.Initializer;
-
-/** Enumerates the types of {@link Parameter}. */
-public enum ParameterType {
- WEIGHT(null),
- BIAS(Initializer.ZEROS),
- GAMMA(Initializer.ONES),
- BETA(Initializer.ZEROS),
- RUNNING_MEAN(Initializer.ZEROS),
- RUNNING_VAR(Initializer.ONES),
- OTHER(null);
-
- private final transient Initializer initializer;
-
- ParameterType(Initializer initializer) {
- this.initializer = initializer;
- }
-
- /**
- * Gets the {@link Initializer} of this {@code ParameterType}.
- *
- * @return the {@link Initializer} of this {@code ParameterType}
- */
- public Initializer getInitializer() {
- return initializer;
- }
-}
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
index 6334935804c..c2fa87f0a41 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
@@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -102,13 +101,22 @@ public Convolution(ConvolutionBuilder> builder) {
weight =
addParameter(
- new Parameter("weight", this, ParameterType.WEIGHT),
+ Parameter.builder()
+ .setName("weight")
+ .setBlock(this)
+ .setType(Parameter.Type.WEIGHT)
+ .build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
- new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
+ Parameter.builder()
+ .setName("bias")
+ .setBlock(this)
+ .setType(Parameter.Type.BIAS)
+ .build(),
+ new Shape(filters));
}
}
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
index 890a4db9eff..30d348cebc9 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
@@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -78,13 +77,22 @@ public Deconvolution(DeconvolutionBuilder> builder) {
weight =
addParameter(
- new Parameter("weight", this, ParameterType.WEIGHT),
+ Parameter.builder()
+ .setName("weight")
+ .setBlock(this)
+ .setType(Parameter.Type.WEIGHT)
+ .build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
- new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
+ Parameter.builder()
+ .setName("bias")
+ .setBlock(this)
+ .setType(Parameter.Type.BIAS)
+ .build(),
+ new Shape(filters));
}
}
diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java
index 48736aaa8e1..52ec3e2f306 100644
--- a/api/src/main/java/ai/djl/nn/core/Embedding.java
+++ b/api/src/main/java/ai/djl/nn/core/Embedding.java
@@ -23,7 +23,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -56,7 +55,13 @@ protected Embedding(BaseBuilder