Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Multiplication block #2110

Merged
merged 3 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions api/src/main/java/ai/djl/nn/core/LinearCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
* {@link Linear} block, the involved shapes have typically one split dimension added which
* separates the different linear transformations from each other. Another difference to a {@link
* Linear} block is that the weight is not transposed (to align with the internally used algebraic
* operation {@link NDArray#matMul(NDArray)} ). There is currently only a single batch dimension
* \(x_1\) supported.
* operation {@link NDArray#matMul(NDArray)} ).
*
* <p>It has the following shapes:
*
Expand Down Expand Up @@ -117,15 +116,15 @@ protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
Shape input = inputShapes[0];
inputFeatures = input.slice(1, input.dimension()).size();
inputFeatures = input.slice(1).size();
inputShape = input.slice(0, 1);
}

/** {@inheritDoc} */
@Override
public void prepare(Shape[] inputShapes) {
Shape input = inputShapes[0];
weight.setShape(input.slice(1, input.dimension()).add(units));
weight.setShape(input.slice(1).add(units));
if (bias != null) {
bias.setShape(input.slice(1, input.dimension() - 1).add(units));
}
Expand Down
195 changes: 195 additions & 0 deletions api/src/main/java/ai/djl/nn/core/Multiplication.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

/**
* A Multiplication block performs an element-wise multiplication of inputs and weights as opposed
* to a {@link Linear} block which additionally sums up each element-wise multiplication.
*
* <p>Similar to a {@link LinearCollection}, multiple split dimensions are supported but they remain
* optional (i.e. \(t\) can be zero). Other differences to a {@link Linear} block are that the
* weight has an additional dimension of size 1 interspersed (to broadcast the weight to every input
* of the batch when applying the internally used algebraic operation {@link NDArray#mul(NDArray)} )
* and that biases are not supported.
*
* <p>Caution: the output-channel is the left-most dimension as opposed to traditionally being the
* right-most dimension. As the output is one dimension larger than that of a {@link Linear} block,
* it is more efficient and therefore recommended to apply an aggregating function (like the sum)
* first and only then shift the first axis of the aggregated and thus smaller {@link NDArray}
* instance into last position.
*
* <p>It has the following shapes:
*
* <ul>
* <li>input X: [x_1, s_1, s_2, …, s_t, input_dim]
* <li>weight W: [units, 1, s_1, s_2, …, s_t, input_dim]
* <li>output Y: [units, x_1, s_1, s_2, …, s_t, input_dim]
* </ul>
*
* <p>The Multiplication block should be constructed using {@link Multiplication.Builder}.
*/
public class Multiplication extends AbstractBlock {

private static final byte VERSION = 1;

private long units;
private long inputFeatures;

private Shape inputShape;

private Parameter weight;

Multiplication(Builder builder) {
super(VERSION);
units = builder.units;
weight =
addParameter(
Parameter.builder()
.setName("weight")
.setType(Parameter.Type.WEIGHT)
.build());
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Device device = input.getDevice();
NDArray weightArr = parameterStore.getValue(weight, device, training);
return multiply(input, weightArr);
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputs) {
return new Shape[] {new Shape(units).addAll(inputs[0])};
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
return new PairList<>(
Collections.singletonList("linearInput"), Collections.singletonList(inputShape));
}

/** {@inheritDoc} */
@Override
protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
Shape input = inputShapes[0];
inputFeatures = input.slice(1).size();
inputShape = input.slice(0, 1);
}

/** {@inheritDoc} */
@Override
public void prepare(Shape[] inputShapes) {
Shape input = inputShapes[0];
weight.setShape(new Shape(units, 1).addAll(input.slice(1)));
}

/** {@inheritDoc} */
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
os.writeLong(units);
os.writeLong(inputFeatures);
os.write(inputShape.getEncoded());
}

/** {@inheritDoc} */
@Override
public void loadMetadata(byte loadVersion, DataInputStream is)
throws IOException, MalformedModelException {
if (loadVersion == VERSION) {
units = is.readLong();
inputFeatures = is.readLong();
} else {
throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
}
inputShape = Shape.decode(is);
}

/**
* Applies an element-wise multiplication to the incoming data.
*
* @param input The incoming data
* @param weight The weight of this block
* @return element-wise multiplication of input and weight using broadcasting rules
*/
public NDList multiply(NDArray input, NDArray weight) {
NDArray resultArr = input.mul(weight);
return new NDList(resultArr);
}

/**
* Creates a builder to build a {@code Linear}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/** The Builder to construct a {@link Multiplication} type of {@link Block}. */
public static final class Builder {

private long units;

Builder() {}

/**
* Sets the number of output channels.
*
* @param units the number of desired output channels
* @return this Builder
*/
public Builder setUnits(long units) {
this.units = units;
return this;
}

/**
* Returns the constructed {@code Linear}.
*
* @return the constructed {@code Linear}
* @throws IllegalArgumentException if all required parameters (outChannels) have not been
* set
*/
public Multiplication build() {
Preconditions.checkArgument(units > 0, "You must specify unit");
return new Multiplication(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import ai.djl.nn.convolutional.Conv3d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.core.LinearCollection;
import ai.djl.nn.core.Multiplication;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.norm.GhostBatchNorm;
Expand Down Expand Up @@ -333,6 +334,72 @@ public void testLinearCollection() throws IOException, MalformedModelException {
}
}

@Test
public void testMultiplication() throws IOException, MalformedModelException {

// 4 samples times 3 features
float[][] dataArr = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}};

// 2 units times 3 features
float[][] weightArr = {{0, 1, 2}, {10, 11, 12}};

// store sum on Multiplication block's result
NDArray sum;

try (NDManager sharedManager = NDManager.newBaseManager()) {

// test algebraic expectation of Multiplication block
long outSize = 2;
Block block = Multiplication.builder().setUnits(outSize).build();
try (Model model = Model.newInstance("model")) {
model.setBlock(block);

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(
(m, s, t) -> m.create(weightArr).expandDims(1),
Parameter.Type.WEIGHT);
try (Trainer trainer = model.newTrainer(config)) {
Shape inputShape = new Shape(4, 3);
trainer.initialize(inputShape);

NDManager manager = trainer.getManager();
NDArray data = manager.create(dataArr);
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
NDArray expected = data.mul(manager.create(weightArr).expandDims(1));
Assert.assertEquals(result, expected);

testEncode(manager, block);

sum = result.sum(new int[] {-1}).transpose();
sum.attach(sharedManager);
}
}

// test "sum" is equal to linear transformation without bias
block = Linear.builder().setUnits(outSize).optBias(false).build();
try (Model model = Model.newInstance("model")) {
model.setBlock(block);

TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(
(m, s, t) -> m.create(weightArr), Parameter.Type.WEIGHT);
try (Trainer trainer = model.newTrainer(config)) {
Shape inputShape = new Shape(4, 3);
trainer.initialize(inputShape);

NDManager manager = trainer.getManager();
NDArray data = manager.create(dataArr);
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
NDArray expected = data.dot(manager.create(weightArr).transpose());
Assert.assertEquals(result, expected);
Assert.assertEquals(result, sum);
}
}
}
}

@SuppressWarnings("try")
@Test
public void testBatchNorm() throws IOException, MalformedModelException {
Expand Down