Skip to content

Commit

Permalink
Add enum in naivebayes, pca, vectorsizehint, imputer, lda
Browse files Browse the repository at this point in the history
see #65
  • Loading branch information
liulfy authored and shaomeng.wang committed Apr 9, 2020
1 parent ede9f3f commit 50f9f2d
Show file tree
Hide file tree
Showing 36 changed files with 487 additions and 367 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.alibaba.alink.operator.batch.classification;

import java.util.ArrayList;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
Expand All @@ -11,6 +13,7 @@
import com.alibaba.alink.operator.common.statistics.StatisticsHelper;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.params.classification.NaiveBayesTextTrainParams;

import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
Expand All @@ -25,8 +28,6 @@
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

import java.util.ArrayList;

/**
* Text Naive Bayes Classifier.
*
Expand All @@ -35,8 +36,8 @@
*/

public final class NaiveBayesTextTrainBatchOp
extends BatchOperator<NaiveBayesTextTrainBatchOp>
implements NaiveBayesTextTrainParams<NaiveBayesTextTrainBatchOp> {
extends BatchOperator<NaiveBayesTextTrainBatchOp>
implements NaiveBayesTextTrainParams<NaiveBayesTextTrainBatchOp> {

/**
* Constructor.
Expand All @@ -63,43 +64,42 @@ public NaiveBayesTextTrainBatchOp(Params params) {
@Override
public NaiveBayesTextTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
TypeInformation<?> labelType;
TypeInformation <?> labelType;
String labelColName = getLabelCol();
NaiveBayesTextModelDataConverter.BayesType bayesType = NaiveBayesTextModelDataConverter.BayesType
.valueOf(getModelType().toUpperCase());
ModelType modelType = getModelType();
String weightColName = getWeightCol();
double smoothing = getSmoothing();
String vectorColName = getVectorCol();

labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelColName);

String[] keepColNames = (weightColName == null) ? new String[] {labelColName}
: new String[] {weightColName, labelColName};
Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> dataSrt
= StatisticsHelper.summaryHelper(in, null, vectorColName, keepColNames);
DataSet<Tuple2<Vector, Row>> data = dataSrt.f0;
DataSet<BaseVectorSummary> srt = dataSrt.f1;
: new String[] {weightColName, labelColName};
Tuple2 <DataSet <Tuple2 <Vector, Row>>, DataSet <BaseVectorSummary>> dataSrt
= StatisticsHelper.summaryHelper(in, null, vectorColName, keepColNames);
DataSet <Tuple2 <Vector, Row>> data = dataSrt.f0;
DataSet <BaseVectorSummary> srt = dataSrt.f1;

DataSet<Integer> vectorSize = srt.map(new MapFunction<BaseVectorSummary, Integer>() {
DataSet <Integer> vectorSize = srt.map(new MapFunction <BaseVectorSummary, Integer>() {
@Override
public Integer map(BaseVectorSummary value) {
return value.vectorSize();
}
});

// Transform data in the form of label, weight, feature.
DataSet<Tuple3<Object, Double, Vector>> trainData = data
.mapPartition(new Transform());

DataSet<Row> probs = trainData
.groupBy(new SelectLabel())
.reduceGroup(new ReduceItem())
.withBroadcastSet(vectorSize, "vectorSize")
.mapPartition(new GenerateModel(smoothing, bayesType, vectorColName, labelType))
.withBroadcastSet(vectorSize, "vectorSize")
.setParallelism(1);

//save the model matrix.
DataSet <Tuple3 <Object, Double, Vector>> trainData = data
.mapPartition(new Transform());

DataSet <Row> probs = trainData
.groupBy(new SelectLabel())
.reduceGroup(new ReduceItem())
.withBroadcastSet(vectorSize, "vectorSize")
.mapPartition(new GenerateModel(smoothing, modelType, vectorColName, labelType))
.withBroadcastSet(vectorSize, "vectorSize")
.setParallelism(1);

//save the model matrix.
this.setOutput(probs, new NaiveBayesTextModelDataConverter(labelType).getModelSchema());
return this;
}
Expand All @@ -108,28 +108,28 @@ public Integer map(BaseVectorSummary value) {
* Generate model.
*/
public static class GenerateModel extends AbstractRichFunction
implements MapPartitionFunction<Tuple3<Object, Double, Vector>, Row> {
implements MapPartitionFunction <Tuple3 <Object, Double, Vector>, Row> {
private int numFeature;
private double smoothing;
private NaiveBayesTextModelDataConverter.BayesType bayesType;
private ModelType modelType;
private String vectorColName;
private TypeInformation labelType;

GenerateModel(double smoothing, NaiveBayesTextModelDataConverter.BayesType bayesType,
GenerateModel(double smoothing, ModelType modelType,
String vectorColName, TypeInformation labelType) {
this.smoothing = smoothing;
this.bayesType = bayesType;
this.modelType = modelType;
this.labelType = labelType;
this.vectorColName = vectorColName;
}

@Override
public void mapPartition(Iterable <Tuple3<Object, Double, Vector>> values, Collector<Row> collector)
throws Exception {
public void mapPartition(Iterable <Tuple3 <Object, Double, Vector>> values, Collector <Row> collector)
throws Exception {
double numDocs = 0.0;
ArrayList <Tuple3<Object, Double, Vector>> modelArray = new ArrayList <>();
ArrayList <Tuple3 <Object, Double, Vector>> modelArray = new ArrayList <>();

for (Tuple3<Object, Double, Vector> tup : values) {
for (Tuple3 <Object, Double, Vector> tup : values) {
numDocs += tup.f1;
modelArray.add(tup);
}
Expand All @@ -146,12 +146,12 @@ public void mapPartition(Iterable <Tuple3<Object, Double, Vector>> values, Colle
numTerm += feature.get(j);
}
double thetaLog = 0.0;
switch (this.bayesType) {
case MULTINOMIAL: {
switch (this.modelType) {
case Multinomial: {
thetaLog += Math.log(numTerm + this.numFeature * this.smoothing);
break;
}
case BERNOULLI: {
case Bernoulli: {
thetaLog += Math.log(modelArray.get(i).f1 + 2.0 * this.smoothing);
break;
}
Expand All @@ -173,37 +173,37 @@ public void mapPartition(Iterable <Tuple3<Object, Double, Vector>> values, Colle
trainResultData.label = labels;
trainResultData.theta = theta;
trainResultData.vectorColName = vectorColName;
trainResultData.modelType = bayesType;
trainResultData.modelType = modelType;

new NaiveBayesTextModelDataConverter(labelType).save(trainResultData, collector);
}

@Override
public void open(Configuration parameters) throws Exception {
this.numFeature = (Integer) getRuntimeContext()
.getBroadcastVariable("vectorSize").get(0);
.getBroadcastVariable("vectorSize").get(0);
}
}

/**
* Transform the data format.
*/
public static class Transform
implements MapPartitionFunction<Tuple2<Vector, Row>, Tuple3<Object, Double, Vector>> {
implements MapPartitionFunction <Tuple2 <Vector, Row>, Tuple3 <Object, Double, Vector>> {

@Override
public void mapPartition(Iterable <Tuple2<Vector, Row>> values,
Collector<Tuple3<Object, Double, Vector>> out)
throws Exception {
for (Tuple2<Vector, Row> in : values) {
public void mapPartition(Iterable <Tuple2 <Vector, Row>> values,
Collector <Tuple3 <Object, Double, Vector>> out)
throws Exception {
for (Tuple2 <Vector, Row> in : values) {
Vector feature = in.f0;
Object labelVal = in.f1.getArity() == 2 ? in.f1.getField(1) : in.f1.getField(0);
Double weightVal = in.f1.getArity() == 2 ?
in.f1.getField(0) instanceof Number ?
((Number) in.f1.getField(0)).doubleValue() :
Double.parseDouble(in.f1.getField(0).toString())
: 1.0;
out.collect(new Tuple3<>(labelVal, weightVal, feature));
: 1.0;
out.collect(new Tuple3 <>(labelVal, weightVal, feature));

}
}
Expand All @@ -212,10 +212,10 @@ public void mapPartition(Iterable <Tuple2<Vector, Row>> values,
/**
* Group by trainData with its label.
*/
public static class SelectLabel implements KeySelector<Tuple3<Object, Double, Vector>, String> {
public static class SelectLabel implements KeySelector <Tuple3 <Object, Double, Vector>, String> {

@Override
public String getKey(Tuple3<Object, Double, Vector> t3) {
public String getKey(Tuple3 <Object, Double, Vector> t3) {
return t3.f0.toString();
}
}
Expand All @@ -224,18 +224,18 @@ public String getKey(Tuple3<Object, Double, Vector> t3) {
* Calculate the sum of feature with same label and the label weight.
*/
public static class ReduceItem extends AbstractRichFunction
implements GroupReduceFunction<Tuple3<Object, Double, Vector>, Tuple3<Object, Double, Vector>> {
implements GroupReduceFunction <Tuple3 <Object, Double, Vector>, Tuple3 <Object, Double, Vector>> {
private int vectorSize = 0;

@Override
public void reduce(Iterable <Tuple3<Object, Double, Vector>> rows,
Collector<Tuple3<Object, Double, Vector>> out) {
public void reduce(Iterable <Tuple3 <Object, Double, Vector>> rows,
Collector <Tuple3 <Object, Double, Vector>> out) {
Object label = null;

double weightSum = 0.0;
Vector featureSum = new DenseVector(this.vectorSize);

for (Tuple3<Object, Double, Vector> row : rows) {
for (Tuple3 <Object, Double, Vector> row : rows) {
label = row.f0;
double w = row.f1;
weightSum += w;
Expand All @@ -252,7 +252,7 @@ public void reduce(Iterable <Tuple3<Object, Double, Vector>> rows,
}
}
}
Tuple3<Object, Double, Vector> t3 = new Tuple3<>(label, weightSum, featureSum);
Tuple3 <Object, Double, Vector> t3 = new Tuple3 <>(label, weightSum, featureSum);

out.collect(t3);
}
Expand All @@ -261,7 +261,7 @@ public void reduce(Iterable <Tuple3<Object, Double, Vector>> rows,
public void open(Configuration parameters) throws Exception {

this.vectorSize = (Integer) getRuntimeContext()
.getBroadcastVariable("vectorSize").get(0);
.getBroadcastVariable("vectorSize").get(0);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public LdaTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
int numTopic = getTopicNum();
int numIter = getNumIter();
String vectorColName = getSelectedCol();
String optimizer = getMethod();
Method optimizer = getMethod();
getParams().set(SELECTED_COL, vectorColName);
final DataSet<DocCountVectorizerModelData> resDocCountModel = DocCountVectorizerTrainBatchOp
.generateDocCountModel(getParams(), in);
Expand All @@ -94,12 +94,11 @@ public LdaTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
= StatisticsHelper.summaryHelper(trainData, null, vectorColName);
double beta = getParams().get(BETA);
double alpha = getParams().get(ALPHA);
LdaUtil.OptimizerMethod optimizerMethod = LdaUtil.OptimizerMethod.valueOf(optimizer.toUpperCase());
switch (optimizerMethod) {
switch (optimizer) {
case EM:
gibbsSample(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel);
break;
case ONLINE:
case Online:
online(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.ImputerModelDataConverter;
import com.alibaba.alink.operator.common.statistics.StatisticsHelper;
import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary;
import com.alibaba.alink.params.dataproc.ImputerTrainParams;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.statistics.StatisticsHelper;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.ImputerTrainParams;
import org.apache.flink.util.Collector;

/**
Expand All @@ -27,7 +26,7 @@
* If value, will replace missing value with the value.
*/
public class ImputerTrainBatchOp extends BatchOperator<ImputerTrainBatchOp>
implements ImputerTrainParams<ImputerTrainBatchOp> {
implements ImputerTrainParams<ImputerTrainBatchOp> {

public ImputerTrainBatchOp() {
super(null);
Expand All @@ -41,16 +40,13 @@ public ImputerTrainBatchOp(Params params) {
public ImputerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String[] selectedColNames = getSelectedCols();
String strategy = getStrategy();
Strategy strategy = getStrategy();

//result is statistic model with strategy.
ImputerModelDataConverter converter = new ImputerModelDataConverter();
converter.selectedColNames = selectedColNames;
converter.selectedColTypes = TableUtil.findColTypesWithAssertAndHint(in.getSchema(), selectedColNames);

Params meta = new Params()
.set(ImputerTrainParams.STRATEGY, strategy);

//if strategy is not min, max, mean
DataSet<Row> rows;
if (isNeedStatModel()) {
Expand All @@ -59,9 +55,12 @@ public ImputerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
TableUtil.findColTypesWithAssertAndHint(in.getSchema(), selectedColNames), strategy));

} else {
if (!getParams().contains(ImputerTrainParams.FILL_VALUE)) {
throw new RuntimeException("In VALUE strategy, the filling value is necessary.");
}
String fillValue = getFillValue();
RowCollector collector = new RowCollector();
converter.save(Tuple2.of(fillValue, null), collector);
converter.save(Tuple3.of(Strategy.VALUE, null, fillValue), collector);
rows = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromCollection(collector.getRows());
}

Expand All @@ -70,13 +69,13 @@ public ImputerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
}

private boolean isNeedStatModel() {
String strategy = getStrategy();
if ("min".equals(strategy) || "max".equals(strategy) || "mean".equals(strategy)) {
ImputerTrainParams.Strategy strategy = getStrategy();
if (Strategy.MIN.equals(strategy) || Strategy.MAX.equals(strategy) || Strategy.MEAN.equals(strategy)) {
return true;
} else if ("value".equals(strategy)){
} else if (Strategy.VALUE.equals(strategy)){
return false;
} else {
throw new IllegalArgumentException("Only support \"max\", \"mean\", \"min\" and \"value\" strategy.");
throw new IllegalArgumentException("Only support \"MAX\", \"MEAN\", \"MIN\" and \"VALUE\" strategy.");
}
}

Expand All @@ -87,9 +86,9 @@ private boolean isNeedStatModel() {
public static class BuildImputerModel implements FlatMapFunction<TableSummary, Row> {
private String[] selectedColNames;
private TypeInformation[] selectedColTypes;
private String strategy;
private Strategy strategy;

public BuildImputerModel(String[] selectedColNames, TypeInformation[] selectedColTypes, String strategy) {
public BuildImputerModel(String[] selectedColNames, TypeInformation[] selectedColTypes, Strategy strategy) {
this.selectedColNames = selectedColNames;
this.selectedColTypes = selectedColTypes;
this.strategy = strategy;
Expand All @@ -102,9 +101,9 @@ public void flatMap(TableSummary srt, Collector<Row> collector) throws Exception
converter.selectedColNames = selectedColNames;
converter.selectedColTypes = selectedColTypes;

converter.save(new Tuple2<>(strategy, srt), collector);
converter.save(new Tuple3<>(strategy, srt, ""), collector);
}
}
}

}
}
Loading

0 comments on commit 50f9f2d

Please sign in to comment.