Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds documentation about the new protobuf serialization format and updates the helper programs to use it #279

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ public Evaluator<Event, AnomalyEvaluation> getEvaluator() {
return evaluator;
}

@Override
public Class<Event> getTypeWitness() {
return Event.class;
}

@Override
public int hashCode() {
return "AnomalyFactory".hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public Evaluator<Label,LabelEvaluation> getEvaluator() {
return evaluator;
}

@Override
public Class<Label> getTypeWitness() {
return Label.class;
}

@Override
public int hashCode() {
return "LabelFactory".hashCode();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,11 +29,10 @@
import org.tribuo.util.Util;

import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.logging.Logger;

Expand Down Expand Up @@ -78,6 +77,9 @@ public String getOptionsDescription() {
*/
@Option(charName = 't', longName = "trainer-name", usage = "Name of the trainer in the configuration file.")
public SequenceTrainer<Label> trainer;

@Option(charName = 'p', longName = "protobuf-model", usage = "Load the model from a protobuf. Optional")
public boolean protobufFormat;
}

/**
Expand Down Expand Up @@ -111,19 +113,31 @@ public static void main(String[] args) throws ClassNotFoundException, IOExceptio
break;
default:
if ((o.trainDataset != null) && (o.testDataset != null)) {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
train = tmpTrain;
if (o.protobufFormat) {
logger.info("Loading protobuf format training data from " + o.trainDataset);
SequenceDataset<?> tmpTrain = SequenceDataset.deserializeFromFile(o.trainDataset);
train = SequenceDataset.castDataset(tmpTrain, Label.class);
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
test = tmpTest;
logger.info("Loading protobuf format testing data from " + o.testDataset);
SequenceDataset<?> tmpTest = SequenceDataset.deserializeFromFile(o.testDataset);
test = SequenceDataset.castDataset(tmpTest, Label.class);
logger.info(String.format("Loaded %d testing examples", test.size()));
} else {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.trainDataset)));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(o.testDataset)))) {
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
train = tmpTrain;
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
test = tmpTest;
logger.info(String.format("Loaded %d testing examples", test.size()));
}
}
} else {
logger.warning("Unknown dataset " + o.datasetName);
Expand All @@ -148,7 +162,7 @@ public static void main(String[] args) throws ClassNotFoundException, IOExceptio
System.out.println(evaluation.getConfusionMatrix().toString());

if (o.outputPath != null) {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(o.outputPath))) {
oos.writeObject(model);
logger.info("Serialized model to file: " + o.outputPath);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,15 +32,14 @@
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.data.DataOptions;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -72,6 +71,9 @@ public String getOptionsDescription() {
*/
@Option(charName = 'd', longName = "output-directory", usage = "Directory to write out the models and test reports.")
public File directory;

@Option(longName = "write-protobuf-models", usage = "Write out models in protobuf format.")
public boolean protobuf;
}

/**
Expand Down Expand Up @@ -124,8 +126,12 @@ public static void main(String[] args) throws IOException {
logger.info("Found two trainers with the name " + name);
}
String outputPath = o.directory.toString()+"/"+name;
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath+".model"))) {
oos.writeObject(curModel);
if (o.protobuf) {
curModel.serializeToFile(Paths.get(outputPath + ".model"));
} else {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath + ".model"))) {
oos.writeObject(curModel);
}
}
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath+".output"), StandardCharsets.UTF_8))) {
writer.println("Model = " + name);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -116,6 +116,9 @@ public String getOptionsDescription() {
*/
@Option(charName = 'v', longName = "testing-file", usage = "Path to the testing file.")
public Path testingPath;

@Option(longName = "read-protobuf-model", usage = "Load the model in protobuf format.")
public boolean protobufModel;
}

/**
Expand All @@ -129,16 +132,17 @@ public static Pair<Model<Label>,Dataset<Label>> load(ConfigurableTestOptions o)
Path modelPath = o.modelPath;
Path datasetPath = o.testingPath;
logger.info(String.format("Loading model from %s", modelPath));
Model<Label> model;
try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) {
model = (Model<Label>) mois.readObject();
boolean valid = model.validate(Label.class);
if (!valid) {
throw new ClassCastException("Failed to cast deserialised Model to Model<Label>");
Model<?> tmpModel;
if (o.protobufModel) {
tmpModel = Model.deserializeFromFile(modelPath);
} else {
try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) {
tmpModel = (Model<?>) mois.readObject();
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised model", e);
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised model", e);
}
Model<Label> model = tmpModel.castModel(Label.class);
logger.info(String.format("Loading data from %s", datasetPath));
Dataset<Label> test;
switch (o.inputFormat) {
Expand All @@ -154,6 +158,18 @@ public static Pair<Model<Label>,Dataset<Label>> load(ConfigurableTestOptions o)
throw new IllegalArgumentException("Unknown class in serialised dataset", e);
}
break;
case SERIALIZED_PROTOBUF:
//
// Load Tribuo protobuf serialised datasets.
Dataset<?> tmp = Dataset.deserializeFromFile(datasetPath);
if (tmp.validate(Label.class)) {
test = Dataset.castDataset(tmp, Label.class);
test = ImmutableDataset.copyDataset(test,model.getFeatureIDMap(),model.getOutputIDInfo());
logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
} else {
throw new IllegalArgumentException("Invalid test dataset type, expected Label.class");
}
break;
case LIBSVM:
//
// Load the libsvm text-based data format.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -117,26 +117,36 @@ public void startShell() {
* @param path The path to load the model from.
* @return A status message.
*/
@Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter")
public String loadModel(CommandInterpreter ci, File path) {
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
@SuppressWarnings("unchecked") // deserialising generically typed model.
Model<Label> m = (Model<Label>) ois.readObject();
model = m;
} catch (ClassNotFoundException e) {
logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e);
return "Failed to load model";
} catch (FileNotFoundException e) {
logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e);
return "Failed to load model";
} catch (IOException e) {
logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e);
return "Failed to load model";
@Command(usage = "<filename> <load-protobuf> - Load a model from disk.", completers="fileCompleter")
public String loadModel(CommandInterpreter ci, File path, boolean protobuf) {
String output = "Failed to load model";
if (protobuf) {
try {
Model<?> tmpModel = Model.deserializeFromFile(path.toPath());
model = tmpModel.castModel(Label.class);
output = "Loaded model from path " + path.getAbsolutePath();
} catch (IllegalStateException e) {
logger.log(Level.SEVERE, "Failed to deserialize protobuf when reading from file " + path.getAbsolutePath(), e);
} catch (IOException e) {
logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
}
} else {
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
Model<?> tmpModel = (Model<?>) ois.readObject();
model = tmpModel.castModel(Label.class);
output = "Loaded model from path " + path.getAbsolutePath();
} catch (ClassNotFoundException e) {
logger.log(Level.SEVERE, "Failed to load class from stream " + path.getAbsolutePath(), e);
} catch (FileNotFoundException e) {
logger.log(Level.SEVERE, "Failed to open file " + path.getAbsolutePath(), e);
} catch (IOException e) {
logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
}
}

limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor,tokenizer);

return "Loaded model from path " + path.toString();
return output;
}

/**
Expand Down Expand Up @@ -293,6 +303,12 @@ public static class LIMETextCLIOptions implements Options {
*/
@Option(charName = 'f', longName = "filename", usage = "Model file to load. Optional.")
public String modelFilename;

/**
* Load the model from a protobuf. Optional.
*/
@Option(charName = 'p', longName = "protobuf-model", usage = "Load the model from a protobuf. Optional")
public boolean protobufFormat;
}

/**
Expand All @@ -305,7 +321,7 @@ public static void main(String[] args) {
ConfigurationManager cm = new ConfigurationManager(args, options, false);
LIMETextCLI driver = new LIMETextCLI();
if (options.modelFilename != null) {
logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename)));
logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename), options.protobufFormat));
}
driver.startShell();
} catch (UsageException e) {
Expand Down
Loading