-
Notifications
You must be signed in to change notification settings - Fork 1
/
RandomForestImageClassifierTrainer.java
42 lines (33 loc) · 1.29 KB
/
RandomForestImageClassifierTrainer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import java.io.File;
import weka.classifiers.functions.LibSVM;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ArffLoader;
import weka.core.converters.Loader;
import weka.gui.beans.Classifier;
public class RandomForestImageClassifierTrainer
{
public static void main(String args[]) throws Exception
{
ArffLoader trainLoader = new ArffLoader();
trainLoader.setSource(new File("train.arff"));
trainLoader.setRetrieval(Loader.BATCH);
Instances trainDataSet = trainLoader.getDataSet();
Attribute trainAttribute = trainDataSet.attribute("class");
trainDataSet.setClass(trainAttribute);
//trainDataSet.deleteStringAttributes();
RandomForest classifier = new RandomForest();
classifier.setNumTrees(500);
classifier.setMaxDepth(30);
classifier.setDebug(true);
final double startTime = System.currentTimeMillis();
classifier.buildClassifier(trainDataSet);
final double endTime = System.currentTimeMillis();
double executionTime = (endTime - startTime)/(1000.0);
System.out.println("Total execution time: " + executionTime );
SerializationHelper.write("classifier500.model", classifier);
System.out.println("Saved trained model to classifier.model");
}
}