From 9ba8a1b71461ea029c709c64046415b2e29e7d09 Mon Sep 17 00:00:00 2001 From: gstu1130 Date: Tue, 16 Feb 2021 23:18:20 -0800 Subject: [PATCH] Use builder pattern for Parameter --- api/src/main/java/ai/djl/nn/Parameter.java | 211 ++++++++++++------ .../main/java/ai/djl/nn/ParameterType.java | 41 ---- .../ai/djl/nn/convolutional/Convolution.java | 14 +- .../djl/nn/convolutional/Deconvolution.java | 14 +- .../main/java/ai/djl/nn/core/Embedding.java | 29 +-- api/src/main/java/ai/djl/nn/core/Linear.java | 16 +- api/src/main/java/ai/djl/nn/core/Prelu.java | 10 +- .../main/java/ai/djl/nn/norm/BatchNorm.java | 29 ++- .../ai/djl/nn/recurrent/RecurrentBlock.java | 12 +- .../java/ai/djl/nn/transformer/BertBlock.java | 7 +- .../BertMaskedLanguageModelBlock.java | 7 +- .../ai/djl/nn/transformer/IdEmbedding.java | 7 +- .../ai/djl/mxnet/engine/MxSymbolBlock.java | 25 ++- 13 files changed, 268 insertions(+), 154 deletions(-) delete mode 100644 api/src/main/java/ai/djl/nn/ParameterType.java diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index 82af7f8a8e1..ba985d85979 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -40,61 +40,22 @@ public class Parameter implements AutoCloseable { private String id; private String name; private Block block; - private ParameterType type; - private DataType mandatoryDataType; + private Type type; private Initializer initializer; private NDArray array; private boolean requiresGrad; private SparseFormat gradientFormat; - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - */ - public Parameter(String name, Block block, ParameterType type) { - this(name, block, type, true, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requiresGrad whether this {@code Parameter} needs to compute gradients - */ - public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) { - this(name, block, type, requiresGrad, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requireGrad whether this {@code Parameter} needs to compute gradients - * @param gradientFormat the {@link SparseFormat} of the gradient array - */ - public Parameter( - String name, - Block block, - ParameterType type, - boolean requireGrad, - SparseFormat gradientFormat) { + Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); - this.name = name; - this.block = block; - this.type = type; - this.requiresGrad = requireGrad; - this.initializer = type.getInitializer(); - this.gradientFormat = gradientFormat; + this.name = builder.name; + this.block = builder.block; + this.type = builder.type; + this.array = builder.array; + this.requiresGrad = builder.requiresGrad; + this.initializer = + (builder.initializer != null) ? builder.initializer : type.getInitializer(); + this.gradientFormat = builder.gradientFormat; } /** @@ -120,7 +81,7 @@ public String getName() { * * @return the type of this {@code Parameter} */ - public ParameterType getType() { + public Type getType() { return type; } @@ -155,15 +116,6 @@ public boolean requireGradient() { return requiresGrad; } - /** - * Sets the mandatory data type for this {@code Parameter}. - * - * @param mandatoryDataType the mandatory data type for this {@code Parameter} - */ - public void setMandatoryDataType(DataType mandatoryDataType) { - this.mandatoryDataType = mandatoryDataType; - } - /** * Checks if this {@code Parameter} is initialized. * @@ -198,11 +150,7 @@ public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes Objects.requireNonNull(initializer, "No initializer has been set"); if (!isInitialized()) { Shape shape = block.getParameterShape(name, inputShapes); - array = - initializer.initialize( - manager, - shape, - mandatoryDataType == null ? dataType : mandatoryDataType); + array = initializer.initialize(manager, shape, dataType); array.setName(name); } @@ -273,4 +221,139 @@ public void close() { array = null; } } + + /** + * Creates a builder to build a {@code Parameter}. + * + *

The methods start with {@code set} are required fields, and {@code opt} for optional + * fields. + * + * @return a new builder + */ + public static Parameter.Builder builder() { + return new Parameter.Builder(); + } + + /** Enumerates the types of {@link Parameter}. */ + public enum Type { + WEIGHT(null), + BIAS(Initializer.ZEROS), + GAMMA(Initializer.ONES), + BETA(Initializer.ZEROS), + RUNNING_MEAN(Initializer.ZEROS), + RUNNING_VAR(Initializer.ONES), + OTHER(null); + + private final transient Initializer initializer; + + Type(Initializer initializer) { + this.initializer = initializer; + } + + /** + * Gets the {@link Initializer} of this {@code ParameterType}. + * + * @return the {@link Initializer} of this {@code ParameterType} + */ + public Initializer getInitializer() { + return initializer; + } + } + + /** A Builder to construct a {@code Parameter}. */ + public static final class Builder { + String name; + Block block; + Type type; + Initializer initializer; + NDArray array; + boolean requiresGrad = true; + SparseFormat gradientFormat; + + /** + * Sets the name of the {@code Parameter}. + * + * @param name the name of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setName(String name) { + this.name = name; + return this; + } + + /** + * Sets the block of the {@code Parameter}. + * + * @param block the block of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setBlock(Block block) { + this.block = block; + return this; + } + + /** + * Sets the {@code Type} of the {@code Parameter}. + * + * @param type the {@code Type} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setType(Type type) { + this.type = type; + return this; + } + + /** + * Sets the Initializer of the {@code Parameter}. + * + * @param initializer the Initializer of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optInitializer(Initializer initializer) { + this.initializer = initializer; + return this; + } + + /** + * Sets the array of the {@code Parameter}. + * + * @param array the array of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optArray(NDArray array) { + this.array = array; + return this; + } + + /** + * Sets if the {@code Parameter} requires gradient. + * + * @param requiresGrad if the {@code Parameter} requires gradient + * @return this {@code Parameter} + */ + public Builder optRequiresGrad(boolean requiresGrad) { + this.requiresGrad = requiresGrad; + return this; + } + + /** + * Sets the {@link SparseFormat} of the {@code Parameter}. + * + * @param gradientFormat the {@link SparseFormat} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optGradientFormat(SparseFormat gradientFormat) { + this.gradientFormat = gradientFormat; + return this; + } + + /** + * Builds a {@code Parameter} instance. + * + * @return the {@code Parameter} instance + */ + public Parameter build() { + return new Parameter(this); + } + } } diff --git a/api/src/main/java/ai/djl/nn/ParameterType.java b/api/src/main/java/ai/djl/nn/ParameterType.java deleted file mode 100644 index dde3cb23ab8..00000000000 --- a/api/src/main/java/ai/djl/nn/ParameterType.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.nn; - -import ai.djl.training.initializer.Initializer; - -/** Enumerates the types of {@link Parameter}. */ -public enum ParameterType { - WEIGHT(null), - BIAS(Initializer.ZEROS), - GAMMA(Initializer.ONES), - BETA(Initializer.ZEROS), - RUNNING_MEAN(Initializer.ZEROS), - RUNNING_VAR(Initializer.ONES), - OTHER(null); - - private final transient Initializer initializer; - - ParameterType(Initializer initializer) { - this.initializer = initializer; - } - - /** - * Gets the {@link Initializer} of this {@code ParameterType}. - * - * @return the {@link Initializer} of this {@code ParameterType} - */ - public Initializer getInitializer() { - return initializer; - } -} diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index 9fcbfc83b48..48bae1ccbe6 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -22,7 +22,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -99,13 +98,22 @@ public Convolution(ConvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), + Parameter.builder() + .setName("weight") + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .build(), (inputShapes) -> new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setBlock(this) + .setType(Parameter.Type.BIAS) + .build(), + new Shape(filters)); } } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 890a4db9eff..30d348cebc9 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -22,7 +22,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -78,13 +77,22 @@ public Deconvolution(DeconvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), + Parameter.builder() + .setName("weight") + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .build(), (inputShapes) -> new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setBlock(this) + .setType(Parameter.Type.BIAS) + .build(), + new Shape(filters)); } } diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index f723500dacf..c38682ffd4b 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -25,7 +25,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -60,12 +59,14 @@ protected Embedding(BaseBuilder baseBuilder) { dataType = baseBuilder.dataType; embedding = addParameter( - new Parameter( - "embedding", - this, - ParameterType.WEIGHT, - true, - sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE), + Parameter.builder() + .setName("embedding") + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .optRequiresGrad(true) + .optGradientFormat( + sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE) + .build(), (inputShapes) -> new Shape(numItems, embeddingSize)); if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) { throw new IllegalArgumentException( @@ -103,12 +104,14 @@ public Embedding(NDArray embedding, boolean sparseGrad) { dataType = embedding.getDataType(); this.embedding = addParameter( - new Parameter( - "embedding", - this, - ParameterType.WEIGHT, - true, - sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE), + Parameter.builder() + .setName("embedding") + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .optRequiresGrad(true) + .optGradientFormat( + sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE) + .build(), (inputShapes) -> new Shape(numItems, embeddingSize)); this.embedding.setArray(embedding); numItems = Math.toIntExact(embedding.getShape().size(0)); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 34a8c020a18..9041b624984 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -21,7 +21,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import ai.djl.util.Preconditions; @@ -63,10 +62,21 @@ public class Linear extends AbstractBlock { // a callback, even if we do not used the callback parameter weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), + Parameter.builder() + .setName("weight") + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .build(), inputShapes -> new Shape(units, inputFeatures)); if (builder.bias) { - bias = addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(units)); + bias = + addParameter( + Parameter.builder() + .setName("bias") + .setBlock(this) + .setType(Parameter.Type.BIAS) + .build(), + new Shape(units)); } } diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index abb268f9446..de125d6d648 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -20,7 +20,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -44,7 +43,14 @@ public class Prelu extends AbstractBlock { /** Creates a Parametric ReLU Block. */ public Prelu() { super(VERSION); - alpha = addParameter(new Parameter("alpha", this, ParameterType.OTHER), new Shape()); + alpha = + addParameter( + Parameter.builder() + .setName("alpha") + .setBlock(this) + .setType(Parameter.Type.OTHER) + .build(), + new Shape()); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java index cbe1bffd932..f8324bb209a 100644 --- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java @@ -21,7 +21,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -88,20 +87,40 @@ public class BatchNorm extends AbstractBlock { // make gamma trainable if scale gamma = addParameter( - new Parameter("gamma", this, ParameterType.GAMMA, scale), + Parameter.builder() + .setName("gamma") + .setBlock(this) + .setType(Parameter.Type.GAMMA) + .optRequiresGrad(scale) + .build(), (inputShapes) -> new Shape(inChannels)); // make beta trainable if center beta = addParameter( - new Parameter("beta", this, ParameterType.BETA, center), + Parameter.builder() + .setName("beta") + .setBlock(this) + .setType(Parameter.Type.BETA) + .optRequiresGrad(center) + .build(), (inputShapes) -> new Shape(inChannels)); runningMean = addParameter( - new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false), + Parameter.builder() + .setName("runningMean") + .setBlock(this) + .setType(Parameter.Type.RUNNING_MEAN) + .optRequiresGrad(false) + .build(), (inputShapes) -> new Shape(inChannels)); runningVar = addParameter( - new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false), + Parameter.builder() + .setName("runningVar") + .setBlock(this) + .setType(Parameter.Type.RUNNING_VAR) + .optRequiresGrad(false) + .build(), (inputShapes) -> new Shape(inChannels)); } diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index ab99442f180..b474c9708e0 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -19,7 +19,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import java.io.DataInputStream; import java.io.IOException; @@ -67,20 +66,25 @@ public RecurrentBlock(BaseBuilder builder) { bidirectional = builder.bidirectional; returnState = builder.returnState; - ParameterType[] parameterTypes = {ParameterType.WEIGHT, ParameterType.BIAS}; + Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS}; String[] directions = {"l"}; if (builder.bidirectional) { directions = new String[] {"l", "r"}; } String[] gateStrings = {"i2h", "h2h"}; - for (ParameterType parameterType : parameterTypes) { + for (Parameter.Type parameterType : parameterTypes) { for (int i = 0; i < numLayers; i++) { for (String direction : directions) { for (String gateString : gateStrings) { String name = direction + '_' + i + '_' + gateString + '_' + parameterType.name(); - addParameter(new Parameter(name, this, parameterType)); + addParameter( + Parameter.builder() + .setName(name) + .setBlock(this) + .setType(parameterType) + .build()); } } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index 91dc57fbc04..4deabda9964 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -22,7 +22,6 @@ import ai.djl.nn.Activation; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.nn.norm.Dropout; @@ -77,7 +76,11 @@ private BertBlock(Builder builder) { // embedding for the position this.positionEmebdding = addParameter( - new Parameter(PARAM_POSITION_EMBEDDING, this, ParameterType.WEIGHT), + Parameter.builder() + .setName(PARAM_POSITION_EMBEDDING) + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .build(), new Shape(builder.maxSequenceLength, builder.embeddingSize)); // embedding for the input types this.typeEmbedding = diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index ed7d1d10bf3..7d7c97e5288 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -19,7 +19,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.training.ParameterStore; @@ -59,7 +58,11 @@ public BertMaskedLanguageModelBlock( this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build()); this.dictionaryBias = addParameter( - new Parameter("dictionaryBias", this, ParameterType.BIAS), + Parameter.builder() + .setName("dictionaryBias") + .setBlock(this) + .setType(Parameter.Type.BIAS) + .build(), new Shape(bertBlock.getTokenDictionarySize())); this.hiddenActivation = hiddenActivation; } diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java index 42dcb54c230..b01c359bdaa 100644 --- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java +++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java @@ -21,7 +21,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.util.Arrays; @@ -47,7 +46,11 @@ private IdEmbedding(Builder builder) { this.embeddingSize = builder.embeddingSize; this.embedding = addParameter( - new Parameter(EMBEDDING_PARAM_NAME, this, ParameterType.WEIGHT), + Parameter.builder() + .setName(EMBEDDING_PARAM_NAME) + .setBlock(this) + .setType(Parameter.Type.WEIGHT) + .build(), new Shape(dictionarySize, embeddingSize)); } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 2cfd4267701..50254346689 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -21,7 +21,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -79,9 +78,15 @@ public MxSymbolBlock(NDManager manager, Symbol symbol) { Set auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames())); for (String name : allNames) { - ParameterType type = inferType(name); + Parameter.Type type = inferType(name); boolean requireGrad = !auxNameSet.contains(name); - mxNetParams.add(new Parameter(name, this, type, requireGrad)); + mxNetParams.add( + Parameter.builder() + .setName(name) + .setBlock(this) + .setType(type) + .optRequiresGrad(requireGrad) + .build()); } first = true; } @@ -283,18 +288,18 @@ public void loadParameters(NDManager manager, DataInputStream is) } } - private static ParameterType inferType(String name) { + private static Parameter.Type inferType(String name) { if (name.endsWith("bias")) { - return ParameterType.BIAS; + return Parameter.Type.BIAS; } else if (name.endsWith("gamma")) { - return ParameterType.GAMMA; + return Parameter.Type.GAMMA; } else if (name.endsWith("beta")) { - return ParameterType.BETA; + return Parameter.Type.BETA; } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) { - return ParameterType.RUNNING_MEAN; + return Parameter.Type.RUNNING_MEAN; } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { - return ParameterType.RUNNING_VAR; + return Parameter.Type.RUNNING_VAR; } - return ParameterType.OTHER; + return Parameter.Type.OTHER; } }