Skip to content

Commit

Permalink
sparseMax finished
Browse files Browse the repository at this point in the history
  • Loading branch information
warthecatalyst committed Sep 20, 2022
1 parent cb9f1f0 commit 7c44d33
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/nn/Activation.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public static NDList relu(NDList arrays) {
* @param array the input singleton {@link NDArray}
* @return the {@link NDArray} after applying ReLU6 activation
*/
public static NDArray relu6(NDArray array){
return NDArrays.minimum(6,array.getNDArrayInternal().relu());
public static NDArray relu6(NDArray array) {
return NDArrays.minimum(6, array.getNDArrayInternal().relu());
}

/**
Expand Down
84 changes: 57 additions & 27 deletions api/src/main/java/ai/djl/nn/core/SparseMax.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
Expand All @@ -13,63 +25,81 @@
import java.util.stream.IntStream;

/**
* {@code SparseMax} contains a generic implementation of sparsemax function
* the definition of SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf.
* {@code SparseMax} is a simpler implementation of sparseMax function,
* where we set K as a hyperParameter(default 3). We only do softmax on those max-K data,
* and we set all the other value as 0.
* {@code SparseMax} contains a generic implementation of sparsemax function the definition of
* SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. {@code SparseMax} is a simpler
* implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do
* softmax on those max-K data, and we set all the other value as 0.
*/
public class SparseMax extends AbstractBlock {
private static final Byte VERSION = 1;

private int axis;
private int topK;
private NDManager manager;

public SparseMax(){
this(-1,3);
/** creates a sparseMax activation function with default parameters. */
public SparseMax() {
this(-1, 3);
}

public SparseMax(int axis){
this(axis,3);
/**
* creates a sparseMax activation function along a given axis.
*
* @param axis the axis to do sparseMax for
*/
public SparseMax(int axis) {
this(axis, 3);
}

public SparseMax(int axis,int K){
/**
* creates a sparseMax activation function along a given axis with hyperParameter K.
*
* @param axis the axis to do sparseMax for
* @param topK hyperParameter K
*/
public SparseMax(int axis, int topK) {
super(VERSION);
this.axis = axis;
this.topK = K;
this.topK = topK;
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
//the shape of input and output are the same
// the shape of input and output are the same
return new Shape[0];
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
/*
A simple implementation of sparseMax, where we only calculate softMax with largest K data
*/
manager = inputs.getManager();
NDArray input = inputs.singletonOrThrow();
if(axis!=-1){
input = input.swapAxes(axis,-1);
if (axis != -1) {
input = input.swapAxes(axis, -1);
}

//level should be: the max i-th is index j in input
NDArray level = input.argSort(-1,false).toType(DataType.INT64,false);
int lastDimSize = (int)input.size(input.getShape().dimension()-1);
// level should be: the max i-th is index j in input
NDArray level = input.argSort(-1, false).toType(DataType.INT64, false);
int lastDimSize = (int) input.size(input.getShape().dimension() - 1);

//maskTopK should be: the topK in input is 1 and other is zero
NDArray maskTopK = NDArrays.add(IntStream.range(0,topK).mapToObj(
j-> level.get("..., {}",j).oneHot(lastDimSize)
).toArray(NDArray[]::new));
// maskTopK should be: the topK in input is 1 and other is zero
NDArray maskTopK =
NDArrays.add(
IntStream.range(0, topK)
.mapToObj(j -> level.get("..., {}", j).oneHot(lastDimSize))
.toArray(NDArray[]::new));

NDArray expSum = input.exp().mul(maskTopK).sum(new int[]{-1},true).broadcast(input.getShape());
NDArray expSum =
input.exp().mul(maskTopK).sum(new int[] {-1}, true).broadcast(input.getShape());
NDArray output = input.exp().mul(maskTopK).div(expSum);

if(axis!=-1) {
if (axis != -1) {
output = output.swapAxes(axis, -1);
}
return new NDList(output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.SparseMax;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
Expand Down Expand Up @@ -241,4 +242,28 @@ public void testPrelu() {
}
}
}

@Test
public void testSparseMax() {
try (Model model = Model.newInstance("model")) {
model.setBlock(new SparseMax());

try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(4));
NDManager manager = trainer.getManager();
NDArray data = manager.create(new float[] {0, 1, 2, 3});
double expSum = Math.exp(1) + Math.exp(2) + Math.exp(3);
NDArray expected =
manager.create(
new float[] {
0,
(float) (Math.exp(1) / expSum),
(float) (Math.exp(2) / expSum),
(float) (Math.exp(3) / expSum)
});
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
Assertions.assertAlmostEquals(result, expected);
}
}
}
}

0 comments on commit 7c44d33

Please sign in to comment.