Skip to content

Commit

Permalink
[Feature][Flink] Support Decimal Type with configurable precision and…
Browse files Browse the repository at this point in the history
… scale
  • Loading branch information
CheneyYin committed Sep 2, 2023
1 parent d8c92a1 commit a105c6c
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 29 deletions.
1 change: 1 addition & 0 deletions release-note.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
- [Core] [API] Add copy method to Catalog codes (#4414)
- [Core] [API] Add options check before create source and sink and transform in FactoryUtil (#4424)
- [Core] [Shade] Add guava shade module (#4358)
- [Core] [Flink] Support Decimal Type with configurable precision and scale (#5419)

### Connector-V2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.shade.com.typesafe.config.Config;

import org.apache.seatunnel.api.env.EnvCommonOptions;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.Constants;
import org.apache.seatunnel.common.config.CheckResult;
import org.apache.seatunnel.common.constants.JobMode;
Expand Down Expand Up @@ -51,8 +52,11 @@

import java.net.URL;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

@Slf4j
Expand All @@ -64,7 +68,8 @@ public class FlinkRuntimeEnvironment implements RuntimeEnvironment {
private StreamExecutionEnvironment environment;

private StreamTableEnvironment tableEnvironment;

private Map<String, SeaTunnelRowType> stagedTypes = new LinkedHashMap<>();
private Optional<SeaTunnelRowType> defaultType = Optional.empty();
private JobMode jobMode;

private String jobName = Constants.LOGO;
Expand Down Expand Up @@ -331,6 +336,24 @@ public void registerResultTable(Config config, DataStream<Row> dataStream) {
}
}

public void stageType(String tblName, SeaTunnelRowType type) {
stagedTypes.put(tblName, type);
}

public void stageDefaultType(SeaTunnelRowType type) {
this.defaultType = Optional.of(type);
}

public Optional<SeaTunnelRowType> type(String tblName) {
return stagedTypes.containsKey(tblName)
? Optional.of(stagedTypes.get(tblName))
: Optional.empty();
}

public Optional<SeaTunnelRowType> defaultType() {
return this.defaultType;
}

public static FlinkRuntimeEnvironment getInstance(Config config) {
if (INSTANCE == null) {
synchronized (FlinkRuntimeEnvironment.class) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelSinkPluginDiscovery;
import org.apache.seatunnel.translation.flink.sink.FlinkSink;
import org.apache.seatunnel.translation.flink.utils.TypeConverterUtils;

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
Expand Down Expand Up @@ -101,8 +100,8 @@ public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams)
SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable> seaTunnelSink =
plugins.get(i);
DataStream<Row> stream = fromSourceTable(sinkConfig).orElse(input);
seaTunnelSink.setTypeInfo(
(SeaTunnelRowType) TypeConverterUtils.convert(stream.getType()));
SeaTunnelRowType sourceType = initSourceType(sinkConfig, stream);
seaTunnelSink.setTypeInfo(sourceType);
if (SupportDataSaveMode.class.isAssignableFrom(seaTunnelSink.getClass())) {
SupportDataSaveMode saveModeSink = (SupportDataSaveMode) seaTunnelSink;
DataSaveMode dataSaveMode = saveModeSink.getUserConfigSaveMode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import org.apache.seatunnel.shade.com.typesafe.config.Config;

import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.utils.ReflectionUtils;
import org.apache.seatunnel.core.starter.execution.PluginExecuteProcessor;
import org.apache.seatunnel.core.starter.flink.utils.TableUtil;
import org.apache.seatunnel.translation.flink.utils.TypeConverterUtils;

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
Expand Down Expand Up @@ -86,6 +88,36 @@ protected void registerResultTable(Config pluginConfig, DataStream<Row> dataStre
flinkRuntimeEnvironment.registerResultTable(pluginConfig, dataStream);
}

protected void stageType(Config pluginConfig, SeaTunnelRowType type) {
if (!flinkRuntimeEnvironment.defaultType().isPresent()) {
flinkRuntimeEnvironment.stageDefaultType(type);
}

if (pluginConfig.hasPath("result_table_name")) {
String tblName = pluginConfig.getString("result_table_name");
flinkRuntimeEnvironment.stageType(tblName, type);
}
}

protected Optional<SeaTunnelRowType> sourceType(Config pluginConfig) {
if (pluginConfig.hasPath(SOURCE_TABLE_NAME)) {
String tblName = pluginConfig.getString(SOURCE_TABLE_NAME);
return flinkRuntimeEnvironment.type(tblName);
} else {
return flinkRuntimeEnvironment.defaultType();
}
}

protected SeaTunnelRowType initSourceType(Config sinkConfig, DataStream<Row> stream) {
SeaTunnelRowType sourceType =
sourceType(sinkConfig)
.orElseGet(
() ->
(SeaTunnelRowType)
TypeConverterUtils.convert(stream.getType()));
return sourceType;
}

protected abstract List<T> initializePlugins(
List<URL> jarPaths, List<? extends Config> pluginConfigs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.shade.com.typesafe.config.Config;

import org.apache.seatunnel.api.env.EnvCommonOptions;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.Constants;
import org.apache.seatunnel.common.config.CheckResult;
import org.apache.seatunnel.common.constants.JobMode;
Expand Down Expand Up @@ -51,8 +52,11 @@

import java.net.URL;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

@Slf4j
Expand All @@ -65,6 +69,9 @@ public class FlinkRuntimeEnvironment implements RuntimeEnvironment {

private StreamTableEnvironment tableEnvironment;

private Map<String, SeaTunnelRowType> stagedTypes = new LinkedHashMap<>();
private Optional<SeaTunnelRowType> defaultType = Optional.empty();

private JobMode jobMode;

private String jobName = Constants.LOGO;
Expand Down Expand Up @@ -331,6 +338,24 @@ public void registerResultTable(Config config, DataStream<Row> dataStream) {
}
}

public void stageType(String tblName, SeaTunnelRowType type) {
stagedTypes.put(tblName, type);
}

public void stageDefaultType(SeaTunnelRowType type) {
this.defaultType = Optional.of(type);
}

public Optional<SeaTunnelRowType> type(String tblName) {
return stagedTypes.containsKey(tblName)
? Optional.of(stagedTypes.get(tblName))
: Optional.empty();
}

public Optional<SeaTunnelRowType> defaultType() {
return this.defaultType;
}

public static FlinkRuntimeEnvironment getInstance(Config config) {
if (INSTANCE == null) {
synchronized (FlinkRuntimeEnvironment.class) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelSinkPluginDiscovery;
import org.apache.seatunnel.translation.flink.sink.FlinkSink;
import org.apache.seatunnel.translation.flink.utils.TypeConverterUtils;

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSink;
Expand Down Expand Up @@ -102,8 +101,8 @@ public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams)
SeaTunnelSink<SeaTunnelRow, Serializable, Serializable, Serializable> seaTunnelSink =
plugins.get(i);
DataStream<Row> stream = fromSourceTable(sinkConfig).orElse(input);
seaTunnelSink.setTypeInfo(
(SeaTunnelRowType) TypeConverterUtils.convert(stream.getType()));
SeaTunnelRowType sourceType = initSourceType(sinkConfig, stream);
seaTunnelSink.setTypeInfo(sourceType);
if (SupportDataSaveMode.class.isAssignableFrom(seaTunnelSink.getClass())) {
SupportDataSaveMode saveModeSink = (SupportDataSaveMode) seaTunnelSink;
DataSaveMode dataSaveMode = saveModeSink.getUserConfigSaveMode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.source.SeaTunnelSource;
import org.apache.seatunnel.api.source.SupportCoordinate;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.constants.JobMode;
import org.apache.seatunnel.core.starter.enums.PluginType;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
Expand Down Expand Up @@ -74,13 +75,15 @@ public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams)
boolean bounded =
internalSource.getBoundedness()
== org.apache.seatunnel.api.source.Boundedness.BOUNDED;
Config pluginConfig = pluginConfigs.get(i);
DataStreamSource<Row> sourceStream =
addSource(
executionEnvironment,
sourceFunction,
"SeaTunnel " + internalSource.getClass().getSimpleName(),
bounded);
Config pluginConfig = pluginConfigs.get(i);
stageType(pluginConfig, (SeaTunnelRowType) internalSource.getProducedType());

if (pluginConfig.hasPath(CommonOptions.PARALLELISM.key())) {
int parallelism = pluginConfig.getInt(CommonOptions.PARALLELISM.key());
sourceStream.setParallelism(parallelism);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import org.apache.seatunnel.shade.com.typesafe.config.Config;

import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.transform.SeaTunnelTransform;
import org.apache.seatunnel.core.starter.exception.TaskExecuteException;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
Expand Down Expand Up @@ -97,7 +97,10 @@ public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams)
SeaTunnelTransform<SeaTunnelRow> transform = plugins.get(i);
Config pluginConfig = pluginConfigs.get(i);
DataStream<Row> stream = fromSourceTable(pluginConfig).orElse(input);
input = flinkTransform(transform, stream);
SeaTunnelRowType sourceType = initSourceType(pluginConfig, stream);
transform.setTypeInfo(sourceType);
input = flinkTransform(sourceType, transform, stream);
stageType(pluginConfig, (SeaTunnelRowType) transform.getProducedType());
registerResultTable(pluginConfig, input);
result.add(input);
} catch (Exception e) {
Expand All @@ -111,11 +114,10 @@ public List<DataStream<Row>> execute(List<DataStream<Row>> upstreamDataStreams)
return result;
}

protected DataStream<Row> flinkTransform(SeaTunnelTransform transform, DataStream<Row> stream) {
SeaTunnelDataType seaTunnelDataType = TypeConverterUtils.convert(stream.getType());
transform.setTypeInfo(seaTunnelDataType);
protected DataStream<Row> flinkTransform(
SeaTunnelRowType sourceType, SeaTunnelTransform transform, DataStream<Row> stream) {
TypeInformation rowTypeInfo = TypeConverterUtils.convert(transform.getProducedType());
FlinkRowConverter transformInputRowConverter = new FlinkRowConverter(seaTunnelDataType);
FlinkRowConverter transformInputRowConverter = new FlinkRowConverter(sourceType);
FlinkRowConverter transformOutputRowConverter =
new FlinkRowConverter(transform.getProducedType());
DataStream<Row> output =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ env {

source {
FakeSource {
row.num = 100000
schema = {
fields {
c_map = "map<string, string>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ sink {
row_rules = [
{
rule_type = MAX_ROW
rule_value = 5
rule_value = 100000
}
],
field_rules = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ public void convertShortType() {

@Test
public void convertBigDecimalType() {
/**
* To solve lost precision and scale of {@link DecimalType}, use {@link
* BasicTypeInfo.STRING_TYPE_INFO} as the convert result of {@link DecimalType} instance.
*/
Assertions.assertEquals(
BasicTypeInfo.BIG_DEC_TYPE_INFO,
TypeConverterUtils.convert(new DecimalType(30, 2)));
BasicTypeInfo.STRING_TYPE_INFO, TypeConverterUtils.convert(new DecimalType(30, 2)));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.seatunnel.translation.flink.serialization;

import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
Expand All @@ -28,6 +29,8 @@
import org.apache.flink.types.RowKind;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -68,6 +71,14 @@ private static Object convert(Object field, SeaTunnelDataType<?> dataType) {
case MAP:
return convertMap(
(Map<?, ?>) field, (MapType<?, ?>) dataType, FlinkRowConverter::convert);

/**
* To solve lost precision and scale of {@link DecimalType}, use {@link String} as
* the convert result of {@link BigDecimal} instance.
*/
case DECIMAL:
BigDecimal decimal = (BigDecimal) field;
return decimal.toString();
default:
return field;
}
Expand Down Expand Up @@ -122,6 +133,17 @@ private static Object reconvert(Object field, SeaTunnelDataType<?> dataType) {
case MAP:
return convertMap(
(Map<?, ?>) field, (MapType<?, ?>) dataType, FlinkRowConverter::reconvert);

/**
* To solve lost precision and scale of {@link DecimalType}, create {@link
* BigDecimal} instance from {@link String} type field.
*/
case DECIMAL:
DecimalType decimalType = (DecimalType) dataType;
String decimalData = (String) field;
BigDecimal decimal = new BigDecimal(decimalData);
decimal.setScale(decimalType.getScale(), RoundingMode.HALF_UP);
return decimal;
default:
return field;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo;

import java.math.BigDecimal;
import java.time.LocalDate;
Expand Down Expand Up @@ -70,11 +69,13 @@ public class TypeConverterUtils {
BridgedType.of(BasicType.DOUBLE_TYPE, BasicTypeInfo.DOUBLE_TYPE_INFO));
BRIDGED_TYPES.put(
Void.class, BridgedType.of(BasicType.VOID_TYPE, BasicTypeInfo.VOID_TYPE_INFO));
// TODO: there is a still an unresolved issue that the BigDecimal type will lose the
// precision and scale
/**
* To solve lost precision and scale of {@link DecimalType}, use {@link
* BasicTypeInfo.STRING_TYPE_INFO} as the payload of {@link DecimalType}.
*/
BRIDGED_TYPES.put(
BigDecimal.class,
BridgedType.of(new DecimalType(38, 18), BasicTypeInfo.BIG_DEC_TYPE_INFO));
BridgedType.of(new DecimalType(38, 18), BasicTypeInfo.STRING_TYPE_INFO));

// data time types
BRIDGED_TYPES.put(
Expand Down Expand Up @@ -134,10 +135,7 @@ public static SeaTunnelDataType<?> convert(TypeInformation<?> dataType) {
if (bridgedType != null) {
return bridgedType.getSeaTunnelType();
}
if (dataType instanceof BigDecimalTypeInfo) {
BigDecimalTypeInfo decimalType = (BigDecimalTypeInfo) dataType;
return new DecimalType(decimalType.precision(), decimalType.scale());
}

if (dataType instanceof MapTypeInfo) {
MapTypeInfo<?, ?> mapTypeInfo = (MapTypeInfo<?, ?>) dataType;
return new MapType<>(
Expand All @@ -160,10 +158,7 @@ public static TypeInformation<?> convert(SeaTunnelDataType<?> dataType) {
if (bridgedType != null) {
return bridgedType.getFlinkType();
}
if (dataType instanceof DecimalType) {
DecimalType decimalType = (DecimalType) dataType;
return new BigDecimalTypeInfo(decimalType.getPrecision(), decimalType.getScale());
}

if (dataType instanceof MapType) {
MapType<?, ?> mapType = (MapType<?, ?>) dataType;
return new MapTypeInfo<>(
Expand Down

0 comments on commit a105c6c

Please sign in to comment.