Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add INSERT OVERWRITE to Trino SQL #11603

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
Expand Down Expand Up @@ -293,7 +294,7 @@ Optional<TableExecuteHandle> getTableHandleForExecute(
/**
* Begin insert query
*/
InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns);
InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns, Optional<InsertMode> insertMode);

/**
* @return whether connector handles missing columns during insert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import io.trino.spi.connector.ConnectorViewDefinition;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
Expand Down Expand Up @@ -900,13 +901,20 @@ public Optional<ConnectorOutputMetadata> finishCreateTable(Session session, Outp
}

@Override
public InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns)
public InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns, Optional<InsertMode> insertMode)
{
CatalogName catalogName = tableHandle.getCatalogName();
CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogName);
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);
ConnectorTransactionHandle transactionHandle = catalogMetadata.getTransactionHandleFor(catalogName);
ConnectorInsertTableHandle handle = metadata.beginInsert(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), columns, getRetryPolicy(session).getRetryMode());
ConnectorInsertTableHandle handle;
if (insertMode.isPresent()) {
handle = metadata.beginInsert(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), columns, getRetryPolicy(session).getRetryMode(), insertMode);
}
else {
handle = metadata.beginInsert(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), columns, getRetryPolicy(session).getRetryMode());
}

return new InsertTableHandle(tableHandle.getCatalogName(), transactionHandle, handle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnSchema;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.eventlistener.ColumnDetail;
import io.trino.spi.eventlistener.ColumnInfo;
import io.trino.spi.eventlistener.RoutineInfo;
Expand Down Expand Up @@ -1238,14 +1239,16 @@ public static final class Insert
private final TableHandle target;
private final List<ColumnHandle> columns;
private final Optional<TableLayout> newTableLayout;
private final Optional<InsertMode> insertMode;

public Insert(Table table, TableHandle target, List<ColumnHandle> columns, Optional<TableLayout> newTableLayout)
public Insert(Table table, TableHandle target, List<ColumnHandle> columns, Optional<TableLayout> newTableLayout, Optional<InsertMode> insertMode)
{
this.table = requireNonNull(table, "table is null");
this.target = requireNonNull(target, "target is null");
this.columns = requireNonNull(columns, "columns is null");
checkArgument(columns.size() > 0, "No columns given to insert");
this.newTableLayout = requireNonNull(newTableLayout, "newTableLayout is null");
this.insertMode = insertMode;
}

public Table getTable()
Expand All @@ -1267,6 +1270,11 @@ public Optional<TableLayout> getNewTableLayout()
{
return newTableLayout;
}

public Optional<InsertMode> getInsertMode()
{
return insertMode;
}
}

@Immutable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ColumnSchema;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.connector.PointerType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TableProcedureMetadata;
Expand Down Expand Up @@ -515,11 +516,20 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)
insertColumns = tableColumns;
}

Optional<InsertMode> insertMode;
if (insert.isOverwrite()) {
insertMode = Optional.of(InsertMode.OVERWRITE);
}
else {
insertMode = Optional.empty();
}

analysis.setInsert(new Analysis.Insert(
insert.getTable(),
targetTableHandle.get(),
insertColumns.stream().map(columnHandles::get).collect(toImmutableList()),
newTableLayout));
newTableLayout,
insertMode));

List<Type> tableTypes = insertColumns.stream()
.map(insertColumn -> tableSchema.getColumn(insertColumn).getType())
Expand All @@ -543,7 +553,13 @@ protected Scope visitInsert(Insert insert, Optional<Scope> scope)
.map(Type::toString),
Column::new);

analysis.setUpdateType("INSERT");
if (insertMode.isPresent() && insertMode.get() == InsertMode.OVERWRITE) {
analysis.setUpdateType("INSERT OVERWRITE");
}
else {
analysis.setUpdateType("INSERT");
}

analysis.setUpdateTarget(
targetTable,
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.security.AccessDeniedException;
import io.trino.spi.statistics.TableStatisticsMetadata;
import io.trino.spi.type.CharType;
Expand Down Expand Up @@ -401,7 +402,8 @@ private RelationPlan getInsertPlan(
TableHandle tableHandle,
List<ColumnHandle> insertColumns,
Optional<TableLayout> newTableLayout,
Optional<WriterTarget> materializedViewRefreshWriterTarget)
Optional<WriterTarget> materializedViewRefreshWriterTarget,
Optional<InsertMode> insertMode)
{
TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle);

Expand Down Expand Up @@ -497,7 +499,8 @@ private RelationPlan getInsertPlan(
tableHandle,
insertedTableColumnNames.stream()
.map(columns::get)
.collect(toImmutableList()));
.collect(toImmutableList()),
insertMode);
return createTableWriterPlan(
analysis,
plan.getRoot(),
Expand Down Expand Up @@ -530,7 +533,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
TableHandle tableHandle = insert.getTarget();
Query query = insertStatement.getQuery();
Optional<TableLayout> newTableLayout = insert.getNewTableLayout();
return getInsertPlan(analysis, insert.getTable(), query, tableHandle, insert.getColumns(), newTableLayout, Optional.empty());
return getInsertPlan(analysis, insert.getTable(), query, tableHandle, insert.getColumns(), newTableLayout, Optional.empty(), insert.getInsertMode());
}

private RelationPlan createRefreshMaterializedViewPlan(Analysis analysis)
Expand All @@ -553,7 +556,7 @@ private RelationPlan createRefreshMaterializedViewPlan(Analysis analysis)
viewAnalysis.getTable(),
tableHandle,
new ArrayList<>(analysis.getTables()));
return getInsertPlan(analysis, viewAnalysis.getTable(), query, tableHandle, viewAnalysis.getColumns(), newTableLayout, Optional.of(writerTarget));
return getInsertPlan(analysis, viewAnalysis.getTable(), query, tableHandle, viewAnalysis.getColumns(), newTableLayout, Optional.of(writerTarget), Optional.of(null));
}

private RelationPlan createTableWriterPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ private WriterTarget createWriterTarget(WriterTarget target)
}
if (target instanceof InsertReference) {
InsertReference insert = (InsertReference) target;
return new InsertTarget(metadata.beginInsert(session, insert.getHandle(), insert.getColumns()), metadata.getTableMetadata(session, insert.getHandle()).getTable());
return new InsertTarget(metadata.beginInsert(session, insert.getHandle(), insert.getColumns(), insert.getInsertMode()), metadata.getTableMetadata(session, insert.getHandle()).getTable());
}
if (target instanceof DeleteTarget) {
DeleteTarget delete = (DeleteTarget) target;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.metadata.TableLayout;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.connector.SchemaTableName;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
Expand Down Expand Up @@ -287,11 +288,13 @@ public static class InsertReference
{
private final TableHandle handle;
private final List<ColumnHandle> columns;
private final Optional<InsertMode> insertMode;

public InsertReference(TableHandle handle, List<ColumnHandle> columns)
public InsertReference(TableHandle handle, List<ColumnHandle> columns, Optional<InsertMode> insertMode)
{
this.handle = requireNonNull(handle, "handle is null");
this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null"));
this.insertMode = insertMode;
}

public TableHandle getHandle()
Expand All @@ -304,6 +307,11 @@ public List<ColumnHandle> getColumns()
return columns;
}

public Optional<InsertMode> getInsertMode()
{
return insertMode;
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.InsertMode;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
Expand Down Expand Up @@ -369,7 +370,7 @@ public void cleanupQuery(Session session)
}

@Override
public InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns)
public InsertTableHandle beginInsert(Session session, TableHandle tableHandle, List<ColumnHandle> columns, Optional<InsertMode> insertMode)
{
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2054,96 +2054,108 @@ public void testExplainAnalyze()
@Test
public void testInsert()
{
assertFails("INSERT INTO t6 (a) SELECT b from t6")
testInsert("INTO");
testInsert("OVERWRITE");
}

private void testInsert(String insertMethod)
{
assertFails("INSERT " + insertMethod + " t6 (a) SELECT b from t6")
.hasErrorCode(TYPE_MISMATCH);
analyze("INSERT INTO t1 SELECT * FROM t1");
analyze("INSERT INTO t3 SELECT * FROM t3");
analyze("INSERT INTO t3 SELECT a, b FROM t3");
assertFails("INSERT INTO t1 VALUES (1, 2)")
analyze("INSERT " + insertMethod + " t1 SELECT * FROM t1");
analyze("INSERT " + insertMethod + " t3 SELECT * FROM t3");
analyze("INSERT " + insertMethod + " t3 SELECT a, b FROM t3");
assertFails("INSERT " + insertMethod + " t1 VALUES (1, 2)")
.hasErrorCode(TYPE_MISMATCH);
analyze("INSERT INTO t5 (a) VALUES(null)");
analyze("INSERT " + insertMethod + " t5 (a) VALUES(null)");

// ignore t5 hidden column
analyze("INSERT INTO t5 VALUES (1)");
analyze("INSERT " + insertMethod + " t5 VALUES (1)");

// fail if hidden column provided
assertFails("INSERT INTO t5 VALUES (1, 2)")
assertFails("INSERT " + insertMethod + " t5 VALUES (1, 2)")
.hasErrorCode(TYPE_MISMATCH);

// note b is VARCHAR, while a,c,d are BIGINT
analyze("INSERT INTO t6 (a) SELECT a from t6");
analyze("INSERT INTO t6 (a) SELECT c from t6");
analyze("INSERT INTO t6 (a,b,c,d) SELECT * from t6");
analyze("INSERT INTO t6 (A,B,C,D) SELECT * from t6");
analyze("INSERT INTO t6 (a,b,c,d) SELECT d,b,c,a from t6");
assertFails("INSERT INTO t6 (a) SELECT b from t6")
analyze("INSERT " + insertMethod + " t6 (a) SELECT a from t6");
analyze("INSERT " + insertMethod + " t6 (a) SELECT c from t6");
analyze("INSERT " + insertMethod + " t6 (a,b,c,d) SELECT * from t6");
analyze("INSERT " + insertMethod + " t6 (A,B,C,D) SELECT * from t6");
analyze("INSERT " + insertMethod + " t6 (a,b,c,d) SELECT d,b,c,a from t6");
assertFails("INSERT " + insertMethod + " t6 (a) SELECT b from t6")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t6 (unknown) SELECT * FROM t6")
assertFails("INSERT " + insertMethod + " t6 (unknown) SELECT * FROM t6")
.hasErrorCode(COLUMN_NOT_FOUND);
assertFails("INSERT INTO t6 (a, a) SELECT * FROM t6")
assertFails("INSERT " + insertMethod + " t6 (a, a) SELECT * FROM t6")
.hasErrorCode(DUPLICATE_COLUMN_NAME);
assertFails("INSERT INTO t6 (a, A) SELECT * FROM t6")
assertFails("INSERT " + insertMethod + " t6 (a, A) SELECT * FROM t6")
.hasErrorCode(DUPLICATE_COLUMN_NAME);

// b is bigint, while a is double, coercion is possible either way
analyze("INSERT INTO t7 (b) SELECT (a) FROM t7 ");
analyze("INSERT INTO t7 (a) SELECT (b) FROM t7");
analyze("INSERT " + insertMethod + " t7 (b) SELECT (a) FROM t7 ");
analyze("INSERT " + insertMethod + " t7 (a) SELECT (b) FROM t7");

// d is array of bigints, while c is array of doubles, coercion is possible either way
analyze("INSERT INTO t7 (d) SELECT (c) FROM t7 ");
analyze("INSERT INTO t7 (c) SELECT (d) FROM t7 ");
analyze("INSERT " + insertMethod + " t7 (d) SELECT (c) FROM t7 ");
analyze("INSERT " + insertMethod + " t7 (c) SELECT (d) FROM t7 ");

analyze("INSERT INTO t7 (d) VALUES (ARRAY[null])");
analyze("INSERT " + insertMethod + " t7 (d) VALUES (ARRAY[null])");

analyze("INSERT INTO t6 (d) VALUES (1), (2), (3)");
analyze("INSERT INTO t6 (a,b,c,d) VALUES (1, 'a', 1, 1), (2, 'b', 2, 2), (3, 'c', 3, 3), (4, 'd', 4, 4)");
analyze("INSERT " + insertMethod + " t6 (d) VALUES (1), (2), (3)");
analyze("INSERT " + insertMethod + " t6 (a,b,c,d) VALUES (1, 'a', 1, 1), (2, 'b', 2, 2), (3, 'c', 3, 3), (4, 'd', 4, 4)");

// coercion is allowed between compatible types
analyze("INSERT INTO t8 (tinyint_column, integer_column, decimal_column, real_column) VALUES (1e0, 1e0, 1e0, 1e0)");
analyze("INSERT INTO t8 (char_column, bounded_varchar_column, unbounded_varchar_column) VALUES (VARCHAR 'aa ', VARCHAR 'aa ', VARCHAR 'aa ')");
analyze("INSERT INTO t8 (tinyint_array_column) SELECT (bigint_array_column) FROM t8");
analyze("INSERT INTO t8 (row_column) VALUES (ROW(ROW(1e0, VARCHAR 'aa ')))");
analyze("INSERT INTO t8 (date_column) VALUES (TIMESTAMP '2019-11-18 22:13:40')");
analyze("INSERT " + insertMethod + " t8 (tinyint_column, integer_column, decimal_column, real_column) VALUES (1e0, 1e0, 1e0, 1e0)");
analyze("INSERT " + insertMethod + " t8 (char_column, bounded_varchar_column, unbounded_varchar_column) VALUES (VARCHAR 'aa ', VARCHAR 'aa ', VARCHAR 'aa ')");
analyze("INSERT " + insertMethod + " t8 (tinyint_array_column) SELECT (bigint_array_column) FROM t8");
analyze("INSERT " + insertMethod + " t8 (row_column) VALUES (ROW(ROW(1e0, VARCHAR 'aa ')))");
analyze("INSERT " + insertMethod + " t8 (date_column) VALUES (TIMESTAMP '2019-11-18 22:13:40')");

// coercion is not allowed between incompatible types
assertFails("INSERT INTO t8 (integer_column) VALUES ('text')")
assertFails("INSERT " + insertMethod + " t8 (integer_column) VALUES ('text')")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t8 (integer_column) VALUES (true)")
assertFails("INSERT " + insertMethod + " t8 (integer_column) VALUES (true)")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t8 (integer_column) VALUES (ROW(ROW(1e0)))")
assertFails("INSERT " + insertMethod + " t8 (integer_column) VALUES (ROW(ROW(1e0)))")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t8 (integer_column) VALUES (TIMESTAMP '2019-11-18 22:13:40')")
assertFails("INSERT " + insertMethod + " t8 (integer_column) VALUES (TIMESTAMP '2019-11-18 22:13:40')")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t8 (unbounded_varchar_column) VALUES (1)")
assertFails("INSERT " + insertMethod + " t8 (unbounded_varchar_column) VALUES (1)")
.hasErrorCode(TYPE_MISMATCH);

// coercion with potential loss is not allowed for nested bounded character string types
assertFails("INSERT INTO t8 (nested_bounded_varchar_column) VALUES (ROW(ROW(CAST('aa' AS varchar(10)))))")
assertFails("INSERT " + insertMethod + " t8 (nested_bounded_varchar_column) VALUES (ROW(ROW(CAST('aa' AS varchar(10)))))")
.hasErrorCode(TYPE_MISMATCH);
}

@Test
public void testInvalidInsert()
{
assertFails("INSERT INTO foo VALUES (1)")
testInvalidInsert("INTO");
testInvalidInsert("OVERWRITE");
}

private void testInvalidInsert(String insertMethod)
{
assertFails("INSERT " + insertMethod + " foo VALUES (1)")
.hasErrorCode(TABLE_NOT_FOUND);
assertFails("INSERT INTO v1 VALUES (1)")
assertFails("INSERT " + insertMethod + " v1 VALUES (1)")
.hasErrorCode(NOT_SUPPORTED);

// fail if inconsistent fields count
assertFails("INSERT INTO t1 (a) VALUES (1), (1, 2)")
assertFails("INSERT " + insertMethod + " t1 (a) VALUES (1), (1, 2)")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t1 (a, b) VALUES (1), (1, 2)")
assertFails("INSERT " + insertMethod + " t1 (a, b) VALUES (1), (1, 2)")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t1 (a, b) VALUES (1, 2), (1, 2), (1, 2, 3)")
assertFails("INSERT " + insertMethod + " t1 (a, b) VALUES (1, 2), (1, 2), (1, 2, 3)")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t1 (a, b) VALUES ('a', 'b'), ('a', 'b', 'c')")
assertFails("INSERT " + insertMethod + " t1 (a, b) VALUES ('a', 'b'), ('a', 'b', 'c')")
.hasErrorCode(TYPE_MISMATCH);

// fail if mismatched column types
assertFails("INSERT INTO t1 (a, b) VALUES ('a', 'b'), (1, 'b')")
assertFails("INSERT " + insertMethod + " t1 (a, b) VALUES ('a', 'b'), (1, 'b')")
.hasErrorCode(TYPE_MISMATCH);
assertFails("INSERT INTO t1 (a, b) VALUES ('a', 'b'), ('a', 'b'), (1, 'b')")
assertFails("INSERT " + insertMethod + " t1 (a, b) VALUES ('a', 'b'), ('a', 'b'), (1, 'b')")
.hasErrorCode(TYPE_MISMATCH);
}

Expand Down
Loading