Skip to content

Commit

Permalink
Fix parsing failure on WITH clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaho12 committed Oct 24, 2024
1 parent fd418e6 commit 32ddf59
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import io.trino.sql.tree.Statement;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableFunctionInvocation;
import io.trino.sql.tree.WithQuery;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.HttpMethod;

Expand All @@ -68,12 +69,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.io.BaseEncoding.base64Url;
import static io.airlift.json.JsonCodec.jsonCodec;
import static java.lang.Math.toIntExact;
Expand All @@ -90,6 +93,7 @@ public class TrinoQueryProperties
private String queryType = "";
private String resourceGroupQueryType = "";
private Set<QualifiedName> tables = ImmutableSet.of();
private final Set<QualifiedName> temporaryTables = new HashSet<>();
private final Optional<String> defaultCatalog;
private final Optional<String> defaultSchema;
private Set<String> catalogs = ImmutableSet.of();
Expand Down Expand Up @@ -201,12 +205,17 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {

getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
tables = tableBuilder.build();
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());

Set<QualifiedName> filteredTables = tables.stream()
.filter(table -> !temporaryTables.contains(table))
.collect(toImmutableSet());

catalogBuilder.addAll(filteredTables.stream().map(q -> q.getParts().getFirst()).iterator());
catalogs = catalogBuilder.build();
schemaBuilder.addAll(tables.stream().map(q -> q.getParts().get(1)).iterator());
schemaBuilder.addAll(filteredTables.stream().map(q -> q.getParts().get(1)).iterator());
schemas = schemaBuilder.build();
catalogSchemaBuilder.addAll(
tables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
filteredTables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
catalogSchemas = catalogSchemaBuilder.build();
}
catch (IOException e) {
Expand Down Expand Up @@ -336,8 +345,14 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
case Table s -> tableBuilder.add(qualifyName(s.getName()));
case Table s -> {
// ignore temporary tables as they can have various table parts
if (!temporaryTables.contains(s.getName())) {
tableBuilder.add(qualifyName(s.getName()));
}
}
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
case WithQuery withQuery -> temporaryTables.add(QualifiedName.of(withQuery.getName().getValue()));
default -> {}
}

Expand Down Expand Up @@ -385,6 +400,7 @@ private QualifiedName qualifyName(QualifiedName table)
throws RequestParsingException
{
List<String> tableParts = table.getParts();

return switch (tableParts.size()) {
case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst());
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,37 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
}

@Test
void testWithQueryNameExcluded()
throws IOException
{
String query = """
WITH dos AS (SELECT c1 from cat.schem.tbl1),
uno as (SELECT c1 FROM dos)
SELECT c1 FROM uno, dos
""";
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);

TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(mockRequestWithDefaults, requestAnalyzerConfig);

Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
assertThat(tablesWithDefaults).containsExactlyInAnyOrder(
QualifiedName.of("cat", "schem", "tbl1")
);

HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));

TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(mockRequestNoDefaults, requestAnalyzerConfig);
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
assertThat(tablesNoDefaults).containsExactlyInAnyOrder(
QualifiedName.of("cat", "schem", "tbl1")
);
}

private HttpServletRequest prepareMockRequest()
{
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
Expand Down

0 comments on commit 32ddf59

Please sign in to comment.