Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete DataManager #691

Merged
merged 1 commit into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ protected NDList forwardInternal(
return block.forward(parameterStore, inputs, training, params);
}

@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList data,
NDList labels,
PairList<String, Object> params) {
return super.forwardInternal(parameterStore, data, labels, params);
}

/** {@inheritDoc} */
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Expand Down
43 changes: 43 additions & 0 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,55 @@ public final NDList forward(
return forwardInternal(parameterStore, inputs, training, params);
}

/** {@inheritDoc} */
@Override
public NDList forward(
ParameterStore parameterStore,
NDList data,
NDList labels,
PairList<String, Object> params) {
NDManager paramsManager = parameterStore.getManager();
if (!isInitialized()) {
initialize(paramsManager, DataType.FLOAT32, data.getShapes());
}
return forwardInternal(parameterStore, data, labels, params);
}

/**
* A helper for {@link Block#forward(ParameterStore, NDList, boolean, PairList)} after
* initialization.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass
* @param params optional parameters
* @return the output of the forward pass
*/
protected abstract NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params);

/**
* A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after
* initialization.
*
* @param parameterStore the parameter store
* @param data the input data NDList
* @param labels the input labels NDList
* @param params optional parameters
* @return the output of the forward pass
* @see #forward(ParameterStore, NDList, boolean, PairList)
*/
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList data,
NDList labels,
PairList<String, Object> params) {
return forwardInternal(parameterStore, data, true, params);
}

/**
* Use this to add a child block to this block.
*
Expand Down
14 changes: 14 additions & 0 deletions api/src/main/java/ai/djl/nn/ParallelBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ protected NDList forwardInternal(
.collect(Collectors.toList()));
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList data,
NDList labels,
PairList<String, Object> params) {
return function.apply(
children.values()
.stream()
.map(block -> block.forward(parameterStore, data, labels, params))
.collect(Collectors.toList()));
}

/** {@inheritDoc} */
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Expand Down
14 changes: 14 additions & 0 deletions api/src/main/java/ai/djl/nn/SequentialBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ protected NDList forwardInternal(
return current;
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList data,
NDList labels,
PairList<String, Object> params) {
NDList current = data;
for (Block block : children.values()) {
current = block.forward(parameterStore, current, labels, params);
}
return current;
}

/** {@inheritDoc} */
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Expand Down
44 changes: 0 additions & 44 deletions api/src/main/java/ai/djl/training/DataManager.java

This file was deleted.

19 changes: 0 additions & 19 deletions api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public class DefaultTrainingConfig implements TrainingConfig {
private Optimizer optimizer;
private Device[] devices;
private Loss loss;
private DataManager dataManager;
private List<Evaluator> evaluators;
private List<TrainingListener> listeners;

Expand All @@ -50,7 +49,6 @@ public DefaultTrainingConfig(Loss loss) {
this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2);
optimizer = Adam.builder().build();
this.loss = loss;
dataManager = DataManager.DEFAULT_DATA_MANAGER;
evaluators = new ArrayList<>();
listeners = new ArrayList<>();
}
Expand Down Expand Up @@ -89,17 +87,6 @@ public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
return this;
}

/**
* Sets the {@link DataManager} to be used during training.
*
* @param dataManager the {@link DataManager} to be set
* @return this {@code DefaultTrainingConfig}
*/
public DefaultTrainingConfig optDataManager(DataManager dataManager) {
this.dataManager = dataManager;
return this;
}

/**
* Adds an {@link Evaluator} that needs to be computed during training.
*
Expand Down Expand Up @@ -149,12 +136,6 @@ public Loss getLossFunction() {
return loss;
}

/** {@inheritDoc} */
@Override
public DataManager getDataManager() {
return dataManager;
}

/** {@inheritDoc} */
@Override
public List<Evaluator> getEvaluators() {
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/training/EasyTrain.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public static void trainBatch(Trainer trainer, Batch batch) {
new BatchData(batch, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());
try (GradientCollector collector = trainer.newGradientCollector()) {
for (Batch split : splits) {
NDList data = trainer.getDataManager().getData(split);
NDList labels = trainer.getDataManager().getLabels(split);
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.forward(data, labels);
long time = System.nanoTime();
NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
Expand Down Expand Up @@ -123,8 +123,8 @@ public static void validateBatch(Trainer trainer, Batch batch) {
new BatchData(batch, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());

for (Batch split : splits) {
NDList data = trainer.getDataManager().getData(split);
NDList labels = trainer.getDataManager().getLabels(split);
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.evaluate(data);
batchData.getLabels().put(labels.get(0).getDevice(), labels);
batchData.getPredictions().put(preds.get(0).getDevice(), preds);
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/training/ParallelTrain.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ public void trainBatch(Trainer trainer, Batch batch) {
executor.submit(
() -> {
try (GradientCollector collector = trainer.newGradientCollector()) {
NDList data = trainer.getDataManager().getData(split);
NDList labels = trainer.getDataManager().getLabels(split);
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.forward(data);
long time = System.nanoTime();
NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
Expand Down
11 changes: 0 additions & 11 deletions api/src/main/java/ai/djl/training/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public class Trainer implements AutoCloseable {
private ParameterStore parameterStore;
private List<Evaluator> evaluators;
private Loss loss;
private DataManager dataManager;

private boolean gradientsChecked;

Expand All @@ -89,7 +88,6 @@ public Trainer(Model model, TrainingConfig trainingConfig) {
manager.setName("trainer");
devices = trainingConfig.getDevices();
loss = trainingConfig.getLossFunction();
dataManager = trainingConfig.getDataManager();
Objects.requireNonNull(loss, "You must specify a loss for the trainer");
evaluators = new ArrayList<>(trainingConfig.getEvaluators());
evaluators.add(loss); // track loss as an evaluator by default
Expand Down Expand Up @@ -240,15 +238,6 @@ public Model getModel() {
return model;
}

/**
* Returns the {@link DataManager}.
*
* @return the {@link DataManager}
*/
public DataManager getDataManager() {
return dataManager;
}

/**
* Gets all {@link Evaluator}s.
*
Expand Down
7 changes: 0 additions & 7 deletions api/src/main/java/ai/djl/training/TrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ public interface TrainingConfig {
*/
Loss getLossFunction();

/**
* Gets the {@link DataManager} that computes data and labels from the output of dataset.
*
* @return a {@link DataManager}
*/
DataManager getDataManager();

/**
* Returns the list of {@link Evaluator}s that should be computed during training.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DataManager;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
Expand Down Expand Up @@ -111,13 +109,13 @@ public static TrainingResult runExample(String[] args)
getDataset(Dataset.Usage.TRAIN, executorService, arguments);
StanfordMovieReview validateSet =
getDataset(Dataset.Usage.TEST, executorService, arguments);
model.setBlock(getModel());
model.setBlock(getModel(modelZooTextEmbedding));

// setup training configuration
DefaultTrainingConfig config = setupTrainingConfig(arguments, modelZooTextEmbedding);
DefaultTrainingConfig config = setupTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(arguments.getBatchSize(), 10, 50);
Shape encoderInputShape = new Shape(arguments.getBatchSize(), 10);

// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
Expand All @@ -143,8 +141,16 @@ public static TrainingResult runExample(String[] args)
}
}

private static Block getModel() {
private static Block getModel(ModelZooTextEmbedding embedding) {
return new SequentialBlock()
.addSingleton(
a -> {
try {
return embedding.embedText(a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem to put it here is: I would like to freeze this block and not use for training. If remove DataManager, I will not be able to freeze the block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be frozen because the ModelZooTextEmbedding freezes (it uses a separate predictor). In a different case, you should be able to just call embedding.freeze() to freeze just the embedding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you verify if PyTorch engine will work in this case? I am not sure if the training work are correctly handled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no word embedding for pt in the model zoo so I can't test it. But, is there a reason that PT may not work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work by default. Just concerned if the NDArray inside will be impacted. In the DistiledBert training for Amazon Dataset, we throw the embedding layer into a lambda function to avoid parameter collections

} catch (EmbeddingException e) {
throw new IllegalStateException(e);
}
})
.add(
LSTM.builder()
.setNumLayers(2)
Expand All @@ -163,8 +169,7 @@ private static Block getModel() {
.add(Linear.builder().setUnits(2).build());
}

public static DefaultTrainingConfig setupTrainingConfig(
Arguments arguments, ModelZooTextEmbedding embedding) {
public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
String outputDir = arguments.getOutputDir();
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
Expand All @@ -177,7 +182,6 @@ public static DefaultTrainingConfig setupTrainingConfig(
});

return new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss())
.optDataManager(new EmbeddingDataManager(embedding))
.addEvaluator(new Accuracy())
.optDevices(Device.getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
Expand Down Expand Up @@ -243,22 +247,4 @@ public Batchifier getBatchifier() {
.build();
}
}

private static final class EmbeddingDataManager extends DataManager {

private ModelZooTextEmbedding embedding;

public EmbeddingDataManager(ModelZooTextEmbedding embedding) {
this.embedding = embedding;
}

@Override
public NDList getData(Batch batch) {
try {
return new NDList(embedding.embedText(batch.getData().head()));
} catch (EmbeddingException e) {
throw new IllegalArgumentException(e.getMessage(), e);
}
}
}
}
Loading