Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Sparsemax block #2028

Merged
merged 8 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions api/src/main/java/ai/djl/nn/Activation.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,28 @@ public static NDList relu(NDList arrays) {
return new NDList(arrays.singletonOrThrow().getNDArrayInternal().relu());
}

/**
* Applies ReLU6 activation on the input {@link NDArray}.
*
* <p>ReLU6 is defined by: \( y = min(6,max(0, x)) \)
*
* @param array the input singleton {@link NDArray}
* @return the {@link NDArray} after applying ReLU6 activation
*/
public static NDArray relu6(NDArray array) {
return NDArrays.minimum(6, array.getNDArrayInternal().relu());
}

/**
* Applies ReLU6 activation on the input singleton {@link NDList}.
*
* <p>ReLU is defined by: \( y = min(6,max(0, x)) \)
* <p>ReLU6 is defined by: \( y = min(6,max(0, x)) \)
*
* @param arrays the input singleton {@link NDList}
* @return the singleton {@link NDList} after applying ReLU6 activation
*/
public static NDList relu6(NDList arrays) {
return new NDList(
NDArrays.minimum(6, arrays.singletonOrThrow().getNDArrayInternal().relu()));
return new NDList(relu6(arrays.singletonOrThrow()));
}

/**
Expand Down
107 changes: 107 additions & 0 deletions api/src/main/java/ai/djl/nn/core/SparseMax.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.util.stream.IntStream;

/**
* {@code SparseMax} contains a generic implementation of sparsemax function the definition of
* SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. {@code SparseMax} is a simpler
* implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do
* softmax on those max-K data, and we set all the other value as 0.
*/
public class SparseMax extends AbstractBlock {
private static final Byte VERSION = 1;

private int axis;
private int topK;

/** Creates a sparseMax activation function for the last axis and 3 elements. */
public SparseMax() {
this(-1, 3);
}

/**
* Creates a sparseMax activation function along a given axis for 3 elements.
*
* @param axis the axis to do sparseMax for
*/
public SparseMax(int axis) {
this(axis, 3);
}

/**
* Creates a sparseMax activation function along a given axis and number of elements.
*
* @param axis the axis to do sparseMax for
* @param topK hyperParameter K
*/
public SparseMax(int axis, int topK) {
super(VERSION);
this.axis = axis;
this.topK = topK;
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
// the shape of input and output are the same
return new Shape[0];
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
/*
A simple implementation of sparseMax, where we only calculate softMax with largest K data
*/
NDArray input = inputs.singletonOrThrow();
if (axis != -1) {
input = input.swapAxes(axis, -1);
}

// level should be: the max i-th is index j in input
NDArray level = input.argSort(-1, false).toType(DataType.INT64, false);
int lastDimSize = (int) input.size(input.getShape().dimension() - 1);

// maskTopK should be: the topK in input is 1 and other is zero
NDArray maskTopK =
NDArrays.add(
IntStream.range(0, topK)
.mapToObj(j -> level.get("..., {}", j).oneHot(lastDimSize))
.toArray(NDArray[]::new));

NDArray expSum =
input.exp().mul(maskTopK).sum(new int[] {-1}, true).broadcast(input.getShape());
NDArray output = input.exp().mul(maskTopK).div(expSum);

if (axis != -1) {
output = output.swapAxes(axis, -1);
}
return new NDList(output);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.SparseMax;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
Expand Down Expand Up @@ -241,4 +242,28 @@ public void testPrelu() {
}
}
}

@Test
public void testSparseMax() {
try (Model model = Model.newInstance("model")) {
model.setBlock(new SparseMax());

try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(4));
NDManager manager = trainer.getManager();
NDArray data = manager.create(new float[] {0, 1, 2, 3});
double expSum = Math.exp(1) + Math.exp(2) + Math.exp(3);
NDArray expected =
manager.create(
new float[] {
0,
(float) (Math.exp(1) / expSum),
(float) (Math.exp(2) / expSum),
(float) (Math.exp(3) / expSum)
});
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
Assertions.assertAlmostEquals(result, expected);
}
}
}
}