Skip to content

Commit

Permalink
Fix mysqlsource error when i use collect method. see #45
Browse files Browse the repository at this point in the history
  • Loading branch information
shaomengwang committed Apr 9, 2020
1 parent db3fd54 commit ede9f3f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
53 changes: 53 additions & 0 deletions core/src/main/java/com/alibaba/alink/common/io/MySqlDB.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package com.alibaba.alink.common.io;

import org.apache.flink.api.java.io.jdbc.JDBCInputFormat;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.io.annotations.DBAnnotation;
import com.alibaba.alink.common.utils.DataSetConversionUtil;
import com.alibaba.alink.common.utils.DataStreamConversionUtil;
import com.alibaba.alink.operator.common.io.csv.CsvUtil;
import com.alibaba.alink.params.io.MySqlDBParams;
import com.alibaba.alink.params.io.MySqlSourceParams;

/**
* DB of MySql.
Expand Down Expand Up @@ -94,4 +103,48 @@ public void setPort(String port) {
this.port = port;
}

@Override
public Table getStreamTable(String tableName, Params params, Long sessionId) throws Exception {
if (!params.contains(MySqlSourceParams.SCHEMA_STR)) {
return super.getStreamTable(tableName, params, sessionId);
} else {
TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR));

JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat()
.setUsername(getUserName())
.setPassword(getPassword())
.setDrivername(getDriverName())
.setDBUrl(getDbUrl())
.setQuery("select * from " + tableName)
.setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames()))
.finish();

return DataStreamConversionUtil.toTable(
sessionId,
MLEnvironmentFactory.get(sessionId).getStreamExecutionEnvironment().createInput(inputFormat),
schema.getFieldNames(), schema.getFieldTypes());
}
}

@Override
public Table getBatchTable(String tableName, Params params, Long sessionId) throws Exception {
if (!params.contains(MySqlSourceParams.SCHEMA_STR)) {
return super.getBatchTable(tableName, params, sessionId);
} else {
TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR));

JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat()
.setUsername(getUserName())
.setPassword(getPassword())
.setDrivername(getDriverName())
.setDBUrl(getDbUrl())
.setQuery("select * from " + tableName)
.setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames()))
.finish();

return DataSetConversionUtil.toTable(sessionId,
MLEnvironmentFactory.get(sessionId).getExecutionEnvironment().createInput(inputFormat),
schema.getFieldNames(), schema.getFieldTypes());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import org.apache.flink.ml.api.misc.param.WithParams;

public interface HasSchemaStr_null<T> extends WithParams<T> {
public interface HasSchemaStrDefaultAsNull<T> extends WithParams<T> {
ParamInfo <String> SCHEMA_STR = ParamInfoFactory
.createParamInfo("schemaStr", String.class)
.setDescription("Formatted schema")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@

public interface MySqlSourceParams<T> extends WithParams<T>,
MySqlDBParams <T>,
HasInputTableName <T> {
HasInputTableName <T>,
HasSchemaStrDefaultAsNull<T>{
}

0 comments on commit ede9f3f

Please sign in to comment.