Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set SeaTunnelRowTypeInfo to SeaTunnelSink #1904

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.seatunnel.api.common.PluginIdentifierInterface;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.table.type.SeaTunnelRowTypeInfo;

import java.io.IOException;
import java.io.Serializable;
Expand All @@ -40,6 +41,15 @@
*/
public interface SeaTunnelSink<IN, StateT, CommitInfoT, AggregatedCommitInfoT> extends Serializable, PluginIdentifierInterface {

/**
* Set the row type info of sink row data. This method will be automatically called by translation.
*
* @param seaTunnelRowTypeInfo The row type info of sink.
*/
default void setTypeInfo(SeaTunnelRowTypeInfo seaTunnelRowTypeInfo) {

}

/**
* This method will be called to creat {@link SinkWriter}
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@
public interface Converter<T1, T2> {

T2 convert(T1 dataType);

T1 reconvert(T2 dataType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public class PojoType<T> implements SeaTunnelDataType<T> {
private final Field[] fields;
private final SeaTunnelDataType<?>[] fieldTypes;

public PojoType(Class<T> pojoClass) {
this(pojoClass, null, null);
}

public PojoType(Class<T> pojoClass, Field[] fields, SeaTunnelDataType<?>[] fieldTypes) {
this.pojoClass = pojoClass;
this.fields = fields;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
import lombok.AllArgsConstructor;
import lombok.Data;

import java.io.Serializable;

@Data
@AllArgsConstructor
public class SeaTunnelRowTypeInfo {
public class SeaTunnelRowTypeInfo implements Serializable {
private static final long serialVersionUID = 1L;

/**
* The field name of the {@link SeaTunnelRow}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.sink.SinkWriter;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowTypeInfo;
import org.apache.seatunnel.connectors.seatunnel.console.state.ConsoleState;

import com.google.auto.service.AutoService;
Expand All @@ -29,9 +30,16 @@
@AutoService(SeaTunnelSink.class)
public class ConsoleSink implements SeaTunnelSink<SeaTunnelRow, ConsoleState, ConsoleCommitInfo, ConsoleAggregatedCommitInfo> {

private SeaTunnelRowTypeInfo seaTunnelRowTypeInfo;

@Override
public void setTypeInfo(SeaTunnelRowTypeInfo seaTunnelRowTypeInfo) {
this.seaTunnelRowTypeInfo = seaTunnelRowTypeInfo;
}

@Override
public SinkWriter<SeaTunnelRow, ConsoleCommitInfo, ConsoleState> createWriter(SinkWriter.Context context) {
return new ConsoleSinkWriter();
return new ConsoleSinkWriter(seaTunnelRowTypeInfo);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.seatunnel.api.sink.SinkWriter;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowTypeInfo;
import org.apache.seatunnel.connectors.seatunnel.console.state.ConsoleState;

import org.slf4j.Logger;
Expand All @@ -30,6 +31,12 @@ public class ConsoleSinkWriter implements SinkWriter<SeaTunnelRow, ConsoleCommit

private static final Logger LOGGER = LoggerFactory.getLogger(ConsoleSinkWriter.class);

private final SeaTunnelRowTypeInfo seaTunnelRowTypeInfo;

public ConsoleSinkWriter(SeaTunnelRowTypeInfo seaTunnelRowTypeInfo) {
this.seaTunnelRowTypeInfo = seaTunnelRowTypeInfo;
}

@Override
@SuppressWarnings("checkstyle:RegexpSingleline")
public void write(SeaTunnelRow element) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,49 @@
package org.apache.seatunnel.core.flink.execution;

import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowTypeInfo;
import org.apache.seatunnel.flink.FlinkEnvironment;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelSinkPluginDiscovery;
import org.apache.seatunnel.translation.flink.sink.FlinkSinkConverter;
import org.apache.seatunnel.translation.flink.utils.TypeConverterUtils;

import org.apache.seatunnel.shade.com.typesafe.config.Config;

import com.google.common.collect.Lists;
import org.apache.flink.api.connector.sink.Sink;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.types.Row;

import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import scala.Serializable;

public class SinkExecuteProcessor extends AbstractPluginExecuteProcessor<Sink<Row, Serializable, Serializable, Serializable>> {
public class SinkExecuteProcessor extends AbstractPluginExecuteProcessor<SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable>> {

protected SinkExecuteProcessor(FlinkEnvironment flinkEnvironment,
List<? extends Config> pluginConfigs) {
super(flinkEnvironment, pluginConfigs);
}

@Override
protected List<Sink<Row, Serializable, Serializable, Serializable>> initializePlugins(List<? extends Config> pluginConfigs) {
protected List<SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable>> initializePlugins(List<? extends Config> pluginConfigs) {
SeaTunnelSinkPluginDiscovery sinkPluginDiscovery = new SeaTunnelSinkPluginDiscovery();
List<URL> pluginJars = new ArrayList<>();
FlinkSinkConverter<SeaTunnelRow, Row, Serializable, Serializable, Serializable> flinkSinkConverter = new FlinkSinkConverter<>();
List<Sink<Row, Serializable, Serializable, Serializable>> sinks = pluginConfigs.stream().map(sinkConfig -> {
List<SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable>> sinks = pluginConfigs.stream().map(sinkConfig -> {
PluginIdentifier pluginIdentifier = PluginIdentifier.of(
"seatunnel",
"sink",
sinkConfig.getString("plugin_name"));
pluginJars.addAll(sinkPluginDiscovery.getPluginJarPaths(Lists.newArrayList(pluginIdentifier)));
SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable> pluginInstance =
sinkPluginDiscovery.getPluginInstance(pluginIdentifier);
return flinkSinkConverter.convert(pluginInstance, Collections.emptyMap());
return (SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable>) sinkPluginDiscovery.getPluginInstance(pluginIdentifier);
}).collect(Collectors.toList());
flinkEnvironment.registerPlugin(pluginJars);
return sinks;
Expand All @@ -68,12 +69,23 @@ protected List<Sink<Row, Serializable, Serializable, Serializable>> initializePl
@Override
public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams) throws Exception {
DataStream<Row> input = upstreamDataStreams.get(0);
FlinkSinkConverter<SeaTunnelRow, Row, Serializable, Serializable, Serializable> flinkSinkConverter = new FlinkSinkConverter<>();
for (int i = 0; i < plugins.size(); i++) {
Config sinkConfig = pluginConfigs.get(i);
SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable> seaTunnelSink = plugins.get(i);
DataStream<Row> stream = fromSourceTable(sinkConfig).orElse(input);
stream.sinkTo(plugins.get(i));
seaTunnelSink.setTypeInfo(getSeaTunnelRowTypeInfo(stream));
stream.sinkTo(flinkSinkConverter.convert(seaTunnelSink, Collections.emptyMap()));
}
// the sink is the last stream
return null;
}

private SeaTunnelRowTypeInfo getSeaTunnelRowTypeInfo(DataStream<Row> stream) {
RowTypeInfo typeInformation = (RowTypeInfo) stream.getType();
String[] fieldNames = typeInformation.getFieldNames();
SeaTunnelDataType<?>[] seaTunnelDataTypes = Arrays.stream(typeInformation.getFieldTypes())
.map(TypeConverterUtils::convertType).toArray(SeaTunnelDataType[]::new);
return new SeaTunnelRowTypeInfo(fieldNames, seaTunnelDataTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,9 @@ public BasicArrayTypeInfo<T1, T2> convert(ArrayType<T1> arrayType) {
}
throw new IllegalArgumentException("Unsupported basic type: " + elementType);
}

@Override
public ArrayType<T1> reconvert(BasicArrayTypeInfo<T1, T2> typeInformation) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,10 @@ public BasicTypeConverter(BasicType<T1> seaTunnelDataType, TypeInformation<T1> f
public TypeInformation<T1> convert(BasicType<T1> seaTunnelDataType) {
return flinkTypeInformation;
}

@Override
public BasicType<T1> reconvert(TypeInformation<T1> dataType) {
return seaTunnelDataType;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,13 @@ public interface FlinkTypeConverter<T1, T2> extends Converter<T1, T2> {
@Override
T2 convert(T1 seaTunnelDataType);

/**
* Convert flink {@link TypeInformation} to SeaTunnel {@link SeaTunnelDataType}.
*
* @param typeInformation flink {@link TypeInformation}
* @return seatunnel {@link SeaTunnelDataType}
*/
@Override
T1 reconvert(T2 typeInformation);

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,9 @@ public PojoTypeInfo<T1> convert(PojoType<T1> seaTunnelDataType) {
Class<T1> pojoClass = seaTunnelDataType.getPojoClass();
return (PojoTypeInfo<T1>) PojoTypeInfo.of(pojoClass);
}

@Override
public PojoType<T1> reconvert(PojoTypeInfo<T1> typeInformation) {
return new PojoType<>(typeInformation.getTypeClass());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ private TimestampTypeConverter() {
public TimestampDataTypeInfo convert(TimestampType seaTunnelDataType) {
return new TimestampDataTypeInfo(seaTunnelDataType.getPrecision());
}

@Override
public TimestampType reconvert(TimestampDataTypeInfo typeInformation) {
return new TimestampType(typeInformation.getPrecision());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.seatunnel.translation.flink.types.TimestampTypeConverter;

import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.EnumTypeInfo;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
Expand All @@ -47,6 +48,14 @@ private TypeConverterUtils() {
throw new UnsupportedOperationException("TypeConverterUtils is a utility class and cannot be instantiated");
}

public static <T1, T2> SeaTunnelDataType<T2> convertType(TypeInformation<T1> dataType) {
if (dataType instanceof BasicTypeInfo) {
return (SeaTunnelDataType<T2>) convertBasicType((BasicTypeInfo<T1>) dataType);
}
// todo:
throw new IllegalArgumentException("Unsupported data type: " + dataType);
}

@SuppressWarnings("unchecked")
public static <T1, T2> TypeInformation<T2> convertType(SeaTunnelDataType<T1> dataType) {
if (dataType instanceof BasicType) {
Expand Down Expand Up @@ -146,6 +155,77 @@ public static <T> TypeInformation<T> convertBasicType(BasicType<T> basicType) {
throw new IllegalArgumentException("Unsupported basic type: " + basicType);
}

@SuppressWarnings("unchecked")
private static <T1> SeaTunnelDataType<T1> convertBasicType(BasicTypeInfo<T1> flinkDataType) {
Class<T1> physicalTypeClass = flinkDataType.getTypeClass();
if (physicalTypeClass == Boolean.class) {
TypeInformation<Boolean> booleanTypeInformation = (TypeInformation<Boolean>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.BOOLEAN_CONVERTER.reconvert(booleanTypeInformation);
}
if (physicalTypeClass == String.class) {
TypeInformation<String> stringBasicType = (TypeInformation<String>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.STRING_CONVERTER.reconvert(stringBasicType);
}
if (physicalTypeClass == Date.class) {
TypeInformation<Date> dateBasicType = (TypeInformation<Date>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.DATE_CONVERTER.reconvert(dateBasicType);
}
if (physicalTypeClass == Double.class) {
TypeInformation<Double> doubleBasicType = (TypeInformation<Double>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.DOUBLE_CONVERTER.reconvert(doubleBasicType);
}
if (physicalTypeClass == Integer.class) {
TypeInformation<Integer> integerBasicType = (TypeInformation<Integer>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.INTEGER_CONVERTER.reconvert(integerBasicType);
}
if (physicalTypeClass == Long.class) {
TypeInformation<Long> longBasicType = (TypeInformation<Long>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.LONG_CONVERTER.reconvert(longBasicType);
}
if (physicalTypeClass == Float.class) {
TypeInformation<Float> floatBasicType = (TypeInformation<Float>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.FLOAT_CONVERTER.reconvert(floatBasicType);
}
if (physicalTypeClass == Byte.class) {
TypeInformation<Byte> byteBasicType = (TypeInformation<Byte>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.BYTE_CONVERTER.reconvert(byteBasicType);
}
if (physicalTypeClass == Short.class) {
TypeInformation<Short> shortBasicType = (TypeInformation<Short>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.SHORT_CONVERTER.reconvert(shortBasicType);
}
if (physicalTypeClass == Character.class) {
TypeInformation<Character> characterBasicType = (TypeInformation<Character>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.CHARACTER_CONVERTER.reconvert(characterBasicType);
}
if (physicalTypeClass == BigInteger.class) {
TypeInformation<BigInteger> bigIntegerBasicType = (TypeInformation<BigInteger>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.BIG_INTEGER_CONVERTER.reconvert(bigIntegerBasicType);
}
if (physicalTypeClass == BigDecimal.class) {
TypeInformation<BigDecimal> bigDecimalBasicType = (TypeInformation<BigDecimal>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.BIG_DECIMAL_CONVERTER.reconvert(bigDecimalBasicType);
}
if (physicalTypeClass == Void.class) {
TypeInformation<Void> voidBasicType = (TypeInformation<Void>) flinkDataType;
return (SeaTunnelDataType<T1>)
BasicTypeConverter.NULL_CONVERTER.reconvert(voidBasicType);
}
throw new IllegalArgumentException("Unsupported flink type: " + flinkDataType);
}

public static <T1, T2> BasicArrayTypeInfo<T1, T2> convertArrayType(ArrayType<T1> arrayType) {
ArrayTypeConverter<T1, T2> arrayTypeConverter = new ArrayTypeConverter<>();
return arrayTypeConverter.convert(arrayType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,9 @@ public org.apache.spark.sql.types.ArrayType convert(ArrayType<T1> seaTunnelDataT
DataType elementType = TypeConverterUtils.convert(seaTunnelDataType.getElementType());
return DataTypes.createArrayType(elementType);
}

@Override
public ArrayType<T1> reconvert(org.apache.spark.sql.types.ArrayType dataType) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,9 @@ public BasicTypeConverter(BasicType<T1> seatunnelDataType, DataType sparkDataTyp
public DataType convert(BasicType<T1> seaTunnelDataType) {
return sparkDataType;
}

@Override
public BasicType<T1> reconvert(DataType dataType) {
return seatunnelDataType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ public class PojoTypeConverter<T1> implements SparkDataTypeConverter<PojoType<T1
public ObjectType convert(PojoType<T1> seaTunnelDataType) {
return new ObjectType(seaTunnelDataType.getPojoClass());
}

@Override
public PojoType<T1> reconvert(ObjectType dataType) {
return new PojoType<>((Class<T1>) dataType.cls());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,7 @@ public interface SparkDataTypeConverter<T1, T2> extends Converter<T1, T2> {
*/
@Override
T2 convert(T1 seaTunnelDataType);

@Override
T1 reconvert(T2 dataType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,9 @@ private TimestampTypeConverter() {
public org.apache.spark.sql.types.TimestampType convert(TimestampType seaTunnelDataType) {
return (org.apache.spark.sql.types.TimestampType) DataTypes.TimestampType;
}

@Override
public TimestampType reconvert(org.apache.spark.sql.types.TimestampType dataType) {
return new TimestampType(dataType.defaultSize());
}
}