diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java index 72a6ead095f..e6dd38fedf3 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java @@ -64,8 +64,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return block.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return block.getOutputShapes(inputShapes); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java index 7ad4fbdcf6f..e3b1631322f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java @@ -79,8 +79,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return block.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return block.getOutputShapes(inputShapes); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index afc0958ace9..a7c47380a9d 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -97,19 +97,18 @@ public NDList forward( * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); encoder.initialize(manager, dataType, inputShapes[0]); - return decoder.initialize(manager, dataType, inputShapes[1]); + decoder.initialize(manager, dataType, inputShapes[1]); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return decoder.getOutputShapes(manager, new Shape[] {inputShapes[1]}); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return decoder.getOutputShapes(new Shape[] {inputShapes[1]}); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java index f848f25f3e5..4fd7dc71277 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java @@ -88,7 +88,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return trainableWordEmbedding.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return trainableWordEmbedding.getOutputShapes(inputShapes); } } diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 6d4c2c4ecd2..7f67d8431a6 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -30,7 +30,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; -import java.util.function.Function; +import java.util.function.Predicate; /** * {@code AbstractBlock} is an abstract implementation of {@link Block}. @@ -43,12 +43,11 @@ * * *

If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take - * care of parameter initialization yourself. In this case, you need to override {@link - * AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If - * you use the other variants of {@code addParameter} this is done for you. + * care of parameter initialization yourself. In this case, you need to setShape to your parameters + * if you know the shape of Parameter or you can implement prepare to setShape when you see the + * input shape. */ // Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers // of this API know the children and parameters are always iterated over in insertion order. @@ -99,14 +98,6 @@ public abstract class AbstractBlock implements Block { */ protected LinkedHashMap parameters = new LinkedHashMap<>(); - /** - * Callbacks to determine the shape of a parameter. Values may be null in which case extending - * classes need to override {@link Block#getParameterShape(String, Shape[])} and implement - * parameter shape resolution manually. - */ - protected LinkedHashMap> parameterShapeCallbacks = - new LinkedHashMap<>(); - /** * Builds an empty block with the given version for parameter serialization. * @@ -195,73 +186,20 @@ protected final B addChildBlock(String name, B block) { return block; } - /** - * Adds a parameter to this block. If parameters are added with this method, subclasses need to - * override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters - * themselves. - * - * @param parameter the parameter to add, not null - * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter) { - return addParameter(parameter, (Function) null); - } - /** * Adds a parameter to this block. If parameters are added with this method, intialization of * the parameter works out of the box * - * @param parameter the parameter to add, not null - * @param shape the shape of the parameter * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter, Shape shape) { - return addParameter(parameter, (inputShapes) -> shape); - } - - /** - * Adds a parameter to this block. If parameters are added with this method, intialization of - * the parameter works out of the box - * * @param parameter the parameter to add, not null - * @param shapeCallback the method to call once the input shape of this block is known to - * determine the shape of the given parameter - * @param

the specific parameter subclass * @return the parameter passed as arguments to make it easier to create and assign parameters * in one line */ - protected final

P addParameter( - P parameter, Function shapeCallback) { + protected final

P addParameter(P parameter) { parameters.put(parameter.getName(), parameter); - parameterShapeCallbacks.put(parameter.getName(), shapeCallback); return parameter; } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - Function callback = parameterShapeCallbacks.get(name); - if (callback == null) { - Parameter parameter = parameters.get(name); - if (parameter == null) { - throw new IllegalArgumentException( - "No parameter named " + name + " found in this block."); - } else { - throw new IllegalStateException( - "No shape initializer for parameter " - + name - + "found. " - + "Either pass an initializer for the shape when adding the " - + "parameter or override getParameterShape in the subclass."); - } - } - return callback.apply(inputShapes); - } - /** {@inheritDoc} */ @Override public BlockList getChildren() { @@ -285,13 +223,9 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void setInitializer(Initializer initializer) { - for (Parameter parameter : parameters.values()) { - parameter.setInitializer(initializer, false); - } - for (Block child : children.values()) { - child.setInitializer(initializer); - } + public void setInitializer(Initializer initializer, Parameter.Type params) { + Predicate predicate = parameter -> parameter.getType().equals(params); + setInitializer(initializer, predicate); } /** {@inheritDoc} */ @@ -301,18 +235,50 @@ public void setInitializer(Initializer initializer, String paramName) { if (parameter == null) { throw new IllegalArgumentException("Could not find parameter " + paramName); } - parameter.setInitializer(initializer, true); + parameter.setInitializer(initializer); } /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void setInitializer(Initializer initializer, Predicate predicate) { + List params = getParameters().values(); + for (Parameter param : params) { + if (predicate.test(param)) { + param.setInitializer(initializer); + } + } + } + + /** {@inheritDoc} */ + @Override + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); + // if parameters are initialized, skip it + if (!isInitialized()) { + // setShape for all params + prepare(inputShapes); + } for (Parameter parameter : parameters.values()) { - parameter.initialize(manager, dataType, inputShapes); + parameter.initialize(manager, dataType); } initializeChildBlocks(manager, dataType, inputShapes); - return getOutputShapes(manager, inputShapes); + } + + /** + * Performs any action necessary before initialization. For example, keep the input information + * or verify the layout. + * + * @param inputShapes the expected shapes of the input + */ + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; } /** @@ -355,20 +321,11 @@ public ParameterList getDirectParameters() { } /** - * Performs any action necessary before initialization. + * Sets the shape of {@link Parameter}s. * - * @param inputShapes the expected shapes of the input + * @param inputShapes the shapes of inputs */ - protected void beforeInitialize(Shape[] inputShapes) { - if (inputNames.isEmpty()) { - // automatically assign input names - inputNames = new ArrayList<>(); - for (int i = 0; i < inputShapes.length; ++i) { - inputNames.add("data" + i); - } - } - this.inputShapes = inputShapes; - } + protected void prepare(Shape[] inputShapes) {} /** {@inheritDoc} */ @Override @@ -494,7 +451,7 @@ public String toString() { appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); sb.append(" -> "); Shape[] outputShapes = - getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0])); + getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); appendShape(sb, outputShapes); } else { sb.append("Uninitialized"); diff --git a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java index f870eea148f..fdd7fcc0d6c 100644 --- a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java @@ -12,7 +12,6 @@ */ package ai.djl.nn; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; /** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */ @@ -29,7 +28,7 @@ public AbstractSymbolBlock(byte version) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("not implement!"); } } diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 6febeff5f6f..4bd3fb09e2b 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -25,6 +25,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.function.Predicate; /** * A {@code Block} is a composable function that forms a neural network. @@ -158,11 +159,12 @@ default NDList forward( } /** - * Sets an {@link Initializer} to the block. + * Sets an {@link Initializer} to all the parameters that match parameter type in the block. * * @param initializer the initializer to set + * @param type the Parameter Type we want to setInitializer */ - void setInitializer(Initializer initializer); + void setInitializer(Initializer initializer, Parameter.Type type); /** * Sets an {@link Initializer} to the specified direct parameter of the block, overriding the @@ -173,15 +175,22 @@ default NDList forward( */ void setInitializer(Initializer initializer, String paramName); + /** + * Sets an {@link Initializer} to all the parameters that match Predicate in the block. + * + * @param initializer the initializer to be set + * @param predicate predicate function to indicate parameters you want to set + */ + void setInitializer(Initializer initializer, Predicate predicate); + /** * Initializes the parameters of the block. This method must be called before calling `forward`. * * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ - Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes); + void initialize(NDManager manager, DataType dataType, Shape... inputShapes); /** * Returns a boolean whether the block is initialized. @@ -232,25 +241,13 @@ default NDList forward( */ ParameterList getParameters(); - /** - * Returns the shape of the specified direct parameter of this block given the shapes of the - * input to the block. - * - * @param name the name of the parameter - * @param inputShapes the shapes of the input to the block - * @return the shape of the parameter specified - * @throws IllegalArgumentException if the parameter name specified is invalid - */ - Shape getParameterShape(String name, Shape[] inputShapes); - /** * Returns the expected output shapes of the block for the specified input shapes. * - * @param manager an NDManager * @param inputShapes the shapes of the inputs * @return the expected output shapes of the block */ - Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes); + Shape[] getOutputShapes(Shape[] inputShapes); /** * Writes the parameters of the block to the given outputStream. diff --git a/api/src/main/java/ai/djl/nn/LambdaBlock.java b/api/src/main/java/ai/djl/nn/LambdaBlock.java index 939e0e09b29..1b6fc0d3bc5 100644 --- a/api/src/main/java/ai/djl/nn/LambdaBlock.java +++ b/api/src/main/java/ai/djl/nn/LambdaBlock.java @@ -70,11 +70,11 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - try (NDManager subManager = manager.newSubManager()) { + public Shape[] getOutputShapes(Shape[] inputShapes) { + try (NDManager manager = NDManager.newBaseManager()) { NDList input = new NDList(inputShapes.length); for (Shape shape : inputShapes) { - input.add(subManager.zeros(shape)); + input.add(manager.zeros(shape)); } NDList output = lambda.apply(input); Shape[] outputShapes = new Shape[output.size()]; diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java index 92ce9f6349e..9d04569a63f 100644 --- a/api/src/main/java/ai/djl/nn/ParallelBlock.java +++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java @@ -149,16 +149,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { Preconditions.checkArgument(!children.isEmpty(), "The parallel block is empty"); - try (NDManager subManager = manager.newSubManager()) { + try (NDManager manager = NDManager.newBaseManager()) { List inputs = new ArrayList<>(); for (Block block : children.values()) { - Shape[] shapes = block.getOutputShapes(manager, inputShapes); + Shape[] shapes = block.getOutputShapes(inputShapes); NDList output = new NDList(shapes.length); for (Shape shape : shapes) { - output.add(subManager.create(shape)); + output.add(manager.create(shape)); } inputs.add(output); } diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index 7d3c2954cde..369111057c7 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.training.initializer.Initializer; +import ai.djl.training.initializer.XavierInitializer; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -42,62 +43,23 @@ public class Parameter implements AutoCloseable { private String id; private String name; - private Block block; - private ParameterType type; - private DataType mandatoryDataType; + private Shape shape; + private Type type; private Initializer initializer; private NDArray array; private boolean requiresGrad; private SparseFormat gradientFormat; - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - */ - public Parameter(String name, Block block, ParameterType type) { - this(name, block, type, true, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requiresGrad whether this {@code Parameter} needs to compute gradients - */ - public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) { - this(name, block, type, requiresGrad, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requireGrad whether this {@code Parameter} needs to compute gradients - * @param gradientFormat the {@link SparseFormat} of the gradient array - */ - public Parameter( - String name, - Block block, - ParameterType type, - boolean requireGrad, - SparseFormat gradientFormat) { + Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); - this.name = name; - this.block = block; - this.type = type; - this.requiresGrad = requireGrad; - this.initializer = type.getInitializer(); - this.gradientFormat = gradientFormat; + this.name = builder.name; + this.shape = builder.shape; + this.type = builder.type; + this.array = builder.array; + this.requiresGrad = builder.requiresGrad; + this.initializer = + (builder.initializer != null) ? builder.initializer : type.getInitializer(); + this.gradientFormat = builder.gradientFormat; } /** @@ -123,7 +85,7 @@ public String getName() { * * @return the type of this {@code Parameter} */ - public ParameterType getType() { + public Type getType() { return type; } @@ -133,10 +95,26 @@ public ParameterType getType() { * @param array the {@link NDArray} that contains values of this {@code Parameter} */ public void setArray(NDArray array) { + if (shape != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } this.array = array; + shape = array.getShape(); array.setName(name); } + /** + * Sets the shape of this {@code Parameter}. + * + * @param shape the shape of this {@code Parameter} + */ + public void setShape(Shape shape) { + if (array != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } + this.shape = shape; + } + /** * Gets the values of this {@code Parameter} as an {@link NDArray}. * @@ -158,15 +136,6 @@ public boolean requireGradient() { return requiresGrad; } - /** - * Sets the mandatory data type for this {@code Parameter}. - * - * @param mandatoryDataType the mandatory data type for this {@code Parameter} - */ - public void setMandatoryDataType(DataType mandatoryDataType) { - this.mandatoryDataType = mandatoryDataType; - } - /** * Checks if this {@code Parameter} is initialized. * @@ -181,12 +150,9 @@ public boolean isInitialized() { * flag is true, sets the initializer regardless. * * @param initializer the initializer to be set - * @param overwrite if true, set the initializer regardless of whether its already set or not */ - public void setInitializer(Initializer initializer, boolean overwrite) { - if (overwrite || this.initializer == null) { - this.initializer = initializer; - } + public void setInitializer(Initializer initializer) { + this.initializer = initializer; } /** @@ -195,17 +161,12 @@ public void setInitializer(Initializer initializer, boolean overwrite) { * * @param manager an NDManager to create the arrays * @param dataType the datatype of the {@code Parameter} - * @param inputShapes the expected input shapes */ - public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) { + public void initialize(NDManager manager, DataType dataType) { Objects.requireNonNull(initializer, "No initializer has been set"); + Objects.requireNonNull(shape, "No parameter shape has been set"); if (!isInitialized()) { - Shape shape = block.getParameterShape(name, inputShapes); - array = - initializer.initialize( - manager, - shape, - mandatoryDataType == null ? dataType : mandatoryDataType); + array = initializer.initialize(manager, shape, dataType); array.setName(name); } @@ -266,6 +227,8 @@ public void load(NDManager manager, DataInputStream dis) } array = manager.decode(dis); + // set the shape of the parameter and prepare() can be skipped + shape = array.getShape(); } /** {@inheritDoc} */ @@ -276,4 +239,141 @@ public void close() { array = null; } } + + /** + * Creates a builder to build a {@code Parameter}. + * + *

The methods start with {@code set} are required fields, and {@code opt} for optional + * fields. + * + * @return a new builder + */ + public static Parameter.Builder builder() { + return new Parameter.Builder(); + } + + /** Enumerates the types of {@link Parameter}. */ + public enum Type { + WEIGHT( + new XavierInitializer( + XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)), + BIAS(Initializer.ZEROS), + GAMMA(Initializer.ONES), + BETA(Initializer.ZEROS), + RUNNING_MEAN(Initializer.ZEROS), + RUNNING_VAR(Initializer.ONES), + OTHER(null); + + private final transient Initializer initializer; + + Type(Initializer initializer) { + this.initializer = initializer; + } + + /** + * Gets the {@link Initializer} of this {@code ParameterType}. + * + * @return the {@link Initializer} of this {@code ParameterType} + */ + public Initializer getInitializer() { + return initializer; + } + } + + /** A Builder to construct a {@code Parameter}. */ + public static final class Builder { + String name; + Shape shape; + Type type; + Initializer initializer; + NDArray array; + boolean requiresGrad = true; + SparseFormat gradientFormat; + + /** + * Sets the name of the {@code Parameter}. + * + * @param name the name of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setName(String name) { + this.name = name; + return this; + } + + /** + * Sets the {@code Type} of the {@code Parameter}. + * + * @param type the {@code Type} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setType(Type type) { + this.type = type; + return this; + } + + /** + * Sets the shape of the {@code Parameter}. + * + * @param shape the shape of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optShape(Shape shape) { + this.shape = shape; + return this; + } + + /** + * Sets the Initializer of the {@code Parameter}. + * + * @param initializer the Initializer of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optInitializer(Initializer initializer) { + this.initializer = initializer; + return this; + } + + /** + * Sets the array of the {@code Parameter}. + * + * @param array the array of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optArray(NDArray array) { + this.array = array; + return this; + } + + /** + * Sets if the {@code Parameter} requires gradient. + * + * @param requiresGrad if the {@code Parameter} requires gradient + * @return this {@code Parameter} + */ + public Builder optRequiresGrad(boolean requiresGrad) { + this.requiresGrad = requiresGrad; + return this; + } + + /** + * Sets the {@link SparseFormat} of the {@code Parameter}. + * + * @param gradientFormat the {@link SparseFormat} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optGradientFormat(SparseFormat gradientFormat) { + this.gradientFormat = gradientFormat; + return this; + } + + /** + * Builds a {@code Parameter} instance. + * + * @return the {@code Parameter} instance + */ + public Parameter build() { + return new Parameter(this); + } + } } diff --git a/api/src/main/java/ai/djl/nn/ParameterType.java b/api/src/main/java/ai/djl/nn/ParameterType.java deleted file mode 100644 index dde3cb23ab8..00000000000 --- a/api/src/main/java/ai/djl/nn/ParameterType.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.nn; - -import ai.djl.training.initializer.Initializer; - -/** Enumerates the types of {@link Parameter}. */ -public enum ParameterType { - WEIGHT(null), - BIAS(Initializer.ZEROS), - GAMMA(Initializer.ONES), - BETA(Initializer.ZEROS), - RUNNING_MEAN(Initializer.ZEROS), - RUNNING_VAR(Initializer.ONES), - OTHER(null); - - private final transient Initializer initializer; - - ParameterType(Initializer initializer) { - this.initializer = initializer; - } - - /** - * Gets the {@link Initializer} of this {@code ParameterType}. - * - * @return the {@link Initializer} of this {@code ParameterType} - */ - public Initializer getInitializer() { - return initializer; - } -} diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java index d0c6f168fb8..8908c1ad6c8 100644 --- a/api/src/main/java/ai/djl/nn/SequentialBlock.java +++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java @@ -154,19 +154,20 @@ protected NDList forwardInternal( public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { - shapes = child.initialize(manager, dataType, shapes); + child.initialize(manager, dataType, shapes); + shapes = child.getOutputShapes(shapes); } } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { if (children.isEmpty()) { throw new IllegalArgumentException("The sequential block is empty"); } Shape[] current = inputs; for (Block block : children.values()) { - current = block.getOutputShapes(manager, current); + current = block.getOutputShapes(current); } return current; } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index 6334935804c..1d211f1d903 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -16,13 +16,11 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -102,13 +100,17 @@ public Convolution(ConvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -149,15 +151,24 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputs) { long[] shape = new long[numDimensions()]; shape[0] = inputs[0].get(0); shape[1] = filters; diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 890a4db9eff..10059208a5c 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -16,13 +16,11 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -78,13 +76,17 @@ public Deconvolution(DeconvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -126,15 +128,24 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputs) { long[] shape = new long[numDimensions()]; shape[0] = inputs[0].get(0); shape[1] = filters; diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index e2dceddcef4..ce5c486d701 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -57,7 +57,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(embedding.getShape())}; } diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index 48736aaa8e1..2e65295c1f7 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -23,7 +23,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -56,8 +55,11 @@ protected Embedding(BaseBuilder baseBuilder) { sparseFormat = baseBuilder.sparseFormat; embedding = addParameter( - new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + Parameter.builder() + .setName("embedding") + .setType(Parameter.Type.WEIGHT) + .optGradientFormat(sparseFormat) + .build()); if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) { throw new IllegalArgumentException( "You can not specify both a fallthrough and a defaultItem"); @@ -93,15 +95,25 @@ public Embedding(NDArray embedding, SparseFormat format) { this.sparseFormat = format; this.embedding = addParameter( - new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + Parameter.builder() + .setName("embedding") + .setType(Parameter.Type.WEIGHT) + .optGradientFormat(sparseFormat) + .build()); this.embedding.setArray(embedding); inputShapes = new Shape[] {new Shape(-1)}; } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public void prepare(Shape[] inputShapes) { + // numItems will be adjusted by embedding array or fallthroughEmbedding + embedding.setShape(new Shape(numEmbeddings, embeddingSize)); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; } diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 34a8c020a18..79459348fb2 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -16,12 +16,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import ai.djl.util.Preconditions; @@ -59,14 +57,19 @@ public class Linear extends AbstractBlock { Linear(Builder builder) { super(VERSION); units = builder.units; - // "inputFeatures" is only known after "beforeInitialize" is called, hence we need - // a callback, even if we do not used the callback parameter weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - inputShapes -> new Shape(units, inputFeatures)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (builder.bias) { - bias = addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(units)); + bias = + addParameter( + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -86,8 +89,8 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { - return new Shape[] {inputShape.addAll(new Shape(units))}; + public Shape[] getOutputShapes(Shape[] inputs) { + return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)}; } /** {@inheritDoc} */ @@ -99,13 +102,24 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); + Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input"); Shape input = inputShapes[0]; inputFeatures = input.get(input.dimension() - 1); inputShape = input.slice(0, input.dimension() - 1); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + Shape input = inputShapes[0]; + weight.setShape(new Shape(units, input.get(input.dimension() - 1))); + if (bias != null) { + bias.setShape(new Shape(units)); + } + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index abb268f9446..ced9d709e87 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -15,12 +15,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -44,7 +42,13 @@ public class Prelu extends AbstractBlock { /** Creates a Parametric ReLU Block. */ public Prelu() { super(VERSION); - alpha = addParameter(new Parameter("alpha", this, ParameterType.OTHER), new Shape()); + alpha = + addParameter( + Parameter.builder() + .setName("alpha") + .setType(Parameter.Type.WEIGHT) + .optShape(new Shape()) + .build()); } /** {@inheritDoc} */ @@ -61,7 +65,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { return new Shape[] {inputs[0]}; } diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java index 7f5f243646b..dc93360811a 100644 --- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java @@ -16,12 +16,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -86,26 +84,37 @@ public class BatchNorm extends AbstractBlock { momentum = builder.momentum; center = builder.center; scale = builder.scale; - // When creating parameters we use a callback as "inChannels" is set before initialization, - // it is not known yet. + // make gamma trainable if scale gamma = addParameter( - new Parameter("gamma", this, ParameterType.GAMMA, scale), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("gamma") + .setType(Parameter.Type.GAMMA) + .optRequiresGrad(scale) + .build()); // make beta trainable if center beta = addParameter( - new Parameter("beta", this, ParameterType.BETA, center), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("beta") + .setType(Parameter.Type.BETA) + .optRequiresGrad(center) + .build()); runningMean = addParameter( - new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("runningMean") + .setType(Parameter.Type.RUNNING_MEAN) + .optRequiresGrad(false) + .build()); runningVar = addParameter( - new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("runningVar") + .setType(Parameter.Type.RUNNING_VAR) + .optRequiresGrad(false) + .build()); } /** {@inheritDoc} */ @@ -135,17 +144,26 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0]}; } /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); inChannels = inputShapes[0].size(axis); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + gamma.setShape(new Shape(inChannels)); + beta.setShape(new Shape(inChannels)); + runningMean.setShape(new Shape(inChannels)); + runningVar.setShape(new Shape(inChannels)); + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/norm/Dropout.java b/api/src/main/java/ai/djl/nn/norm/Dropout.java index 7a44c954288..20bcd201138 100644 --- a/api/src/main/java/ai/djl/nn/norm/Dropout.java +++ b/api/src/main/java/ai/djl/nn/norm/Dropout.java @@ -15,7 +15,6 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; @@ -76,7 +75,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0]}; } diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 1251c32ee49..c51620f5dc5 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -13,13 +13,13 @@ package ai.djl.nn.recurrent; import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; +import ai.djl.nn.ParameterList; +import ai.djl.util.Pair; import java.io.DataInputStream; import java.io.IOException; @@ -67,7 +67,7 @@ public RecurrentBlock(BaseBuilder builder) { bidirectional = builder.bidirectional; returnState = builder.returnState; - ParameterType[] parameterTypes = {ParameterType.WEIGHT, ParameterType.BIAS}; + Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS}; String[] directions = {"l"}; if (builder.bidirectional) { directions = new String[] {"l", "r"}; @@ -75,12 +75,13 @@ public RecurrentBlock(BaseBuilder builder) { String[] gateStrings = {"i2h", "h2h"}; for (int i = 0; i < numLayers; i++) { - for (ParameterType parameterType : parameterTypes) { + for (Parameter.Type parameterType : parameterTypes) { for (String direction : directions) { for (String gateString : gateStrings) { String name = direction + '_' + i + '_' + gateString + '_' + parameterType.name(); - addParameter(new Parameter(name, this, parameterType)); + addParameter( + Parameter.builder().setName(name).setType(parameterType).build()); } } } @@ -89,7 +90,7 @@ public RecurrentBlock(BaseBuilder builder) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { Shape inputShape = inputs[0]; Shape outputShape = new Shape(inputShape.get(0), inputShape.get(1), stateSize * getNumDirections()); @@ -109,31 +110,34 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(EXPECTED_LAYOUT, inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - int layer = Integer.parseInt(name.split("_")[1]); - Shape shape = inputShapes[0]; - long inputs = shape.get(2); - if (layer > 0) { - inputs = stateSize * getNumDirections(); - } - if (name.contains("BIAS")) { - return new Shape(gates * stateSize); - } - if (name.contains("i2h")) { - return new Shape(gates * stateSize, inputs); - } - if (name.contains("h2h")) { - return new Shape(gates * stateSize, stateSize); + public void prepare(Shape[] inputs) { + Shape inputShape = inputs[0]; + ParameterList parameters = getDirectParameters(); + for (Pair pair : parameters) { + String name = pair.getKey(); + Parameter parameter = pair.getValue(); + int layer = Integer.parseInt(name.split("_")[1]); + long inputSize = inputShape.get(2); + if (layer > 0) { + inputSize = stateSize * getNumDirections(); + } + if (name.contains("BIAS")) { + parameter.setShape(new Shape(gates * stateSize)); + } else if (name.contains("i2h")) { + parameter.setShape(new Shape(gates * stateSize, inputSize)); + } else if (name.contains("h2h")) { + parameter.setShape(new Shape(gates * stateSize, stateSize)); + } else { + throw new IllegalArgumentException("Invalid parameter name"); + } } - throw new IllegalArgumentException("Invalid parameter name"); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index 066fc0c4229..b8ca2caf120 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -22,7 +22,6 @@ import ai.djl.nn.Activation; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.nn.norm.Dropout; @@ -77,8 +76,12 @@ private BertBlock(Builder builder) { // embedding for the position this.positionEmebdding = addParameter( - new Parameter(PARAM_POSITION_EMBEDDING, this, ParameterType.WEIGHT), - new Shape(builder.maxSequenceLength, builder.embeddingSize)); + Parameter.builder() + .setName(PARAM_POSITION_EMBEDDING) + .setType(Parameter.Type.WEIGHT) + .optShape( + new Shape(builder.maxSequenceLength, builder.embeddingSize)) + .build()); // embedding for the input types this.typeEmbedding = addChildBlock( @@ -153,7 +156,7 @@ public int getTypeDictionarySize() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { long batch = inputShapes[0].get(0); long seqLength = inputShapes[0].get(1); return new Shape[] { @@ -164,11 +167,12 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { /** {@inheritDoc} */ @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { - beforeInitialize(inputShapes); + super.beforeInitialize(inputShapes); inputNames = Arrays.asList("tokenIds", "typeIds", "masks"); Shape[] tokenShape = {inputShapes[0]}; Shape[] typeShape = {inputShapes[1]}; - Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape); + this.tokenEmbedding.initialize(manager, dataType, tokenShape); + Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(tokenShape); this.typeEmbedding.initialize(manager, dataType, typeShape); this.embeddingNorm.initialize(manager, dataType, embeddingOutput); this.embeddingDropout.initialize(manager, dataType, embeddingOutput); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index ab9a7216803..7fc336434b6 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -19,7 +19,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.training.ParameterStore; @@ -59,8 +58,11 @@ public BertMaskedLanguageModelBlock( this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build()); this.dictionaryBias = addParameter( - new Parameter("dictionaryBias", this, ParameterType.BIAS), - new Shape(bertBlock.getTokenDictionarySize())); + Parameter.builder() + .setName("dictionaryBias") + .setType(Parameter.Type.BIAS) + .optShape(new Shape(bertBlock.getTokenDictionarySize())) + .build()); this.hiddenActivation = hiddenActivation; } @@ -140,7 +142,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(final NDManager manager, final Shape[] inputShapes) { + public Shape[] getOutputShapes(final Shape[] inputShapes) { int batchSize = (int) inputShapes[0].get(0); int indexCount = (int) inputShapes[1].get(1); int dictionarySize = (int) inputShapes[2].get(0); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java index 40f9a717b65..aac91ab5745 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java @@ -53,7 +53,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {new Shape(inputShapes[0].get(0), 2)}; } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index 2478811d037..d7f44089564 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -51,7 +51,8 @@ public BertPretrainingBlock(final BertBlock.Builder builder) { @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices"); - Shape[] bertOutputShapes = bertBlock.initialize(manager, dataType, inputShapes); + bertBlock.initialize(manager, dataType, inputShapes); + Shape[] bertOutputShapes = bertBlock.getOutputShapes(inputShapes); Shape embeddedSequence = bertOutputShapes[0]; Shape pooledOutput = bertOutputShapes[1]; Shape maskedIndices = inputShapes[2]; @@ -97,7 +98,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { long batchSize = inputShapes[0].get(0); long maskedIndexCount = inputShapes[3].get(1); return new Shape[] { diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java index 4aa6484f766..2b3c52d3526 100644 --- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java +++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java @@ -21,7 +21,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.util.Arrays; @@ -47,13 +46,16 @@ private IdEmbedding(Builder builder) { this.embeddingSize = builder.embeddingSize; this.embedding = addParameter( - new Parameter(EMBEDDING_PARAM_NAME, this, ParameterType.WEIGHT), - new Shape(dictionarySize, embeddingSize)); + Parameter.builder() + .setName(EMBEDDING_PARAM_NAME) + .setType(Parameter.Type.WEIGHT) + .optShape(new Shape(dictionarySize, embeddingSize)) + .build()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; } diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java index 10d4aa20b40..9d3367f9cf5 100644 --- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java @@ -32,8 +32,6 @@ public class PointwiseFeedForwardBlock extends AbstractBlock { private static final byte VERSION = 1; - private Shape outputShape; - /** * Creates a pointwise feed-forward block. * @@ -49,7 +47,7 @@ public PointwiseFeedForwardBlock( super(VERSION); // add hidden layers with activation int count = 0; - for (final int hiddenSize : hiddenSizes) { + for (int hiddenSize : hiddenSizes) { addChildBlock( "linear_" + count, Linear.builder().optBias(true).setUnits(hiddenSize).build()); addChildBlock("activation_" + count, new LambdaBlock(activationFunction)); @@ -61,8 +59,11 @@ public PointwiseFeedForwardBlock( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return new Shape[] {outputShape}; + public Shape[] getOutputShapes(Shape[] inputShapes) { + for (Block child : children.values()) { + inputShapes = child.getOutputShapes(inputShapes); + } + return inputShapes; } /** {@inheritDoc} */ @@ -75,16 +76,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... } // Now that we know the input shape, we can determine the reshape necessary // to shape the input and re-shape the output - final Shape inputShape = inputShapes[0]; + Shape inputShape = inputShapes[0]; if (inputShape.dimension() < 2) { throw new IllegalArgumentException( "Pointwise feed forward blocks need an input of at least dimension 2."); } Shape lastShape = inputShape; for (Block child : children.values()) { - lastShape = child.initialize(manager, dataType, lastShape)[0]; + child.initialize(manager, dataType, lastShape); + lastShape = getOutputShapes(new Shape[] {lastShape})[0]; } - outputShape = lastShape; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java index ffee4df219c..fa293f4b37a 100644 --- a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java @@ -146,7 +146,7 @@ public Linear getResultProjection() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { // Return shape is the shape of the query. For 2 or less inputs we have self-attention, i.e. // the shape of the output is the shape of the input if (inputShapes.length == 1 || inputShapes.length == 2) { diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java index 83fb87446e9..4cbabecada7 100644 --- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java @@ -82,7 +82,7 @@ public TransformerEncoderBlock( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return inputShapes; } diff --git a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java index be9f7bbfec6..b9580914201 100644 --- a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java +++ b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java @@ -13,23 +13,23 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; -import ai.djl.training.initializer.XavierInitializer; -import ai.djl.training.initializer.XavierInitializer.FactorType; -import ai.djl.training.initializer.XavierInitializer.RandomType; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Adam; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Predicate; /** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */ public class DefaultTrainingConfig implements TrainingConfig { - private Initializer initializer; + private PairList> initializers = new PairList<>(); private Optimizer optimizer; private Device[] devices; private Loss loss; @@ -38,15 +38,12 @@ public class DefaultTrainingConfig implements TrainingConfig { /** * Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code - * DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link - * XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The - * evaluators and listeners are left to the user's discretion. + * DefaultTrainingConfig} creates a default {@link TrainingConfig}, {@link Adam} as optimiser, + * and the given {@link Loss}. The evaluators and listeners are left to the user's discretion. * * @param loss the loss to use for training */ public DefaultTrainingConfig(Loss loss) { - // Defaults to initializer defined in https://arxiv.org/abs/1502.01852 - this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2); optimizer = Adam.builder().build(); this.loss = loss; evaluators = new ArrayList<>(); @@ -58,10 +55,38 @@ public DefaultTrainingConfig(Loss loss) { * href="https://arxiv.org/abs/1502.01852">paper). * * @param initializer the initialer to use for the parameters + * @param type the {@link Parameter.Type} of the parameters * @return this {@code DefaultTrainingConfig} */ - public DefaultTrainingConfig optInitializer(Initializer initializer) { - this.initializer = initializer; + public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) { + initializers.add(initializer, parameter -> parameter.getType().equals(type)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param name the name of the parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer(Initializer initializer, String name) { + initializers.add(initializer, parameter -> parameter.getName().equals(name)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param predicate the predicate to identify parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer( + Initializer initializer, Predicate predicate) { + initializers.add(initializer, predicate); return this; } @@ -120,8 +145,8 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { - return initializer; + public PairList> getInitializers() { + return initializers; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/TrainingConfig.java b/api/src/main/java/ai/djl/training/TrainingConfig.java index 46748232035..0a4b5928266 100644 --- a/api/src/main/java/ai/djl/training/TrainingConfig.java +++ b/api/src/main/java/ai/djl/training/TrainingConfig.java @@ -13,12 +13,15 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.List; +import java.util.function.Predicate; /** * An interface that is responsible for holding the configuration required by {@link Trainer}. @@ -64,11 +67,11 @@ public interface TrainingConfig { Device[] getDevices(); /** - * Gets the {@link Initializer} to initialize the parameters of the model. + * Gets a list of {@link Initializer} and Predicate to initialize the parameters of the model. * * @return an {@link Initializer} */ - Initializer getInitializer(); + PairList> getInitializers(); /** * Gets the {@link Optimizer} to use during training. diff --git a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java index 64ee6279f71..8cf76c7b8f0 100644 --- a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java +++ b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java @@ -80,7 +80,7 @@ public XavierInitializer(RandomType randomType, FactorType factorType, float mag /** Creates a new instance of {@code XavierInitializer}. */ public XavierInitializer() { - this(RandomType.UNIFORM, FactorType.AVG, 3f); + this(RandomType.UNIFORM, FactorType.AVG, 6f); } /** {@inheritDoc} */ diff --git a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java index abd5e9ec2a2..562dd3a0de9 100644 --- a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java +++ b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java @@ -31,7 +31,7 @@ private ShapeUtils() {} * @return the corresponding output shape for the provided input */ public static Shape outputShapeForBlock(NDManager manager, Block block, Shape inputShape) { - Shape[] outputs = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputs = block.getOutputShapes(new Shape[] {inputShape}); return outputs[0]; } diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java index 3efdd00af77..dc8028c5c92 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AirfoilRandomAccessTest { public void testAirfoilRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java index 20bf7f8ac56..716f04c578f 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AmesRandomAccessTest { public void testAmesRandomAccessRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java index 75bcb3e914a..063ec07fb13 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java @@ -16,6 +16,7 @@ import ai.djl.basicdataset.cv.PikachuDetection; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -42,7 +43,7 @@ public void testPikachuRemote() throws IOException, TranslateException { .build(); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(new NormalInitializer(0.01f)); + .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); try (Trainer trainer = model.newTrainer(config)) { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index 53eca2e14b4..7d1cda2adc2 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.BertBlock; import ai.djl.nn.transformer.BertPretrainingBlock; import ai.djl.nn.transformer.BertPretrainingLoss; @@ -135,7 +136,8 @@ private static Model createBertPretrainingModel(Dictionary dictionary) { model.setBlock( new BertPretrainingBlock( BERT_BUILDER.setTokenDictionarySize(dictionary.tokens.size()))); - model.getBlock().setInitializer(new TruncatedNormalInitializer(0.02f)); + model.getBlock() + .setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT); return model; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index 13ce0f2f742..76078106a22 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -31,7 +31,6 @@ import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.evaluator.Accuracy; -import ai.djl.training.initializer.XavierInitializer; import ai.djl.training.listener.SaveModelTrainingListener; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; @@ -119,7 +118,6 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optInitializer(new XavierInitializer()) .optDevices(Device.getDevices(arguments.getMaxGpus())) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java index b1631b6b73f..b68fc81e1c7 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java @@ -13,15 +13,18 @@ package ai.djl.fasttext; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.TrainingConfig; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; /** An interface that is responsible for holding the configuration required by fastText training. */ public class FtTrainingConfig implements TrainingConfig { @@ -247,7 +250,7 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { + public PairList> getInitializers() { return null; } diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java index f105d46c8cd..2f6e41cda36 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java @@ -20,7 +20,6 @@ import ai.djl.nn.Block; import ai.djl.nn.SequentialBlock; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -31,41 +30,34 @@ public class SingleShotDetectionTest { @Test public void testClassPredictorBlocks() { - try (NDManager manager = NDManager.newBaseManager()) { - Block block = SingleShotDetection.getClassPredictionBlock(5, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0], - new Shape(2, 55, 20, 20)); - block = SingleShotDetection.getClassPredictionBlock(3, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0], - new Shape(2, 33, 10, 10)); - } + Block block = SingleShotDetection.getClassPredictionBlock(5, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0], + new Shape(2, 55, 20, 20)); + block = SingleShotDetection.getClassPredictionBlock(3, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0], + new Shape(2, 33, 10, 10)); } @Test public void testAnchorPredictorBlocks() { - try (NDManager manager = NDManager.newBaseManager()) { - Block block = SingleShotDetection.getAnchorPredictionBlock(5); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0], - new Shape(2, 20, 20, 20)); - block = SingleShotDetection.getClassPredictionBlock(3, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0], - new Shape(2, 33, 10, 10)); - } + Block block = SingleShotDetection.getAnchorPredictionBlock(5); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0], + new Shape(2, 20, 20, 20)); + block = SingleShotDetection.getClassPredictionBlock(3, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0], + new Shape(2, 33, 10, 10)); } @Test public void testDownSamplingBlock() { - try (NDManager manager = NDManager.newBaseManager()) { - Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10); - Assert.assertEquals( - sequentialBlock - .getOutputShapes(manager, new Shape[] {new Shape(2, 3, 20, 20)})[0], - new Shape(2, 10, 10, 10)); - } + Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10); + Assert.assertEquals( + sequentialBlock.getOutputShapes(new Shape[] {new Shape(2, 3, 20, 20)})[0], + new Shape(2, 10, 10, 10)); } @Test @@ -97,7 +89,6 @@ public void testSingleShotDetectionShape() { .setSizes(sizes) .setBaseNetwork(block) .build(); - ssd.setInitializer(new XavierInitializer()); ssd.initialize(manager, DataType.FLOAT32, new Shape(32, 3, 256, 256)); ParameterStore ps = new ParameterStore(manager, false); NDList output = @@ -105,8 +96,7 @@ public void testSingleShotDetectionShape() { Assert.assertEquals(output.get(0).getShape(), new Shape(1, 5444, 4)); Assert.assertEquals(output.get(1).getShape(), new Shape(32, 5444, 2)); Assert.assertEquals(output.get(2).getShape(), new Shape(32, 21776)); - Shape[] outputShapes = - ssd.getOutputShapes(manager, new Shape[] {new Shape(32, 3, 256, 256)}); + Shape[] outputShapes = ssd.getOutputShapes(new Shape[] {new Shape(32, 3, 256, 256)}); Assert.assertEquals(outputShapes[0], new Shape(1, 5444, 4)); Assert.assertEquals(outputShapes[1], new Shape(32, 5444, 2)); Assert.assertEquals(outputShapes[2], new Shape(32, 21776)); diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java index 131df85c53e..cc6d43853c8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java @@ -23,7 +23,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.recurrent.LSTM; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.Arrays; import org.testng.Assert; import org.testng.annotations.Test; @@ -50,7 +49,6 @@ public void testEncoder() { .optReturnState(true) .build()); try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) { - encoder.setInitializer(new XavierInitializer()); encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7)); NDList output = encoder.forward( diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java index f956bfb46bc..ae0ea251e24 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java @@ -159,7 +159,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block alexNet = AlexNet.builder().build(); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -169,7 +169,7 @@ public void testOutputShapes() { alexNet.getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(alexNet.getChildren().get(i).getKey(), currentShape); } @@ -188,7 +188,7 @@ public void testForwardMethod() { Block alexNet = AlexNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = alexNet.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java index a149b7e40e3..4befb517a2f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java @@ -100,7 +100,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block googLeNet = GoogLeNet.builder().build(); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -111,7 +111,7 @@ public void testOutputShapes() { .getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(googLeNet.getChildren().get(i).getKey(), currentShape); } @@ -130,7 +130,7 @@ public void testForwardMethod() { Block googLeNet = GoogLeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = googLeNet diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java index 39744eeefe6..6160ed5f7e4 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java @@ -137,7 +137,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block leNet = LeNet.builder().build(); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -147,7 +147,7 @@ public void testOutputShapes() { leNet.getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(leNet.getChildren().get(i).getKey(), currentShape); } @@ -165,7 +165,7 @@ public void testForwardMethod() { Block leNet = LeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = leNet.forward(new ParameterStore(manager, true), new NDList(x), true) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java index cac77221d69..f8544debc70 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java @@ -152,17 +152,14 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block nin = NiN.builder().build(); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); for (int i = 0; i < nin.getChildren().size(); i++) { Shape[] newShape = - nin.getChildren() - .get(i) - .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + nin.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(nin.getChildren().get(i).getKey(), currentShape); } @@ -180,7 +177,7 @@ public void testForwardMethod() { Block nin = NiN.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = nin.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java index fd50e39943f..21c6779f36f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java @@ -57,7 +57,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block resNet50 = ResNetV1.builder() @@ -123,7 +123,7 @@ public void testLoadTrain() TrainingConfig config = new DefaultTrainingConfig(Loss.l1Loss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Trainer trainer = model.newTrainer(config)) { int batchSize = 2; Shape inputShape = new Shape(batchSize, 3, 32, 32); @@ -131,8 +131,7 @@ public void testLoadTrain() trainer.initialize(inputShape); NDManager manager = trainer.getManager(); - Shape[] outputShape = - model.getBlock().getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = model.getBlock().getOutputShapes(new Shape[] {inputShape}); NDArray data = manager.ones(new Shape(batchSize, 3, 32, 32)); NDArray label = manager.ones(outputShape[0]); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java index 267bd41cf80..114907144aa 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java @@ -40,7 +40,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block squeezeNet = SqueezeNet.squeezenet(10); try (Model model = Model.newInstance("squeezenet")) { model.setBlock(squeezeNet); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java index b8520edd8f9..5c22bacecf1 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java @@ -107,17 +107,14 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block vgg = VGG.builder().build(); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); for (int i = 0; i < vgg.getChildren().size(); i++) { Shape[] newShape = - vgg.getChildren() - .get(i) - .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + vgg.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(vgg.getChildren().get(i).getKey(), currentShape); } @@ -137,8 +134,9 @@ public void testForwardMethod() { Block vgg = VGG.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, x.getShape()); + NDArray xHat = vgg.forward(new ParameterStore(manager, true), new NDList(x), false) .singletonOrThrow(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java index 8e97ed38b87..2dbfac8ef2e 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -139,7 +140,7 @@ public void testAddScalar() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.add(lhs, 2); @@ -360,7 +361,7 @@ public void testDot() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.dot(lhs, rhs); diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index b0289856ae9..c133b382327 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -62,7 +62,8 @@ public class BlockCoreTest { @Test public void testLinear() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -124,7 +125,8 @@ public void testLinear() throws IOException, MalformedModelException { @Test public void testLinearWithDefinedLayout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -176,7 +178,8 @@ public void testLinearWithDefinedLayout() throws IOException, MalformedModelExce @Test public void testBatchNorm() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = BatchNorm.builder().build(); try (Model model = Model.newInstance("model")) { @@ -203,7 +206,8 @@ public void testBatchNorm() throws IOException, MalformedModelException { @Test public void testDropout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Dropout.builder().optRate(.5f).build(); try (Model model = Model.newInstance("model")) { @@ -229,7 +233,8 @@ public void testDropout() throws IOException, MalformedModelException { @Test public void testEmbedding() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); TrainableWordEmbedding block = TrainableWordEmbedding.builder() @@ -262,7 +267,8 @@ public void testEmbedding() throws IOException, MalformedModelException { @Test public void testConv1d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1d.builder().setKernelShape(new Shape(2)).setFilters(1).optBias(false).build(); @@ -283,7 +289,7 @@ public void testConv1d() throws IOException, MalformedModelException { NDArray out = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(out, expected); - Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape}); Assert.assertEquals(out.getShape(), outputShape[0]); testEncode(manager, block); @@ -294,7 +300,8 @@ public void testConv1d() throws IOException, MalformedModelException { @Test public void testConv1dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1dTranspose.builder() @@ -317,7 +324,7 @@ public void testConv1dTranspose() throws IOException, MalformedModelException { NDArray out = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(out, expected); - Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape}); Assert.assertEquals(out.getShape(), outputShape[0]); testEncode(manager, block); @@ -328,7 +335,8 @@ public void testConv1dTranspose() throws IOException, MalformedModelException { @Test public void testConv2d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2d.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -359,7 +367,8 @@ public void testConv2d() throws IOException, MalformedModelException { @Test public void testConv2dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2dTranspose.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); @@ -396,7 +405,8 @@ public void testConv2dTranspose() throws IOException, MalformedModelException { @Test public void testConv3d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv3d.builder().setKernelShape(new Shape(2, 2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -422,8 +432,7 @@ public void testConv3d() throws IOException, MalformedModelException { NDArray result = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(result, expected); - Shape[] outputShape = - block.getOutputShapes(manager, new Shape[] {new Shape(1, 1, 3, 3, 3)}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {new Shape(1, 1, 3, 3, 3)}); Assert.assertEquals(result.getShape(), outputShape[0]); testEncode(manager, block); @@ -437,7 +446,7 @@ public void testRNNTanh() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -484,7 +493,7 @@ public void testRNNRelu() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -534,7 +543,7 @@ public void testLstm() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = LSTM.builder() @@ -585,7 +594,7 @@ public void testGRU() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); GRU block = GRU.builder() @@ -638,7 +647,8 @@ public void testGRU() throws IOException, MalformedModelException { @Test public void testSequentialBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); SequentialBlock block = new SequentialBlock(); block.addSingleton(x -> x.mul(6.5f)); block.add(Linear.builder().setUnits(10).build()); @@ -678,7 +688,8 @@ public void testSequentialBlock() throws IOException, MalformedModelException { @Test public void testParallelBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); ParallelBlock block = new ParallelBlock( list -> diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java index 6f5f4e1ebc3..44a4e9816dd 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.pooling.Pool; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -29,7 +30,8 @@ public class PoolingOperationsTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testMaxPool1d() { diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java index 9595e88ef15..df8a64f4803 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.ScaledDotProductAttentionBlock; import ai.djl.training.GradientCollector; import ai.djl.training.ParameterStore; @@ -752,7 +753,7 @@ public void testMaskedAttention() { .optAttentionProbsDropoutProb(0.0f) .build(); - block.setInitializer(new NormalInitializer()); + block.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); block.getKeyProjection().setInitializer(keyKernelInitializer, "weight"); block.getValueProjection().setInitializer(valueKernelInitializer, "weight"); block.getQueryProjection().setInitializer(queryKernelInitializer, "weight"); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java index 443370b10d8..b3e8502a9ca 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Activation; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -30,7 +31,8 @@ public class ActivationTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testRelu() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java index a86705162ef..630ad57d3e6 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.ParameterStore; @@ -30,7 +31,8 @@ public class BlocksTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testFlattenBlock() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java index 4b5dd466c31..8d06c85efd8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -49,7 +50,8 @@ public class DatasetTest { private TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testSequenceSampler() throws IOException, TranslateException { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index 154583e59e3..96680659f78 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -49,7 +50,7 @@ public void testAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); @@ -87,7 +88,7 @@ public void testTrain() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) .addTrainingListeners(new EvaluatorTrainingListener()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optimizer); try (Model model = Model.newInstance("linear")) { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java index 2e919ca78d5..6023b1b054f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java @@ -21,7 +21,6 @@ import ai.djl.nn.convolutional.Conv2d; import ai.djl.nn.norm.BatchNorm; import ai.djl.testing.Assertions; -import ai.djl.training.initializer.XavierInitializer; import java.io.IOException; import java.nio.file.Paths; import org.testng.Assert; @@ -36,7 +35,6 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException { block.add(BatchNorm.builder().build()); try (Model saveModel = Model.newInstance("saveModel"); Model loadModel = Model.newInstance("loadModel")) { - block.setInitializer(new XavierInitializer()); block.initialize(saveModel.getNDManager(), DataType.FLOAT32, new Shape(1, 3, 32, 32)); ParameterList savedParameters = block.getParameters(); saveModel.setBlock(block); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java index 8e472c626f2..5657f584805 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -46,7 +47,7 @@ public void testSgd() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(sgd) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -78,7 +79,7 @@ public void testSgdWithMomentum() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -118,7 +119,7 @@ public void testNag() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -147,7 +148,7 @@ public void testAdam() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -176,7 +177,7 @@ public void testAdagrad() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -209,7 +210,7 @@ public void testRMSProp() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -243,7 +244,7 @@ public void testRMSPropAlex() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -273,7 +274,7 @@ public void testAdadelta() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java index 4d9486c70cd..430c1aa66aa 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java @@ -111,46 +111,48 @@ private NDArray concatPredictions(NDList output) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - // TODO: output shape is wrong - Shape[] childInputShapes = inputShapes; - Shape[] anchorShapes = new Shape[features.size()]; - Shape[] classPredictionShapes = new Shape[features.size()]; - Shape[] anchorPredictionShapes = new Shape[features.size()]; - for (int i = 0; i < features.size(); i++) { - childInputShapes = features.get(i).getOutputShapes(manager, childInputShapes); - anchorShapes[i] = - multiBoxPriors - .get(i) - .generateAnchorBoxes(manager.ones(childInputShapes[0])) - .getShape(); - classPredictionShapes[i] = - classPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0]; - anchorPredictionShapes[i] = - anchorPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0]; - } - Shape anchorOutputShape = new Shape(); - for (Shape shape : anchorShapes) { - anchorOutputShape = concatShape(anchorOutputShape, shape, 1); - } + public Shape[] getOutputShapes(Shape[] inputShapes) { + try (NDManager manager = NDManager.newBaseManager()) { + // TODO: output shape is wrong + Shape[] childInputShapes = inputShapes; + Shape[] anchorShapes = new Shape[features.size()]; + Shape[] classPredictionShapes = new Shape[features.size()]; + Shape[] anchorPredictionShapes = new Shape[features.size()]; + for (int i = 0; i < features.size(); i++) { + childInputShapes = features.get(i).getOutputShapes(childInputShapes); + anchorShapes[i] = + multiBoxPriors + .get(i) + .generateAnchorBoxes(manager.ones(childInputShapes[0])) + .getShape(); + classPredictionShapes[i] = + classPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0]; + anchorPredictionShapes[i] = + anchorPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0]; + } + Shape anchorOutputShape = new Shape(); + for (Shape shape : anchorShapes) { + anchorOutputShape = concatShape(anchorOutputShape, shape, 1); + } - NDList classPredictions = new NDList(); - for (Shape shape : classPredictionShapes) { - classPredictions.add(manager.ones(shape)); - } - NDArray classPredictionOutput = concatPredictions(classPredictions); - Shape classPredictionOutputShape = - classPredictionOutput - .reshape(classPredictionOutput.size(0), -1, numClasses + 1) - .getShape(); - NDList anchorPredictions = new NDList(); - for (Shape shape : anchorPredictionShapes) { - anchorPredictions.add(manager.ones(shape)); + NDList classPredictions = new NDList(); + for (Shape shape : classPredictionShapes) { + classPredictions.add(manager.ones(shape)); + } + NDArray classPredictionOutput = concatPredictions(classPredictions); + Shape classPredictionOutputShape = + classPredictionOutput + .reshape(classPredictionOutput.size(0), -1, numClasses + 1) + .getShape(); + NDList anchorPredictions = new NDList(); + for (Shape shape : anchorPredictionShapes) { + anchorPredictions.add(manager.ones(shape)); + } + Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape(); + return new Shape[] { + anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape + }; } - Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape(); - return new Shape[] { - anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape - }; } private Shape concatShape(Shape shape, Shape concat, int axis) { @@ -177,15 +179,15 @@ private Shape concatShape(Shape shape, Shape concat, int axis) { /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); Shape[] shapes = inputShapes; for (int i = 0; i < features.size(); i++) { - shapes = features.get(i).initialize(manager, dataType, shapes); + features.get(i).initialize(manager, dataType, shapes); + shapes = features.get(i).getOutputShapes(shapes); classPredictionBlocks.get(i).initialize(manager, dataType, shapes); anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes); } - return getOutputShapes(manager, inputShapes); } /** {@inheritDoc} */ diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index ee8afeba541..54c58bf1a01 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -24,6 +24,8 @@ import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -33,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -122,12 +125,16 @@ public void load(Path modelPath, String prefix, Map options) /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index ffc04996a4b..3b63b96a928 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -21,7 +21,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -201,7 +200,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { if (outputShapes == null) { String[] outputNames = symbol.getOutputNames(); outputShapes = new Shape[outputNames.length]; @@ -232,9 +231,7 @@ public void removeLastBlock() { } } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { + private Shape getParameterShape(String name, Shape[] inputShapes) { if (paramShapes == null) { PairList pairs = new PairList<>(); for (int i = 0; i < inputNames.size(); i++) { @@ -314,25 +311,32 @@ private void initBlock() { Set auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames())); for (String name : allNames) { - ParameterType type = inferType(name); + Parameter.Type type = inferType(name); boolean requireGrad = !auxNameSet.contains(name); - mxNetParams.add(new Parameter(name, this, type, requireGrad)); + mxNetParams.add( + Parameter.builder() + .setName(name) + .setType(type) + .optRequiresGrad(requireGrad) + .build()); } first = true; } - private static ParameterType inferType(String name) { + private static Parameter.Type inferType(String name) { if (name.endsWith("bias")) { - return ParameterType.BIAS; + return Parameter.Type.BIAS; } else if (name.endsWith("gamma")) { - return ParameterType.GAMMA; + return Parameter.Type.GAMMA; } else if (name.endsWith("beta")) { - return ParameterType.BETA; + return Parameter.Type.BETA; } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) { - return ParameterType.RUNNING_MEAN; + return Parameter.Type.RUNNING_MEAN; } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { - return ParameterType.RUNNING_VAR; + return Parameter.Type.RUNNING_VAR; + } else if (name.endsWith("weight")) { + return Parameter.Type.WEIGHT; } - return ParameterType.OTHER; + return Parameter.Type.OTHER; } } diff --git a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java index 0ba9177404f..299136587c7 100644 --- a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java +++ b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -38,7 +39,7 @@ public void testMxAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); diff --git a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java index d855bb90ced..74fde4bf89a 100644 --- a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java +++ b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java @@ -25,6 +25,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.SequentialBlock; import ai.djl.nn.SymbolBlock; import ai.djl.nn.core.Linear; @@ -85,7 +86,7 @@ public void trainWithNewParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { model.getBlock().clear(); try (Trainer trainer = model.newTrainer(config)) { @@ -113,7 +114,7 @@ public void trainWithExistParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { try (Trainer trainer = model.newTrainer(config)) { NDManager manager = trainer.getManager(); @@ -140,7 +141,7 @@ public void trainWithCustomLayer() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { NDManager manager = model.getNDManager(); SymbolBlock mlp = (SymbolBlock) model.getBlock(); @@ -149,7 +150,6 @@ public void trainWithCustomLayer() newMlp.add(mlp); Linear linear = Linear.builder().setUnits(10).build(); - linear.setInitializer(Initializer.ONES); newMlp.add(linear); model.setBlock(newMlp); diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java index 522a928b32b..43b3a960f4c 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java @@ -98,7 +98,7 @@ private NDList getOutputs(PpNDArray[] outputs, boolean foreignEngine, NDManager /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 51ec90497ac..3d56947e34d 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -17,10 +17,13 @@ import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ndarray.types.DataType; +import ai.djl.nn.Parameter; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -28,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Predicate; import java.util.stream.Collectors; /** @@ -125,12 +129,16 @@ private Path findModelFile(String prefix) { /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 7cf0e5e3ac8..964649cd5a1 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -155,7 +155,7 @@ public PairList describeOutput() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 43f667d2dc2..7a6f8bdb753 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -131,8 +131,8 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { - return new Shape[0]; + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + throw new IllegalStateException("TfSymbolBlock can't be initialized"); } /** {@inheritDoc} */ @@ -197,7 +197,7 @@ public final PairList describeOutput() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; }