You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
i'm trying to load a previous saved DJL model and continue training to reduce the loss (improve the model). i'm using pytorch as engine.
probably i'm doing something wrong. if there any documentation about tell me.
these are my steps:
create the model, train, predict and save = ok, all works fine.
load the previous saved model, train, predict and save =
using pytorch as engine: i can load the previous saved model, continue training it but i cant save the new trained model i get the exception "java.lang.IllegalArgumentException: key value size mismatch."**
using mxnet as engine: i can load the previous saved model, but i cant continue the training i get this exception: "ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: !AGInfo: :IsNone(*i): Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward."
this is my entire code it is a model to solve the Xor problem (xor gate):
to run this code you need djl 0.25 and apache commons lang.
this is a single file code which belongs to the package "main" (no package).
in there you will see in the main function the call for three methods. the method "train" train the model for the first time.
the method "continueTraining" train the model again. and there a third method which trains once time more.
these methods are basically copies just changing few things.
and there also the class for the the creation of a "RandomAccessDatabase" and a builder required by their interface
and a "Translator" which takes a "Float[]" with 2 elements 0 or 1 and returns a Integer 0 or 1.
the models created will be saved in the current path.
xorModel.java
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Progress;
import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class XorModel {
public final static int nEpochsEach = 100;
public final static List<Map<String, Object>> data = new ArrayList<Map<String, Object>>();
static{
data.add( Map.of("problem", List.of(0,0),"solution",0) );
data.add( Map.of("problem", List.of(0,1), "solution", 1) );
data.add( Map.of("problem", List.of(1,0), "solution", 1) );
data.add( Map.of("problem", List.of(1,1), "solution", 0) );
}
public static class MyDataset extends RandomAccessDataset {
public final List<Map<String, Object>> data;
public MyDataset(BaseBuilder<?> builder, List<Map<String, Object>> data) {
super(builder);
this.data = data;
}
@Override
public Record get(NDManager ndManager, long l) throws IOException {
Map<String, Object> event = this.data.get((int) l);
NDArray problem = ndManager.create(
this.floatObjectToPrimitive(
((List<Integer>) event.get("problem")).stream()
.map(x->(float) x)
.toArray(Float[]::new)
)
);
NDArray solution = ndManager.create( new float[]{ (float) ((int) event.get("solution")) } );
return new Record(new NDList(problem), new NDList(solution));
}
private float[] floatObjectToPrimitive(Float[] x){
float[] result = new float[x.length];
for ( int i = 0; i < x.length; i++ ){
result[i] = x[i];
}
return result;
}
@Override
protected long availableSize() {
return this.data.size();
}
@Override
public void prepare() throws IOException, TranslateException {
super.prepare();
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
}
}
public static class MyDatasetBuilder extends RandomAccessDataset.BaseBuilder<XorModel.MyDatasetBuilder> {
@Override
protected XorModel.MyDatasetBuilder self() {
return this;
}
public Dataset build(List<Map<String, Object>> data){
return new XorModel.MyDataset(this, data);
}
}
public static class MyTranslator implements ai.djl.translate.Translator<Float[], Integer> {
public Integer processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
float result = ndList.get(0).getFloat(0);
if ( result <= 0.5 ){
return 0;
}else{
return 1;
}
}
@Override
public NDList processInput(TranslatorContext translatorContext, Float[] floats) throws Exception {
NDManager manager = translatorContext.getNDManager();
NDArray data = manager.create( ArrayUtils.toPrimitive(floats) );
return new NDList( data );
}
}
public static void train(String ... args) throws TranslateException, IOException, MalformedModelException, ModelNotFoundException {
// obtain the data for the model, the parameter passed to the build method is the data
// for the builder. the build takes a map as input. the input is a static Map.
Dataset myDataset = new MyDatasetBuilder().setSampling(1, true).build(XorModel.data);
Model model = Model.newInstance("xorModel");
model.setBlock( new Mlp(2, 1, new int[]{10,10,10}) );
TrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.l2Loss())
.optOptimizer(Optimizer.adam().build())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(trainingConfig);
// manual training, i'm training manually because i'll need the manual training in my
// future project. so manual training is a requirement for hour.
// i still need to learn tricks about low-level learning.
GradientCollector gc = trainer.newGradientCollector();
int nEpochs = XorModel.nEpochsEach;
for ( int i = 0; i < nEpochs; i++ ) {
Iterable<Batch> batches = trainer.iterateDataset(myDataset);
long estimatedNBatches = batches.spliterator().estimateSize();
for ( Batch batch : batches ) {
NDList problem = batch.getData();
NDList solution = batch.getLabels();
//System.out.println(batch.getProgress());
NDList prediction = trainer.forward( problem, solution );
NDArray loss = trainer.getLoss().evaluate( solution, prediction );
gc.backward( loss );
trainer.step();
System.out.print(
String.format(
"\repoch: %s; loss: %s",
i+"/"+nEpochs,
loss.mean().getFloat()
)
);
batch.close();
}
}
gc.close();
// to exit of the training progress information.
System.out.println("\n");
ZooModel<Float[], Integer> completeModel = new ZooModel<>(model, new XorModel.MyTranslator());
List<Float[]> dataToPredict = List.of( new Float[]{0f,1f}, new Float[]{1f,0f}, new Float[]{0f,0f}, new Float[]{1f,1f} );
List<Integer> predictionsDone = completeModel.newPredictor().batchPredict(dataToPredict);
// is expected the correct answer: [1,1,0,0]
System.out.println("predictions: " + predictionsDone);
// here prints "[1,1,0,0]" the prediction is correct.
// i'll save the model then.
model.save(Paths.get("xorModel"), null);
trainer.close();
model.close();
}
public static void continueTraining(String ... args) throws Exception {
// obtain the data for the model, the parameter passed to the build method is the data
// for the builder. the build takes a map as input. the input is a static Map.
Dataset myDataset = new MyDatasetBuilder().setSampling(1, true).build(XorModel.data);
Model model = Model.newInstance("xorModel");
model.setBlock( new Mlp(2, 1, new int[]{10,10,10}) );
// here is my problem, i load the model but i cant continue training it.
model.load(Paths.get("xorModel"));
TrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.l2Loss())
.optOptimizer(Optimizer.adam().build())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(trainingConfig);
// manual training, i'm training manually because i'll need the manual training in my
// future project. so manual training is a requirement for hour.
// i still need to learn tricks about low-level learning.
GradientCollector gc = trainer.newGradientCollector();
int nEpochs = XorModel.nEpochsEach;
for ( int i = 0; i < nEpochs; i++ ) {
Iterable<Batch> batches = trainer.iterateDataset(myDataset);
long estimatedNBatches = batches.spliterator().estimateSize();
for ( Batch batch : batches ) {
NDList problem = batch.getData();
NDList solution = batch.getLabels();
//System.out.println(batch.getProgress());
NDList prediction = trainer.forward( problem, solution );
NDArray loss = trainer.getLoss().evaluate( solution, prediction );
gc.backward( loss );
trainer.step();
System.out.print(
String.format(
"\repoch: %s; loss: %s",
i+"/"+nEpochs,
loss.mean().getFloat()
)
);
batch.close();
}
}
gc.close();
// to exit of the training progress information.
System.out.println("\n");
ZooModel<Float[], Integer> completeModel = new ZooModel<>(model, new XorModel.MyTranslator());
List<Float[]> dataToPredict = List.of( new Float[]{0f,1f}, new Float[]{1f,0f}, new Float[]{0f,0f}, new Float[]{1f,1f} );
List<Integer> predictionsDone = completeModel.newPredictor().batchPredict(dataToPredict);
// is expected the correct answer: [1,1,0,0]
System.out.println("predictions: " + predictionsDone);
// here prints "[1,1,0,0]" the prediction is correct.
// i'll save the model then.
model.save(Paths.get("xorModel2"), "xorModel2");
trainer.close();
model.close();
}
public static void continueTrainingThird(String ... args) throws Exception {
// obtain the data for the model, the parameter passed to the build method is the data
// for the builder. the build takes a map as input. the input is a static Map.
Dataset myDataset = new MyDatasetBuilder().setSampling(1, true).build(XorModel.data);
Model model = Model.newInstance("xorModel2");
model.setBlock( new Mlp(2, 1, new int[]{10,10,10}) );
// here is my problem, i load the model but i cant continue training it.
model.load(Paths.get("xorModel2"));
TrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.l2Loss())
.optOptimizer(Optimizer.adam().build())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(trainingConfig);
// manual training, i'm training manually because i'll need the manual training in my
// future project. so manual training is a requirement for hour.
// i still need to learn tricks about low-level learning.
GradientCollector gc = trainer.newGradientCollector();
int nEpochs = XorModel.nEpochsEach;
for ( int i = 0; i < nEpochs; i++ ) {
Iterable<Batch> batches = trainer.iterateDataset(myDataset);
long estimatedNBatches = batches.spliterator().estimateSize();
for ( Batch batch : batches ) {
NDList problem = batch.getData();
NDList solution = batch.getLabels();
//System.out.println(batch.getProgress());
NDList prediction = trainer.forward( problem, solution );
NDArray loss = trainer.getLoss().evaluate( solution, prediction );
gc.backward( loss );
trainer.step();
System.out.print(
String.format(
"\repoch: %s; loss: %s",
i+"/"+nEpochs,
loss.mean().getFloat()
)
);
batch.close();
}
}
gc.close();
// to exit of the training progress information.
System.out.println("\n");
ZooModel<Float[], Integer> completeModel = new ZooModel<>(model, new XorModel.MyTranslator());
List<Float[]> dataToPredict = List.of( new Float[]{0f,1f}, new Float[]{1f,0f}, new Float[]{0f,0f}, new Float[]{1f,1f} );
List<Integer> predictionsDone = completeModel.newPredictor().batchPredict(dataToPredict);
// is expected the correct answer: [1,1,0,0]
System.out.println("predictions: " + predictionsDone);
// here prints "[1,1,0,0]" the prediction is correct.
// i'll save the model then.
model.save(Paths.get("xorModel3"), "xorModel3");
trainer.close();
model.close();
}
public static void main(String ... args) throws Exception{
// the train method will train the model for the first time
XorModel.train();
System.out.println("model already trained once");
// now that the model was trained for the first time will train again using
// the previous saved model.
System.out.println("training once more...");
XorModel.continueTraining();
// train once time more using the previous saved again model.
System.out.println("training for the third time...");
XorModel.continueTrainingThird();
}
}
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
i'm trying to load a previous saved DJL model and continue training to reduce the loss (improve the model). i'm using pytorch as engine.
probably i'm doing something wrong. if there any documentation about tell me.
these are my steps:
this is my entire code it is a model to solve the Xor problem (xor gate):
to run this code you need djl 0.25 and apache commons lang.
this is a single file code which belongs to the package "main" (no package).
in there you will see in the main function the call for three methods. the method "train" train the model for the first time.
the method "continueTraining" train the model again. and there a third method which trains once time more.
these methods are basically copies just changing few things.
and there also the class for the the creation of a "RandomAccessDatabase" and a builder required by their interface
and a "Translator" which takes a "Float[]" with 2 elements 0 or 1 and returns a Integer 0 or 1.
the models created will be saved in the current path.
xorModel.java
Beta Was this translation helpful? Give feedback.
All reactions