From 70b30af3aa1aa514f97f41703b3d0538102649dc Mon Sep 17 00:00:00 2001 From: cainingnk Date: Thu, 23 Jul 2020 10:59:51 +0800 Subject: [PATCH] Refine pca, feature selector model info. see #117 --- .../batch/feature/ChiSqSelectorBatchOp.java | 45 ++- .../batch/feature/PcaTrainBatchOp.java | 298 ++++++++++++++---- .../feature/VectorChiSqSelectorBatchOp.java | 37 ++- .../statistics/ChiSquareTestBatchOp.java | 60 +++- .../batch/statistics/CorrelationBatchOp.java | 34 ++ .../VectorChiSquareTestBatchOp.java | 52 ++- .../statistics/VectorCorrelationBatchOp.java | 33 ++ .../statistics/VectorSummarizerBatchOp.java | 43 ++- .../ChiSqSelectorModelDataConverter.java | 19 +- .../feature/ChisqSelectorModelInfo.java | 149 +++++++++ .../ChisqSelectorModelInfoBatchOp.java | 27 ++ .../common/feature/ChisqSelectorUtil.java | 206 ++++++++++++ .../common/feature/pca/PcaModelData.java | 28 +- .../common/feature/pca/PcaModelMapper.java | 49 +-- .../common/statistics/ChiSquareTest.java | 166 +--------- .../statistics/ChiSquareTestResult.java | 25 +- .../common/statistics/ChiSquareTestUtil.java | 111 ++----- .../params/feature/HasCalculationType.java | 21 +- .../params/feature/PcaPredictParams.java | 42 --- .../alink/params/feature/PcaTrainParams.java | 15 +- .../alibaba/alink/pipeline/feature/PCA.java | 4 +- .../feature/ChiSqSelectorBatchOpTest.java | 27 +- .../VectorChiSqSelectorBatchOpTest.java | 20 +- .../statistics/ChiSquareTestBatchOpTest.java | 2 +- .../VectorChiSquareTestBatchOpTest.java | 2 +- .../common/statistics/ChiSquareTestTest.java | 53 +++- .../alink/pipeline/feature/PCATest.java | 140 +++++--- 27 files changed, 1169 insertions(+), 539 deletions(-) create mode 100644 core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfo.java create mode 100644 core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfoBatchOp.java create mode 100644 core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorUtil.java diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java index 1004e7ddc..70498db6f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java @@ -1,15 +1,26 @@ package com.alibaba.alink.operator.batch.feature; +import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfoBatchOp; +import com.alibaba.alink.operator.common.feature.ChisqSelectorUtil; import com.alibaba.alink.operator.common.statistics.ChiSquareTestUtil; -import com.alibaba.alink.operator.batch.BatchOperator; -import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.params.feature.ChiSqSelectorParams; -import org.apache.flink.util.Preconditions; - +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +/** + * chi-square selector for table. + */ public final class ChiSqSelectorBatchOp extends BatchOperator - implements ChiSqSelectorParams { + implements ChiSqSelectorParams, + WithModelInfoBatchOp { + + private static final long serialVersionUID = 942267749590810559L; public ChiSqSelectorBatchOp() { super(null); @@ -32,23 +43,23 @@ public ChiSqSelectorBatchOp linkFrom(BatchOperator... inputs) { double fdr = getParams().get(FDR); double fwe = getParams().get(FWE); - setOutputTable(ChiSquareTestUtil.selector(in, selectedColNames, labelColName, - selectorType, numTopFeatures, percentile, fpr, fdr, fwe)); + DataSet chiSquareTest = + ChiSquareTestUtil.test(in, selectedColNames, labelColName); - return this; - } + DataSet model = chiSquareTest.mapPartition( + new ChisqSelectorUtil.ChiSquareSelector(selectedColNames, selectorType, numTopFeatures, percentile, fpr, fdr, fwe)) + .name("FilterFeature") + .setParallelism(1); + setOutputTable(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), model, new ChiSqSelectorModelDataConverter().getModelSchema())); - public String[] collectResult() { - Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); + return this; + } - int[] indices = new ChiSqSelectorModelDataConverter().load(this.collect()); - String[] selectedColNames = new String[indices.length]; - for (int i = 0; i < indices.length; i++) { - selectedColNames[i] = this.getSelectedCols()[i]; - } - return selectedColNames; + @Override + public ChisqSelectorModelInfoBatchOp getModelInfoBatchOp() { + return new ChisqSelectorModelInfoBatchOp(getParams()).linkFrom(this); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java index 1e992e989..d5beda565 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java @@ -1,26 +1,34 @@ package com.alibaba.alink.operator.batch.feature; +import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.*; +import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.feature.pca.PcaModelData; +import com.alibaba.alink.operator.common.feature.pca.PcaModelDataConverter; +import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; +import com.alibaba.alink.params.feature.PcaTrainParams; +import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; -import com.alibaba.alink.operator.common.feature.pca.PcaModelDataConverter; -import com.alibaba.alink.operator.common.feature.pca.PcaModelData; -import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; -import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; - -import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.params.feature.PcaTrainParams; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Random; /** @@ -29,8 +37,10 @@ * The calculation is done using eigen on the correlation or covariance matrix. */ public final class PcaTrainBatchOp extends BatchOperator - implements PcaTrainParams { + implements PcaTrainParams, + WithModelInfoBatchOp { + private static final long serialVersionUID = 6098674439183289020L; /** * block size when transmit */ @@ -40,7 +50,7 @@ public final class PcaTrainBatchOp extends BatchOperator * default constructor */ public PcaTrainBatchOp() { - super(null); + this(null); } /** @@ -50,9 +60,8 @@ public PcaTrainBatchOp() { * selectedColNames: compute col names. when input is table, not tensor. * tensorColName: compute tensor col. when input is tensor. * isSparse: true is sparse tensor, false is dense tensor. default is false. - * pcaType: compute type, be CORR, COV_SAMPLE, COVAR_POP. - * CORR is correlation matrix,COV_SAMPLE is covariance of sample,COVAR_POP is covariance of - * population. + * pcaType: compute type, be CORR, COV_SAMPLE, COV_POPULATION. + * CORR is correlation matrix,COV is covariance * p: number of principal component */ public PcaTrainBatchOp(Params params) { @@ -80,7 +89,7 @@ public PcaTrainBatchOp linkFrom(BatchOperator... inputs) { VectorSplit vectorSplit = new VectorSplit(); //combine vector - VecCombine vecCombine = new VecCombine(calcType.name(), k, selectedColNames, vectorColName); + VecCombine vecCombine = new VecCombine(calcType, k, selectedColNames, vectorColName); DataSet srt = data .mapPartition(new StatisticsHelper.VectorSummarizerPartition(true)) @@ -90,13 +99,28 @@ public PcaTrainBatchOp linkFrom(BatchOperator... inputs) { //convert model to table this.setOutput(srt, new PcaModelDataConverter().getModelSchema()); + //cal model summary + DataSet modelSummary = srt + .mapPartition(new ModelInfoMapPartition(getCalculationType())) + .setParallelism(1); + + Table[] tables = new Table[1]; + tables[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), modelSummary, + new String[]{"modelinfo"}, + new TypeInformation[]{Types.STRING}); + + this.setSideOutputTables(tables); + return this; } + /** * split rowNum,sum, squareSum, dot vector */ public static class VectorSplit extends RichFlatMapFunction> { + private static final long serialVersionUID = 4372448784539139888L; + @Override public void flatMap(BaseVectorSummarizer srt, Collector> collector) throws Exception { @@ -138,9 +162,14 @@ public void flatMap(BaseVectorSummarizer srt, Collector, Row> { - protected String pcaType; + private static final long serialVersionUID = 2228432228822829081L; + protected CalculationType pcaType; protected int p; protected String[] featureColNames; protected String tensorColName; - public VecCombine(String pcaType, int p, String[] featureColNames, String tensorColName) { + public VecCombine(CalculationType pcaType, int p, String[] featureColNames, String tensorColName) { this.pcaType = pcaType; this.p = p; this.featureColNames = featureColNames; @@ -186,8 +216,8 @@ public VecCombine(String pcaType, int p, String[] featureColNames, String tensor * @param colNum col number * @return covariance matrix */ - public static double[][] getCov(double[] counts, double[] sums, double[] dotProduct, - int colNum) { + static double[][] getCov(double[] counts, double[] sums, double[] dotProduct, + int colNum) { double[][] cov = new double[colNum][colNum]; double d = 0; int idx = 0; @@ -202,6 +232,36 @@ public static double[][] getCov(double[] counts, double[] sums, double[] dotProd return cov; } + static double[] dotProdctionCut(double[] dotProduct, List nonEqualColIdx, int nAll) { + int nCut = nonEqualColIdx.size(); + double[] dotProductCut = new double[nCut * (nCut + 1) / 2]; + int idx = 0; + int idxOrigin = 0; + for (int i = 0; i < nAll; i++) { + if (nonEqualColIdx.contains(i)) { + for (int j = i; j < nAll; j++) { + if (nonEqualColIdx.contains(j)) { + dotProductCut[idx] = dotProduct[idxOrigin + j - i]; + idx++; + } + } + } + idxOrigin += (nAll - i); + } + return dotProductCut; + } + + static double[] vectorCut(double[] vec, List nonEqualColIdx) { + int nCut = nonEqualColIdx.size(); + double[] vecCut = new double[nCut]; + int i = 0; + for (int idx : nonEqualColIdx) { + vecCut[i] = vec[idx]; + i++; + } + return vecCut; + } + /** * get correlation matrix * @@ -302,41 +362,23 @@ public void mapPartition(Iterable> splitVec, Collec int nxNe = nonEqualColIdx.size(); int nxAll = nx; if (nxNe != nx) { - double[] countsNe = new double[nxNe]; - double[] sumsNe = new double[nxNe]; - double[] sum2sNe = new double[nxNe]; - double[] dotProductNe = new double[nxNe * (nxNe + 1) / 2]; - - int i = 0; - for (int idx : nonEqualColIdx) { - countsNe[i] = counts[idx]; - sumsNe[i] = sums[idx]; - sum2sNe[i] = sum2s[idx]; - dotProductNe[i] = dotProduct[idx]; - i++; - } - - counts = countsNe; - sums = sumsNe; - sum2s = sum2sNe; - dotProduct = dotProductNe; - + counts = vectorCut(counts, nonEqualColIdx); + sums = vectorCut(sums, nonEqualColIdx); + sum2s = vectorCut(sum2s, nonEqualColIdx); + dotProduct = dotProdctionCut(dotProduct, nonEqualColIdx, nxAll); nx = nxNe; } PcaModelData pcr = new PcaModelData(); //get correlation or covariance matrix - CalculationType pcaTypeEnum = CalculationType.valueOf(pcaType.toUpperCase()); - double[][] corr = null; - switch (pcaTypeEnum) { + switch (pcaType) { case CORR: corr = getCorr(counts, sums, sum2s, dotProduct, nx); break; - case COVAR_POP: - case COV_SAMPLE: + case COV: corr = getCov(counts, sums, dotProduct, nx); break; default: @@ -345,14 +387,6 @@ public void mapPartition(Iterable> splitVec, Collec DenseMatrix calculateMatrix = new DenseMatrix(corr); - if (pcaTypeEnum.equals(CalculationType.COVAR_POP)) { - double cnt = counts[0]; - if (cnt > 1) { - calculateMatrix.scaleEqual(cnt / (cnt - 1)); - } else { - throw new RuntimeException("record num is less than 2!"); - } - } //get mean and stddev pcr.means = new double[nx]; @@ -368,8 +402,7 @@ public void mapPartition(Iterable> splitVec, Collec "k is larger than vector size. k: " + p + " vectorSize: " + calculateMatrix.numCols()); } - //get eig values and eig vectors - scala.Tuple2 eigValueAndVector = EigenSolver.solve(calculateMatrix, p, 10e-8, 300); + scala.Tuple2 eigValueAndVector = solve(calculateMatrix, p); if (eigValueAndVector._1.size() < p) { throw new RuntimeException("Fail to converge when solving eig value problem."); } @@ -390,6 +423,7 @@ public void mapPartition(Iterable> splitVec, Collec buildModel(pcr, nonEqualColIdx, nxAll, model); } + /** * build pca model. * @@ -410,6 +444,10 @@ protected void buildModel(PcaModelData modelData, List nonEqualColIndic } } + public synchronized static scala.Tuple2 solve(DenseMatrix calculateMatrix, int p) { + return EigenSolver.solve(calculateMatrix, p, 10e-8, 300); + } + /** * dense vector or sparse vector to dense vector. @@ -425,4 +463,154 @@ private static DenseVector toDenseVector(Vector vector) { } } + @Override + public PcaModelInfoBatchOp getModelInfoBatchOp() { + return new PcaModelInfoBatchOp(getParams()).linkFrom(this.getSideOutput(0)); + } + + public static class PcaModelInfoBatchOp + extends ExtractModelInfoBatchOp { + + public PcaModelInfoBatchOp() { + this(null); + } + + public PcaModelInfoBatchOp(Params params) { + super(params); + } + + @Override + protected PcaModelInfo createModelInfo(List rows) { + return JsonConverter.fromJson((String) rows.get(0).getField(0), PcaModelInfo.class); + } + + } + + private static class ModelInfoMapPartition implements MapPartitionFunction { + private CalculationType calculationType; + + public ModelInfoMapPartition(CalculationType calculationType) { + this.calculationType = calculationType; + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + List rows = new ArrayList<>(); + values.forEach(k -> rows.add(k)); + + PcaModelData data = new PcaModelDataConverter().load(rows); + PcaModelInfo summary = new PcaModelInfo(); + summary.featureCols = data.featureColNames; + summary.egenValues = data.lambda; + summary.p = data.p; + summary.nx = data.nx; + summary.calculationType = this.calculationType; + summary.egenVectors = data.coef; + + Row outRow = new Row(1); + outRow.setField(0, JsonConverter.toJson(summary)); + out.collect(outRow); + } + } + + public static class PcaModelInfo { + private CalculationType calculationType; + private String[] featureCols; + private double[] egenValues; + private double[][] egenVectors; + private int p; + private int nx; + + public String[] getCols() { + return featureCols; + } + + public double[] getEgenValues() { + return egenValues; + } + + public double[][] getEgenVectors() { + return egenVectors; + } + + public double[] getProportions() { + double[] propertions = new double[p]; + for (int i = 0; i < p; i++) { + propertions[i] = egenValues[i] / nx; + } + return propertions; + } + + public double[] getCumulatives() { + double[] cumulatives = new double[p]; + double sum = 0; + for (int i = 0; i < p; i++) { + double cur = egenValues[i] / nx; + sum += cur; + cumulatives[i] = sum; + } + return cumulatives; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append(PrettyDisplayUtils.displayHeadline("PCA", '-')); + sbd.append("CalculationType: " + calculationType.name() + "\n"); + sbd.append("Number of Principal Component: " + this.p + "\n"); + sbd.append("\n"); + sbd.append("EigenValues: \n"); + String[] colColNames = new String[]{"Prin", "Eigenvalue", "Proportion", "Cumulative"}; + double[] proportions = getProportions(); + double[] cumulatives = getCumulatives(); + + Object[][] vals = new Object[p][4]; + for (int i = 0; i < p; i++) { + vals[i][0] = "Prin" + i; + vals[i][1] = egenValues[i]; + vals[i][2] = proportions[i]; + vals[i][3] = cumulatives[i]; + } + + sbd.append(PrettyDisplayUtils.indentLines(PrettyDisplayUtils.displayTable(vals, p, 4, null, colColNames, null), 4)); + + sbd.append("\n"); + sbd.append("\n"); + + sbd.append("EigenVectors: \n"); + String[] vecColNames = new String[p + 1]; + vecColNames[0] = "colName"; + for (int i = 0; i < p; i++) { + vecColNames[i + 1] = "Prin" + i; + } + + Object[][] vecVals = new Object[nx][p + 1]; + ; + if (featureCols != null) { + for (int j = 0; j < nx; j++) { + vecVals[j][0] = featureCols[j]; + } + } else { + for (int j = 0; j < nx; j++) { + vecVals[j][0] = j; + } + } + for (int i = 0; i < p; i++) { + for (int j = 0; j < nx; j++) { + vecVals[j][i + 1] = egenVectors[i][j]; + } + } + + sbd.append(PrettyDisplayUtils.indentLines( + PrettyDisplayUtils.displayTable(vecVals, + nx, p + 1, null, vecColNames, + null, 100, 100), + 4)); + + + return sbd.toString(); + } + + } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java index 4a7a4418c..ffa590230 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java @@ -1,16 +1,29 @@ package com.alibaba.alink.operator.batch.feature; +import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfoBatchOp; +import com.alibaba.alink.operator.common.feature.ChisqSelectorUtil; import com.alibaba.alink.operator.common.statistics.ChiSquareTestUtil; import com.alibaba.alink.params.feature.VectorChiSqSelectorParams; +import org.apache.flink.api.java.DataSet; import org.apache.flink.ml.api.misc.param.Params; -import org.apache.flink.util.Preconditions; +import org.apache.flink.types.Row; + +/** + * chi-square selector for vector. + */ public final class VectorChiSqSelectorBatchOp extends BatchOperator - implements VectorChiSqSelectorParams { + implements VectorChiSqSelectorParams, + WithModelInfoBatchOp { + + private static final long serialVersionUID = 2668694739982519452L; - public VectorChiSqSelectorBatchOp() { + public VectorChiSqSelectorBatchOp() { super(null); } @@ -31,15 +44,21 @@ public VectorChiSqSelectorBatchOp linkFrom(BatchOperator... inputs) { double fdr = getParams().get(FDR); double fwe = getParams().get(FWE); - setOutputTable(ChiSquareTestUtil.vectorSelector(in, vectorColName, labelColName, - selectorType, numTopFeatures, percentile, fpr, fdr, fwe)); + DataSet chiSquareTest = + ChiSquareTestUtil.vectorTest(in, vectorColName, labelColName); + + DataSet model = chiSquareTest.mapPartition( + new ChisqSelectorUtil.ChiSquareSelector(null, selectorType, numTopFeatures, percentile, fpr, fdr, fwe)) + .name("FilterFeature") + .setParallelism(1); + + setOutputTable(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), model, new ChiSqSelectorModelDataConverter().getModelSchema())); return this; } - public int[] collectResult() { - Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); - return new ChiSqSelectorModelDataConverter().load(this.collect()); + @Override + public ChisqSelectorModelInfoBatchOp getModelInfoBatchOp() { + return new ChisqSelectorModelInfoBatchOp(getParams()).linkFrom(this); } - } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java index 76c1a33c5..2ccf315f7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java @@ -2,6 +2,7 @@ import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.common.statistics.ChiSquareTestResult; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.utils.TableUtil; @@ -12,6 +13,7 @@ import org.apache.flink.util.Preconditions; import java.util.List; +import java.util.function.Consumer; /** * Chi-square test is chi-square independence test. @@ -54,17 +56,70 @@ public ChiSquareTestBatchOp linkFrom(BatchOperator... inputs) { this.setOutputTable(ChiSquareTestUtil.buildResult( ChiSquareTestUtil.test(in, selectedColNames, labelColName), selectedColNames, + null, getMLEnvironmentId())); return this; } - public ChiSquareTestResult[] collectChiSquareTestResult() { + /** + * Collect result. + */ + public ChiSquareTestResult[] collectChiSquareTest() { Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); + return toResult(this.collect()); + } + + /** + * lazy collect result. + */ + @SafeVarargs + public final ChiSquareTestBatchOp lazyCollectChiSquareTest(Consumer... callbacks) { + this.lazyCollect(d -> { + ChiSquareTestResult[] summary = toResult(d); + for (Consumer callback : callbacks) { + callback.accept(summary); + } + }); + return this; + } + + /** + * lazy print. + */ + public final ChiSquareTestBatchOp lazyPrintChiSquareTest() { + return lazyPrintChiSquareTest(null); + } - List rows = this.collect(); + /** + * lazy print with title. + */ + public final ChiSquareTestBatchOp lazyPrintChiSquareTest(String title) { + lazyCollectChiSquareTest(new Consumer() { + @Override + public void accept(ChiSquareTestResult[] summary) { + if (title != null) { + System.out.println(title); + } + + System.out.println(PrettyDisplayUtils.displayHeadline("ChiSquareTest", '-')); + Object[][] data = new Object[summary.length][3]; + for (int i = 0; i < summary.length; i++) { + data[i][0] = summary[i].getP(); + data[i][1] = summary[i].getValue(); + data[i][2] = summary[i].getDf(); + } + String re = PrettyDisplayUtils.displayTable(data, summary.length, 3, + getSelectedCols(), new String[]{"p", "value", "df"}, "col"); + System.out.println(re); + } + }); + return this; + } + + private ChiSquareTestResult[] toResult(List rows) { //get result ChiSquareTestResult[] result = new ChiSquareTestResult[rows.size()]; String[] selectedColNames = getSelectedCols(); @@ -80,5 +135,4 @@ public ChiSquareTestResult[] collectChiSquareTestResult() { return result; } - } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java index edb65e283..feb04b382 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java @@ -9,6 +9,7 @@ import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult; import com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import com.alibaba.alink.params.statistics.CorrelationParams; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -20,6 +21,8 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.function.Consumer; + /** * Calculating the correlation between two series of data is a common operation in Statistics. */ @@ -94,6 +97,37 @@ public CorrelationResult collectCorrelationResult() { Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); return new CorrelationDataConverter().load(this.collect()); } + + @SafeVarargs + public final CorrelationBatchOp lazyCollectCorrelation(Consumer... callbacks) { + this.lazyCollect(d -> { + CorrelationResult correlationResult = new CorrelationDataConverter().load(d); + for (Consumer callback : callbacks) { + callback.accept(correlationResult); + } + }); + return this; + } + + public final CorrelationBatchOp lazyPrintCorrelation() { + return lazyPrintCorrelation(null); + } + + public final CorrelationBatchOp lazyPrintCorrelation(String title) { + lazyCollectCorrelation(new Consumer() { + @Override + public void accept(CorrelationResult summary) { + if (title != null) { + System.out.println(title); + } + + System.out.println(PrettyDisplayUtils.displayHeadline("Correlation", '-')); + System.out.println(summary.toString()); + } + }); + return this; + } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java index 33fc77494..dc1fe53a8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java @@ -2,6 +2,7 @@ import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.common.statistics.ChiSquareTestResult; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.utils.TableUtil; @@ -12,6 +13,7 @@ import org.apache.flink.util.Preconditions; import java.util.List; +import java.util.function.Consumer; /** * Chi-square test is chi-square independence test. @@ -32,7 +34,6 @@ public VectorChiSquareTestBatchOp() { public VectorChiSquareTestBatchOp(Params params) { super(params); } - /** * overwrite linkFrom in BatchOperator * @@ -55,6 +56,7 @@ public VectorChiSquareTestBatchOp linkFrom(BatchOperator... inputs) { this.setOutputTable(ChiSquareTestUtil.buildResult( ChiSquareTestUtil.vectorTest(in, selectedColName, labelColName) , null, + selectedColName, getMLEnvironmentId())); return this; @@ -63,14 +65,56 @@ public VectorChiSquareTestBatchOp linkFrom(BatchOperator... inputs) { /** * @return ChiSquareTestResult[] */ - public ChiSquareTestResult[] collectChiSquareTestResult() { + public ChiSquareTestResult[] collectChiSquareTest() { Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); - List rows = this.collect(); + return toResult(this.collect()); + } + + @SafeVarargs + public final VectorChiSquareTestBatchOp lazyCollectChiSquareTest(Consumer... callbacks) { + this.lazyCollect(d -> { + ChiSquareTestResult[] summary = toResult(d); + for (Consumer callback : callbacks) { + callback.accept(summary); + } + }); + return this; + } + + public final VectorChiSquareTestBatchOp lazyPrintChiSquareTest() { + return lazyPrintChiSquareTest(null); + } + + public final VectorChiSquareTestBatchOp lazyPrintChiSquareTest(String title) { + lazyCollectChiSquareTest(new Consumer() { + @Override + public void accept(ChiSquareTestResult[] summary) { + if (title != null) { + System.out.println(title); + } + + System.out.println(PrettyDisplayUtils.displayHeadline("ChiSquareTest", '-')); + Object[][] data = new Object[summary.length][3]; + String[] colNames = new String[summary.length]; + for (int i = 0; i < summary.length; i++) { + data[i][0] = summary[i].getP(); + data[i][1] = summary[i].getValue(); + data[i][2] = summary[i].getDf(); + colNames[i] = String.valueOf(i); + } + String re = PrettyDisplayUtils.displayTable(data, summary.length, 3, + colNames, new String[]{"p", "value", "df"}, "col"); + System.out.println(re); + } + }); + return this; + } + private ChiSquareTestResult[] toResult(List rows) { ChiSquareTestResult[] result = new ChiSquareTestResult[rows.size()]; for (Row row : rows) { - int id = ((Long) row.getField(0)).intValue(); + int id = Integer.parseInt(String.valueOf(row.getField(0))); result[id] = JsonConverter.fromJson((String) row.getField(1), ChiSquareTestResult.class); } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java index fe60d5247..ba8f17ddd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java @@ -8,6 +8,7 @@ import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult; import com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation; import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; @@ -20,6 +21,8 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.function.Consumer; + /** * Calculating the correlation between two series of data is a common operation in Statistics. */ @@ -80,6 +83,36 @@ public CorrelationResult collectCorrelation() { Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); return new CorrelationDataConverter().load(this.collect()); } + + @SafeVarargs + public final VectorCorrelationBatchOp lazyCollectCorrelation(Consumer... callbacks) { + this.lazyCollect(d -> { + CorrelationResult correlationResult = new CorrelationDataConverter().load(d); + for (Consumer callback : callbacks) { + callback.accept(correlationResult); + } + }); + return this; + } + + public final VectorCorrelationBatchOp lazyPrintCorrelation() { + return lazyPrintCorrelation(null); + } + + public final VectorCorrelationBatchOp lazyPrintCorrelation(String title) { + lazyCollectCorrelation(new Consumer() { + @Override + public void accept(CorrelationResult summary) { + if (title != null) { + System.out.println(title); + } + System.out.println(PrettyDisplayUtils.displayHeadline("Correlation", '-')); + System.out.println(summary.toString()); + } + }); + return this; + } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java index 3967129e1..94248014f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.operator.common.statistics.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.VectorSummaryDataConverter; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import com.alibaba.alink.params.statistics.VectorSummarizerParams; import org.apache.flink.api.common.functions.FlatMapFunction; @@ -14,6 +15,8 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.function.Consumer; + /** * It is summary of table, support count, mean, variance, min, max, sum. */ @@ -43,6 +46,42 @@ public VectorSummarizerBatchOp linkFrom(BatchOperator... inputs) { return this; } + + + public BaseVectorSummary collectVectorSummary() { + Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); + return new VectorSummaryDataConverter().load(this.collect()); + } + + @SafeVarargs + public final VectorSummarizerBatchOp lazyCollectVectorSummary(Consumer... callbacks) { + this.lazyCollect(d -> { + BaseVectorSummary summary = new VectorSummaryDataConverter().load(d); + for (Consumer callback : callbacks) { + callback.accept(summary); + } + }); + return this; + } + + public final VectorSummarizerBatchOp lazyPrintVectorSummary() { + return lazyPrintVectorSummary(null); + } + + public final VectorSummarizerBatchOp lazyPrintVectorSummary(String title) { + lazyCollectVectorSummary(new Consumer() { + @Override + public void accept(BaseVectorSummary summary) { + if (title != null) { + System.out.println(title); + } + System.out.println(PrettyDisplayUtils.displayHeadline("Summary", '-')); + System.out.println(summary.toString()); + } + }); + return this; + } + /** * vector summary build model. */ @@ -60,9 +99,5 @@ public void flatMap(BaseVectorSummary srt, Collector collector) throws Exce } } - public BaseVectorSummary collectVectorSummary() { - Preconditions.checkArgument(null != this.getOutputTable(), "Please link from or link to."); - return new VectorSummaryDataConverter().load(this.collect()); - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ChiSqSelectorModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChiSqSelectorModelDataConverter.java index bc94ccf77..53704f7c2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/ChiSqSelectorModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChiSqSelectorModelDataConverter.java @@ -1,15 +1,17 @@ package com.alibaba.alink.operator.common.feature; +import com.alibaba.alink.common.model.SimpleModelDataConverter; import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.model.SimpleModelDataConverter; + import java.util.Collections; /** * ChiSqSelector model. */ -public class ChiSqSelectorModelDataConverter extends SimpleModelDataConverter { +public class ChiSqSelectorModelDataConverter extends SimpleModelDataConverter { public ChiSqSelectorModelDataConverter() { } @@ -17,23 +19,22 @@ public ChiSqSelectorModelDataConverter() { /** * Serialize the model to "Tuple2>" * - * @param modelData: selected col indices + * @param modelInfo: selected col indices */ @Override - public Tuple2> serializeModel(int[] modelData) { - return Tuple2.of(new Params(), Collections.singletonList(JsonConverter.toJson(modelData))); + public Tuple2> serializeModel(ChisqSelectorModelInfo modelInfo) { + return Tuple2.of(new Params(), Collections.singletonList(JsonConverter.toJson(modelInfo))); } /** - * - * @param meta The model meta data. + * @param meta The model meta data. * @param modelData: json * @return */ @Override - public int[] deserializeModel(Params meta, Iterable modelData) { + public ChisqSelectorModelInfo deserializeModel(Params meta, Iterable modelData) { String json = modelData.iterator().next(); - return JsonConverter.fromJson(json, int[].class); + return JsonConverter.fromJson(json, ChisqSelectorModelInfo.class); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfo.java new file mode 100644 index 000000000..8e0eb4aa8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfo.java @@ -0,0 +1,149 @@ +package com.alibaba.alink.operator.common.feature; + +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.statistics.ChiSquareTestResult; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; +import org.apache.flink.types.Row; + +import java.util.Arrays; +import java.util.List; + +/** + * Chisq selector model info. + */ +public class ChisqSelectorModelInfo { + protected ChiSquareTestResult[] chiSqs; + protected String[] colNames; + protected String[] siftOutColNames; + protected BasedChisqSelectorParams.SelectorType selectorType; + protected int numTopFeatures; + protected double percentile; + protected double fpr; + protected double fdr; + protected double fwe; + + public ChisqSelectorModelInfo() { + + } + + public ChisqSelectorModelInfo(List rows) { + ChisqSelectorModelInfo modelInfo = new ChiSqSelectorModelDataConverter().load(rows); + this.chiSqs = modelInfo.chiSqs; + this.colNames = modelInfo.colNames; + this.siftOutColNames = modelInfo.siftOutColNames; + this.selectorType = modelInfo.selectorType; + this.numTopFeatures = modelInfo.numTopFeatures; + this.percentile = modelInfo.percentile; + this.fpr = modelInfo.fpr; + this.fdr = modelInfo.fdr; + this.fwe = modelInfo.fwe; + } + + public double chisq(String colName) { + return chiSqs[getIdx(chiSqs, colName)].getValue(); + } + + public double pValue(String colName) { + return chiSqs[getIdx(chiSqs, colName)].getP(); + } + + public BasedChisqSelectorParams.SelectorType getSelectorType() { + return selectorType; + } + + public int getNumTopFeatures() { + return numTopFeatures; + } + + public double getPercentile() { + return percentile; + } + + public double getFpr() { + return fpr; + } + + public double getFdr() { + return fdr; + } + + public double getFwe() { + return fwe; + } + + public double getSelectorNum() { + return this.siftOutColNames.length; + } + + public String[] getColNames() { + return this.colNames; + } + + public String[] getSiftOutColNames() { + return this.siftOutColNames; + } + + + @Override + public String toString() { + int n = this.chiSqs.length; + StringBuilder sbd = new StringBuilder() + .append(PrettyDisplayUtils.displayHeadline("ChisqSelectorModelInfo", '-')); + + sbd.append("Number of Selector Features: " + getSelectorNum() + "\n"); + sbd.append("Type of Selector: " + this.selectorType.name() + "\n"); + + switch (this.selectorType) { + case NumTopFeatures: + sbd.append("Number of Top Features: " + this.numTopFeatures + "\n"); + break; + case PERCENTILE: + sbd.append("Percentile of Features: " + this.percentile + "\n"); + break; + case FDR: + sbd.append("FDR of Features: " + this.fdr + "\n"); + break; + case FPR: + sbd.append("FPR of Features: " + this.fpr + "\n"); + break; + case FWE: + sbd.append("FWE of Features: " + this.fwe + "\n"); + break; + + } + String[] colcolNames = new String[]{"ColName", "ChiSquare", "PValue", "DF", "Selected"}; + Object[][] vals = new Object[n][5]; + + if (colNames == null) { + colcolNames[0] = "VectorIndex"; + } + List chisqList = Arrays.asList(chiSqs); + chisqList.sort(new ChisqSelectorUtil.RowAscComparator(false, true)); + for (int i = 0; i < n && i < chiSqs.length; i++) { + ChiSquareTestResult chisq = chisqList.get(i); + vals[i][0] = chisq.getColName(); + vals[i][1] = chisq.getValue(); + vals[i][2] = chisq.getP(); + vals[i][3] = chisq.getDf(); + vals[i][4] = i < getSelectorNum(); + } + + sbd.append("Selector Indices: " + "\n"); + sbd.append(PrettyDisplayUtils.displayTable(vals, n, 5, null, colcolNames, null)); + + return sbd.toString(); + } + + static int getIdx(ChiSquareTestResult[] test, String colName) { + for (int i = 0; i < test.length; i++) { + if (colName.equals(test[i].getColName())) { + return i; + } + } + return -1; + } + + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfoBatchOp.java new file mode 100644 index 000000000..dc21c4765 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorModelInfoBatchOp.java @@ -0,0 +1,27 @@ +package com.alibaba.alink.operator.common.feature; + +import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import java.util.List; + +/** + * Chisq selector model info. + */ +public class ChisqSelectorModelInfoBatchOp + extends ExtractModelInfoBatchOp { + + public ChisqSelectorModelInfoBatchOp() { + this(null); + } + + public ChisqSelectorModelInfoBatchOp(Params params) { + super(params); + } + + @Override + protected ChisqSelectorModelInfo createModelInfo(List rows) { + return new ChisqSelectorModelInfo(rows); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorUtil.java new file mode 100644 index 000000000..51c24c0dc --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ChisqSelectorUtil.java @@ -0,0 +1,206 @@ +package com.alibaba.alink.operator.common.feature; + +import com.alibaba.alink.operator.common.statistics.ChiSquareTestResult; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; +import com.google.common.primitives.Ints; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** + * Chisq selector util. + */ +public class ChisqSelectorUtil { + + /** + * chi-square selector for table data. + * + * @param chiSquareTest: first entry is colIdx, second entry is chi-square test. + * @param selectorType: "NumTopFeatures", "percentile", "fpr", "fdr", "fwe" + * @param numTopFeatures: if selectorType is NumTopFeatures, select the largest NumTopFeatures features. + * @param percentile: if selectorType is percentile, select the largest percentile * numFeatures features. + * @param fpr: if selectorType is fpr, select feature which chi-square value less than fpr. + * @param fdr: if selectorType is fdr, select feature which chi-square value less than fdr * (i + 1) / n. + * @param fwe: if selectorType is fwe, select feature which chi-square value less than fwe / n. + * @return selected col indices. + */ + public static int[] selector(List chiSquareTest, + BasedChisqSelectorParams.SelectorType selectorType, + int numTopFeatures, + double percentile, + double fpr, + double fdr, + double fwe) { + + + int len = chiSquareTest.size(); + + List selectedColIndices = new ArrayList<>(); + switch (selectorType) { + case NumTopFeatures: + chiSquareTest.sort(new RowAscComparator(false, true)); + + for (int i = 0; i < numTopFeatures && i < len; i++) { + selectedColIndices.add(getIdx(chiSquareTest.get(i))); + } + + break; + case PERCENTILE: + chiSquareTest.sort(new RowAscComparator(false, true)); + int size = (int) (len * percentile); + if (size == 0) { + size = 1; + } + for (int i = 0; i < size && i < len; i++) { + selectedColIndices.add(getIdx(chiSquareTest.get(i))); + } + break; + case FPR: + for (ChiSquareTestResult row : chiSquareTest) { + if (row.getValue() < fpr) { + selectedColIndices.add(getIdx(row)); + } + } + break; + case FDR: + chiSquareTest.sort(new RowAscComparator(false, true)); + int maxIdx = 0; + for (int i = 0; i < len; i++) { + ChiSquareTestResult row = chiSquareTest.get(i); + if (row.getValue() <= fdr * (i + 1) / len) { + maxIdx = i; + } + } + + for (int i = 0; i <= maxIdx; i++) { + selectedColIndices.add(getIdx(chiSquareTest.get(i))); + } + Collections.sort(selectedColIndices); + break; + case FWE: + for (ChiSquareTestResult row : chiSquareTest) { + if (row.getValue() <= fwe / len) { + selectedColIndices.add(getIdx(row)); + } + } + break; + } + + return Ints.toArray(selectedColIndices); + } + + + /** + * chi-square selector and build model. + */ + public static class ChiSquareSelector implements MapPartitionFunction { + private static final long serialVersionUID = -482962272562482883L; + private String[] cols; + private BasedChisqSelectorParams.SelectorType selectorType; + private int numTopFeatures; + private double percentile; + private double fpr; + private double fdr; + private double fwe; + + public ChiSquareSelector(String[] cols, + BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, + double percentile, double fpr, + double fdr, double fwe) { + this.cols = cols; + this.selectorType = selectorType; + this.numTopFeatures = numTopFeatures; + this.percentile = percentile; + this.fpr = fpr; + this.fdr = fdr; + this.fwe = fwe; + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) { + List chiSquareTest = new ArrayList<>(); + for (Row row : iterable) { + //f0: id, f1:p, f2: chisq, f3: df + chiSquareTest.add(new ChiSquareTestResult( + (double) row.getField(3), + (double) row.getField(1), + (double) row.getField(2), + row.getField(0).toString())); + } + + int[] selectedIndices = selector(chiSquareTest, + selectorType, + numTopFeatures, + percentile, + fpr, + fdr, + fwe); + + ChisqSelectorModelInfo modelInfo = new ChisqSelectorModelInfo(); + modelInfo.chiSqs = chiSquareTest.toArray(new ChiSquareTestResult[0]); + modelInfo.colNames = cols; + modelInfo.fwe = fwe; + modelInfo.fdr = fdr; + modelInfo.fpr = fpr; + modelInfo.percentile = percentile; + modelInfo.numTopFeatures = numTopFeatures; + modelInfo.selectorType = selectorType; + + modelInfo.siftOutColNames = new String[selectedIndices.length]; + + for (int i = 0; i < selectedIndices.length; i++) { + modelInfo.siftOutColNames[i] = + cols == null ? String.valueOf(selectedIndices[i]) : cols[selectedIndices[i]]; + } + if (cols != null) { + for (int i = 0; i < modelInfo.chiSqs.length; i++) { + modelInfo.chiSqs[i].setColName(cols[getIdx(modelInfo.chiSqs[i])]); + } + } + + new ChiSqSelectorModelDataConverter().save(modelInfo, collector); + } + } + + /** + * row asc comparator. + */ + static class RowAscComparator implements Comparator { + private boolean isChisq; + private boolean isDes; + + public RowAscComparator(boolean isChisq, boolean isDes) { + this.isChisq = isChisq; + this.isDes = isDes; + } + + @Override + public int compare(ChiSquareTestResult o1, ChiSquareTestResult o2) { + double d1; + double d2; + if (isChisq) { + d1 = o1.getValue(); + d2 = o2.getValue(); + } else { + d1 = o1.getP(); + d2 = o2.getP(); + } + + return isDes ? Double.compare(d1, d2) : -Double.compare(d1, d2); + } + } + + /** + * find index. + */ + static int getIdx(ChiSquareTestResult test) { + return (int) Math.round(Double.parseDouble(test.getColName())); + } + + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelData.java index 402060c6a..1d9a877dd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelData.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelData.java @@ -1,5 +1,7 @@ package com.alibaba.alink.operator.common.feature.pca; +import com.alibaba.alink.params.feature.HasCalculationType; + public class PcaModelData { /** @@ -15,7 +17,7 @@ public class PcaModelData { /** * pca type */ - public String pcaType; + public HasCalculationType.CalculationType pcaType; /** * name of calculate cols @@ -70,13 +72,11 @@ public class PcaModelData { * @param vec data * @return principal component */ - public double[] calcPrinValue(double[] vec) { + double[] calcPrinValue(double[] vec) { int nx = vec.length; double[] v = new double[nx]; double[] r = new double[p]; - for (int i = 0; i < nx; i++) { - v[i] = vec[i]; - } + System.arraycopy(vec, 0, v, 0, nx); for (int k = 0; k < p; k++) { r[k] = 0; for (int i = 0; i < nx; i++) { @@ -90,24 +90,30 @@ public double[] calcPrinValue(double[] vec) { public String toString() { java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); - int nx = nameX.length; - pw.println("Eigenvalues of the CorrelationBak : "); - pw.println(" \tEigenvalue \tProportion \tCumulative"); + int nx = featureColNames.length; + pw.println("Eigenvalues of the Correlation : "); + pw.println(" \tEigenvalue \tProportion \tCumulative"); double sum = 0; for (int i = 0; i < p; i++) { double cur = lambda[i] / nx; sum += cur; - pw.println("Prin" + (i + 1) + " \t" + lambda[i] + " \t" + cur + " \t" + sum); + pw.println("Prin" + (i + 1) + " \t" + trim(lambda[i]) + " \t" + trim(cur) + " \t" + trim(sum)); } + pw.println("Principle Components : "); for (int i = 0; i < p; i++) { - pw.print("Prin" + (i + 1) + " = " + coef[i][0] + " * " + nameX[0]); + pw.print("Prin" + (i + 1) + " = " + coef[i][0] + " * " + featureColNames[0]); for (int j = 1; j < nx; j++) { - pw.print(" + " + coef[i][j] + " * " + nameX[j]); + pw.print(" + " + coef[i][j] + " * " + featureColNames[j]); } pw.println(); } return cw.toString(); } + private String trim(double val) { + return String.format("%.8f", val); + } + + } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java index dc07020ae..9ea038d48 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java @@ -1,6 +1,7 @@ package com.alibaba.alink.operator.common.feature.pca; import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.mapper.ModelMapper; @@ -21,25 +22,22 @@ */ public class PcaModelMapper extends ModelMapper { + private static final long serialVersionUID = -6656670267982283314L; private PcaModelData model = null; private int[] featureIdxs = null; private boolean isVector; - private PcaPredictParams.TransformType transformType = null; - private String pcaType = null; + private HasCalculationType.CalculationType pcaType = null; private double[] sourceMean = null; private double[] sourceStd = null; - private double[] scoreStd = null; private OutputColsHelper outputColsHelper; public PcaModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); - transformType = this.params.get(PcaPredictParams.TRANSFORM_TYPE); - String[] keepColNames = this.params.get(PcaPredictParams.RESERVED_COLS); String predResultColName = this.params.get(PcaPredictParams.PREDICTION_COL); this.outputColsHelper = new OutputColsHelper(dataSchema, predResultColName, Types.STRING(), keepColNames); @@ -76,42 +74,15 @@ public void loadModel(List modelRows) { this.featureIdxs = checkGetColIndices(isVector, featureColNames, vectorColName); this.pcaType = model.pcaType; int nx = model.means.length; - int p = model.p; - - HasCalculationType.CalculationType pcaTypeEnum = HasCalculationType.CalculationType.valueOf(this.pcaType.toUpperCase()); //transform mean, stdDevs and scoreStd sourceMean = new double[nx]; sourceStd = new double[nx]; - scoreStd = new double[nx]; - Arrays.fill(sourceStd, 1); - Arrays.fill(scoreStd, 1); - if (HasCalculationType.CalculationType.CORR.equals(pcaTypeEnum)) { + if (HasCalculationType.CalculationType.CORR == this.pcaType) { sourceStd = model.stddevs; - } - - switch (transformType) { - case SUBMEAN: - sourceMean = model.means; - break; - case NORMALIZATION: - sourceMean = model.means; - for (int i = 0; i < p; i++) { - double tmp = 0; - for (int j = 0; j < nx; j++) { - for (int k = 0; k < nx; k++) { - tmp += model.coef[i][j] * model.coef[i][k] * model.cov[j][k]; - } - } - scoreStd[i] = Math.sqrt(tmp); - } - break; - case SIMPLE: - break; - default: - throw new IllegalArgumentException("Error transformType: " + transformType); + sourceMean = model.means; } } @@ -126,6 +97,11 @@ public Row map(Row in) throws Exception { double[] data = new double[this.model.nx]; if (isVector) { Vector parsed = VectorUtil.getVector(in.getField(featureIdxs[0])); + if (parsed instanceof SparseVector) { + if (parsed.size() < 0) { + ((SparseVector) parsed).setSize(model.nx); + } + } for (int i = 0; i < parsed.size(); i++) { data[i] = parsed.get(i); } @@ -154,12 +130,7 @@ public Row map(Row in) throws Exception { } predictData = model.calcPrinValue(data); } - for (int i = 0; i < predictData.length; i++) { - predictData[i] /= this.scoreStd[i]; - } return outputColsHelper.getResultRow(in, Row.of(VectorUtil.toString(new DenseVector(predictData)))); } - - } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java index 03525dc37..90fd2c3b5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java @@ -1,9 +1,10 @@ package com.alibaba.alink.operator.common.statistics; +import com.alibaba.alink.common.utils.DataSetConversionUtil; +import org.apache.commons.math3.distribution.ChiSquaredDistribution; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; @@ -13,17 +14,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; -import com.alibaba.alink.params.feature.BasedChisqSelectorParams; -import com.google.common.primitives.Ints; -import org.apache.commons.math3.distribution.ChiSquaredDistribution; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; -import java.util.List; import java.util.Map; /** @@ -32,15 +23,16 @@ public class ChiSquareTest { /** - * @param in: the last col is label col, others are selectedCols. + * @param in: the last col is label col, others are selectedCols. * @param sessionId: sessionId * @return 3 cols, 1th is colId, 2th is pValue, 3th is chi-square value - * */ protected static DataSet test(DataSet in, Long sessionId) { //flatting data to triple. DataSet dataSet = in .flatMap(new FlatMapFunction() { + private static final long serialVersionUID = -5007568317570417558L; + @Override public void flatMap(Row row, Collector result) { int n = row.getArity() - 1; @@ -68,13 +60,15 @@ public void flatMap(Row row, Collector result) { .select("col,feature,label,count(1) as count2")) .groupBy("col").reduceGroup( new GroupReduceFunction>() { + private static final long serialVersionUID = 3320220768468472007L; + @Override public void reduce(Iterable iterable, Collector> collector) { Map, Long> map = new HashMap<>(); int colIdx = -1; for (Row row : iterable) { map.put(Tuple2.of(row.getField(1).toString(), - row.getField(2).toString()), + row.getField(2).toString()), (long) row.getField(3)); colIdx = (Integer) row.getField(0); } @@ -84,90 +78,12 @@ public void reduce(Iterable iterable, Collector> .map(new ChiSquareTestFromCrossTable()); } - /** - * chi-square selector for table data. - * @param chiSquareTest: first entry is colIdx, second entry is chi-square test. - * @param selectorType: "numTopFeatures", "percentile", "fpr", "fdr", "fwe" - * @param numTopFeatures: if selectorType is numTopFeatures, select the largest numTopFeatures features. - * @param percentile: if selectorType is percentile, select the largest percentile * numFeatures features. - * @param fpr: if selectorType is fpr, select feature which chi-square value less than fpr. - * @param fdr: if selectorType is fdr, select feature which chi-square value less than fdr * (i + 1) / n. - * @param fwe: if selectorType is fwe, select feature which chi-square value less than fwe / n. - * @return selected col indices. - */ - protected static int[] selector(List chiSquareTest, - BasedChisqSelectorParams.SelectorType selectorType, - int numTopFeatures, - double percentile, - double fpr, - double fdr, - double fwe) { - - - int len = chiSquareTest.size(); - - List selectedColIndices = new ArrayList<>(); - switch (selectorType) { - case NumTopFeatures: - chiSquareTest.sort(new RowAscComparator()); - - for (int i = 0; i < numTopFeatures && i < len; i++) { - selectedColIndices.add((int) chiSquareTest.get(i).getField(0)); - } - - break; - case PERCENTILE: - chiSquareTest.sort(new RowAscComparator()); - int size = (int) (len * percentile); - if (size == 0) { - size = 1; - } - for (int i = 0; i < size && i < len; i++) { - selectedColIndices.add((int) chiSquareTest.get(i).getField(0)); - } - break; - case FPR: - for (Row row : chiSquareTest) { - if ((double) row.getField(1) < fpr) { - selectedColIndices.add((int) row.getField(0)); - } - } - break; - case FDR: - chiSquareTest.sort(new RowAscComparator()); - int maxIdx = 0; - for (int i = 0; i < len; i++) { - Row row = chiSquareTest.get(i); - if ((double) row.getField(1) <= fdr * (i + 1) / len) { - maxIdx = i; - } - } - - for (int i = 0; i <= maxIdx; i++) { - selectedColIndices.add((int) chiSquareTest.get(i).getField(0)); - } - Collections.sort(selectedColIndices); - break; - case FWE: - for (Row row : chiSquareTest) { - if ((double) row.getField(1) <= fwe / len) { - selectedColIndices.add((int) row.getField(0)); - } - } - break; - default: - throw new RuntimeException("Selector Type not support. " + selectorType); - } - - return Ints.toArray(selectedColIndices); - } - /** * @param crossTabWithId: f0 is id, f1 is cross table * @return tuple4: f0 is id which is id of cross table, f1 is pValue, f2 is chi-square Value, f3 is df */ - protected static Tuple4 test(Tuple2 crossTabWithId) { + public static Tuple4 test(Tuple2 crossTabWithId) { int colIdx = crossTabWithId.f0; Crosstab crosstab = crossTabWithId.f1; @@ -200,67 +116,15 @@ protected static Tuple4 test(Tuple2 { - private BasedChisqSelectorParams.SelectorType selectorType; - private int numTopFeatures; - private double percentile; - private double fpr; - private double fdr; - private double fwe; - - ChiSquareSelector(BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, - double percentile, double fpr, - double fdr, double fwe) { - this.selectorType = selectorType; - this.numTopFeatures = numTopFeatures; - this.percentile = percentile; - this.fpr = fpr; - this.fdr = fdr; - this.fwe = fwe; - } - - @Override - public void mapPartition(Iterable iterable, Collector collector) { - List chiSquareTest = new ArrayList<>(); - for (Row row : iterable) { - chiSquareTest.add(row); - } - - int[] selectedIndices = selector(chiSquareTest, - selectorType, - numTopFeatures, - percentile, - fpr, - fdr, - fwe); - - new ChiSqSelectorModelDataConverter().save(selectedIndices, collector); - } - } - - /** - * row asc comparator. - */ - static class RowAscComparator implements Comparator { - @Override - public int compare(Row o1, Row o2) { - double d1 = (double) o1.getField(1); - double d2 = (double) o2.getField(1); - - return Double.compare(d1, d2); - } + return Tuple4.of(colIdx, p, chiSq, (double) (rowLen - 1) * (colLen - 1)); } /** * calculate chi-square test value from cross table. */ - public static class ChiSquareTestFromCrossTable implements MapFunction, Row> { + static class ChiSquareTestFromCrossTable implements MapFunction, Row> { + + private static final long serialVersionUID = 4588157669356711825L; ChiSquareTestFromCrossTable() { } @@ -276,7 +140,7 @@ public static class ChiSquareTestFromCrossTable implements MapFunction crossTabWithId) throws Exception { - Tuple4 tuple4 = test(crossTabWithId); + Tuple4 tuple4 = ChiSquareTest.test(crossTabWithId); Row row = new Row(4); row.setField(0, tuple4.f0); @@ -287,4 +151,6 @@ public Row map(Tuple2 crossTabWithId) throws Exception { return row; } } + + } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestResult.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestResult.java index 37064d5f5..5c5ec6312 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestResult.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestResult.java @@ -1,13 +1,16 @@ package com.alibaba.alink.operator.common.statistics; +import java.io.Serializable; + /** * chi-square test result. */ -public class ChiSquareTestResult { +public class ChiSquareTestResult implements Serializable { + private static final long serialVersionUID = 7324787640414737723L; /** - * comment: pearson chi-square independence test + * col name. */ - private String comment; + private String colName; /** * freedom @@ -22,29 +25,29 @@ public class ChiSquareTestResult { */ private double value; - public ChiSquareTestResult() { } + /** * @param df: degree freedom * @param p: p value * @param value: chi-square test value - * @param comment: comment + * @param colName: colName */ public ChiSquareTestResult(double df, double p, double value, - String comment) { + String colName) { this.df = df; this.p = p; this.value = value; - this.comment = comment; + this.colName = colName; } - public String getComment() { - return comment; + public String getColName() { + return colName; } public double getDf() { @@ -58,4 +61,8 @@ public double getP() { public double getValue() { return value; } + + public void setColName(String colName) { + this.colName = colName; + } } \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java index 701aa252c..9e97da9c9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java @@ -1,10 +1,8 @@ package com.alibaba.alink.operator.common.statistics; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; -import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.params.feature.BasedChisqSelectorParams; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.batch.BatchOperator; import org.apache.commons.lang3.ArrayUtils; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -20,9 +18,10 @@ public class ChiSquareTestUtil { /** * chi-square test for vector data. - * @param in: input + * + * @param in: input * @param vectorColName: vector col name - * @param labelColName: label col name + * @param labelColName: label col name * @return chi-square test result */ public static DataSet vectorTest(BatchOperator in, @@ -37,9 +36,10 @@ public static DataSet vectorTest(BatchOperator in, /** * chi-square test for table data. - * @param in: input + * + * @param in: input * @param selectedColNames: selected col names - * @param labelColName: label col name + * @param labelColName: label col name * @return chi-square test result */ public static DataSet test(BatchOperator in, @@ -49,97 +49,28 @@ public static DataSet test(BatchOperator in, return ChiSquareTest.test(in.getDataSet(), in.getMLEnvironmentId()); } - /** - * chi-square selector for table data. - * @param in: input - * @param selectedColNames: selected col names - * @param labelColName: label col name - * @param selectorType: "numTopFeatures", "percentile", "fpr", "fdr", "fwe" - * @param numTopFeatures: if selectorType is numTopFeatures, select the largest numTopFeatures features. - * @param percentile: if selectorType is percentile, select the largest percentile * numFeatures features. - * @param fpr: if selectorType is fpr, select feature which chi-square value less than fpr. - * @param fdr: if selectorType is fdr, select feature which chi-square value less than fdr * (i + 1) / n. - * @param fwe: if selectorType is fwe, select feature which chi-square value less than fwe / n. - * @return selected col indices. - */ - public static Table selector(BatchOperator in, - String[] selectedColNames, - String labelColName, - BasedChisqSelectorParams.SelectorType selectorType, - int numTopFeatures, - double percentile, - double fpr, - double fdr, - double fwe) { - DataSet chiSquareTest = - ChiSquareTestUtil.test(in, selectedColNames, labelColName); - - DataSet model = chiSquareTest.mapPartition( - new ChiSquareTest.ChiSquareSelector(selectorType, numTopFeatures, percentile, fpr, fdr, fwe)) - .name("FilterFeature") - .setParallelism(1); - - return DataSetConversionUtil.toTable(in.getMLEnvironmentId(), model, new ChiSqSelectorModelDataConverter().getModelSchema()); - } - - /** - * chi-square selector for vector data. - * @param in: input - * @param selectedColName: selected vector name - * @param labelColName: label col name - * @param selectorType: "numTopFeatures", "percentile", "fpr", "fdr", "fwe" - * @param numTopFeatures: if selectorType is numTopFeatures, select the largest numTopFeatures features. - * @param percentile: if selectorType is percentile, select the largest percentile * numFeatures features. - * @param fpr: if selectorType is fpr, select feature which chi-square value less than fpr. - * @param fdr: if selectorType is fdr, select feature which chi-square value less than fdr * (i + 1) / n. - * @param fwe: if selectorType is fwe, select feature which chi-square value less than fwe / n. - * @return selected col indices. - */ - public static Table vectorSelector(BatchOperator in, - String selectedColName, - String labelColName, - BasedChisqSelectorParams.SelectorType selectorType, - int numTopFeatures, - double percentile, - double fpr, - double fdr, - double fwe) { - DataSet chiSquareTest = - ChiSquareTestUtil.vectorTest(in, selectedColName, labelColName); - - DataSet model = chiSquareTest.mapPartition( - new ChiSquareTest.ChiSquareSelector(selectorType, numTopFeatures, percentile, fpr, fdr, fwe)) - .name("FilterFeature") - .setParallelism(1); - - return DataSetConversionUtil.toTable(in.getMLEnvironmentId(), model, new ChiSqSelectorModelDataConverter().getModelSchema()); - } - - /** * build chi-square test result, it is for table and vector. */ public static Table buildResult(DataSet in, - String[] selectedColNames, + String[] selectedCols, + String vectorCol, Long sessionId) { String[] outColNames = new String[]{"col", "chisquare_test"}; - TypeInformation[] outColTypes; - if (selectedColNames == null) { - outColTypes = new TypeInformation[]{Types.LONG, Types.STRING}; - } else { - outColTypes = new TypeInformation[]{Types.STRING, Types.STRING}; - } + TypeInformation[] outColTypes = new TypeInformation[]{Types.STRING, Types.STRING}; return DataSetConversionUtil.toTable(sessionId, - in.map(new BuildResult(selectedColNames)), + in.map(new BuildResult(selectedCols)), outColNames, outColTypes); } + /** * chi-square test build result. */ private static class BuildResult implements MapFunction { + private static final long serialVersionUID = 3043216661405231563L; private String[] selectedColNames; BuildResult(String[] selectedColNames) { @@ -148,22 +79,22 @@ private static class BuildResult implements MapFunction { @Override public Row map(Row row) throws Exception { + int id = (Integer) row.getField(0); double p = (double) row.getField(1); double value = (double) row.getField(2); double df = (double) row.getField(3); - ChiSquareTestResult ctr = new ChiSquareTestResult(df, p, value, "chi-square test"); + String colName = selectedColNames != null ? selectedColNames[id] : String.valueOf(id); + + ChiSquareTestResult ctr = new ChiSquareTestResult(df, p, value, colName); Row out = new Row(2); - int id = (Integer) row.getField(0); - if (selectedColNames != null) { - out.setField(0, selectedColNames[id]); - } else { - out.setField(0, (long) id); - } + out.setField(0, colName); + out.setField(1, JsonConverter.toJson(ctr)); return out; } } + } diff --git a/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java b/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java index 36c602cc0..4dc6251ae 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java @@ -1,16 +1,18 @@ package com.alibaba.alink.params.feature; +import com.alibaba.alink.params.ParamUtil; import org.apache.flink.ml.api.misc.param.ParamInfo; import org.apache.flink.ml.api.misc.param.ParamInfoFactory; import org.apache.flink.ml.api.misc.param.WithParams; -import com.alibaba.alink.params.ParamUtil; - public interface HasCalculationType extends WithParams { - + /** + * @cn-name 计算类型 + * @cn 计算类型,包含"CORR", "COV"两种。 + */ ParamInfo CALCULATION_TYPE = ParamInfoFactory .createParamInfo("calculationType", CalculationType.class) - .setDescription("compute type, be CORR, COV_SAMPLE, COVAR_POP.") + .setDescription("compute type, be CORR, COV.") .setHasDefaultValue(CalculationType.CORR) .setAlias(new String[]{"calcType", "pcaType"}) .build(); @@ -32,17 +34,12 @@ default T setCalculationType(String value) { */ enum CalculationType { /** - * correlation + * Correlation */ CORR, /** - * sample variance - */ - COV_SAMPLE, - - /** - * population variance + * Covariance */ - COVAR_POP + COV } } diff --git a/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java b/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java index c4baa1084..0c4c8a51d 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java @@ -1,9 +1,5 @@ package com.alibaba.alink.params.feature; -import org.apache.flink.ml.api.misc.param.ParamInfo; -import org.apache.flink.ml.api.misc.param.ParamInfoFactory; - -import com.alibaba.alink.params.ParamUtil; import com.alibaba.alink.params.shared.colname.HasPredictionCol; import com.alibaba.alink.params.shared.colname.HasReservedCols; import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull; @@ -15,42 +11,4 @@ public interface PcaPredictParams extends HasReservedCols, HasPredictionCol, HasVectorColDefaultAsNull { - - ParamInfo TRANSFORM_TYPE = ParamInfoFactory - .createParamInfo("transformType", TransformType.class) - .setDescription("'SIMPLE' or 'SUBMEAN', SIMPLE is data * model, SUBMEAN is (data - mean) * model") - .setHasDefaultValue(TransformType.SIMPLE) - .build(); - - default TransformType getTransformType() { - return get(TRANSFORM_TYPE); - } - - default T setTransformType(TransformType value) { - return set(TRANSFORM_TYPE, value); - } - - default T setTransformType(String value) { - return set(TRANSFORM_TYPE, ParamUtil.searchEnum(TRANSFORM_TYPE, value)); - } - - /** - * pca transform type. - */ - enum TransformType { - /** - * data * model - */ - SIMPLE, - - /** - * (data - mean) * model - */ - SUBMEAN, - - /** - * (data - mean) / stdVar * model - */ - NORMALIZATION - } } diff --git a/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java b/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java index 04ec6cc70..07cff613b 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java @@ -1,8 +1,5 @@ package com.alibaba.alink.params.feature; -import com.alibaba.alink.params.dataproc.HasWithMean; -import com.alibaba.alink.params.dataproc.HasWithStd; - import com.alibaba.alink.params.shared.colname.HasSelectedColsDefaultAsNull; import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull; @@ -10,10 +7,8 @@ * Trait for parameter PcaTrain. */ public interface PcaTrainParams extends - HasSelectedColsDefaultAsNull, - HasVectorColDefaultAsNull, - HasWithMean, - HasWithStd, - HasK, - HasCalculationType { -} + HasSelectedColsDefaultAsNull, + HasVectorColDefaultAsNull, + HasK, + HasCalculationType { +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/pipeline/feature/PCA.java b/core/src/main/java/com/alibaba/alink/pipeline/feature/PCA.java index aab35d424..0ca5b5525 100644 --- a/core/src/main/java/com/alibaba/alink/pipeline/feature/PCA.java +++ b/core/src/main/java/com/alibaba/alink/pipeline/feature/PCA.java @@ -1,5 +1,6 @@ package com.alibaba.alink.pipeline.feature; +import com.alibaba.alink.common.lazy.HasLazyPrintModelInfo; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.feature.PcaTrainBatchOp; import org.apache.flink.ml.api.misc.param.Params; @@ -14,7 +15,8 @@ */ public class PCA extends Trainer implements PcaTrainParams, - PcaPredictParams { + PcaPredictParams, + HasLazyPrintModelInfo { public PCA() { super(); diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOpTest.java index ec5e35016..1f71db245 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOpTest.java @@ -1,17 +1,19 @@ package com.alibaba.alink.operator.batch.feature; +import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; import org.apache.flink.types.Row; +import org.junit.Assert; import org.junit.Test; import java.util.Arrays; - -import static org.junit.Assert.assertArrayEquals; +import java.util.function.Consumer; public class ChiSqSelectorBatchOpTest { @Test - public void test() { + public void test() throws Exception { Row[] testArray = new Row[]{ Row.of("a", 1L, 1, 2.0, true), @@ -31,9 +33,22 @@ public void test() { selector.linkFrom(data); - String[] selectedColNames = selector.collectResult(); - - assertArrayEquals(new String[]{"f_string", "f_long"}, selectedColNames); + selector.lazyPrintModelInfo(); + + selector.lazyCollectModelInfo( + new Consumer() { + @Override + public void accept(ChisqSelectorModelInfo chisqSelectorSummary) { + Assert.assertEquals(chisqSelectorSummary.chisq("f_long"), 8.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.chisq("f_int"), 8.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.chisq("f_string"), 5.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.chisq("f_double"), 5.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.pValue("f_double"), 0.2872974951836462, 10e-10); + } + } + ); + + BatchOperator.execute(); } } \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOpTest.java index b642e228c..9e1fcbbd0 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOpTest.java @@ -1,12 +1,13 @@ package com.alibaba.alink.operator.batch.feature; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; +import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; import org.apache.flink.types.Row; +import org.junit.Assert; import org.junit.Test; import java.util.Arrays; - -import static org.junit.Assert.assertArrayEquals; +import java.util.function.Consumer; public class VectorChiSqSelectorBatchOpTest { @@ -32,7 +33,18 @@ public void testDense() { selector.linkFrom(source); - int[] selectedIndices = selector.collectResult(); - assertArrayEquals(selectedIndices, new int[] {2, 0}); + selector.lazyPrintModelInfo(); + + selector.lazyCollectModelInfo( + new Consumer() { + @Override + public void accept(ChisqSelectorModelInfo chisqSelectorSummary) { + Assert.assertEquals(chisqSelectorSummary.chisq("0"), 4.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.chisq("1"), 2.0, 10e-10); + Assert.assertEquals(chisqSelectorSummary.chisq("2"), 4.0, 10e-10); + } + } + ); + } } \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOpTest.java index f0eb0b53a..306046979 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOpTest.java @@ -33,6 +33,6 @@ public void test() { test.linkFrom(source); - Assert.assertEquals(test.collectChiSquareTestResult()[0].getP(), 0.004301310843500827, 10e-4); + Assert.assertEquals(test.collectChiSquareTest()[0].getP(), 0.004301310843500827, 10e-4); } } \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOpTest.java index fe60ffddb..c49626f92 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOpTest.java @@ -30,7 +30,7 @@ public void test() { test.linkFrom(source); - ChiSquareTestResult[] result = test.collectChiSquareTestResult(); + ChiSquareTestResult[] result = test.collectChiSquareTest(); Assert.assertEquals(result[0].getP(), 0.3864762307712323, 10e-4); Assert.assertEquals(result[0].getDf(), 1.0, 10e-4); diff --git a/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java b/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java index 5195fde48..b1f1ee986 100644 --- a/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java @@ -1,12 +1,12 @@ package com.alibaba.alink.operator.common.statistics; -import com.alibaba.alink.operator.common.statistics.ChiSquareTest; -import com.alibaba.alink.operator.common.statistics.Crosstab; +import com.alibaba.alink.operator.common.feature.ChisqSelectorUtil; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; +import org.apache.flink.api.common.functions.util.ListCollector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.types.Row; - -import com.alibaba.alink.params.feature.BasedChisqSelectorParams; +import org.junit.Assert; import org.junit.Test; import java.util.ArrayList; @@ -48,7 +48,7 @@ public void testChiSqSelector() { assertEquals(2, selectedIndices.length); assertEquals(0, selectedIndices[0]); - assertEquals(2, selectedIndices[1]); + assertEquals(1, selectedIndices[1]); } @Test @@ -64,7 +64,7 @@ public void testChiSqSelector2() { assertEquals(2, selectedIndices.length); assertEquals(0, selectedIndices[0]); - assertEquals(2, selectedIndices[1]); + assertEquals(1, selectedIndices[1]); } @Test @@ -114,19 +114,44 @@ public void testChiSqSelector5() { assertEquals(0, selectedIndices[0]); } + @Test + public void testChisqSelectorMap() { + ChisqSelectorUtil.ChiSquareSelector selector = + new ChisqSelectorUtil.ChiSquareSelector(null, BasedChisqSelectorParams.SelectorType.NumTopFeatures, + 5, 0, 0, 0, 0); + + List rowList = new ArrayList<>(); + ListCollector rows = new ListCollector(rowList); + + List test = new ArrayList<>(); + test.add(Row.of("1", 0.1, 0.1, 1.0)); + test.add(Row.of("2", 0.2, 0.2, 2.0)); + test.add(Row.of("3", 0.3, 0.3, 3.0)); + test.add(Row.of("4", 0.4, 0.4, 4.0)); + + selector.mapPartition(test, rows); + + for(Row row: rowList) { + if((long)row.getField(0) == 1048576) { + Assert.assertEquals("{\"chiSqs\":[{\"colName\":\"1\",\"df\":1.0,\"p\":0.1,\"value\":0.1},{\"colName\":\"2\",\"df\":2.0,\"p\":0.2,\"value\":0.2},{\"colName\":\"3\",\"df\":3.0,\"p\":0.3,\"value\":0.3},{\"colName\":\"4\",\"df\":4.0,\"p\":0.4,\"value\":0.4}],\"colNames\":null,\"siftOutColNames\":[\"1\",\"2\",\"3\",\"4\"],\"selectorType\":\"NumTopFeatures\",\"numTopFeatures\":5,\"percentile\":0.0,\"fpr\":0.0,\"fdr\":0.0,\"fwe\":0.0}" + , (String)row.getField(1)); + } + } + } + private int[] testSelector(BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, double fdr, double fwe) { - List data = new ArrayList<>(); - data.add(Row.of(0, 0.1, 1.0)); - data.add(Row.of(1, 0.3, 2.0)); - data.add(Row.of(2, 0.2, 4.0)); - data.add(Row.of(3, 0.4, 3.0)); - data.add(Row.of(4, 0.5, 4.0)); - - return ChiSquareTest.selector(data, selectorType, numTopFeatures, percentile, fpr, fdr, fwe); + List data = new ArrayList<>(); + data.add(new ChiSquareTestResult(0, 1.0, 0.1, "0")); + data.add(new ChiSquareTestResult(1, 2.0, 0.3, "1")); + data.add(new ChiSquareTestResult(2, 4.0, 0.2, "2")); + data.add(new ChiSquareTestResult(3, 3.0, 0.4, "3")); + data.add(new ChiSquareTestResult(4, 4.0, 0.5, "4")); + + return ChisqSelectorUtil.selector(data, selectorType, numTopFeatures, percentile, fpr, fdr, fwe); } } \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/pipeline/feature/PCATest.java b/core/src/test/java/com/alibaba/alink/pipeline/feature/PCATest.java index dff147460..926ce124b 100644 --- a/core/src/test/java/com/alibaba/alink/pipeline/feature/PCATest.java +++ b/core/src/test/java/com/alibaba/alink/pipeline/feature/PCATest.java @@ -1,40 +1,29 @@ package com.alibaba.alink.pipeline.feature; -import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.linalg.DenseVector; -import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; import com.alibaba.alink.operator.batch.statistics.VectorSummarizerBatchOp; -import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; -import org.apache.flink.table.api.Table; -import org.apache.flink.types.Row; - -import com.alibaba.alink.common.linalg.SparseVector; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; +import java.util.function.Consumer; public class PCATest { - private Table trainSparseBatch; - private Table predictSparseBatch; + @Test + public void test() throws Exception { + testTable(); + testSparse(); + testDense(); - @Before - public void setUp() throws Exception { - genSparseTensor(); + BatchOperator.execute(); } - @Test - public void testPipeline() throws Exception { - String[] colNames = new String[] {"id", "vec"}; + private void testDense() { + String[] colNames = new String[]{"id", "vec"}; - Object[][] data = new Object[][] { + Object[][] data = new Object[][]{ {1, "0.1 0.2 0.3 0.4"}, {2, "0.2 0.1 0.2 0.6"}, {3, "0.2 0.3 0.5 0.4"}, @@ -51,6 +40,8 @@ public void testPipeline() throws Exception { .setReservedCols("id") .setVectorCol("vec"); + pca.enableLazyPrintModelInfo(); + PCAModel model = pca.fit(source); BatchOperator predict = model.transform(source); @@ -59,38 +50,91 @@ public void testPipeline() throws Exception { summarizerOp.linkFrom(predict); - BaseVectorSummary summary = summarizerOp.collectVectorSummary(); - - Assert.assertEquals(4.840575043553453, Math.abs(summary.sum().get(0)), 10e-4); + summarizerOp.lazyCollectVectorSummary( + new Consumer() { + @Override + public void accept(BaseVectorSummary summary) { + Assert.assertEquals(3.4416913763379853E-15, Math.abs(summary.sum().get(0)), 10e-8); + } + } + ); } - private void genSparseTensor() { - int row = 100; - int col = 10; - Random random = new Random(2018L); - - String[] colNames = new String[2]; - colNames[0] = "id"; - colNames[1] = "matrix"; - - int[] indices = new int[col]; - for (int i = 0; i < col; i++) { - indices[i] = i; - } - - List rows = new ArrayList<>(); - double[] data = new double[col]; - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j++) { - data[j] = random.nextDouble(); + private void testSparse() { + String[] colNames = new String[]{"id", "vec"}; + + Object[][] data = new Object[][]{ + {1, "0:0.1 1:0.2 2:0.3 3:0.4"}, + {2, "0:0.2 1:0.1 2:0.2 3:0.6"}, + {3, "0:0.2 1:0.3 2:0.5 3:0.4"}, + {4, "0:0.3 1:0.1 2:0.3 3:0.7"}, + {5, "0:0.4 1:0.2 2:0.4 3:0.4"} + }; + + MemSourceBatchOp source = new MemSourceBatchOp(data, colNames); + + PCA pca = new PCA() + .setK(3) + .setCalculationType("CORR") + .setPredictionCol("pred") + .setReservedCols("id") + .setVectorCol("vec"); + + pca.enableLazyPrintModelInfo(); + + PCAModel model = pca.fit(source); + BatchOperator predict = model.transform(source); + + VectorSummarizerBatchOp summarizerOp = new VectorSummarizerBatchOp() + .setSelectedCol("pred"); + + summarizerOp.linkFrom(predict); + + summarizerOp.lazyCollectVectorSummary(new Consumer() { + @Override + public void accept(BaseVectorSummary summary) { + Assert.assertEquals(3.4416913763379853E-15, Math.abs(summary.sum().get(0)), 10e-8); } - SparseVector tensor = new SparseVector(col, indices, data); + }); + + } + + public void testTable() throws Exception { + String[] colNames = new String[]{"id", "f0", "f1", "f2", "f3"}; + + Object[][] data = new Object[][]{ + {1, 0.1, 0.2, 0.3, 0.4}, + {2, 0.2, 0.1, 0.2, 0.6}, + {3, 0.2, 0.3, 0.5, 0.4}, + {4, 0.3, 0.1, 0.3, 0.7}, + {5, 0.4, 0.2, 0.4, 0.4} + }; + + MemSourceBatchOp source = new MemSourceBatchOp(data, colNames); + + PCA pca = new PCA() + .setK(3) + .setCalculationType("CORR") + .setPredictionCol("pred") + .setReservedCols("id") + .setSelectedCols("f0", "f1", "f2", "f3"); + + pca.enableLazyPrintModelInfo(); + + PCAModel model = pca.fit(source); + BatchOperator predict = model.transform(source); + + VectorSummarizerBatchOp summarizerOp = new VectorSummarizerBatchOp() + .setSelectedCol("pred"); - rows.add(Row.of(i, VectorUtil.toString(tensor))); - } + summarizerOp.linkFrom(predict); - trainSparseBatch = MLEnvironmentFactory.getDefault().createBatchTable(rows, colNames); - predictSparseBatch = MLEnvironmentFactory.getDefault().createBatchTable(rows, colNames); + summarizerOp.lazyCollectVectorSummary(new Consumer() { + @Override + public void accept(BaseVectorSummary summary) { + Assert.assertEquals(3.1086244689504383E-15, Math.abs(summary.sum().get(0)), 10e-8); + } + }); } } \ No newline at end of file