Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove NDManager on getOutputShapes #710

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/modality/nlp/Decoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return block.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return block.getOutputShapes(inputShapes);
}

/** {@inheritDoc} */
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/modality/nlp/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return block.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return block.getOutputShapes(inputShapes);
}

/** {@inheritDoc} */
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ public void initialize(NDManager manager, DataType dataType, Shape... inputShape

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return decoder.getOutputShapes(manager, new Shape[] {inputShapes[1]});
public Shape[] getOutputShapes(Shape[] inputShapes) {
return decoder.getOutputShapes(new Shape[] {inputShapes[1]});
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return trainableWordEmbedding.getOutputShapes(manager, inputShapes);
public Shape[] getOutputShapes(Shape[] inputShapes) {
return trainableWordEmbedding.getOutputShapes(inputShapes);
}
}
6 changes: 3 additions & 3 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
* <li>Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
* <li>Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
* <li>Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
* of your custom block's output based on the input it will receive.
* <li>Override {@link Block#getOutputShapes(Shape[])} to determine the shape of your custom
* block's output based on the input it will receive.
* <li>Override {@link AbstractBlock#initializeChildBlocks(NDManager, DataType, Shape...)} if you
* added child blocks to initialize them based on the input shape your block will receive. You
* can skip this if your block does not contain child blocks
Expand Down Expand Up @@ -451,7 +451,7 @@ public String toString() {
appendShape(sb, inputShapeDescription.values().toArray(new Shape[0]));
sb.append(" -> ");
Shape[] outputShapes =
getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0]));
getOutputShapes(inputShapeDescription.values().toArray(new Shape[0]));
appendShape(sb, outputShapes);
} else {
sb.append("Uninitialized");
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.nn;

import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
Expand All @@ -29,7 +28,7 @@ public AbstractSymbolBlock(byte version) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}
}
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,10 @@ default NDList forward(
/**
* Returns the expected output shapes of the block for the specified input shapes.
*
* @param manager an NDManager
* @param inputShapes the shapes of the inputs
* @return the expected output shapes of the block
*/
Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes);
Shape[] getOutputShapes(Shape[] inputShapes);

/**
* Writes the parameters of the block to the given outputStream.
Expand Down
6 changes: 3 additions & 3 deletions api/src/main/java/ai/djl/nn/LambdaBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
try (NDManager subManager = manager.newSubManager()) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
try (NDManager manager = NDManager.newBaseManager()) {
NDList input = new NDList(inputShapes.length);
for (Shape shape : inputShapes) {
input.add(subManager.zeros(shape));
input.add(manager.zeros(shape));
}
NDList output = lambda.apply(input);
Shape[] outputShapes = new Shape[output.size()];
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/nn/ParallelBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
Preconditions.checkArgument(!children.isEmpty(), "The parallel block is empty");

try (NDManager subManager = manager.newSubManager()) {
try (NDManager manager = NDManager.newBaseManager()) {
List<NDList> inputs = new ArrayList<>();
for (Block block : children.values()) {
Shape[] shapes = block.getOutputShapes(manager, inputShapes);
Shape[] shapes = block.getOutputShapes(inputShapes);
NDList output = new NDList(shapes.length);
for (Shape shape : shapes) {
output.add(subManager.create(shape));
output.add(manager.create(shape));
}
inputs.add(output);
}
Expand Down
6 changes: 3 additions & 3 deletions api/src/main/java/ai/djl/nn/SequentialBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,19 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
Shape[] shapes = inputShapes;
for (Block child : getChildren().values()) {
child.initialize(manager, dataType, shapes);
shapes = child.getOutputShapes(manager, shapes);
shapes = child.getOutputShapes(shapes);
}
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
if (children.isEmpty()) {
throw new IllegalArgumentException("The sequential block is empty");
}
Shape[] current = inputs;
for (Block block : children.values()) {
current = block.getOutputShapes(manager, current);
current = block.getOutputShapes(current);
}
return current;
}
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/convolutional/Convolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -169,7 +168,7 @@ protected void prepare(Shape[] inputs) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
long[] shape = new long[numDimensions()];
shape[0] = inputs[0].get(0);
shape[1] = filters;
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -146,7 +145,7 @@ protected void prepare(Shape[] inputs) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
long[] shape = new long[numDimensions()];
shape[0] = inputs[0].get(0);
shape[1] = filters;
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(embedding.getShape())};
}

Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/core/Embedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void prepare(Shape[] inputShapes) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
}

Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/core/Linear.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
Expand Down Expand Up @@ -90,7 +89,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)};
}

Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/core/Prelu.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -66,7 +65,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
return new Shape[] {inputs[0]};
}

Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/norm/BatchNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -145,7 +144,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0]};
}

Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/norm/Dropout.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -76,7 +75,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0]};
}

Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.nn.recurrent;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand Down Expand Up @@ -91,7 +90,7 @@ public RecurrentBlock(BaseBuilder<?> builder) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
public Shape[] getOutputShapes(Shape[] inputs) {
Shape inputShape = inputs[0];
Shape outputShape =
new Shape(inputShape.get(0), inputShape.get(1), stateSize * getNumDirections());
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/nn/transformer/BertBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public int getTypeDictionarySize() {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
long batch = inputShapes[0].get(0);
long seqLength = inputShapes[0].get(1);
return new Shape[] {
Expand All @@ -172,7 +172,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
Shape[] tokenShape = {inputShapes[0]};
Shape[] typeShape = {inputShapes[1]};
this.tokenEmbedding.initialize(manager, dataType, tokenShape);
Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(manager, tokenShape);
Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(tokenShape);
this.typeEmbedding.initialize(manager, dataType, typeShape);
this.embeddingNorm.initialize(manager, dataType, embeddingOutput);
this.embeddingDropout.initialize(manager, dataType, embeddingOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(final NDManager manager, final Shape[] inputShapes) {
public Shape[] getOutputShapes(final Shape[] inputShapes) {
int batchSize = (int) inputShapes[0].get(0);
int indexCount = (int) inputShapes[1].get(1);
int dictionarySize = (int) inputShapes[2].get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {new Shape(inputShapes[0].get(0), 2)};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public BertPretrainingBlock(final BertBlock.Builder builder) {
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices");
bertBlock.initialize(manager, dataType, inputShapes);
Shape[] bertOutputShapes = bertBlock.getOutputShapes(manager, inputShapes);
Shape[] bertOutputShapes = bertBlock.getOutputShapes(inputShapes);
Shape embeddedSequence = bertOutputShapes[0];
Shape pooledOutput = bertOutputShapes[1];
Shape maskedIndices = inputShapes[2];
Expand Down Expand Up @@ -98,7 +98,7 @@ protected NDList forwardInternal(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
long batchSize = inputShapes[0].get(0);
long maskedIndexCount = inputShapes[3].get(1);
return new Shape[] {
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private IdEmbedding(Builder builder) {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ public PointwiseFeedForwardBlock(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
for (Block child : children.values()) {
inputShapes = child.getOutputShapes(manager, inputShapes);
inputShapes = child.getOutputShapes(inputShapes);
}
return inputShapes;
}
Expand All @@ -84,7 +84,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
Shape lastShape = inputShape;
for (Block child : children.values()) {
child.initialize(manager, dataType, lastShape);
lastShape = getOutputShapes(manager, new Shape[] {lastShape})[0];
lastShape = getOutputShapes(new Shape[] {lastShape})[0];
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public Linear getResultProjection() {

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
// Return shape is the shape of the query. For 2 or less inputs we have self-attention, i.e.
// the shape of the output is the shape of the input
if (inputShapes.length == 1 || inputShapes.length == 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public TransformerEncoderBlock(

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
return inputShapes;
}

Expand Down
2 changes: 1 addition & 1 deletion api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private ShapeUtils() {}
* @return the corresponding output shape for the provided input
*/
public static Shape outputShapeForBlock(NDManager manager, Block block, Shape inputShape) {
Shape[] outputs = block.getOutputShapes(manager, new Shape[] {inputShape});
Shape[] outputs = block.getOutputShapes(new Shape[] {inputShape});
return outputs[0];
}

Expand Down
Loading