Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test reference files missed in PR #261 and #258 + support to set DJL Engine random seed #263

Merged
merged 11 commits into from
Sep 19, 2022
7 changes: 7 additions & 0 deletions moa/src/main/java/moa/classifiers/deeplearning/CAND.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ public class CAND extends AbstractClassifier implements MultiClassClassifier, Ca
"Stats dump file name",
"" );

public IntOption djlRandomSeed = new IntOption(
"djlRandomSeed",
'S',
"Random seed for DJL Engine",
10, 0, Integer.MAX_VALUE);

@Override
public void resetLearningImpl() {
if (nn != null) {
Expand Down Expand Up @@ -428,6 +434,7 @@ class MLPConfigs{
this.nn[i].numberOfLayers.setValue(numberOfLayersInEachMLP.getValue());
this.nn[i].deltaForADWIN = nnConfigs[i].deltaForADWIN;
this.nn[i].backPropLossThreshold.setValue(backPropLossThreshold.getValue());
this.nn[i].djlRandomSeed.setValue(djlRandomSeed.getValue());
this.nn[i].initializeNetwork(instance);
}

Expand Down
17 changes: 14 additions & 3 deletions moa/src/main/java/moa/classifiers/deeplearning/MLP.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.MultiChoiceOption;

import ai.djl.engine.Engine;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
Expand All @@ -48,9 +49,7 @@

import java.text.DecimalFormat;
import java.lang.Math;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

import static ai.djl.ndarray.types.DataType.FLOAT32;

Expand Down Expand Up @@ -220,6 +219,12 @@ public static class NormalizeInfo{
new String[]{"GPU (use CPU if not available)", "CPU"},
deviceTypeOptionCPU);

public IntOption djlRandomSeed = new IntOption(
"djlRandomSeed",
'S',
"Random seed for DJL Engine",
10, 0, Integer.MAX_VALUE);


public double deltaForADWIN = 1.0E-5;

Expand Down Expand Up @@ -443,6 +448,12 @@ public void initializeNetwork(Instance inst) {
return;
}

Set<String> engines = Engine.getAllEngines();
Iterator<String> engineIterator = engines.iterator();
while (engineIterator.hasNext()){
Engine.getEngine(engineIterator.next()).setRandomSeed(djlRandomSeed.getValue());
}

votes = new double [inst.numClasses()];

if (useNormalization.isSet() || useOneHotEncode.isSet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
this.lambdaSc.add(0.0);
this.lambdaSw.add(0.0);
this.epsilon.add(0.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
this.lambdaTP.add(0.0);
this.lambdaTN.add(0.0);
this.lambdaFP.add(0.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
this.lambdaFP.add(0.0);
this.lambdaFN.add(0.0);
this.lambdaSum.add(0.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
this.lambdaSc.add(0.0);
this.lambdaPos.add(0.0);
this.lambdaNeg.add(0.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ protected void adjustEnsembleSize(int nClasses) {
for (int i = this.nEstimators; i < nClasses; i++) {
this.ensemble.add(this.baseLearner.copy());
this.nEstimators ++;
this.adwinEnsemble.add(new ADWIN());
if (this.driftDetection) {
this.adwinEnsemble.add(new ADWIN());
}
}
}
}
Expand Down
145 changes: 145 additions & 0 deletions moa/src/test/resources/moa/classifiers/deeplearning/CAND.ref
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
--> classification-out0.arff
moa.classifiers.deeplearning.CAND -o 2 -B 4 -h -n

Index
10000
Votes
0: 0.00238989
1: -0.00239092
Measurements
classified instances: 9999
classifications correct (percent): 58.04580458
Kappa Statistic (percent): 0.4860987
Kappa Temporal Statistic (percent): 11.57251265
Kappa M Statistic (percent): -2.44200244
Model measurements
model training instances: 9999

Index
20000
Votes
0: 0.00238989
1: -0.00239092
Measurements
classified instances: 19999
classifications correct (percent): 57.20786039
Kappa Statistic (percent): 0.47593559
Kappa Temporal Statistic (percent): 10.65873264
Kappa M Statistic (percent): -2.58930712
Model measurements
model training instances: 19999

Index
30000
Votes
0: -0.40148243
1: 0.40148208
Measurements
classified instances: 29999
classifications correct (percent): 57.39857995
Kappa Statistic (percent): 0.49809449
Kappa Temporal Statistic (percent): 11.72204186
Kappa M Statistic (percent): -2.33824471
Model measurements
model training instances: 29999

Index
40000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 39999
classifications correct (percent): 57.30893272
Kappa Statistic (percent): 0.91250658
Kappa Temporal Statistic (percent): 11.75649837
Kappa M Statistic (percent): -2.1291866
Model measurements
model training instances: 39999

Index
50000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 49999
classifications correct (percent): 57.3751475
Kappa Statistic (percent): 1.11019024
Kappa Temporal Statistic (percent): 11.78077655
Kappa M Statistic (percent): -1.99081164
Model measurements
model training instances: 49999

Index
60000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 59999
classifications correct (percent): 57.22595377
Kappa Statistic (percent): 1.3553988
Kappa Temporal Statistic (percent): 11.54919869
Kappa M Statistic (percent): -1.83318784
Model measurements
model training instances: 59999

Index
70000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 69999
classifications correct (percent): 57.10795869
Kappa Statistic (percent): 1.46362385
Kappa Temporal Statistic (percent): 11.58750258
Kappa M Statistic (percent): -1.76247289
Model measurements
model training instances: 69999

Index
80000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 79999
classifications correct (percent): 57.18946487
Kappa Statistic (percent): 1.72083945
Kappa Temporal Statistic (percent): 11.84329069
Kappa M Statistic (percent): -1.57788587
Model measurements
model training instances: 79999

Index
90000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 89999
classifications correct (percent): 57.20063556
Kappa Statistic (percent): 1.74312949
Kappa Temporal Statistic (percent): 11.82556942
Kappa M Statistic (percent): -1.57964135
Model measurements
model training instances: 89999

Index
100000
Votes
0: -0.37733972
1: -0.31610203
Measurements
classified instances: 99999
classifications correct (percent): 57.17857179
Kappa Statistic (percent): 1.93990852
Kappa Temporal Statistic (percent): 11.78930455
Kappa M Statistic (percent): -1.51486416
Model measurements
model training instances: 99999



Loading