Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test accumulating gradient collector #2111

Merged
merged 1 commit into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/training/GradientCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
* performed within the try-with-resources are recorded and the variables marked. When {@link
* #backward(NDArray) backward function} is called, gradients are collected w.r.t previously marked
* variables.
*
* <p>The typical behavior is to open up a gradient collector during each batch and close it during
* the end of the batch. In this way, the gradient is reset between batches. If the gradient
* collector is left open for multiple calls to backwards, the gradients collected are accumulated
* and added together.
*
* <p>Due to limitations in most engines, the gradient collectors are global. This means that only
* one can be used at a time. If multiple are opened, an error will be thrown.
*/
public interface GradientCollector extends AutoCloseable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.GradientCollector;

import java.util.concurrent.atomic.AtomicBoolean;

/** {@code PtGradientCollector} is the PyTorch implementation of {@link GradientCollector}. */
public final class PtGradientCollector implements GradientCollector {

private boolean gradModel;
private static AtomicBoolean isCollecting = new AtomicBoolean();

/** Constructs a new {@code PtGradientCollector} instance. */
public PtGradientCollector() {
gradModel = JniUtils.isGradMode();
JniUtils.setGradMode(true);

boolean wasCollecting = isCollecting.getAndSet(true);
if (wasCollecting) {
throw new IllegalStateException(
"A PtGradientCollector is already collecting. Only one can be collecting at a"
+ " time");
}

zeroGradients();
}

Expand Down Expand Up @@ -73,6 +84,7 @@ public void close() {
if (!gradModel) {
JniUtils.setGradMode(false);
}
isCollecting.set(false);
// TODO: do some clean up if necessary
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
public class TrainAirfoilWithTabNetTest {
@Test
public void testTrainAirfoilWithTabNet() throws TranslateException, IOException {
TestRequirements.nightly();
TestRequirements.engine("MXNet", "PyTorch");
String[] args = new String[] {"-g", "1", "-e", "20", "-b", "32"};
TrainingResult result = TrainAirfoilWithTabNet.runExample(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,46 @@ public void testClearGradients() {
}
}

/** Tests that the gradients do accumulate within the same gradient collector. */
@Test
public void testAccumulateGradients() {
// TODO: MXNet support for accumulating gradients does not currently work
TestRequirements.notEngine("MXNet");
try (NDManager manager = NDManager.newBaseManager()) {
NDArray a = manager.create(0.0f);
a.setRequiresGradient(true);

try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
for (int i = 1; i <= 3; i++) {
NDArray b = a.mul(2);
gc.backward(b);
Assert.assertEquals(a.getGradient().getFloat(), 2.0f * i);
}
}
}
}

/**
* Ensures that a gradient collector does not start when one is already created because they are
* global.
*/
@Test
@SuppressWarnings({"try", "PMD.UseTryWithResources"})
public void testMultipleGradientCollectors() {
Assert.assertThrows(
() -> {
GradientCollector gc2 = null;
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
gc2 = Engine.getInstance().newGradientCollector();
gc2.close();
} finally {
if (gc2 != null) {
gc2.close();
}
}
});
}

@Test
public void testFreezeParameters() {
try (Model model = Model.newInstance("model")) {
Expand Down