diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java
index 72a6ead095f..e6dd38fedf3 100644
--- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java
+++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java
@@ -64,8 +64,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- return block.getOutputShapes(manager, inputShapes);
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ return block.getOutputShapes(inputShapes);
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java
index 7ad4fbdcf6f..e3b1631322f 100644
--- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java
+++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java
@@ -79,8 +79,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- return block.getOutputShapes(manager, inputShapes);
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ return block.getOutputShapes(inputShapes);
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
index afc0958ace9..a7c47380a9d 100644
--- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
+++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
@@ -97,19 +97,18 @@ public NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
- * @return the shapes of the outputs of the block
*/
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
encoder.initialize(manager, dataType, inputShapes[0]);
- return decoder.initialize(manager, dataType, inputShapes[1]);
+ decoder.initialize(manager, dataType, inputShapes[1]);
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- return decoder.getOutputShapes(manager, new Shape[] {inputShapes[1]});
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ return decoder.getOutputShapes(new Shape[] {inputShapes[1]});
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java
index f848f25f3e5..4fd7dc71277 100644
--- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java
+++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java
@@ -88,7 +88,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- return trainableWordEmbedding.getOutputShapes(manager, inputShapes);
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ return trainableWordEmbedding.getOutputShapes(inputShapes);
}
}
diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java
index 6d4c2c4ecd2..7f67d8431a6 100644
--- a/api/src/main/java/ai/djl/nn/AbstractBlock.java
+++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java
@@ -30,7 +30,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
-import java.util.function.Function;
+import java.util.function.Predicate;
/**
* {@code AbstractBlock} is an abstract implementation of {@link Block}.
@@ -43,12 +43,11 @@
*
* - Define a version for serializing parameter and metadata and pass it to the parent
* constructor
- *
- Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link
- * AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the
+ *
- Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
*
- Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
- *
- Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
- * of your custom block's output based on the input it will receive.
+ *
- Override {@link Block#getOutputShapes(Shape[])} to determine the shape of your custom
+ * block's output based on the input it will receive.
*
- Override {@link AbstractBlock#initializeChildBlocks(NDManager, DataType, Shape...)} if you
* added child blocks to initialize them based on the input shape your block will receive. You
* can skip this if your block does not contain child blocks
@@ -61,9 +60,9 @@
*
*
* If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take
- * care of parameter initialization yourself. In this case, you need to override {@link
- * AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If
- * you use the other variants of {@code addParameter} this is done for you.
+ * care of parameter initialization yourself. In this case, you need to setShape to your parameters
+ * if you know the shape of Parameter or you can implement prepare to setShape when you see the
+ * input shape.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children and parameters are always iterated over in insertion order.
@@ -99,14 +98,6 @@ public abstract class AbstractBlock implements Block {
*/
protected LinkedHashMap parameters = new LinkedHashMap<>();
- /**
- * Callbacks to determine the shape of a parameter. Values may be null in which case extending
- * classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
- * parameter shape resolution manually.
- */
- protected LinkedHashMap> parameterShapeCallbacks =
- new LinkedHashMap<>();
-
/**
* Builds an empty block with the given version for parameter serialization.
*
@@ -195,73 +186,20 @@ protected final B addChildBlock(String name, B block) {
return block;
}
- /**
- * Adds a parameter to this block. If parameters are added with this method, subclasses need to
- * override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
- * themselves.
- *
- * @param parameter the parameter to add, not null
- * @param the specific parameter subclass
- * @return the parameter passed as arguments to make it easier to create and assign paramters in
- * one line
- */
- protected final
P addParameter(P parameter) {
- return addParameter(parameter, (Function) null);
- }
-
/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
- * @param parameter the parameter to add, not null
- * @param shape the shape of the parameter
* @param the specific parameter subclass
- * @return the parameter passed as arguments to make it easier to create and assign paramters in
- * one line
- */
- protected final
P addParameter(P parameter, Shape shape) {
- return addParameter(parameter, (inputShapes) -> shape);
- }
-
- /**
- * Adds a parameter to this block. If parameters are added with this method, intialization of
- * the parameter works out of the box
- *
* @param parameter the parameter to add, not null
- * @param shapeCallback the method to call once the input shape of this block is known to
- * determine the shape of the given parameter
- * @param
the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
- protected final
P addParameter(
- P parameter, Function shapeCallback) {
+ protected final P addParameter(P parameter) {
parameters.put(parameter.getName(), parameter);
- parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}
- /** {@inheritDoc} */
- @Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
- Function callback = parameterShapeCallbacks.get(name);
- if (callback == null) {
- Parameter parameter = parameters.get(name);
- if (parameter == null) {
- throw new IllegalArgumentException(
- "No parameter named " + name + " found in this block.");
- } else {
- throw new IllegalStateException(
- "No shape initializer for parameter "
- + name
- + "found. "
- + "Either pass an initializer for the shape when adding the "
- + "parameter or override getParameterShape in the subclass.");
- }
- }
- return callback.apply(inputShapes);
- }
-
/** {@inheritDoc} */
@Override
public BlockList getChildren() {
@@ -285,13 +223,9 @@ public PairList describeInput() {
/** {@inheritDoc} */
@Override
- public void setInitializer(Initializer initializer) {
- for (Parameter parameter : parameters.values()) {
- parameter.setInitializer(initializer, false);
- }
- for (Block child : children.values()) {
- child.setInitializer(initializer);
- }
+ public void setInitializer(Initializer initializer, Parameter.Type params) {
+ Predicate predicate = parameter -> parameter.getType().equals(params);
+ setInitializer(initializer, predicate);
}
/** {@inheritDoc} */
@@ -301,18 +235,50 @@ public void setInitializer(Initializer initializer, String paramName) {
if (parameter == null) {
throw new IllegalArgumentException("Could not find parameter " + paramName);
}
- parameter.setInitializer(initializer, true);
+ parameter.setInitializer(initializer);
}
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void setInitializer(Initializer initializer, Predicate predicate) {
+ List params = getParameters().values();
+ for (Parameter param : params) {
+ if (predicate.test(param)) {
+ param.setInitializer(initializer);
+ }
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
+ // if parameters are initialized, skip it
+ if (!isInitialized()) {
+ // setShape for all params
+ prepare(inputShapes);
+ }
for (Parameter parameter : parameters.values()) {
- parameter.initialize(manager, dataType, inputShapes);
+ parameter.initialize(manager, dataType);
}
initializeChildBlocks(manager, dataType, inputShapes);
- return getOutputShapes(manager, inputShapes);
+ }
+
+ /**
+ * Performs any action necessary before initialization. For example, keep the input information
+ * or verify the layout.
+ *
+ * @param inputShapes the expected shapes of the input
+ */
+ protected void beforeInitialize(Shape... inputShapes) {
+ if (inputNames.isEmpty()) {
+ // automatically assign input names
+ inputNames = new ArrayList<>();
+ for (int i = 0; i < inputShapes.length; ++i) {
+ inputNames.add("data" + i);
+ }
+ }
+ this.inputShapes = inputShapes;
}
/**
@@ -355,20 +321,11 @@ public ParameterList getDirectParameters() {
}
/**
- * Performs any action necessary before initialization.
+ * Sets the shape of {@link Parameter}s.
*
- * @param inputShapes the expected shapes of the input
+ * @param inputShapes the shapes of inputs
*/
- protected void beforeInitialize(Shape[] inputShapes) {
- if (inputNames.isEmpty()) {
- // automatically assign input names
- inputNames = new ArrayList<>();
- for (int i = 0; i < inputShapes.length; ++i) {
- inputNames.add("data" + i);
- }
- }
- this.inputShapes = inputShapes;
- }
+ protected void prepare(Shape[] inputShapes) {}
/** {@inheritDoc} */
@Override
@@ -494,7 +451,7 @@ public String toString() {
appendShape(sb, inputShapeDescription.values().toArray(new Shape[0]));
sb.append(" -> ");
Shape[] outputShapes =
- getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0]));
+ getOutputShapes(inputShapeDescription.values().toArray(new Shape[0]));
appendShape(sb, outputShapes);
} else {
sb.append("Uninitialized");
diff --git a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
index f870eea148f..fdd7fcc0d6c 100644
--- a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
+++ b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
@@ -12,7 +12,6 @@
*/
package ai.djl.nn;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
@@ -29,7 +28,7 @@ public AbstractSymbolBlock(byte version) {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}
}
diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java
index 6febeff5f6f..4bd3fb09e2b 100644
--- a/api/src/main/java/ai/djl/nn/Block.java
+++ b/api/src/main/java/ai/djl/nn/Block.java
@@ -25,6 +25,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
+import java.util.function.Predicate;
/**
* A {@code Block} is a composable function that forms a neural network.
@@ -158,11 +159,12 @@ default NDList forward(
}
/**
- * Sets an {@link Initializer} to the block.
+ * Sets an {@link Initializer} to all the parameters that match parameter type in the block.
*
* @param initializer the initializer to set
+ * @param type the Parameter Type we want to setInitializer
*/
- void setInitializer(Initializer initializer);
+ void setInitializer(Initializer initializer, Parameter.Type type);
/**
* Sets an {@link Initializer} to the specified direct parameter of the block, overriding the
@@ -173,15 +175,22 @@ default NDList forward(
*/
void setInitializer(Initializer initializer, String paramName);
+ /**
+ * Sets an {@link Initializer} to all the parameters that match Predicate in the block.
+ *
+ * @param initializer the initializer to be set
+ * @param predicate predicate function to indicate parameters you want to set
+ */
+ void setInitializer(Initializer initializer, Predicate predicate);
+
/**
* Initializes the parameters of the block. This method must be called before calling `forward`.
*
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
- * @return the shapes of the outputs of the block
*/
- Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes);
+ void initialize(NDManager manager, DataType dataType, Shape... inputShapes);
/**
* Returns a boolean whether the block is initialized.
@@ -232,25 +241,13 @@ default NDList forward(
*/
ParameterList getParameters();
- /**
- * Returns the shape of the specified direct parameter of this block given the shapes of the
- * input to the block.
- *
- * @param name the name of the parameter
- * @param inputShapes the shapes of the input to the block
- * @return the shape of the parameter specified
- * @throws IllegalArgumentException if the parameter name specified is invalid
- */
- Shape getParameterShape(String name, Shape[] inputShapes);
-
/**
* Returns the expected output shapes of the block for the specified input shapes.
*
- * @param manager an NDManager
* @param inputShapes the shapes of the inputs
* @return the expected output shapes of the block
*/
- Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes);
+ Shape[] getOutputShapes(Shape[] inputShapes);
/**
* Writes the parameters of the block to the given outputStream.
diff --git a/api/src/main/java/ai/djl/nn/LambdaBlock.java b/api/src/main/java/ai/djl/nn/LambdaBlock.java
index 939e0e09b29..1b6fc0d3bc5 100644
--- a/api/src/main/java/ai/djl/nn/LambdaBlock.java
+++ b/api/src/main/java/ai/djl/nn/LambdaBlock.java
@@ -70,11 +70,11 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- try (NDManager subManager = manager.newSubManager()) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ try (NDManager manager = NDManager.newBaseManager()) {
NDList input = new NDList(inputShapes.length);
for (Shape shape : inputShapes) {
- input.add(subManager.zeros(shape));
+ input.add(manager.zeros(shape));
}
NDList output = lambda.apply(input);
Shape[] outputShapes = new Shape[output.size()];
diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java
index 92ce9f6349e..9d04569a63f 100644
--- a/api/src/main/java/ai/djl/nn/ParallelBlock.java
+++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java
@@ -149,16 +149,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
Preconditions.checkArgument(!children.isEmpty(), "The parallel block is empty");
- try (NDManager subManager = manager.newSubManager()) {
+ try (NDManager manager = NDManager.newBaseManager()) {
List inputs = new ArrayList<>();
for (Block block : children.values()) {
- Shape[] shapes = block.getOutputShapes(manager, inputShapes);
+ Shape[] shapes = block.getOutputShapes(inputShapes);
NDList output = new NDList(shapes.length);
for (Shape shape : shapes) {
- output.add(subManager.create(shape));
+ output.add(manager.create(shape));
}
inputs.add(output);
}
diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java
index 7d3c2954cde..369111057c7 100644
--- a/api/src/main/java/ai/djl/nn/Parameter.java
+++ b/api/src/main/java/ai/djl/nn/Parameter.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.initializer.Initializer;
+import ai.djl.training.initializer.XavierInitializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
@@ -42,62 +43,23 @@ public class Parameter implements AutoCloseable {
private String id;
private String name;
- private Block block;
- private ParameterType type;
- private DataType mandatoryDataType;
+ private Shape shape;
+ private Type type;
private Initializer initializer;
private NDArray array;
private boolean requiresGrad;
private SparseFormat gradientFormat;
- /**
- * Creates a {@code Parameter} with the given name, and parameter type, and associated with the
- * given {@link Block}.
- *
- * @param name the name of the {@code Parameter}
- * @param block the block with which this {@code Parameter} is associated
- * @param type the type of this {@code Parameter}
- */
- public Parameter(String name, Block block, ParameterType type) {
- this(name, block, type, true, SparseFormat.DENSE);
- }
-
- /**
- * Creates a {@code Parameter} with the given name, and parameter type, and associated with the
- * given {@link Block}.
- *
- * @param name the name of the {@code Parameter}
- * @param block the block with which this {@code Parameter} is associated
- * @param type the type of this {@code Parameter}
- * @param requiresGrad whether this {@code Parameter} needs to compute gradients
- */
- public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) {
- this(name, block, type, requiresGrad, SparseFormat.DENSE);
- }
-
- /**
- * Creates a {@code Parameter} with the given name, and parameter type, and associated with the
- * given {@link Block}.
- *
- * @param name the name of the {@code Parameter}
- * @param block the block with which this {@code Parameter} is associated
- * @param type the type of this {@code Parameter}
- * @param requireGrad whether this {@code Parameter} needs to compute gradients
- * @param gradientFormat the {@link SparseFormat} of the gradient array
- */
- public Parameter(
- String name,
- Block block,
- ParameterType type,
- boolean requireGrad,
- SparseFormat gradientFormat) {
+ Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
- this.name = name;
- this.block = block;
- this.type = type;
- this.requiresGrad = requireGrad;
- this.initializer = type.getInitializer();
- this.gradientFormat = gradientFormat;
+ this.name = builder.name;
+ this.shape = builder.shape;
+ this.type = builder.type;
+ this.array = builder.array;
+ this.requiresGrad = builder.requiresGrad;
+ this.initializer =
+ (builder.initializer != null) ? builder.initializer : type.getInitializer();
+ this.gradientFormat = builder.gradientFormat;
}
/**
@@ -123,7 +85,7 @@ public String getName() {
*
* @return the type of this {@code Parameter}
*/
- public ParameterType getType() {
+ public Type getType() {
return type;
}
@@ -133,10 +95,26 @@ public ParameterType getType() {
* @param array the {@link NDArray} that contains values of this {@code Parameter}
*/
public void setArray(NDArray array) {
+ if (shape != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
this.array = array;
+ shape = array.getShape();
array.setName(name);
}
+ /**
+ * Sets the shape of this {@code Parameter}.
+ *
+ * @param shape the shape of this {@code Parameter}
+ */
+ public void setShape(Shape shape) {
+ if (array != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
+ this.shape = shape;
+ }
+
/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
@@ -158,15 +136,6 @@ public boolean requireGradient() {
return requiresGrad;
}
- /**
- * Sets the mandatory data type for this {@code Parameter}.
- *
- * @param mandatoryDataType the mandatory data type for this {@code Parameter}
- */
- public void setMandatoryDataType(DataType mandatoryDataType) {
- this.mandatoryDataType = mandatoryDataType;
- }
-
/**
* Checks if this {@code Parameter} is initialized.
*
@@ -181,12 +150,9 @@ public boolean isInitialized() {
* flag is true, sets the initializer regardless.
*
* @param initializer the initializer to be set
- * @param overwrite if true, set the initializer regardless of whether its already set or not
*/
- public void setInitializer(Initializer initializer, boolean overwrite) {
- if (overwrite || this.initializer == null) {
- this.initializer = initializer;
- }
+ public void setInitializer(Initializer initializer) {
+ this.initializer = initializer;
}
/**
@@ -195,17 +161,12 @@ public void setInitializer(Initializer initializer, boolean overwrite) {
*
* @param manager an NDManager to create the arrays
* @param dataType the datatype of the {@code Parameter}
- * @param inputShapes the expected input shapes
*/
- public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
+ public void initialize(NDManager manager, DataType dataType) {
Objects.requireNonNull(initializer, "No initializer has been set");
+ Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
- Shape shape = block.getParameterShape(name, inputShapes);
- array =
- initializer.initialize(
- manager,
- shape,
- mandatoryDataType == null ? dataType : mandatoryDataType);
+ array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
@@ -266,6 +227,8 @@ public void load(NDManager manager, DataInputStream dis)
}
array = manager.decode(dis);
+ // set the shape of the parameter and prepare() can be skipped
+ shape = array.getShape();
}
/** {@inheritDoc} */
@@ -276,4 +239,141 @@ public void close() {
array = null;
}
}
+
+ /**
+ * Creates a builder to build a {@code Parameter}.
+ *
+ * The methods start with {@code set} are required fields, and {@code opt} for optional
+ * fields.
+ *
+ * @return a new builder
+ */
+ public static Parameter.Builder builder() {
+ return new Parameter.Builder();
+ }
+
+ /** Enumerates the types of {@link Parameter}. */
+ public enum Type {
+ WEIGHT(
+ new XavierInitializer(
+ XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)),
+ BIAS(Initializer.ZEROS),
+ GAMMA(Initializer.ONES),
+ BETA(Initializer.ZEROS),
+ RUNNING_MEAN(Initializer.ZEROS),
+ RUNNING_VAR(Initializer.ONES),
+ OTHER(null);
+
+ private final transient Initializer initializer;
+
+ Type(Initializer initializer) {
+ this.initializer = initializer;
+ }
+
+ /**
+ * Gets the {@link Initializer} of this {@code ParameterType}.
+ *
+ * @return the {@link Initializer} of this {@code ParameterType}
+ */
+ public Initializer getInitializer() {
+ return initializer;
+ }
+ }
+
+ /** A Builder to construct a {@code Parameter}. */
+ public static final class Builder {
+ String name;
+ Shape shape;
+ Type type;
+ Initializer initializer;
+ NDArray array;
+ boolean requiresGrad = true;
+ SparseFormat gradientFormat;
+
+ /**
+ * Sets the name of the {@code Parameter}.
+ *
+ * @param name the name of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setName(String name) {
+ this.name = name;
+ return this;
+ }
+
+ /**
+ * Sets the {@code Type} of the {@code Parameter}.
+ *
+ * @param type the {@code Type} of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setType(Type type) {
+ this.type = type;
+ return this;
+ }
+
+ /**
+ * Sets the shape of the {@code Parameter}.
+ *
+ * @param shape the shape of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optShape(Shape shape) {
+ this.shape = shape;
+ return this;
+ }
+
+ /**
+ * Sets the Initializer of the {@code Parameter}.
+ *
+ * @param initializer the Initializer of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optInitializer(Initializer initializer) {
+ this.initializer = initializer;
+ return this;
+ }
+
+ /**
+ * Sets the array of the {@code Parameter}.
+ *
+ * @param array the array of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optArray(NDArray array) {
+ this.array = array;
+ return this;
+ }
+
+ /**
+ * Sets if the {@code Parameter} requires gradient.
+ *
+ * @param requiresGrad if the {@code Parameter} requires gradient
+ * @return this {@code Parameter}
+ */
+ public Builder optRequiresGrad(boolean requiresGrad) {
+ this.requiresGrad = requiresGrad;
+ return this;
+ }
+
+ /**
+ * Sets the {@link SparseFormat} of the {@code Parameter}.
+ *
+ * @param gradientFormat the {@link SparseFormat} of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optGradientFormat(SparseFormat gradientFormat) {
+ this.gradientFormat = gradientFormat;
+ return this;
+ }
+
+ /**
+ * Builds a {@code Parameter} instance.
+ *
+ * @return the {@code Parameter} instance
+ */
+ public Parameter build() {
+ return new Parameter(this);
+ }
+ }
}
diff --git a/api/src/main/java/ai/djl/nn/ParameterType.java b/api/src/main/java/ai/djl/nn/ParameterType.java
deleted file mode 100644
index dde3cb23ab8..00000000000
--- a/api/src/main/java/ai/djl/nn/ParameterType.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
- * with the License. A copy of the License is located at
- *
- * http://aws.amazon.com/apache2.0/
- *
- * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
- * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
- * and limitations under the License.
- */
-package ai.djl.nn;
-
-import ai.djl.training.initializer.Initializer;
-
-/** Enumerates the types of {@link Parameter}. */
-public enum ParameterType {
- WEIGHT(null),
- BIAS(Initializer.ZEROS),
- GAMMA(Initializer.ONES),
- BETA(Initializer.ZEROS),
- RUNNING_MEAN(Initializer.ZEROS),
- RUNNING_VAR(Initializer.ONES),
- OTHER(null);
-
- private final transient Initializer initializer;
-
- ParameterType(Initializer initializer) {
- this.initializer = initializer;
- }
-
- /**
- * Gets the {@link Initializer} of this {@code ParameterType}.
- *
- * @return the {@link Initializer} of this {@code ParameterType}
- */
- public Initializer getInitializer() {
- return initializer;
- }
-}
diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java
index d0c6f168fb8..8908c1ad6c8 100644
--- a/api/src/main/java/ai/djl/nn/SequentialBlock.java
+++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java
@@ -154,19 +154,20 @@ protected NDList forwardInternal(
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape[] shapes = inputShapes;
for (Block child : getChildren().values()) {
- shapes = child.initialize(manager, dataType, shapes);
+ child.initialize(manager, dataType, shapes);
+ shapes = child.getOutputShapes(shapes);
}
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
+ public Shape[] getOutputShapes(Shape[] inputs) {
if (children.isEmpty()) {
throw new IllegalArgumentException("The sequential block is empty");
}
Shape[] current = inputs;
for (Block block : children.values()) {
- current = block.getOutputShapes(manager, current);
+ current = block.getOutputShapes(current);
}
return current;
}
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
index 6334935804c..1d211f1d903 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
@@ -16,13 +16,11 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -102,13 +100,17 @@ public Convolution(ConvolutionBuilder> builder) {
weight =
addParameter(
- new Parameter("weight", this, ParameterType.WEIGHT),
- (inputShapes) ->
- new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
+ Parameter.builder()
+ .setName("weight")
+ .setType(Parameter.Type.WEIGHT)
+ .build());
if (includeBias) {
bias =
addParameter(
- new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
+ Parameter.builder()
+ .setName("bias")
+ .setType(Parameter.Type.BIAS)
+ .build());
}
}
@@ -149,15 +151,24 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- protected void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(getExpectedLayout(), inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout());
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
+ protected void prepare(Shape[] inputs) {
+ long inputChannel = inputs[0].get(1);
+ weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape));
+ if (bias != null) {
+ bias.setShape(new Shape(filters));
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Shape[] getOutputShapes(Shape[] inputs) {
long[] shape = new long[numDimensions()];
shape[0] = inputs[0].get(0);
shape[1] = filters;
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
index 890a4db9eff..10059208a5c 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
@@ -16,13 +16,11 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -78,13 +76,17 @@ public Deconvolution(DeconvolutionBuilder> builder) {
weight =
addParameter(
- new Parameter("weight", this, ParameterType.WEIGHT),
- (inputShapes) ->
- new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
+ Parameter.builder()
+ .setName("weight")
+ .setType(Parameter.Type.WEIGHT)
+ .build());
if (includeBias) {
bias =
addParameter(
- new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
+ Parameter.builder()
+ .setName("bias")
+ .setType(Parameter.Type.BIAS)
+ .build());
}
}
@@ -126,15 +128,24 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- protected void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(getExpectedLayout(), inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout());
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
+ protected void prepare(Shape[] inputs) {
+ long inputChannel = inputs[0].get(1);
+ weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape));
+ if (bias != null) {
+ bias.setShape(new Shape(filters));
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Shape[] getOutputShapes(Shape[] inputs) {
long[] shape = new long[numDimensions()];
shape[0] = inputs[0].get(0);
shape[1] = filters;
diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
index e2dceddcef4..ce5c486d701 100644
--- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
+++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
@@ -57,7 +57,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(embedding.getShape())};
}
diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java
index 48736aaa8e1..2e65295c1f7 100644
--- a/api/src/main/java/ai/djl/nn/core/Embedding.java
+++ b/api/src/main/java/ai/djl/nn/core/Embedding.java
@@ -23,7 +23,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -56,8 +55,11 @@ protected Embedding(BaseBuilder baseBuilder) {
sparseFormat = baseBuilder.sparseFormat;
embedding =
addParameter(
- new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat),
- (inputShapes) -> new Shape(numEmbeddings, embeddingSize));
+ Parameter.builder()
+ .setName("embedding")
+ .setType(Parameter.Type.WEIGHT)
+ .optGradientFormat(sparseFormat)
+ .build());
if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
throw new IllegalArgumentException(
"You can not specify both a fallthrough and a defaultItem");
@@ -93,15 +95,25 @@ public Embedding(NDArray embedding, SparseFormat format) {
this.sparseFormat = format;
this.embedding =
addParameter(
- new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat),
- (inputShapes) -> new Shape(numEmbeddings, embeddingSize));
+ Parameter.builder()
+ .setName("embedding")
+ .setType(Parameter.Type.WEIGHT)
+ .optGradientFormat(sparseFormat)
+ .build());
this.embedding.setArray(embedding);
inputShapes = new Shape[] {new Shape(-1)};
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public void prepare(Shape[] inputShapes) {
+ // numItems will be adjusted by embedding array or fallthroughEmbedding
+ embedding.setShape(new Shape(numEmbeddings, embeddingSize));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
}
diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java
index 34a8c020a18..79459348fb2 100644
--- a/api/src/main/java/ai/djl/nn/core/Linear.java
+++ b/api/src/main/java/ai/djl/nn/core/Linear.java
@@ -16,12 +16,10 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
@@ -59,14 +57,19 @@ public class Linear extends AbstractBlock {
Linear(Builder builder) {
super(VERSION);
units = builder.units;
- // "inputFeatures" is only known after "beforeInitialize" is called, hence we need
- // a callback, even if we do not used the callback parameter
weight =
addParameter(
- new Parameter("weight", this, ParameterType.WEIGHT),
- inputShapes -> new Shape(units, inputFeatures));
+ Parameter.builder()
+ .setName("weight")
+ .setType(Parameter.Type.WEIGHT)
+ .build());
if (builder.bias) {
- bias = addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(units));
+ bias =
+ addParameter(
+ Parameter.builder()
+ .setName("bias")
+ .setType(Parameter.Type.BIAS)
+ .build());
}
}
@@ -86,8 +89,8 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
- return new Shape[] {inputShape.addAll(new Shape(units))};
+ public Shape[] getOutputShapes(Shape[] inputs) {
+ return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)};
}
/** {@inheritDoc} */
@@ -99,13 +102,24 @@ public PairList describeInput() {
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputShapes) {
+ protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
+ Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
Shape input = inputShapes[0];
inputFeatures = input.get(input.dimension() - 1);
inputShape = input.slice(0, input.dimension() - 1);
}
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ Shape input = inputShapes[0];
+ weight.setShape(new Shape(units, input.get(input.dimension() - 1)));
+ if (bias != null) {
+ bias.setShape(new Shape(units));
+ }
+ }
+
/** {@inheritDoc} */
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java
index abb268f9446..ced9d709e87 100644
--- a/api/src/main/java/ai/djl/nn/core/Prelu.java
+++ b/api/src/main/java/ai/djl/nn/core/Prelu.java
@@ -15,12 +15,10 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -44,7 +42,13 @@ public class Prelu extends AbstractBlock {
/** Creates a Parametric ReLU Block. */
public Prelu() {
super(VERSION);
- alpha = addParameter(new Parameter("alpha", this, ParameterType.OTHER), new Shape());
+ alpha =
+ addParameter(
+ Parameter.builder()
+ .setName("alpha")
+ .setType(Parameter.Type.WEIGHT)
+ .optShape(new Shape())
+ .build());
}
/** {@inheritDoc} */
@@ -61,7 +65,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
+ public Shape[] getOutputShapes(Shape[] inputs) {
return new Shape[] {inputs[0]};
}
diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
index 7f5f243646b..dc93360811a 100644
--- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
+++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
@@ -16,12 +16,10 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
@@ -86,26 +84,37 @@ public class BatchNorm extends AbstractBlock {
momentum = builder.momentum;
center = builder.center;
scale = builder.scale;
- // When creating parameters we use a callback as "inChannels" is set before initialization,
- // it is not known yet.
+
// make gamma trainable if scale
gamma =
addParameter(
- new Parameter("gamma", this, ParameterType.GAMMA, scale),
- (inputShapes) -> new Shape(inChannels));
+ Parameter.builder()
+ .setName("gamma")
+ .setType(Parameter.Type.GAMMA)
+ .optRequiresGrad(scale)
+ .build());
// make beta trainable if center
beta =
addParameter(
- new Parameter("beta", this, ParameterType.BETA, center),
- (inputShapes) -> new Shape(inChannels));
+ Parameter.builder()
+ .setName("beta")
+ .setType(Parameter.Type.BETA)
+ .optRequiresGrad(center)
+ .build());
runningMean =
addParameter(
- new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false),
- (inputShapes) -> new Shape(inChannels));
+ Parameter.builder()
+ .setName("runningMean")
+ .setType(Parameter.Type.RUNNING_MEAN)
+ .optRequiresGrad(false)
+ .build());
runningVar =
addParameter(
- new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false),
- (inputShapes) -> new Shape(inChannels));
+ Parameter.builder()
+ .setName("runningVar")
+ .setType(Parameter.Type.RUNNING_VAR)
+ .optRequiresGrad(false)
+ .build());
}
/** {@inheritDoc} */
@@ -135,17 +144,26 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0]};
}
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputShapes) {
+ protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
inChannels = inputShapes[0].size(axis);
}
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ gamma.setShape(new Shape(inChannels));
+ beta.setShape(new Shape(inChannels));
+ runningMean.setShape(new Shape(inChannels));
+ runningVar.setShape(new Shape(inChannels));
+ }
+
/** {@inheritDoc} */
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
diff --git a/api/src/main/java/ai/djl/nn/norm/Dropout.java b/api/src/main/java/ai/djl/nn/norm/Dropout.java
index 7a44c954288..20bcd201138 100644
--- a/api/src/main/java/ai/djl/nn/norm/Dropout.java
+++ b/api/src/main/java/ai/djl/nn/norm/Dropout.java
@@ -15,7 +15,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
@@ -76,7 +75,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0]};
}
diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
index 1251c32ee49..c51620f5dc5 100644
--- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
+++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
@@ -13,13 +13,13 @@
package ai.djl.nn.recurrent;
import ai.djl.MalformedModelException;
-import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
+import ai.djl.nn.ParameterList;
+import ai.djl.util.Pair;
import java.io.DataInputStream;
import java.io.IOException;
@@ -67,7 +67,7 @@ public RecurrentBlock(BaseBuilder> builder) {
bidirectional = builder.bidirectional;
returnState = builder.returnState;
- ParameterType[] parameterTypes = {ParameterType.WEIGHT, ParameterType.BIAS};
+ Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS};
String[] directions = {"l"};
if (builder.bidirectional) {
directions = new String[] {"l", "r"};
@@ -75,12 +75,13 @@ public RecurrentBlock(BaseBuilder> builder) {
String[] gateStrings = {"i2h", "h2h"};
for (int i = 0; i < numLayers; i++) {
- for (ParameterType parameterType : parameterTypes) {
+ for (Parameter.Type parameterType : parameterTypes) {
for (String direction : directions) {
for (String gateString : gateStrings) {
String name =
direction + '_' + i + '_' + gateString + '_' + parameterType.name();
- addParameter(new Parameter(name, this, parameterType));
+ addParameter(
+ Parameter.builder().setName(name).setType(parameterType).build());
}
}
}
@@ -89,7 +90,7 @@ public RecurrentBlock(BaseBuilder> builder) {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
+ public Shape[] getOutputShapes(Shape[] inputs) {
Shape inputShape = inputs[0];
Shape outputShape =
new Shape(inputShape.get(0), inputShape.get(1), stateSize * getNumDirections());
@@ -109,31 +110,34 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(EXPECTED_LAYOUT, inputShapes[0].getLayout());
}
/** {@inheritDoc} */
@Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
- int layer = Integer.parseInt(name.split("_")[1]);
- Shape shape = inputShapes[0];
- long inputs = shape.get(2);
- if (layer > 0) {
- inputs = stateSize * getNumDirections();
- }
- if (name.contains("BIAS")) {
- return new Shape(gates * stateSize);
- }
- if (name.contains("i2h")) {
- return new Shape(gates * stateSize, inputs);
- }
- if (name.contains("h2h")) {
- return new Shape(gates * stateSize, stateSize);
+ public void prepare(Shape[] inputs) {
+ Shape inputShape = inputs[0];
+ ParameterList parameters = getDirectParameters();
+ for (Pair pair : parameters) {
+ String name = pair.getKey();
+ Parameter parameter = pair.getValue();
+ int layer = Integer.parseInt(name.split("_")[1]);
+ long inputSize = inputShape.get(2);
+ if (layer > 0) {
+ inputSize = stateSize * getNumDirections();
+ }
+ if (name.contains("BIAS")) {
+ parameter.setShape(new Shape(gates * stateSize));
+ } else if (name.contains("i2h")) {
+ parameter.setShape(new Shape(gates * stateSize, inputSize));
+ } else if (name.contains("h2h")) {
+ parameter.setShape(new Shape(gates * stateSize, stateSize));
+ } else {
+ throw new IllegalArgumentException("Invalid parameter name");
+ }
}
- throw new IllegalArgumentException("Invalid parameter name");
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
index 066fc0c4229..b8ca2caf120 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
@@ -22,7 +22,6 @@
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
@@ -77,8 +76,12 @@ private BertBlock(Builder builder) {
// embedding for the position
this.positionEmebdding =
addParameter(
- new Parameter(PARAM_POSITION_EMBEDDING, this, ParameterType.WEIGHT),
- new Shape(builder.maxSequenceLength, builder.embeddingSize));
+ Parameter.builder()
+ .setName(PARAM_POSITION_EMBEDDING)
+ .setType(Parameter.Type.WEIGHT)
+ .optShape(
+ new Shape(builder.maxSequenceLength, builder.embeddingSize))
+ .build());
// embedding for the input types
this.typeEmbedding =
addChildBlock(
@@ -153,7 +156,7 @@ public int getTypeDictionarySize() {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
long batch = inputShapes[0].get(0);
long seqLength = inputShapes[0].get(1);
return new Shape[] {
@@ -164,11 +167,12 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
/** {@inheritDoc} */
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
- beforeInitialize(inputShapes);
+ super.beforeInitialize(inputShapes);
inputNames = Arrays.asList("tokenIds", "typeIds", "masks");
Shape[] tokenShape = {inputShapes[0]};
Shape[] typeShape = {inputShapes[1]};
- Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape);
+ this.tokenEmbedding.initialize(manager, dataType, tokenShape);
+ Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(tokenShape);
this.typeEmbedding.initialize(manager, dataType, typeShape);
this.embeddingNorm.initialize(manager, dataType, embeddingOutput);
this.embeddingDropout.initialize(manager, dataType, embeddingOutput);
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
index ab9a7216803..7fc336434b6 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
@@ -19,7 +19,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.training.ParameterStore;
@@ -59,8 +58,11 @@ public BertMaskedLanguageModelBlock(
this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build());
this.dictionaryBias =
addParameter(
- new Parameter("dictionaryBias", this, ParameterType.BIAS),
- new Shape(bertBlock.getTokenDictionarySize()));
+ Parameter.builder()
+ .setName("dictionaryBias")
+ .setType(Parameter.Type.BIAS)
+ .optShape(new Shape(bertBlock.getTokenDictionarySize()))
+ .build());
this.hiddenActivation = hiddenActivation;
}
@@ -140,7 +142,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(final NDManager manager, final Shape[] inputShapes) {
+ public Shape[] getOutputShapes(final Shape[] inputShapes) {
int batchSize = (int) inputShapes[0].get(0);
int indexCount = (int) inputShapes[1].get(1);
int dictionarySize = (int) inputShapes[2].get(0);
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java
index 40f9a717b65..aac91ab5745 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java
@@ -53,7 +53,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {new Shape(inputShapes[0].get(0), 2)};
}
}
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
index 2478811d037..d7f44089564 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
@@ -51,7 +51,8 @@ public BertPretrainingBlock(final BertBlock.Builder builder) {
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices");
- Shape[] bertOutputShapes = bertBlock.initialize(manager, dataType, inputShapes);
+ bertBlock.initialize(manager, dataType, inputShapes);
+ Shape[] bertOutputShapes = bertBlock.getOutputShapes(inputShapes);
Shape embeddedSequence = bertOutputShapes[0];
Shape pooledOutput = bertOutputShapes[1];
Shape maskedIndices = inputShapes[2];
@@ -97,7 +98,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
long batchSize = inputShapes[0].get(0);
long maskedIndexCount = inputShapes[3].get(1);
return new Shape[] {
diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
index 4aa6484f766..2b3c52d3526 100644
--- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
+++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
@@ -21,7 +21,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
@@ -47,13 +46,16 @@ private IdEmbedding(Builder builder) {
this.embeddingSize = builder.embeddingSize;
this.embedding =
addParameter(
- new Parameter(EMBEDDING_PARAM_NAME, this, ParameterType.WEIGHT),
- new Shape(dictionarySize, embeddingSize));
+ Parameter.builder()
+ .setName(EMBEDDING_PARAM_NAME)
+ .setType(Parameter.Type.WEIGHT)
+ .optShape(new Shape(dictionarySize, embeddingSize))
+ .build());
}
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
}
diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
index 10d4aa20b40..9d3367f9cf5 100644
--- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
@@ -32,8 +32,6 @@ public class PointwiseFeedForwardBlock extends AbstractBlock {
private static final byte VERSION = 1;
- private Shape outputShape;
-
/**
* Creates a pointwise feed-forward block.
*
@@ -49,7 +47,7 @@ public PointwiseFeedForwardBlock(
super(VERSION);
// add hidden layers with activation
int count = 0;
- for (final int hiddenSize : hiddenSizes) {
+ for (int hiddenSize : hiddenSizes) {
addChildBlock(
"linear_" + count, Linear.builder().optBias(true).setUnits(hiddenSize).build());
addChildBlock("activation_" + count, new LambdaBlock(activationFunction));
@@ -61,8 +59,11 @@ public PointwiseFeedForwardBlock(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- return new Shape[] {outputShape};
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ for (Block child : children.values()) {
+ inputShapes = child.getOutputShapes(inputShapes);
+ }
+ return inputShapes;
}
/** {@inheritDoc} */
@@ -75,16 +76,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
}
// Now that we know the input shape, we can determine the reshape necessary
// to shape the input and re-shape the output
- final Shape inputShape = inputShapes[0];
+ Shape inputShape = inputShapes[0];
if (inputShape.dimension() < 2) {
throw new IllegalArgumentException(
"Pointwise feed forward blocks need an input of at least dimension 2.");
}
Shape lastShape = inputShape;
for (Block child : children.values()) {
- lastShape = child.initialize(manager, dataType, lastShape)[0];
+ child.initialize(manager, dataType, lastShape);
+ lastShape = getOutputShapes(new Shape[] {lastShape})[0];
}
- outputShape = lastShape;
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java
index ffee4df219c..fa293f4b37a 100644
--- a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java
@@ -146,7 +146,7 @@ public Linear getResultProjection() {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
// Return shape is the shape of the query. For 2 or less inputs we have self-attention, i.e.
// the shape of the output is the shape of the input
if (inputShapes.length == 1 || inputShapes.length == 2) {
diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java
index 83fb87446e9..4cbabecada7 100644
--- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java
@@ -82,7 +82,7 @@ public TransformerEncoderBlock(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return inputShapes;
}
diff --git a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
index be9f7bbfec6..b9580914201 100644
--- a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
+++ b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
@@ -13,23 +13,23 @@
package ai.djl.training;
import ai.djl.Device;
+import ai.djl.nn.Parameter;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
-import ai.djl.training.initializer.XavierInitializer;
-import ai.djl.training.initializer.XavierInitializer.FactorType;
-import ai.djl.training.initializer.XavierInitializer.RandomType;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
+import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.function.Predicate;
/** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */
public class DefaultTrainingConfig implements TrainingConfig {
- private Initializer initializer;
+ private PairList> initializers = new PairList<>();
private Optimizer optimizer;
private Device[] devices;
private Loss loss;
@@ -38,15 +38,12 @@ public class DefaultTrainingConfig implements TrainingConfig {
/**
* Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code
- * DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link
- * XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The
- * evaluators and listeners are left to the user's discretion.
+ * DefaultTrainingConfig} creates a default {@link TrainingConfig}, {@link Adam} as optimiser,
+ * and the given {@link Loss}. The evaluators and listeners are left to the user's discretion.
*
* @param loss the loss to use for training
*/
public DefaultTrainingConfig(Loss loss) {
- // Defaults to initializer defined in https://arxiv.org/abs/1502.01852
- this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2);
optimizer = Adam.builder().build();
this.loss = loss;
evaluators = new ArrayList<>();
@@ -58,10 +55,38 @@ public DefaultTrainingConfig(Loss loss) {
* href="https://arxiv.org/abs/1502.01852">paper).
*
* @param initializer the initialer to use for the parameters
+ * @param type the {@link Parameter.Type} of the parameters
* @return this {@code DefaultTrainingConfig}
*/
- public DefaultTrainingConfig optInitializer(Initializer initializer) {
- this.initializer = initializer;
+ public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) {
+ initializers.add(initializer, parameter -> parameter.getType().equals(type));
+ return this;
+ }
+
+ /**
+ * Sets the {@link Initializer} to use for the parameters (default from paper).
+ *
+ * @param initializer the initialer to use for the parameters
+ * @param name the name of the parameter
+ * @return this {@code DefaultTrainingConfig}
+ */
+ public DefaultTrainingConfig optInitializer(Initializer initializer, String name) {
+ initializers.add(initializer, parameter -> parameter.getName().equals(name));
+ return this;
+ }
+
+ /**
+ * Sets the {@link Initializer} to use for the parameters (default from paper).
+ *
+ * @param initializer the initialer to use for the parameters
+ * @param predicate the predicate to identify parameter
+ * @return this {@code DefaultTrainingConfig}
+ */
+ public DefaultTrainingConfig optInitializer(
+ Initializer initializer, Predicate predicate) {
+ initializers.add(initializer, predicate);
return this;
}
@@ -120,8 +145,8 @@ public Device[] getDevices() {
/** {@inheritDoc} */
@Override
- public Initializer getInitializer() {
- return initializer;
+ public PairList> getInitializers() {
+ return initializers;
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/training/TrainingConfig.java b/api/src/main/java/ai/djl/training/TrainingConfig.java
index 46748232035..0a4b5928266 100644
--- a/api/src/main/java/ai/djl/training/TrainingConfig.java
+++ b/api/src/main/java/ai/djl/training/TrainingConfig.java
@@ -13,12 +13,15 @@
package ai.djl.training;
import ai.djl.Device;
+import ai.djl.nn.Parameter;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
+import ai.djl.util.PairList;
import java.util.List;
+import java.util.function.Predicate;
/**
* An interface that is responsible for holding the configuration required by {@link Trainer}.
@@ -64,11 +67,11 @@ public interface TrainingConfig {
Device[] getDevices();
/**
- * Gets the {@link Initializer} to initialize the parameters of the model.
+ * Gets a list of {@link Initializer} and Predicate to initialize the parameters of the model.
*
* @return an {@link Initializer}
*/
- Initializer getInitializer();
+ PairList> getInitializers();
/**
* Gets the {@link Optimizer} to use during training.
diff --git a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java
index 64ee6279f71..8cf76c7b8f0 100644
--- a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java
+++ b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java
@@ -80,7 +80,7 @@ public XavierInitializer(RandomType randomType, FactorType factorType, float mag
/** Creates a new instance of {@code XavierInitializer}. */
public XavierInitializer() {
- this(RandomType.UNIFORM, FactorType.AVG, 3f);
+ this(RandomType.UNIFORM, FactorType.AVG, 6f);
}
/** {@inheritDoc} */
diff --git a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java
index abd5e9ec2a2..562dd3a0de9 100644
--- a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java
+++ b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java
@@ -31,7 +31,7 @@ private ShapeUtils() {}
* @return the corresponding output shape for the provided input
*/
public static Shape outputShapeForBlock(NDManager manager, Block block, Shape inputShape) {
- Shape[] outputs = block.getOutputShapes(manager, new Shape[] {inputShape});
+ Shape[] outputs = block.getOutputShapes(new Shape[] {inputShape});
return outputs[0];
}
diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java
index 3efdd00af77..dc8028c5c92 100644
--- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java
+++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java
@@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
@@ -48,7 +49,7 @@ public class AirfoilRandomAccessTest {
public void testAirfoilRemote() throws IOException, TranslateException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java
index 20bf7f8ac56..716f04c578f 100644
--- a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java
+++ b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java
@@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
@@ -48,7 +49,7 @@ public class AmesRandomAccessTest {
public void testAmesRandomAccessRemote() throws IOException, TranslateException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java
index 75bcb3e914a..063ec07fb13 100644
--- a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java
+++ b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java
@@ -16,6 +16,7 @@
import ai.djl.basicdataset.cv.PikachuDetection;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
@@ -42,7 +43,7 @@ public void testPikachuRemote() throws IOException, TranslateException {
.build();
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(new NormalInitializer(0.01f));
+ .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT);
try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
try (Trainer trainer = model.newTrainer(config)) {
diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
index 53eca2e14b4..7d1cda2adc2 100644
--- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
+++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java
@@ -20,6 +20,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Parameter;
import ai.djl.nn.transformer.BertBlock;
import ai.djl.nn.transformer.BertPretrainingBlock;
import ai.djl.nn.transformer.BertPretrainingLoss;
@@ -135,7 +136,8 @@ private static Model createBertPretrainingModel(Dictionary dictionary) {
model.setBlock(
new BertPretrainingBlock(
BERT_BUILDER.setTokenDictionarySize(dictionary.tokens.size())));
- model.getBlock().setInitializer(new TruncatedNormalInitializer(0.02f));
+ model.getBlock()
+ .setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);
return model;
}
diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java
index 13ce0f2f742..76078106a22 100644
--- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java
+++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java
@@ -31,7 +31,6 @@
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
-import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
@@ -119,7 +118,6 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
- .optInitializer(new XavierInitializer())
.optDevices(Device.getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java
index b1631b6b73f..b68fc81e1c7 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java
@@ -13,15 +13,18 @@
package ai.djl.fasttext;
import ai.djl.Device;
+import ai.djl.nn.Parameter;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
+import ai.djl.util.PairList;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.Predicate;
/** An interface that is responsible for holding the configuration required by fastText training. */
public class FtTrainingConfig implements TrainingConfig {
@@ -247,7 +250,7 @@ public Device[] getDevices() {
/** {@inheritDoc} */
@Override
- public Initializer getInitializer() {
+ public PairList> getInitializers() {
return null;
}
diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java
index f105d46c8cd..2f6e41cda36 100644
--- a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java
@@ -20,7 +20,6 @@
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.ParameterStore;
-import ai.djl.training.initializer.XavierInitializer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -31,41 +30,34 @@ public class SingleShotDetectionTest {
@Test
public void testClassPredictorBlocks() {
- try (NDManager manager = NDManager.newBaseManager()) {
- Block block = SingleShotDetection.getClassPredictionBlock(5, 10);
- Assert.assertEquals(
- block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0],
- new Shape(2, 55, 20, 20));
- block = SingleShotDetection.getClassPredictionBlock(3, 10);
- Assert.assertEquals(
- block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0],
- new Shape(2, 33, 10, 10));
- }
+ Block block = SingleShotDetection.getClassPredictionBlock(5, 10);
+ Assert.assertEquals(
+ block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0],
+ new Shape(2, 55, 20, 20));
+ block = SingleShotDetection.getClassPredictionBlock(3, 10);
+ Assert.assertEquals(
+ block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0],
+ new Shape(2, 33, 10, 10));
}
@Test
public void testAnchorPredictorBlocks() {
- try (NDManager manager = NDManager.newBaseManager()) {
- Block block = SingleShotDetection.getAnchorPredictionBlock(5);
- Assert.assertEquals(
- block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0],
- new Shape(2, 20, 20, 20));
- block = SingleShotDetection.getClassPredictionBlock(3, 10);
- Assert.assertEquals(
- block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0],
- new Shape(2, 33, 10, 10));
- }
+ Block block = SingleShotDetection.getAnchorPredictionBlock(5);
+ Assert.assertEquals(
+ block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0],
+ new Shape(2, 20, 20, 20));
+ block = SingleShotDetection.getClassPredictionBlock(3, 10);
+ Assert.assertEquals(
+ block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0],
+ new Shape(2, 33, 10, 10));
}
@Test
public void testDownSamplingBlock() {
- try (NDManager manager = NDManager.newBaseManager()) {
- Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10);
- Assert.assertEquals(
- sequentialBlock
- .getOutputShapes(manager, new Shape[] {new Shape(2, 3, 20, 20)})[0],
- new Shape(2, 10, 10, 10));
- }
+ Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10);
+ Assert.assertEquals(
+ sequentialBlock.getOutputShapes(new Shape[] {new Shape(2, 3, 20, 20)})[0],
+ new Shape(2, 10, 10, 10));
}
@Test
@@ -97,7 +89,6 @@ public void testSingleShotDetectionShape() {
.setSizes(sizes)
.setBaseNetwork(block)
.build();
- ssd.setInitializer(new XavierInitializer());
ssd.initialize(manager, DataType.FLOAT32, new Shape(32, 3, 256, 256));
ParameterStore ps = new ParameterStore(manager, false);
NDList output =
@@ -105,8 +96,7 @@ public void testSingleShotDetectionShape() {
Assert.assertEquals(output.get(0).getShape(), new Shape(1, 5444, 4));
Assert.assertEquals(output.get(1).getShape(), new Shape(32, 5444, 2));
Assert.assertEquals(output.get(2).getShape(), new Shape(32, 21776));
- Shape[] outputShapes =
- ssd.getOutputShapes(manager, new Shape[] {new Shape(32, 3, 256, 256)});
+ Shape[] outputShapes = ssd.getOutputShapes(new Shape[] {new Shape(32, 3, 256, 256)});
Assert.assertEquals(outputShapes[0], new Shape(1, 5444, 4));
Assert.assertEquals(outputShapes[1], new Shape(32, 5444, 2));
Assert.assertEquals(outputShapes[2], new Shape(32, 21776));
diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java
index 131df85c53e..cc6d43853c8 100644
--- a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java
@@ -23,7 +23,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.training.ParameterStore;
-import ai.djl.training.initializer.XavierInitializer;
import java.util.Arrays;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -50,7 +49,6 @@ public void testEncoder() {
.optReturnState(true)
.build());
try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) {
- encoder.setInitializer(new XavierInitializer());
encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7));
NDList output =
encoder.forward(
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java
index f956bfb46bc..ae0ea251e24 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java
@@ -159,7 +159,7 @@ public void testOutputShapes() {
Shape currentShape = x.getShape();
Block alexNet = AlexNet.builder().build();
- alexNet.setInitializer(Initializer.ONES);
+ alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
alexNet.initialize(manager, DataType.FLOAT32, currentShape);
Map shapeMap = new ConcurrentHashMap<>();
@@ -169,7 +169,7 @@ public void testOutputShapes() {
alexNet.getChildren()
.get(i)
.getValue()
- .getOutputShapes(manager, new Shape[] {currentShape});
+ .getOutputShapes(new Shape[] {currentShape});
currentShape = newShape[0];
shapeMap.put(alexNet.getChildren().get(i).getKey(), currentShape);
}
@@ -188,7 +188,7 @@ public void testForwardMethod() {
Block alexNet = AlexNet.builder().build();
int batchSize = 1;
NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224));
- alexNet.setInitializer(Initializer.ONES);
+ alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
alexNet.initialize(manager, DataType.FLOAT32, x.getShape());
NDArray xHat =
alexNet.forward(new ParameterStore(manager, true), new NDList(x), false)
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java
index a149b7e40e3..4befb517a2f 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java
@@ -100,7 +100,7 @@ public void testOutputShapes() {
Shape currentShape = x.getShape();
Block googLeNet = GoogLeNet.builder().build();
- googLeNet.setInitializer(Initializer.ONES);
+ googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
googLeNet.initialize(manager, DataType.FLOAT32, currentShape);
Map shapeMap = new ConcurrentHashMap<>();
@@ -111,7 +111,7 @@ public void testOutputShapes() {
.getChildren()
.get(i)
.getValue()
- .getOutputShapes(manager, new Shape[] {currentShape});
+ .getOutputShapes(new Shape[] {currentShape});
currentShape = newShape[0];
shapeMap.put(googLeNet.getChildren().get(i).getKey(), currentShape);
}
@@ -130,7 +130,7 @@ public void testForwardMethod() {
Block googLeNet = GoogLeNet.builder().build();
int batchSize = 1;
NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28));
- googLeNet.setInitializer(Initializer.ONES);
+ googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
googLeNet.initialize(manager, DataType.FLOAT32, x.getShape());
NDArray xHat =
googLeNet
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java
index 39744eeefe6..6160ed5f7e4 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java
@@ -137,7 +137,7 @@ public void testOutputShapes() {
Shape currentShape = x.getShape();
Block leNet = LeNet.builder().build();
- leNet.setInitializer(Initializer.ONES);
+ leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
leNet.initialize(manager, DataType.FLOAT32, currentShape);
Map shapeMap = new ConcurrentHashMap<>();
@@ -147,7 +147,7 @@ public void testOutputShapes() {
leNet.getChildren()
.get(i)
.getValue()
- .getOutputShapes(manager, new Shape[] {currentShape});
+ .getOutputShapes(new Shape[] {currentShape});
currentShape = newShape[0];
shapeMap.put(leNet.getChildren().get(i).getKey(), currentShape);
}
@@ -165,7 +165,7 @@ public void testForwardMethod() {
Block leNet = LeNet.builder().build();
int batchSize = 1;
NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28));
- leNet.setInitializer(Initializer.ONES);
+ leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
leNet.initialize(manager, DataType.FLOAT32, x.getShape());
NDArray xHat =
leNet.forward(new ParameterStore(manager, true), new NDList(x), true)
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java
index cac77221d69..f8544debc70 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java
@@ -152,17 +152,14 @@ public void testOutputShapes() {
Shape currentShape = x.getShape();
Block nin = NiN.builder().build();
- nin.setInitializer(Initializer.ONES);
+ nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
nin.initialize(manager, DataType.FLOAT32, currentShape);
Map shapeMap = new ConcurrentHashMap<>();
for (int i = 0; i < nin.getChildren().size(); i++) {
Shape[] newShape =
- nin.getChildren()
- .get(i)
- .getValue()
- .getOutputShapes(manager, new Shape[] {currentShape});
+ nin.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape});
currentShape = newShape[0];
shapeMap.put(nin.getChildren().get(i).getKey(), currentShape);
}
@@ -180,7 +177,7 @@ public void testForwardMethod() {
Block nin = NiN.builder().build();
int batchSize = 1;
NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224));
- nin.setInitializer(Initializer.ONES);
+ nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
nin.initialize(manager, DataType.FLOAT32, x.getShape());
NDArray xHat =
nin.forward(new ParameterStore(manager, true), new NDList(x), false)
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java
index fd50e39943f..21c6779f36f 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java
@@ -57,7 +57,7 @@ public void testTrain() {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optDevices(Device.getDevices(2))
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block resNet50 =
ResNetV1.builder()
@@ -123,7 +123,7 @@ public void testLoadTrain()
TrainingConfig config =
new DefaultTrainingConfig(Loss.l1Loss())
.optDevices(Device.getDevices(2))
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Trainer trainer = model.newTrainer(config)) {
int batchSize = 2;
Shape inputShape = new Shape(batchSize, 3, 32, 32);
@@ -131,8 +131,7 @@ public void testLoadTrain()
trainer.initialize(inputShape);
NDManager manager = trainer.getManager();
- Shape[] outputShape =
- model.getBlock().getOutputShapes(manager, new Shape[] {inputShape});
+ Shape[] outputShape = model.getBlock().getOutputShapes(new Shape[] {inputShape});
NDArray data = manager.ones(new Shape(batchSize, 3, 32, 32));
NDArray label = manager.ones(outputShape[0]);
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java
index 267bd41cf80..114907144aa 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java
@@ -40,7 +40,7 @@ public void testTrain() {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optDevices(Device.getDevices(2))
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block squeezeNet = SqueezeNet.squeezenet(10);
try (Model model = Model.newInstance("squeezenet")) {
model.setBlock(squeezeNet);
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java
index b8520edd8f9..5c22bacecf1 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java
@@ -107,17 +107,14 @@ public void testOutputShapes() {
Shape currentShape = x.getShape();
Block vgg = VGG.builder().build();
- vgg.setInitializer(Initializer.ONES);
+ vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
vgg.initialize(manager, DataType.FLOAT32, currentShape);
Map shapeMap = new ConcurrentHashMap<>();
for (int i = 0; i < vgg.getChildren().size(); i++) {
Shape[] newShape =
- vgg.getChildren()
- .get(i)
- .getValue()
- .getOutputShapes(manager, new Shape[] {currentShape});
+ vgg.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape});
currentShape = newShape[0];
shapeMap.put(vgg.getChildren().get(i).getKey(), currentShape);
}
@@ -137,8 +134,9 @@ public void testForwardMethod() {
Block vgg = VGG.builder().build();
int batchSize = 1;
NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224));
- vgg.setInitializer(Initializer.ONES);
+ vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
vgg.initialize(manager, DataType.FLOAT32, x.getShape());
+
NDArray xHat =
vgg.forward(new ParameterStore(manager, true), new NDList(x), false)
.singletonOrThrow();
diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java
index 8e97ed38b87..2dbfac8ef2e 100644
--- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java
@@ -18,6 +18,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.GradientCollector;
@@ -139,7 +140,7 @@ public void testAddScalar() {
try (Trainer trainer =
model.newTrainer(
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES))) {
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) {
try (GradientCollector gradCol = trainer.newGradientCollector()) {
lhs.attachGradient();
result = NDArrays.add(lhs, 2);
@@ -360,7 +361,7 @@ public void testDot() {
try (Trainer trainer =
model.newTrainer(
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES))) {
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) {
try (GradientCollector gradCol = trainer.newGradientCollector()) {
lhs.attachGradient();
result = NDArrays.dot(lhs, rhs);
diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
index b0289856ae9..c133b382327 100644
--- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
@@ -62,7 +62,8 @@ public class BlockCoreTest {
@Test
public void testLinear() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
long outSize = 3;
Block block = Linear.builder().setUnits(outSize).build();
@@ -124,7 +125,8 @@ public void testLinear() throws IOException, MalformedModelException {
@Test
public void testLinearWithDefinedLayout() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
long outSize = 3;
Block block = Linear.builder().setUnits(outSize).build();
@@ -176,7 +178,8 @@ public void testLinearWithDefinedLayout() throws IOException, MalformedModelExce
@Test
public void testBatchNorm() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block = BatchNorm.builder().build();
try (Model model = Model.newInstance("model")) {
@@ -203,7 +206,8 @@ public void testBatchNorm() throws IOException, MalformedModelException {
@Test
public void testDropout() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block = Dropout.builder().optRate(.5f).build();
try (Model model = Model.newInstance("model")) {
@@ -229,7 +233,8 @@ public void testDropout() throws IOException, MalformedModelException {
@Test
public void testEmbedding() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
TrainableWordEmbedding block =
TrainableWordEmbedding.builder()
@@ -262,7 +267,8 @@ public void testEmbedding() throws IOException, MalformedModelException {
@Test
public void testConv1d() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block =
Conv1d.builder().setKernelShape(new Shape(2)).setFilters(1).optBias(false).build();
@@ -283,7 +289,7 @@ public void testConv1d() throws IOException, MalformedModelException {
NDArray out = trainer.forward(new NDList(data)).singletonOrThrow();
Assert.assertEquals(out, expected);
- Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape});
+ Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape});
Assert.assertEquals(out.getShape(), outputShape[0]);
testEncode(manager, block);
@@ -294,7 +300,8 @@ public void testConv1d() throws IOException, MalformedModelException {
@Test
public void testConv1dTranspose() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block =
Conv1dTranspose.builder()
@@ -317,7 +324,7 @@ public void testConv1dTranspose() throws IOException, MalformedModelException {
NDArray out = trainer.forward(new NDList(data)).singletonOrThrow();
Assert.assertEquals(out, expected);
- Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape});
+ Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape});
Assert.assertEquals(out.getShape(), outputShape[0]);
testEncode(manager, block);
@@ -328,7 +335,8 @@ public void testConv1dTranspose() throws IOException, MalformedModelException {
@Test
public void testConv2d() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block = Conv2d.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build();
try (Model model = Model.newInstance("model")) {
@@ -359,7 +367,8 @@ public void testConv2d() throws IOException, MalformedModelException {
@Test
public void testConv2dTranspose() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block =
Conv2dTranspose.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build();
@@ -396,7 +405,8 @@ public void testConv2dTranspose() throws IOException, MalformedModelException {
@Test
public void testConv3d() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block block = Conv3d.builder().setKernelShape(new Shape(2, 2, 2)).setFilters(1).build();
try (Model model = Model.newInstance("model")) {
@@ -422,8 +432,7 @@ public void testConv3d() throws IOException, MalformedModelException {
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
Assert.assertEquals(result, expected);
- Shape[] outputShape =
- block.getOutputShapes(manager, new Shape[] {new Shape(1, 1, 3, 3, 3)});
+ Shape[] outputShape = block.getOutputShapes(new Shape[] {new Shape(1, 1, 3, 3, 3)});
Assert.assertEquals(result.getShape(), outputShape[0]);
testEncode(manager, block);
@@ -437,7 +446,7 @@ public void testRNNTanh() throws IOException, MalformedModelException {
Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true);
TrainingConfig config =
new DefaultTrainingConfig(loss)
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optDevices(TestUtils.getDevices());
Block block =
RNN.builder()
@@ -484,7 +493,7 @@ public void testRNNRelu() throws IOException, MalformedModelException {
Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true);
TrainingConfig config =
new DefaultTrainingConfig(loss)
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optDevices(TestUtils.getDevices());
Block block =
RNN.builder()
@@ -534,7 +543,7 @@ public void testLstm() throws IOException, MalformedModelException {
Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true);
TrainingConfig config =
new DefaultTrainingConfig(loss)
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optDevices(TestUtils.getDevices());
Block block =
LSTM.builder()
@@ -585,7 +594,7 @@ public void testGRU() throws IOException, MalformedModelException {
Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true);
TrainingConfig config =
new DefaultTrainingConfig(loss)
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optDevices(TestUtils.getDevices());
GRU block =
GRU.builder()
@@ -638,7 +647,8 @@ public void testGRU() throws IOException, MalformedModelException {
@Test
public void testSequentialBlock() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
SequentialBlock block = new SequentialBlock();
block.addSingleton(x -> x.mul(6.5f));
block.add(Linear.builder().setUnits(10).build());
@@ -678,7 +688,8 @@ public void testSequentialBlock() throws IOException, MalformedModelException {
@Test
public void testParallelBlock() throws IOException, MalformedModelException {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
ParallelBlock block =
new ParallelBlock(
list ->
diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java
index 6f5f4e1ebc3..44a4e9816dd 100644
--- a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java
@@ -18,6 +18,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Parameter;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
@@ -29,7 +30,8 @@
public class PoolingOperationsTest {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
@Test
public void testMaxPool1d() {
diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java
index 9595e88ef15..df8a64f4803 100644
--- a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Parameter;
import ai.djl.nn.transformer.ScaledDotProductAttentionBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.ParameterStore;
@@ -752,7 +753,7 @@ public void testMaskedAttention() {
.optAttentionProbsDropoutProb(0.0f)
.build();
- block.setInitializer(new NormalInitializer());
+ block.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
block.getKeyProjection().setInitializer(keyKernelInitializer, "weight");
block.getValueProjection().setInitializer(valueKernelInitializer, "weight");
block.getQueryProjection().setInitializer(queryKernelInitializer, "weight");
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java
index 443370b10d8..b3e8502a9ca 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java
@@ -18,6 +18,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
+import ai.djl.nn.Parameter;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
@@ -30,7 +31,8 @@
public class ActivationTest {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
@Test
public void testRelu() {
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java
index a86705162ef..630ad57d3e6 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java
@@ -18,6 +18,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.ParameterStore;
@@ -30,7 +31,8 @@
public class BlocksTest {
TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
@Test
public void testFlattenBlock() {
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java
index 4b5dd466c31..8d06c85efd8 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java
@@ -20,6 +20,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
@@ -49,7 +50,8 @@
public class DatasetTest {
private TrainingConfig config =
- new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES);
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
@Test
public void testSequenceSampler() throws IOException, TranslateException {
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java
index 154583e59e3..96680659f78 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.nn.core.Linear;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
@@ -49,7 +50,7 @@ public void testAutograd() {
try (Trainer trainer =
model.newTrainer(
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES))) {
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) {
try (GradientCollector gradCol = trainer.newGradientCollector()) {
NDArray lhs =
manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3));
@@ -87,7 +88,7 @@ public void testTrain() throws IOException, TranslateException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.addTrainingListeners(new EvaluatorTrainingListener())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optimizer);
try (Model model = Model.newInstance("linear")) {
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java
index 2e919ca78d5..6023b1b054f 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java
@@ -21,7 +21,6 @@
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.testing.Assertions;
-import ai.djl.training.initializer.XavierInitializer;
import java.io.IOException;
import java.nio.file.Paths;
import org.testng.Assert;
@@ -36,7 +35,6 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException {
block.add(BatchNorm.builder().build());
try (Model saveModel = Model.newInstance("saveModel");
Model loadModel = Model.newInstance("loadModel")) {
- block.setInitializer(new XavierInitializer());
block.initialize(saveModel.getNDManager(), DataType.FLOAT32, new Shape(1, 3, 32, 32));
ParameterList savedParameters = block.getParameters();
saveModel.setBlock(block);
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java
index 8e472c626f2..5657f584805 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java
@@ -20,6 +20,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
+import ai.djl.nn.Parameter;
import ai.djl.nn.core.Linear;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
@@ -46,7 +47,7 @@ public void testSgd() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(sgd)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -78,7 +79,7 @@ public void testSgdWithMomentum() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -118,7 +119,7 @@ public void testNag() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -147,7 +148,7 @@ public void testAdam() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -176,7 +177,7 @@ public void testAdagrad() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -209,7 +210,7 @@ public void testRMSProp() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -243,7 +244,7 @@ public void testRMSPropAlex() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
Block block = Linear.builder().setUnits(CHANNELS).build();
@@ -273,7 +274,7 @@ public void testAdadelta() {
Device[] devices = Device.getDevices(1);
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES)
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT)
.optOptimizer(optim)
.optDevices(devices);
diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
index 4d9486c70cd..430c1aa66aa 100644
--- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
+++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
@@ -111,46 +111,48 @@ private NDArray concatPredictions(NDList output) {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
- // TODO: output shape is wrong
- Shape[] childInputShapes = inputShapes;
- Shape[] anchorShapes = new Shape[features.size()];
- Shape[] classPredictionShapes = new Shape[features.size()];
- Shape[] anchorPredictionShapes = new Shape[features.size()];
- for (int i = 0; i < features.size(); i++) {
- childInputShapes = features.get(i).getOutputShapes(manager, childInputShapes);
- anchorShapes[i] =
- multiBoxPriors
- .get(i)
- .generateAnchorBoxes(manager.ones(childInputShapes[0]))
- .getShape();
- classPredictionShapes[i] =
- classPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0];
- anchorPredictionShapes[i] =
- anchorPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0];
- }
- Shape anchorOutputShape = new Shape();
- for (Shape shape : anchorShapes) {
- anchorOutputShape = concatShape(anchorOutputShape, shape, 1);
- }
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ try (NDManager manager = NDManager.newBaseManager()) {
+ // TODO: output shape is wrong
+ Shape[] childInputShapes = inputShapes;
+ Shape[] anchorShapes = new Shape[features.size()];
+ Shape[] classPredictionShapes = new Shape[features.size()];
+ Shape[] anchorPredictionShapes = new Shape[features.size()];
+ for (int i = 0; i < features.size(); i++) {
+ childInputShapes = features.get(i).getOutputShapes(childInputShapes);
+ anchorShapes[i] =
+ multiBoxPriors
+ .get(i)
+ .generateAnchorBoxes(manager.ones(childInputShapes[0]))
+ .getShape();
+ classPredictionShapes[i] =
+ classPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0];
+ anchorPredictionShapes[i] =
+ anchorPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0];
+ }
+ Shape anchorOutputShape = new Shape();
+ for (Shape shape : anchorShapes) {
+ anchorOutputShape = concatShape(anchorOutputShape, shape, 1);
+ }
- NDList classPredictions = new NDList();
- for (Shape shape : classPredictionShapes) {
- classPredictions.add(manager.ones(shape));
- }
- NDArray classPredictionOutput = concatPredictions(classPredictions);
- Shape classPredictionOutputShape =
- classPredictionOutput
- .reshape(classPredictionOutput.size(0), -1, numClasses + 1)
- .getShape();
- NDList anchorPredictions = new NDList();
- for (Shape shape : anchorPredictionShapes) {
- anchorPredictions.add(manager.ones(shape));
+ NDList classPredictions = new NDList();
+ for (Shape shape : classPredictionShapes) {
+ classPredictions.add(manager.ones(shape));
+ }
+ NDArray classPredictionOutput = concatPredictions(classPredictions);
+ Shape classPredictionOutputShape =
+ classPredictionOutput
+ .reshape(classPredictionOutput.size(0), -1, numClasses + 1)
+ .getShape();
+ NDList anchorPredictions = new NDList();
+ for (Shape shape : anchorPredictionShapes) {
+ anchorPredictions.add(manager.ones(shape));
+ }
+ Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape();
+ return new Shape[] {
+ anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape
+ };
}
- Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape();
- return new Shape[] {
- anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape
- };
}
private Shape concatShape(Shape shape, Shape concat, int axis) {
@@ -177,15 +179,15 @@ private Shape concatShape(Shape shape, Shape concat, int axis) {
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
Shape[] shapes = inputShapes;
for (int i = 0; i < features.size(); i++) {
- shapes = features.get(i).initialize(manager, dataType, shapes);
+ features.get(i).initialize(manager, dataType, shapes);
+ shapes = features.get(i).getOutputShapes(shapes);
classPredictionBlocks.get(i).initialize(manager, dataType, shapes);
anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes);
}
- return getOutputShapes(manager, inputShapes);
}
/** {@inheritDoc} */
diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java
index ee8afeba541..54c58bf1a01 100644
--- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java
+++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java
@@ -24,6 +24,8 @@
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
+import ai.djl.util.Pair;
+import ai.djl.util.PairList;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
@@ -33,6 +35,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -122,12 +125,16 @@ public void load(Path modelPath, String prefix, Map options)
/** {@inheritDoc} */
@Override
public Trainer newTrainer(TrainingConfig trainingConfig) {
- Initializer initializer = trainingConfig.getInitializer();
+ PairList> initializer = trainingConfig.getInitializers();
if (block == null) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
- block.setInitializer(initializer);
+ for (Pair> pair : initializer) {
+ if (pair.getKey() != null && pair.getValue() != null) {
+ block.setInitializer(pair.getKey(), pair.getValue());
+ }
+ }
return new Trainer(this, trainingConfig);
}
diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
index ffc04996a4b..3b63b96a928 100644
--- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
+++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
@@ -21,7 +21,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
-import ai.djl.nn.ParameterType;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
@@ -201,7 +200,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
if (outputShapes == null) {
String[] outputNames = symbol.getOutputNames();
outputShapes = new Shape[outputNames.length];
@@ -232,9 +231,7 @@ public void removeLastBlock() {
}
}
- /** {@inheritDoc} */
- @Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
+ private Shape getParameterShape(String name, Shape[] inputShapes) {
if (paramShapes == null) {
PairList pairs = new PairList<>();
for (int i = 0; i < inputNames.size(); i++) {
@@ -314,25 +311,32 @@ private void initBlock() {
Set auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames()));
for (String name : allNames) {
- ParameterType type = inferType(name);
+ Parameter.Type type = inferType(name);
boolean requireGrad = !auxNameSet.contains(name);
- mxNetParams.add(new Parameter(name, this, type, requireGrad));
+ mxNetParams.add(
+ Parameter.builder()
+ .setName(name)
+ .setType(type)
+ .optRequiresGrad(requireGrad)
+ .build());
}
first = true;
}
- private static ParameterType inferType(String name) {
+ private static Parameter.Type inferType(String name) {
if (name.endsWith("bias")) {
- return ParameterType.BIAS;
+ return Parameter.Type.BIAS;
} else if (name.endsWith("gamma")) {
- return ParameterType.GAMMA;
+ return Parameter.Type.GAMMA;
} else if (name.endsWith("beta")) {
- return ParameterType.BETA;
+ return Parameter.Type.BETA;
} else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) {
- return ParameterType.RUNNING_MEAN;
+ return Parameter.Type.RUNNING_MEAN;
} else if (name.endsWith("moving_var") || name.endsWith("running_var")) {
- return ParameterType.RUNNING_VAR;
+ return Parameter.Type.RUNNING_VAR;
+ } else if (name.endsWith("weight")) {
+ return Parameter.Type.WEIGHT;
}
- return ParameterType.OTHER;
+ return Parameter.Type.OTHER;
}
}
diff --git a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java
index 0ba9177404f..299136587c7 100644
--- a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java
+++ b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
+import ai.djl.nn.Parameter;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.GradientCollector;
@@ -38,7 +39,7 @@ public void testMxAutograd() {
try (Trainer trainer =
model.newTrainer(
new DefaultTrainingConfig(Loss.l2Loss())
- .optInitializer(Initializer.ONES))) {
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) {
try (GradientCollector gradCol = trainer.newGradientCollector()) {
NDArray lhs =
manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3));
diff --git a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java
index d855bb90ced..74fde4bf89a 100644
--- a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java
+++ b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java
@@ -25,6 +25,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
+import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
@@ -85,7 +86,7 @@ public void trainWithNewParam()
throws IOException, ModelNotFoundException, MalformedModelException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = MxModelZoo.MLP.loadModel()) {
model.getBlock().clear();
try (Trainer trainer = model.newTrainer(config)) {
@@ -113,7 +114,7 @@ public void trainWithExistParam()
throws IOException, ModelNotFoundException, MalformedModelException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = MxModelZoo.MLP.loadModel()) {
try (Trainer trainer = model.newTrainer(config)) {
NDManager manager = trainer.getManager();
@@ -140,7 +141,7 @@ public void trainWithCustomLayer()
throws IOException, ModelNotFoundException, MalformedModelException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
- .optInitializer(Initializer.ONES);
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = MxModelZoo.MLP.loadModel()) {
NDManager manager = model.getNDManager();
SymbolBlock mlp = (SymbolBlock) model.getBlock();
@@ -149,7 +150,6 @@ public void trainWithCustomLayer()
newMlp.add(mlp);
Linear linear = Linear.builder().setUnits(10).build();
- linear.setInitializer(Initializer.ONES);
newMlp.add(linear);
model.setBlock(newMlp);
diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java
index 522a928b32b..43b3a960f4c 100644
--- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java
+++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java
@@ -98,7 +98,7 @@ private NDList getOutputs(PpNDArray[] outputs, boolean foreignEngine, NDManager
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
}
}
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
index 51ec90497ac..3d56947e34d 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
@@ -17,10 +17,13 @@
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
+import ai.djl.nn.Parameter;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
+import ai.djl.util.Pair;
+import ai.djl.util.PairList;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
@@ -28,6 +31,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
@@ -125,12 +129,16 @@ private Path findModelFile(String prefix) {
/** {@inheritDoc} */
@Override
public Trainer newTrainer(TrainingConfig trainingConfig) {
- Initializer initializer = trainingConfig.getInitializer();
+ PairList> initializer = trainingConfig.getInitializers();
if (block == null) {
throw new IllegalStateException(
"You must set a block for the model before creating a new trainer");
}
- block.setInitializer(initializer);
+ for (Pair> pair : initializer) {
+ if (pair.getKey() != null && pair.getValue() != null) {
+ block.setInitializer(pair.getKey(), pair.getValue());
+ }
+ }
return new Trainer(this, trainingConfig);
}
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
index 7cf0e5e3ac8..964649cd5a1 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
@@ -155,7 +155,7 @@ public PairList describeOutput() {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
}
diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
index 43f667d2dc2..7a6f8bdb753 100644
--- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
+++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
@@ -131,8 +131,8 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
- return new Shape[0];
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ throw new IllegalStateException("TfSymbolBlock can't be initialized");
}
/** {@inheritDoc} */
@@ -197,7 +197,7 @@ public final PairList describeOutput() {
/** {@inheritDoc} */
@Override
- public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
}