diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/RegisterTableProcedure.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/RegisterTableProcedure.java index a36c7edb2aac8..26c006c9440f0 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/RegisterTableProcedure.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/procedure/RegisterTableProcedure.java @@ -14,6 +14,7 @@ package io.trino.plugin.deltalake.procedure; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import com.google.inject.Provider; import io.trino.filesystem.Location; @@ -32,6 +33,7 @@ import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess; import io.trino.spi.TrinoException; import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; @@ -68,7 +70,7 @@ public class RegisterTableProcedure static { try { - REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorSession.class, String.class, String.class, String.class)); + REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorAccessControl.class, ConnectorSession.class, String.class, String.class, String.class)); } catch (ReflectiveOperationException e) { throw new AssertionError(e); @@ -110,6 +112,7 @@ public Procedure get() } public void registerTable( + ConnectorAccessControl accessControl, ConnectorSession clientSession, String schemaName, String tableName, @@ -117,6 +120,7 @@ public void registerTable( { try (ThreadContextClassLoader _ = new ThreadContextClassLoader(getClass().getClassLoader())) { doRegisterTable( + accessControl, clientSession, schemaName, tableName, @@ -125,6 +129,7 @@ public void registerTable( } private void doRegisterTable( + ConnectorAccessControl accessControl, ConnectorSession session, String schemaName, String tableName, @@ -138,6 +143,7 @@ private void doRegisterTable( checkProcedureArgument(!isNullOrEmpty(tableLocation), "table_location cannot be null or empty"); SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); + accessControl.checkCanCreateTable(null, schemaTableName, ImmutableMap.of()); DeltaLakeMetadata metadata = metadataFactory.create(session.getIdentity()); metadata.beginQuery(session); try (UncheckedCloseable ignore = () -> metadata.cleanupQuery(session)) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java index 5ab2b44b5c358..5be05e21c041a 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java @@ -81,6 +81,7 @@ import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.testing.MaterializedResult.resultBuilder; import static io.trino.testing.QueryAssertions.copyTpchTables; +import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_TABLE; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_FUNCTION; import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; import static io.trino.testing.TestingAccessControlManager.privilege; @@ -4820,6 +4821,20 @@ public void testDuplicatedFieldNames() } } + @Test + void testRegisterTableAccessControl() + { + String tableName = "test_register_table_access_control_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " AS SELECT 1 a", 1); + String tableLocation = metastore.getTable(SCHEMA, tableName).orElseThrow().getStorage().getLocation(); + metastore.dropTable(SCHEMA, tableName, false); + + assertAccessDenied( + "CALL system.register_table(CURRENT_SCHEMA, '" + tableName + "', '" + tableLocation + "')", + "Cannot create table .*", + privilege(tableName, CREATE_TABLE)); + } + @Test public void testMetastoreAfterCreateTable() { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/RegisterTableProcedure.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/RegisterTableProcedure.java index d404f6a3b3b8e..62918206085ca 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/RegisterTableProcedure.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/procedure/RegisterTableProcedure.java @@ -14,6 +14,7 @@ package io.trino.plugin.iceberg.procedure; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import com.google.inject.Provider; import io.trino.filesystem.Location; @@ -25,6 +26,7 @@ import io.trino.plugin.iceberg.fileio.ForwardingFileIo; import io.trino.spi.TrinoException; import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.procedure.Procedure; @@ -64,7 +66,7 @@ public class RegisterTableProcedure static { try { - REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorSession.class, String.class, String.class, String.class, String.class)); + REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorAccessControl.class, ConnectorSession.class, String.class, String.class, String.class, String.class)); } catch (ReflectiveOperationException e) { throw new AssertionError(e); @@ -98,6 +100,7 @@ public Procedure get() } public void registerTable( + ConnectorAccessControl accessControl, ConnectorSession clientSession, String schemaName, String tableName, @@ -106,6 +109,7 @@ public void registerTable( { try (ThreadContextClassLoader _ = new ThreadContextClassLoader(getClass().getClassLoader())) { doRegisterTable( + accessControl, clientSession, schemaName, tableName, @@ -115,6 +119,7 @@ public void registerTable( } private void doRegisterTable( + ConnectorAccessControl accessControl, ConnectorSession clientSession, String schemaName, String tableName, @@ -130,6 +135,7 @@ private void doRegisterTable( metadataFileName.ifPresent(RegisterTableProcedure::validateMetadataFileName); SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); + accessControl.checkCanCreateTable(null, schemaTableName, ImmutableMap.of()); TrinoCatalog catalog = catalogFactory.create(clientSession.getIdentity()); if (!catalog.namespaceExists(clientSession, schemaTableName.getSchemaName())) { throw new TrinoException(SCHEMA_NOT_FOUND, format("Schema '%s' does not exist", schemaTableName.getSchemaName())); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java index c860e33a49b8f..c95aed097d367 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java @@ -59,6 +59,8 @@ import static io.trino.plugin.iceberg.IcebergTestUtils.getFileSystemFactory; import static io.trino.plugin.iceberg.IcebergUtil.METADATA_FOLDER_NAME; import static io.trino.plugin.iceberg.IcebergUtil.getLatestMetadataLocation; +import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_TABLE; +import static io.trino.testing.TestingAccessControlManager.privilege; import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -541,6 +543,20 @@ public void testRegisterHadoopTableAndRead() assertUpdate("DROP TABLE " + tempTableName); } + @Test + void testRegisterTableAccessControl() + { + String tableName = "test_register_table_access_control_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " AS SELECT 1 a", 1); + String tableLocation = getTableLocation(tableName); + assertUpdate("CALL system.unregister_table(CURRENT_SCHEMA, '" + tableName + "')"); + + assertAccessDenied( + "CALL system.register_table(CURRENT_SCHEMA, '" + tableName + "', '" + tableLocation + "')", + "Cannot create table .*", + privilege(tableName, CREATE_TABLE)); + } + private String getTableLocation(String tableName) { Pattern locationPattern = Pattern.compile(".*location = '(.*?)'.*", Pattern.DOTALL);