Skip to content

Commit

Permalink
[pytorch] Implements gammaln operator for PyTorch
Browse files Browse the repository at this point in the history
Fixes #3259
  • Loading branch information
frankfliu committed Jun 18, 2024
1 parent 04d6561 commit 4b9c9d5
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ public PtNDArray exp() {
/** {@inheritDoc} */
@Override
public NDArray gammaln() {
throw new UnsupportedOperationException("Not implemented yet.");
return JniUtils.gammaln(this);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,11 @@ public static PtNDArray exp(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle()));
}

public static PtNDArray gammaln(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchLgamma(ndArray.getHandle()));
}

public static PtNDArray log(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ native long[] torchUnique(

native long torchExp(long handle);

native long torchLgamma(long handle);

native long torchLog(long handle);

native long torchLog10(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchExp(JNIEnv*
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLgamma(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->lgamma());
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLog(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.engine.Engine;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -26,10 +26,15 @@ public class TrainTimeSeriesTest {

@Test
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.linux();
String[] args;
Engine engine = Engine.getEngine("PyTorch");
if (engine.getGpuCount() > 0) {
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
} else {
args = new String[] {"-g", "1", "-e", "5", "-b", "32"};
}

// TODO: PyTorch -- PtNDArray.gammaln not implemented
String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
float loss = result.getTrainLoss();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public NDArray evaluate(NDList labels, NDList predictions) {
NDArray lossWeights = predictions.get("loss_weights");
NDArray weightedValue =
NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike());
NDArray sumWeights = lossWeights.sum().maximum(1.);
NDArray sumWeights = lossWeights.sum().maximum(1f);
loss = weightedValue.sum().div(sumWeights);
}
return loss;
Expand Down

0 comments on commit 4b9c9d5

Please sign in to comment.