Skip to content

Commit

Permalink
Make XavierInitializer default value & Improve setInitializer (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Feb 26, 2021
1 parent 95de5df commit 836aac6
Show file tree
Hide file tree
Showing 37 changed files with 190 additions and 103 deletions.
24 changes: 16 additions & 8 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Predicate;

/**
* {@code AbstractBlock} is an abstract implementation of {@link Block}.
Expand Down Expand Up @@ -285,13 +286,9 @@ public PairList<String, Shape> describeInput() {

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
for (Parameter parameter : parameters.values()) {
parameter.setInitializer(initializer, false);
}
for (Block child : children.values()) {
child.setInitializer(initializer);
}
public void setInitializer(Initializer initializer, Parameter.Type params) {
Predicate<Parameter> predicate = parameter -> parameter.getType().equals(params);
setInitializer(initializer, predicate);
}

/** {@inheritDoc} */
Expand All @@ -301,7 +298,18 @@ public void setInitializer(Initializer initializer, String paramName) {
if (parameter == null) {
throw new IllegalArgumentException("Could not find parameter " + paramName);
}
parameter.setInitializer(initializer, true);
parameter.setInitializer(initializer);
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, Predicate<Parameter> predicate) {
List<Parameter> params = getParameters().values();
for (Parameter param : params) {
if (predicate.test(param)) {
param.setInitializer(initializer);
}
}
}

/** {@inheritDoc} */
Expand Down
14 changes: 12 additions & 2 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.function.Predicate;

/**
* A {@code Block} is a composable function that forms a neural network.
Expand Down Expand Up @@ -158,11 +159,12 @@ default NDList forward(
}

/**
* Sets an {@link Initializer} to the block.
* Sets an {@link Initializer} to all the parameters that match parameter type in the block.
*
* @param initializer the initializer to set
* @param type the Parameter Type we want to setInitializer
*/
void setInitializer(Initializer initializer);
void setInitializer(Initializer initializer, Parameter.Type type);

/**
* Sets an {@link Initializer} to the specified direct parameter of the block, overriding the
Expand All @@ -173,6 +175,14 @@ default NDList forward(
*/
void setInitializer(Initializer initializer, String paramName);

/**
* Sets an {@link Initializer} to all the parameters that match Predicate in the block.
*
* @param initializer the initializer to be set
* @param predicate predicate function to indicate parameters you want to set
*/
void setInitializer(Initializer initializer, Predicate<Parameter> predicate);

/**
* Initializes the parameters of the block. This method must be called before calling `forward`.
*
Expand Down
14 changes: 6 additions & 8 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/**
Expand Down Expand Up @@ -133,12 +133,9 @@ public boolean isInitialized() {
* flag is true, sets the initializer regardless.
*
* @param initializer the initializer to be set
* @param overwrite if true, set the initializer regardless of whether its already set or not
*/
public void setInitializer(Initializer initializer, boolean overwrite) {
if (overwrite || this.initializer == null) {
this.initializer = initializer;
}
public void setInitializer(Initializer initializer) {
this.initializer = initializer;
}

/**
Expand All @@ -150,7 +147,6 @@ public void setInitializer(Initializer initializer, boolean overwrite) {
* @param inputShapes the expected input shapes
*/
public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
Objects.requireNonNull(initializer, "No initializer has been set");
if (!isInitialized()) {
Shape shape = block.getParameterShape(name, inputShapes);
array = initializer.initialize(manager, shape, dataType);
Expand Down Expand Up @@ -239,7 +235,9 @@ public static Parameter.Builder builder() {

/** Enumerates the types of {@link Parameter}. */
public enum Type {
WEIGHT(null),
WEIGHT(
new XavierInitializer(
XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)),
BIAS(Initializer.ZEROS),
GAMMA(Initializer.ONES),
BETA(Initializer.ZEROS),
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/core/Prelu.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public Prelu() {
Parameter.builder()
.setName("alpha")
.setBlock(this)
.setType(Parameter.Type.OTHER)
.setType(Parameter.Type.WEIGHT)
.build(),
new Shape());
}
Expand Down
51 changes: 38 additions & 13 deletions api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@
package ai.djl.training;

import ai.djl.Device;
import ai.djl.nn.Parameter;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.initializer.XavierInitializer.FactorType;
import ai.djl.training.initializer.XavierInitializer.RandomType;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;

/** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */
public class DefaultTrainingConfig implements TrainingConfig {

private Initializer initializer;
private PairList<Initializer, Predicate<Parameter>> initializers = new PairList<>();
private Optimizer optimizer;
private Device[] devices;
private Loss loss;
Expand All @@ -38,15 +38,12 @@ public class DefaultTrainingConfig implements TrainingConfig {

/**
* Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code
* DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link
* XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The
* evaluators and listeners are left to the user's discretion.
* DefaultTrainingConfig} creates a default {@link TrainingConfig}, {@link Adam} as optimiser,
* and the given {@link Loss}. The evaluators and listeners are left to the user's discretion.
*
* @param loss the loss to use for training
*/
public DefaultTrainingConfig(Loss loss) {
// Defaults to initializer defined in https://arxiv.org/abs/1502.01852
this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2);
optimizer = Adam.builder().build();
this.loss = loss;
evaluators = new ArrayList<>();
Expand All @@ -58,10 +55,38 @@ public DefaultTrainingConfig(Loss loss) {
* href="https://arxiv.org/abs/1502.01852">paper</a>).
*
* @param initializer the initialer to use for the parameters
* @param type the {@link Parameter.Type} of the parameters
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optInitializer(Initializer initializer) {
this.initializer = initializer;
public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) {
initializers.add(initializer, parameter -> parameter.getType().equals(type));
return this;
}

/**
* Sets the {@link Initializer} to use for the parameters (default from <a
* href="https://arxiv.org/abs/1502.01852">paper</a>).
*
* @param initializer the initialer to use for the parameters
* @param name the name of the parameter
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optInitializer(Initializer initializer, String name) {
initializers.add(initializer, parameter -> parameter.getName().equals(name));
return this;
}

/**
* Sets the {@link Initializer} to use for the parameters (default from <a
* href="https://arxiv.org/abs/1502.01852">paper</a>).
*
* @param initializer the initialer to use for the parameters
* @param predicate the predicate to identify parameter
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optInitializer(
Initializer initializer, Predicate<Parameter> predicate) {
initializers.add(initializer, predicate);
return this;
}

Expand Down Expand Up @@ -120,8 +145,8 @@ public Device[] getDevices() {

/** {@inheritDoc} */
@Override
public Initializer getInitializer() {
return initializer;
public PairList<Initializer, Predicate<Parameter>> getInitializers() {
return initializers;
}

/** {@inheritDoc} */
Expand Down
7 changes: 5 additions & 2 deletions api/src/main/java/ai/djl/training/TrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
package ai.djl.training;

import ai.djl.Device;
import ai.djl.nn.Parameter;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.PairList;
import java.util.List;
import java.util.function.Predicate;

/**
* An interface that is responsible for holding the configuration required by {@link Trainer}.
Expand Down Expand Up @@ -64,11 +67,11 @@ public interface TrainingConfig {
Device[] getDevices();

/**
* Gets the {@link Initializer} to initialize the parameters of the model.
* Gets a list of {@link Initializer} and Predicate to initialize the parameters of the model.
*
* @return an {@link Initializer}
*/
Initializer getInitializer();
PairList<Initializer, Predicate<Parameter>> getInitializers();

/**
* Gets the {@link Optimizer} to use during training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public XavierInitializer(RandomType randomType, FactorType factorType, float mag

/** Creates a new instance of {@code XavierInitializer}. */
public XavierInitializer() {
this(RandomType.UNIFORM, FactorType.AVG, 3f);
this(RandomType.UNIFORM, FactorType.AVG, 6f);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
Expand Down Expand Up @@ -48,7 +49,7 @@ public class AirfoilRandomAccessTest {
public void testAirfoilRemote() throws IOException, TranslateException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optInitializer(Initializer.ONES);
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
Expand Down Expand Up @@ -48,7 +49,7 @@ public class AmesRandomAccessTest {
public void testAmesRandomAccessRemote() throws IOException, TranslateException {
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optInitializer(Initializer.ONES);
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);

try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.basicdataset.cv.PikachuDetection;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
Expand All @@ -42,7 +43,7 @@ public void testPikachuRemote() throws IOException, TranslateException {
.build();
TrainingConfig config =
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optInitializer(new NormalInitializer(0.01f));
.optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT);
try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
try (Trainer trainer = model.newTrainer(config)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.transformer.BertBlock;
import ai.djl.nn.transformer.BertPretrainingBlock;
import ai.djl.nn.transformer.BertPretrainingLoss;
Expand Down Expand Up @@ -135,7 +136,8 @@ private static Model createBertPretrainingModel(Dictionary dictionary) {
model.setBlock(
new BertPretrainingBlock(
BERT_BUILDER.setTokenDictionarySize(dictionary.tokens.size())));
model.getBlock().setInitializer(new TruncatedNormalInitializer(0.02f));
model.getBlock()
.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);
return model;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
Expand Down Expand Up @@ -119,7 +118,6 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {

return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.optInitializer(new XavierInitializer())
.optDevices(Device.getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
package ai.djl.fasttext;

import ai.djl.Device;
import ai.djl.nn.Parameter;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.PairList;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;

/** An interface that is responsible for holding the configuration required by fastText training. */
public class FtTrainingConfig implements TrainingConfig {
Expand Down Expand Up @@ -247,7 +250,7 @@ public Device[] getDevices() {

/** {@inheritDoc} */
@Override
public Initializer getInitializer() {
public PairList<Initializer, Predicate<Parameter>> getInitializers() {
return null;
}

Expand Down
Loading

0 comments on commit 836aac6

Please sign in to comment.