Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dense feature vector optimisations to CRF and LinearSGD models #112

Merged
merged 11 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.classification.sequence;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.example.SequenceDataGenerator;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;
import org.tribuo.util.Util;

import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Path;
import java.util.logging.Logger;

/**
* Build and run a sequence classifier on a generated or serialized dataset using the trainer specified in the configuration file.
*/
public class SeqTrainTest {

private static final Logger logger = Logger.getLogger(SeqTrainTest.class.getName());

public static class SeqTrainTestOptions implements Options {
@Override
public String getOptionsDescription() {
return "Trains and tests a sequence classification model on the specified dataset.";
}

@Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.")
public String datasetName = "";
@Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.")
public Path outputPath;
@Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.")
public Path trainDataset = null;
@Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.")
public Path testDataset = null;
@Option(charName = 't', longName = "trainer-name", usage = "Name of the trainer in the configuration file.")
public SequenceTrainer<Label> trainer;
}

/**
* @param args the command line arguments
* @throws ClassNotFoundException if it failed to load the model.
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws ClassNotFoundException, IOException {

//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();

SeqTrainTestOptions o = new SeqTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}

SequenceDataset<Label> train;
SequenceDataset<Label> test;
switch (o.datasetName) {
case "Gorilla":
case "gorilla":
logger.info("Generating gorilla dataset");
train = SequenceDataGenerator.generateGorillaDataset(1);
test = SequenceDataGenerator.generateGorillaDataset(1);
break;
default:
if ((o.trainDataset != null) && (o.testDataset != null)) {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
train = tmpTrain;
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
@SuppressWarnings("unchecked") // deserialising a generic dataset.
SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
test = tmpTest;
logger.info(String.format("Loaded %d testing examples", test.size()));
}
} else {
logger.warning("Unknown dataset " + o.datasetName);
logger.info(cm.usage());
return;
}
}

logger.info("Training using " + o.trainer.toString());
final long trainStart = System.currentTimeMillis();
SequenceModel<Label> model = o.trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));

LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
final long testStart = System.currentTimeMillis();
LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model,test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
System.out.println();
System.out.println(evaluation.getConfusionMatrix().toString());

if (o.outputPath != null) {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
oos.writeObject(model);
logger.info("Serialized model to file: " + o.outputPath);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeature
this.scoreAggregation = scoreAggregation;
}

/**
* For OLCUT.
*/
private ViterbiTrainer() { }

/**
* The viterbi train method is unique because it delegates to a regular
* {@link Model} train method, but before it does, it adds features derived
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.tribuo.classification.sgd;

import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;

import java.util.SplittableRandom;
Expand All @@ -26,11 +27,13 @@
public class Util {
/**
* In place shuffle of the features, labels and weights.
* @deprecated In favour of {@link org.tribuo.common.sgd.AbstractLinearSGDTrainer#shuffleInPlace}.
* @param features Input features.
* @param labels Input labels.
* @param weights Input weights.
* @param rng SplittableRandom number generator.
*/
@Deprecated
public static void shuffleInPlace(SparseVector[] features, int[] labels, double[] weights, SplittableRandom rng) {
int size = features.length;
// Shuffle array
Expand Down Expand Up @@ -142,13 +145,13 @@ public ExampleArray(SparseVector[] features, int[] labels, double[] weights) {
* @param weights Input weights.
* @param rng SplittableRandom number generator.
*/
public static void shuffleInPlace(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
public static void shuffleInPlace(SGDVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
int size = features.length;
// Shuffle array
for (int i = size; i > 1; i--) {
int j = rng.nextInt(i);
//swap features
SparseVector[] tmpFeature = features[i-1];
SGDVector[] tmpFeature = features[i-1];
features[i-1] = features[j];
features[j] = tmpFeature;
//swap labels
Expand All @@ -170,9 +173,9 @@ public static void shuffleInPlace(SparseVector[][] features, int[][] labels, dou
* @param rng SplittableRandom number generator.
* @return A tuple of shuffled features, labels and weights.
*/
public static SequenceExampleArray shuffle(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
public static SequenceExampleArray shuffle(SGDVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
int size = features.length;
SparseVector[][] newFeatures = new SparseVector[size][];
SGDVector[][] newFeatures = new SGDVector[size][];
int[][] newLabels = new int[size][];
double[] newWeights = new double[size];
for (int i = 0; i < newFeatures.length; i++) {
Expand All @@ -184,7 +187,7 @@ public static SequenceExampleArray shuffle(SparseVector[][] features, int[][] la
for (int i = size; i > 1; i--) {
int j = rng.nextInt(i);
//swap features
SparseVector[] tmpFeature = newFeatures[i-1];
SGDVector[] tmpFeature = newFeatures[i-1];
newFeatures[i-1] = newFeatures[j];
newFeatures[j] = tmpFeature;
//swap labels
Expand All @@ -203,11 +206,11 @@ public static SequenceExampleArray shuffle(SparseVector[][] features, int[][] la
* A nominal tuple. One day it'll be a record, but not today.
*/
public static class SequenceExampleArray {
public final SparseVector[][] features;
public final SGDVector[][] features;
public final int[][] labels;
public final double[] weights;

public SequenceExampleArray(SparseVector[][] features, int[][] labels, double[] weights) {
SequenceExampleArray(SGDVector[][] features, int[][] labels, double[] weights) {
this.features = features;
this.labels = labels;
this.weights = weights;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.provenance.ModelProvenance;
Expand All @@ -43,9 +44,9 @@
import static org.tribuo.Model.BIAS_FEATURE;

/**
* An inference time model for a CRF trained using SGD.
* An inference time model for a linear chain CRF trained using SGD.
* <p>
* Can be switched to use belief propagation, or constrained BP, at test time instead of the standard Viterbi.
* Can be switched to use Viterbi, belief propagation, or constrained BP at test time. By default it uses Viterbi.
* <p>
* See:
* <pre>
Expand Down Expand Up @@ -128,7 +129,7 @@ public DenseVector getFeatureWeights(String featureName) {

@Override
public List<Prediction<Label>> predict(SequenceExample<Label> example) {
SparseVector[] features = convert(example,featureIDMap);
SGDVector[] features = convertToVector(example,featureIDMap);
List<Prediction<Label>> output = new ArrayList<>();
if (confidenceType == ConfidenceType.MULTIPLY) {
DenseVector[] marginals = parameters.predictMarginals(features);
Expand Down Expand Up @@ -193,7 +194,7 @@ public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
q.poll();
q.offer(curr);
}
ArrayList<Pair<String,Double>> b = new ArrayList<>();
List<Pair<String,Double>> b = new ArrayList<>();
while (q.size() > 0) {
b.add(q.poll());
}
Expand Down Expand Up @@ -228,7 +229,7 @@ public <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<
* @return The scores.
*/
public List<Double> scoreChunks(SequenceExample<Label> example, List<Chunk> chunks) {
SparseVector[] features = convert(example,featureIDMap);
SGDVector[] features = convertToVector(example,featureIDMap);
return parameters.predictConfidenceUsingCBP(features,chunks);
}

Expand Down Expand Up @@ -258,11 +259,13 @@ public String generateWeightsString() {

/**
* Converts a {@link SequenceExample} into an array of {@link SparseVector}s suitable for CRF prediction.
* @deprecated As it's replaced with {@link #convertToVector} which is more flexible.
* @param example The sequence example to convert
* @param featureIDMap The feature id map, used to discover the number of features.
* @param <T> The type parameter of the sequence example.
* @return An array of {@link SparseVector}.
*/
@Deprecated
public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) {
int length = example.size();
if (length == 0) {
Expand All @@ -282,11 +285,13 @@ public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> ex

/**
* Converts a {@link SequenceExample} into an array of {@link SparseVector}s and labels suitable for CRF prediction.
* @deprecated As it's replaced with {@link #convertToVector} which is more flexible.
* @param example The sequence example to convert
* @param featureIDMap The feature id map, used to discover the number of features.
* @param labelIDMap The label id map, used to get the index of the labels.
* @return A {@link Pair} of an int array of labels and an array of {@link SparseVector}.
*/
@Deprecated
public static Pair<int[],SparseVector[]> convert(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
int length = example.size();
if (length == 0) {
Expand All @@ -305,4 +310,64 @@ public static Pair<int[],SparseVector[]> convert(SequenceExample<Label> example,
}
return new Pair<>(labels,features);
}

/**
* Converts a {@link SequenceExample} into an array of {@link SGDVector}s suitable for CRF prediction.
* @param example The sequence example to convert
* @param featureIDMap The feature id map, used to discover the number of features.
* @param <T> The type parameter of the sequence example.
* @return An array of {@link SGDVector}.
*/
public static <T extends Output<T>> SGDVector[] convertToVector(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) {
int length = example.size();
if (length == 0) {
throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
}
int featureSpaceSize = featureIDMap.size();
SGDVector[] features = new SGDVector[length];
int i = 0;
for (Example<T> e : example) {
if (e.size() == featureSpaceSize) {
features[i] = DenseVector.createDenseVector(e, featureIDMap, false);
} else {
features[i] = SparseVector.createSparseVector(e, featureIDMap, false);
}
if (features[i].numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + e.toString());
}
i++;
}
return features;
}

/**
* Converts a {@link SequenceExample} into an array of {@link SGDVector}s and labels suitable for CRF prediction.
* @param example The sequence example to convert
* @param featureIDMap The feature id map, used to discover the number of features.
* @param labelIDMap The label id map, used to get the index of the labels.
* @return A {@link Pair} of an int array of labels and an array of {@link SparseVector}.
*/
public static Pair<int[],SGDVector[]> convertToVector(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
int length = example.size();
if (length == 0) {
throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
}
int featureSpaceSize = featureIDMap.size();
int[] labels = new int[length];
SGDVector[] features = new SGDVector[length];
int i = 0;
for (Example<Label> e : example) {
labels[i] = labelIDMap.getID(e.getOutput());
if (e.size() == featureSpaceSize) {
features[i] = DenseVector.createDenseVector(e, featureIDMap, false);
} else {
features[i] = SparseVector.createSparseVector(e, featureIDMap, false);
}
if (features[i].numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + e.toString());
}
i++;
}
return new Pair<>(labels,features);
}
}
Loading