Skip to content

Commit

Permalink
Linear model train params change to enum. see #71
Browse files Browse the repository at this point in the history
  • Loading branch information
weibozhao committed Apr 9, 2020
1 parent aa15681 commit db3fd54
Show file tree
Hide file tree
Showing 17 changed files with 88 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams;
import com.alibaba.alink.params.classification.LinearSvmTrainParams;

import org.apache.flink.ml.api.misc.param.Params;

Expand All @@ -13,14 +12,13 @@
*
*/
public final class LinearSvmTrainBatchOp extends BaseLinearModelTrainBatchOp<LinearSvmTrainBatchOp>
implements LinearSvmTrainParams <LinearSvmTrainBatchOp> {
implements LinearBinaryClassTrainParams<LinearSvmTrainBatchOp> {

public LinearSvmTrainBatchOp() {
this(new Params());
}

public LinearSvmTrainBatchOp(Params params) {
super(params.set(LinearBinaryClassTrainParams.L_2, 1.0 / params.get(LinearSvmTrainParams.C)),
LinearModelType.SVM, "Linear SVM");
super(params, LinearModelType.SVM, "Linear SVM");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.params.classification.LogisticRegressionTrainParams;
import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams;

import org.apache.flink.ml.api.misc.param.Params;

Expand All @@ -12,7 +12,7 @@
*
*/
public final class LogisticRegressionTrainBatchOp extends BaseLinearModelTrainBatchOp<LogisticRegressionTrainBatchOp>
implements LogisticRegressionTrainParams <LogisticRegressionTrainBatchOp> {
implements LinearBinaryClassTrainParams<LogisticRegressionTrainBatchOp> {

public LogisticRegressionTrainBatchOp() {
this(new Params());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.SoftmaxObjFunc;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.OptimMethod;
import com.alibaba.alink.operator.common.optim.OptimizerFactory;
import com.alibaba.alink.operator.common.optim.Owlqn;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
Expand Down Expand Up @@ -212,8 +211,8 @@ public void reduce(Iterable<Integer> values, Collector<OptimObjFunc> out) throws
});

// solve the opt problem.
if (params.contains("optimMethod")) {
OptimMethod method = OptimMethod.valueOf(params.get(LinearTrainParams.OPTIM_METHOD).toUpperCase());
if (params.contains(LinearTrainParams.OPTIM_METHOD)) {
OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD);
return OptimizerFactory.create(objFunc, trainData, coefDim, params, method)
.optimize();
} else if (params.get(SoftmaxTrainParams.L_1) > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.linear.unarylossfunc.*;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.OptimMethod;
import com.alibaba.alink.operator.common.optim.OptimizerFactory;
import com.alibaba.alink.operator.common.optim.Owlqn;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
Expand Down Expand Up @@ -258,7 +257,7 @@ public static DataSet<Tuple2<DenseVector, double[]>> optimize(Params params,
.fromElements(getObjFunction(modelType, params));

if (params.contains(LinearTrainParams.OPTIM_METHOD)) {
OptimMethod method = OptimMethod.valueOf(params.get(LinearTrainParams.OPTIM_METHOD).toUpperCase());
LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD);
return OptimizerFactory.create(objFunc, trainData, coefficientDim, params, method).optimize();
} else if (params.get(HasL1.L_1) > 0) {
return new Owlqn(objFunc, trainData, coefficientDim, params).optimize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;
import com.alibaba.alink.params.shared.optim.HasNumSearchStepDv4;

import org.apache.flink.api.java.DataSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;
import com.alibaba.alink.params.shared.optim.HasNumSearchStepDv4;

import org.apache.flink.api.java.DataSet;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
Expand All @@ -16,11 +17,11 @@ public static Optimizer create(
DataSet <Tuple3 <Double, Double, Vector>> trainData,
DataSet <Integer> coefDim,
Params params,
OptimMethod method) {
LinearTrainParams.OptimMethod method) {
switch (method) {
case SGD:
return new Sgd(objFunc, trainData, coefDim, params);
case NEWTON:
case Newton:
return new Newton(objFunc, trainData, coefDim, params);
case GD:
return new Gd(objFunc, trainData, coefDim, params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.alibaba.alink.params.shared.linear.HasL1;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;
import com.alibaba.alink.params.shared.optim.HasNumSearchStepDv4;

import org.apache.flink.api.java.DataSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.common.comqueue.IterativeComQueue;
import com.alibaba.alink.common.comqueue.communication.AllReduce;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;
import com.alibaba.alink.params.shared.optim.SgdParams;

import org.apache.flink.api.java.DataSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.OptimMethod;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.alibaba.alink.operator.common.optim.subfunc;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.operator.common.optim.OptimMethod;
import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.ComputeFunction;
import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100;
import com.alibaba.alink.params.shared.linear.HasEpsilonDv0000001;
import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.alibaba.alink.params.shared.linear;

import com.alibaba.alink.params.ParamUtil;
import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull;
Expand All @@ -11,29 +12,58 @@

/**
* parameters of linear training.
*
*/
public interface LinearTrainParams<T> extends
HasWithIntercept <T>,
HasMaxIterDefaultAs100<T>,
HasEpsilonDv0000001 <T>,
HasFeatureColsDefaultAsNull <T>,
HasLabelCol <T>,
HasWeightColDefaultAsNull <T>,
HasVectorColDefaultAsNull <T>,
HasStandardization <T> {

ParamInfo<String> OPTIM_METHOD = ParamInfoFactory
.createParamInfo("optimMethod", String.class)
.setDescription("optimization method")
.setHasDefaultValue(null)
.build();

default String getOptimMethod() {
return get(OPTIM_METHOD);
}

default T setOptimMethod(String value) {
return set(OPTIM_METHOD, value);
}
HasWithIntercept<T>,
HasMaxIterDefaultAs100<T>,
HasEpsilonDv0000001<T>,
HasFeatureColsDefaultAsNull<T>,
HasLabelCol<T>,
HasWeightColDefaultAsNull<T>,
HasVectorColDefaultAsNull<T>,
HasStandardization<T> {

ParamInfo<OptimMethod> OPTIM_METHOD = ParamInfoFactory
.createParamInfo("optimMethod", OptimMethod.class)
.setDescription("optimization method")
.setHasDefaultValue(null)
.build();

default OptimMethod getOptimMethod() {
return get(OPTIM_METHOD);
}

default T setOptimMethod(String value) {
return set(OPTIM_METHOD, ParamUtil.searchEnum(OPTIM_METHOD, value));
}

default T setOptimMethod(OptimMethod value) {
return set(OPTIM_METHOD, value);
}

/**
* Optimization Type.
*/
enum OptimMethod {
/**
* LBFGS method
*/
LBFGS,
/**
* Gradient Descent method
*/
GD,
/**
* Newton method
*/
Newton,
/**
* Stochastic Gradient Descent method
*/
SGD,
/**
* OWLQN method
*/
OWLQN
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LinearSvmTrainBatchOp;

import org.apache.flink.ml.api.misc.param.Params;

import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams;
import com.alibaba.alink.params.classification.LinearSvmPredictParams;
import com.alibaba.alink.params.classification.LinearSvmTrainParams;
import com.alibaba.alink.pipeline.Trainer;

/**
* Linear svm pipeline op.
*
*/
public class LinearSvm extends Trainer <LinearSvm, LinearSvmModel>
implements LinearSvmTrainParams <LinearSvm>, LinearSvmPredictParams <LinearSvm> {
public class LinearSvm extends Trainer<LinearSvm, LinearSvmModel>
implements LinearBinaryClassTrainParams<LinearSvm>, LinearSvmPredictParams<LinearSvm> {

public LinearSvm() {
super();
}
public LinearSvm() {
super();
}

public LinearSvm(Params params) {
super(params);
}
public LinearSvm(Params params) {
super(params);
}

@Override
protected BatchOperator train(BatchOperator in) {
return new LinearSvmTrainBatchOp(this.getParams()).linkFrom(in);
}
@Override
protected BatchOperator train(BatchOperator in) {
return new LinearSvmTrainBatchOp(this.getParams()).linkFrom(in);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import org.apache.flink.ml.api.misc.param.Params;

import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams;
import com.alibaba.alink.params.classification.LogisticRegressionPredictParams;
import com.alibaba.alink.params.classification.LogisticRegressionTrainParams;
import com.alibaba.alink.pipeline.Trainer;

/**
* Logistic regression is a popular method to predict a categorical response.
*
*/
public class LogisticRegression extends Trainer <LogisticRegression, LogisticRegressionModel> implements
LogisticRegressionTrainParams <LogisticRegression>,
LogisticRegressionPredictParams <LogisticRegression> {
LogisticRegressionPredictParams <LogisticRegression>,
LinearBinaryClassTrainParams<LogisticRegression> {

public LogisticRegression() {super();}

Expand Down

0 comments on commit db3fd54

Please sign in to comment.