-
Notifications
You must be signed in to change notification settings - Fork 117
/
StackingUtil.java
86 lines (67 loc) · 2.61 KB
/
StackingUtil.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/*
* Copyright (c) 2019 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.ensemble.stacking;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
public class StackingUtil {
private StackingUtil(){
}
static
public <E extends Estimator> MiningModel encodeStacking(List<? extends E> estimators, List<String> stackMethods, PredictFunction predictFunction, E finalEstimator, boolean passthrough, Schema schema){
ClassDictUtil.checkSize(estimators, stackMethods);
SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
List<Feature> stackFeatures = new ArrayList<>();
List<Model> models = new ArrayList<>();
for(int i = 0; i < estimators.size(); i++){
E estimator = estimators.get(i);
String stackMethod = stackMethods.get(i);
Model model = estimator.encode((i + 1), schema);
List<Feature> predictFeatures = predictFunction.apply(i, model, stackMethod, encoder);
if(predictFeatures != null && !predictFeatures.isEmpty()){
stackFeatures.addAll(predictFeatures);
}
models.add(model);
}
if(passthrough){
stackFeatures.addAll(features);
}
{
Schema stackSchema = new Schema(encoder, label, stackFeatures);
Model finalModel = finalEstimator.encode(stackSchema);
models.add(finalModel);
}
return MiningModelUtil.createModelChain(models, Segmentation.MissingPredictionTreatment.RETURN_MISSING);
}
static
public interface PredictFunction {
List<Feature> apply(int index, Model model, String stackMethod, SkLearnEncoder encoder);
}
}