Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use builder pattern for Parameter #659

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 147 additions & 64 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,61 +40,22 @@ public class Parameter implements AutoCloseable {
private String id;
private String name;
private Block block;
private ParameterType type;
private DataType mandatoryDataType;
private Type type;
private Initializer initializer;
private NDArray array;
private boolean requiresGrad;
private SparseFormat gradientFormat;

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
*/
public Parameter(String name, Block block, ParameterType type) {
this(name, block, type, true, SparseFormat.DENSE);
}

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
* @param requiresGrad whether this {@code Parameter} needs to compute gradients
*/
public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) {
this(name, block, type, requiresGrad, SparseFormat.DENSE);
}

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
* @param requireGrad whether this {@code Parameter} needs to compute gradients
* @param gradientFormat the {@link SparseFormat} of the gradient array
*/
public Parameter(
String name,
Block block,
ParameterType type,
boolean requireGrad,
SparseFormat gradientFormat) {
Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = name;
this.block = block;
this.type = type;
this.requiresGrad = requireGrad;
this.initializer = type.getInitializer();
this.gradientFormat = gradientFormat;
this.name = builder.name;
this.block = builder.block;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
this.initializer =
(builder.initializer != null) ? builder.initializer : type.getInitializer();
this.gradientFormat = builder.gradientFormat;
}

/**
Expand All @@ -120,7 +81,7 @@ public String getName() {
*
* @return the type of this {@code Parameter}
*/
public ParameterType getType() {
public Type getType() {
return type;
}

Expand Down Expand Up @@ -155,15 +116,6 @@ public boolean requireGradient() {
return requiresGrad;
}

/**
* Sets the mandatory data type for this {@code Parameter}.
*
* @param mandatoryDataType the mandatory data type for this {@code Parameter}
*/
public void setMandatoryDataType(DataType mandatoryDataType) {
this.mandatoryDataType = mandatoryDataType;
}

/**
* Checks if this {@code Parameter} is initialized.
*
Expand Down Expand Up @@ -198,11 +150,7 @@ public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes
Objects.requireNonNull(initializer, "No initializer has been set");
if (!isInitialized()) {
Shape shape = block.getParameterShape(name, inputShapes);
array =
initializer.initialize(
manager,
shape,
mandatoryDataType == null ? dataType : mandatoryDataType);
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}

Expand Down Expand Up @@ -273,4 +221,139 @@ public void close() {
array = null;
}
}

/**
* Creates a builder to build a {@code Parameter}.
*
* <p>The methods start with {@code set} are required fields, and {@code opt} for optional
* fields.
*
* @return a new builder
*/
public static Parameter.Builder builder() {
return new Parameter.Builder();
}

/** Enumerates the types of {@link Parameter}. */
public enum Type {
WEIGHT(null),
BIAS(Initializer.ZEROS),
GAMMA(Initializer.ONES),
BETA(Initializer.ZEROS),
RUNNING_MEAN(Initializer.ZEROS),
RUNNING_VAR(Initializer.ONES),
OTHER(null);

private final transient Initializer initializer;

Type(Initializer initializer) {
this.initializer = initializer;
}

/**
* Gets the {@link Initializer} of this {@code ParameterType}.
*
* @return the {@link Initializer} of this {@code ParameterType}
*/
public Initializer getInitializer() {
return initializer;
}
}

/** A Builder to construct a {@code Parameter}. */
public static final class Builder {
String name;
Block block;
Type type;
Initializer initializer;
NDArray array;
boolean requiresGrad = true;
SparseFormat gradientFormat;

/**
* Sets the name of the {@code Parameter}.
*
* @param name the name of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setName(String name) {
this.name = name;
return this;
}

/**
* Sets the block of the {@code Parameter}.
*
* @param block the block of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setBlock(Block block) {
this.block = block;
return this;
}

/**
* Sets the {@code Type} of the {@code Parameter}.
*
* @param type the {@code Type} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setType(Type type) {
this.type = type;
return this;
}

/**
* Sets the Initializer of the {@code Parameter}.
*
* @param initializer the Initializer of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optInitializer(Initializer initializer) {
this.initializer = initializer;
return this;
}

/**
* Sets the array of the {@code Parameter}.
*
* @param array the array of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optArray(NDArray array) {
this.array = array;
return this;
}

/**
* Sets if the {@code Parameter} requires gradient.
*
* @param requiresGrad if the {@code Parameter} requires gradient
* @return this {@code Parameter}
*/
public Builder optRequiresGrad(boolean requiresGrad) {
this.requiresGrad = requiresGrad;
return this;
}

/**
* Sets the {@link SparseFormat} of the {@code Parameter}.
*
* @param gradientFormat the {@link SparseFormat} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optGradientFormat(SparseFormat gradientFormat) {
this.gradientFormat = gradientFormat;
return this;
}

/**
* Builds a {@code Parameter} instance.
*
* @return the {@code Parameter} instance
*/
public Parameter build() {
return new Parameter(this);
}
}
}
41 changes: 0 additions & 41 deletions api/src/main/java/ai/djl/nn/ParameterType.java

This file was deleted.

14 changes: 11 additions & 3 deletions api/src/main/java/ai/djl/nn/convolutional/Convolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
Expand Down Expand Up @@ -99,13 +98,22 @@ public Convolution(ConvolutionBuilder<?> builder) {

weight =
addParameter(
new Parameter("weight", this, ParameterType.WEIGHT),
Parameter.builder()
.setName("weight")
.setBlock(this)
.setType(Parameter.Type.WEIGHT)
.build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
Parameter.builder()
.setName("bias")
.setBlock(this)
.setType(Parameter.Type.BIAS)
.build(),
new Shape(filters));
}
}

Expand Down
14 changes: 11 additions & 3 deletions api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
Expand Down Expand Up @@ -78,13 +77,22 @@ public Deconvolution(DeconvolutionBuilder<?> builder) {

weight =
addParameter(
new Parameter("weight", this, ParameterType.WEIGHT),
Parameter.builder()
.setName("weight")
.setBlock(this)
.setType(Parameter.Type.WEIGHT)
.build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
Parameter.builder()
.setName("bias")
.setBlock(this)
.setType(Parameter.Type.BIAS)
.build(),
new Shape(filters));
}
}

Expand Down
Loading