Skip to content

Commit

Permalink
local model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Nov 4, 2022
1 parent 003a864 commit 532f874
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -77,47 +76,29 @@ public static Map<String, Float> predict()
throws IOException, TranslateException, ModelException {
NDManager manager = NDManager.newBaseManager(null, "MXNet");

// If users want to use local repository, then the dataset can be loaded as follows
// Repository repository = Repository.newInstance("local_dataset",
// Paths.get("root/m5-forecasting-accuracy"));
// M5Dataset dataset = M5Dataset.builder()
// .optRepository(repository)
// .build();

// To use local dataset, users can load data as follows
// Repository repository = Repository.newInstance("local_dataset",
// Paths.get("rootPath/m5-forecasting-accuracy"));
// Then set `Builder.optRepository(repository)`
M5Dataset dataset = M5Dataset.builder().setManager(manager).build();

// If users want to use local model, do the following:
Path modelPath = Paths.get("/Users/fenkexin/Downloads/m5forecast.zip");
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelPath(modelPath)
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", predictionLength)
.optArgument("freq", "D")
.optArgument("use_feat_dynamic_real", "false")
.optArgument("use_feat_static_cat", "false")
.optArgument("use_feat_static_real", "false")
.optProgress(new ProgressBar())
.build();

// String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
// int predictionLength = 4;
// Criteria<TimeSeriesData, Forecast> criteria =
// Criteria.builder()
// .setTypes(TimeSeriesData.class, Forecast.class)
// .optModelUrls(modelUrl)
// .optEngine("MXNet")
// .optTranslatorFactory(new DeferredTranslatorFactory())
// .optArgument("prediction_length", predictionLength)
// .optArgument("freq", "W")
// .optArgument("use_feat_dynamic_real", "false")
// .optArgument("use_feat_static_cat", "false")
// .optArgument("use_feat_static_real", "false")
// .optProgress(new ProgressBar())
// .build();
// The modelUrl can be replaced by local model path. E.g.,
// String modelUrl = "rootPath/m5forecast.zip";
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls(modelUrl)
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", predictionLength)
.optArgument("freq", "W")
.optArgument("use_feat_dynamic_real", "false")
.optArgument("use_feat_static_cat", "false")
.optArgument("use_feat_static_real", "false")
.optProgress(new ProgressBar())
.build();

try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,10 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize)
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};

// If users want to use local repository, then the dataset can be loaded as follows
// Repository repository = Repository.newInstance("banana", Paths.get("local_data_root/banana/train"));
// FruitsFreshAndRotten dataset =
// FruitsFreshAndRotten.builder()
// .optRepository(repository)
// ...
// To use local dataset, users can load it as follows
// Repository repository = Repository.newInstance("banana",
// Paths.get("local_data_root/banana/train"));
// Then set `Builder.optRepository(repository)`
FruitsFreshAndRotten dataset =
FruitsFreshAndRotten.builder()
.optUsage(usage)
Expand Down

0 comments on commit 532f874

Please sign in to comment.