diff --git a/api/src/main/java/ai/djl/nn/Activation.java b/api/src/main/java/ai/djl/nn/Activation.java index 2d9e0f44a9ec..5143eb5d8ddb 100644 --- a/api/src/main/java/ai/djl/nn/Activation.java +++ b/api/src/main/java/ai/djl/nn/Activation.java @@ -66,8 +66,8 @@ public static NDList relu(NDList arrays) { * @param array the input singleton {@link NDArray} * @return the {@link NDArray} after applying ReLU6 activation */ - public static NDArray relu6(NDArray array){ - return NDArrays.minimum(6,array.getNDArrayInternal().relu()); + public static NDArray relu6(NDArray array) { + return NDArrays.minimum(6, array.getNDArrayInternal().relu()); } /** diff --git a/api/src/main/java/ai/djl/nn/core/SparseMax.java b/api/src/main/java/ai/djl/nn/core/SparseMax.java index 784c0a2a4979..d9e5e45557fb 100644 --- a/api/src/main/java/ai/djl/nn/core/SparseMax.java +++ b/api/src/main/java/ai/djl/nn/core/SparseMax.java @@ -1,9 +1,21 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + package ai.djl.nn.core; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; @@ -13,63 +25,81 @@ import java.util.stream.IntStream; /** - * {@code SparseMax} contains a generic implementation of sparsemax function - * the definition of SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. - * {@code SparseMax} is a simpler implementation of sparseMax function, - * where we set K as a hyperParameter(default 3). We only do softmax on those max-K data, - * and we set all the other value as 0. + * {@code SparseMax} contains a generic implementation of sparsemax function the definition of + * SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. {@code SparseMax} is a simpler + * implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do + * softmax on those max-K data, and we set all the other value as 0. */ public class SparseMax extends AbstractBlock { private static final Byte VERSION = 1; private int axis; private int topK; - private NDManager manager; - public SparseMax(){ - this(-1,3); + /** creates a sparseMax activation function with default parameters. */ + public SparseMax() { + this(-1, 3); } - public SparseMax(int axis){ - this(axis,3); + /** + * creates a sparseMax activation function along a given axis. + * + * @param axis the axis to do sparseMax for + */ + public SparseMax(int axis) { + this(axis, 3); } - public SparseMax(int axis,int K){ + /** + * creates a sparseMax activation function along a given axis with hyperParameter K. + * + * @param axis the axis to do sparseMax for + * @param topK hyperParameter K + */ + public SparseMax(int axis, int topK) { super(VERSION); this.axis = axis; - this.topK = K; + this.topK = topK; } + /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { - //the shape of input and output are the same + // the shape of input and output are the same return new Shape[0]; } + /** {@inheritDoc} */ @Override - protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { /* A simple implementation of sparseMax, where we only calculate softMax with largest K data */ - manager = inputs.getManager(); NDArray input = inputs.singletonOrThrow(); - if(axis!=-1){ - input = input.swapAxes(axis,-1); + if (axis != -1) { + input = input.swapAxes(axis, -1); } - //level should be: the max i-th is index j in input - NDArray level = input.argSort(-1,false).toType(DataType.INT64,false); - int lastDimSize = (int)input.size(input.getShape().dimension()-1); + // level should be: the max i-th is index j in input + NDArray level = input.argSort(-1, false).toType(DataType.INT64, false); + int lastDimSize = (int) input.size(input.getShape().dimension() - 1); - //maskTopK should be: the topK in input is 1 and other is zero - NDArray maskTopK = NDArrays.add(IntStream.range(0,topK).mapToObj( - j-> level.get("..., {}",j).oneHot(lastDimSize) - ).toArray(NDArray[]::new)); + // maskTopK should be: the topK in input is 1 and other is zero + NDArray maskTopK = + NDArrays.add( + IntStream.range(0, topK) + .mapToObj(j -> level.get("..., {}", j).oneHot(lastDimSize)) + .toArray(NDArray[]::new)); - NDArray expSum = input.exp().mul(maskTopK).sum(new int[]{-1},true).broadcast(input.getShape()); + NDArray expSum = + input.exp().mul(maskTopK).sum(new int[] {-1}, true).broadcast(input.getShape()); NDArray output = input.exp().mul(maskTopK).div(expSum); - if(axis!=-1) { + if (axis != -1) { output = output.swapAxes(axis, -1); } return new NDList(output); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java index 1b30d4b541e7..8ff3aa36fd3f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.Activation; import ai.djl.nn.Parameter; +import ai.djl.nn.core.SparseMax; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -241,4 +242,28 @@ public void testPrelu() { } } } + + @Test + public void testSparseMax() { + try (Model model = Model.newInstance("model")) { + model.setBlock(new SparseMax()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(4)); + NDManager manager = trainer.getManager(); + NDArray data = manager.create(new float[] {0, 1, 2, 3}); + double expSum = Math.exp(1) + Math.exp(2) + Math.exp(3); + NDArray expected = + manager.create( + new float[] { + 0, + (float) (Math.exp(1) / expSum), + (float) (Math.exp(2) / expSum), + (float) (Math.exp(3) / expSum) + }); + NDArray result = trainer.forward(new NDList(data)).singletonOrThrow(); + Assertions.assertAlmostEquals(result, expected); + } + } + } }