diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java index a8a183de317..451fc9676e7 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java @@ -123,7 +123,10 @@ public void zeroGradients() { NDManager systemManager = MxNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); + } } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index d090e08decb..c46671597b3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -76,7 +76,10 @@ public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); + } } } }