From ed487d8a92817c16af2af5165dc7aece53fe5175 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 9 Jun 2023 17:11:17 -0700 Subject: [PATCH 1/5] Add ANTLR grammar file Signed-off-by: Chen Dai --- flint/build.sbt | 13 ++- .../main/antlr4/FlintSparkSqlExtensions.g4 | 103 ++++++++++++++++++ flint/project/plugins.sbt | 1 + 3 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 diff --git a/flint/build.sbt b/flint/build.sbt index 02a05dcfad..51e308bf00 100644 --- a/flint/build.sbt +++ b/flint/build.sbt @@ -24,6 +24,8 @@ ThisBuild / scalastyleConfig := baseDirectory.value / "scalastyle-config.xml" */ ThisBuild / Test / parallelExecution := false +// enablePlugins(Antlr4Plugin) + // Run as part of compile task. lazy val compileScalastyle = taskKey[Unit]("compileScalastyle") @@ -56,7 +58,7 @@ lazy val flintCore = (project in file("flint-core")) lazy val flintSparkIntegration = (project in file("flint-spark-integration")) .dependsOn(flintCore) - .enablePlugins(AssemblyPlugin) + .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, name := "flint-spark-integration", @@ -70,6 +72,15 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % "test"), libraryDependencies ++= deps(sparkVersion), + // ANTLR settings + antlr4Version in Antlr4 := "4.7", + // antlr4PackageName in Antlr4 := Some("org.opensearch.flint.spark.sql"), + antlr4GenListener in Antlr4 := true, + antlr4GenVisitor in Antlr4 := true, + antlr4TreatWarningsAsErrors in Antlr4 := true, + // antlr4Generate in Antlr4 := + // Seq(file("flint-spark-integration/src/main/antlr/FlintSparkSqlExtensions.g4")), + // Assembly settings assemblyPackageScala / assembleArtifact := false, assembly / assemblyOption ~= { _.withIncludeScala(false) diff --git a/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 new file mode 100644 index 0000000000..61bd1a0927 --- /dev/null +++ b/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This file contains code from the Apache Spark project (original license above). + * It contains modifications, which are licensed as follows: + */ + +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +grammar FlintSparkSqlExtensions; + + +// Flint SQL Syntax Extension + +singleStatement + : statement EOF + ; + +statement + : skippingIndexStatement + ; + +skippingIndexStatement + : dropSkippingIndexStatement + ; + +dropSkippingIndexStatement + : DROP SKIPPING INDEX ON multipartIdentifier + ; + + +// Flint Lexer Extension + +DROP: 'DROP'; +INDEX: 'INDEX'; +ON: 'ON'; +SKIPPING : 'SKIPPING'; + + +// Copy from Spark 3.3.1 SqlBaseParser.g4 and SqlBaseLexer.g4 + +multipartIdentifier + : parts+=errorCapturingIdentifier (DOT parts+=errorCapturingIdentifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +DOT: '.'; +MINUS: '-'; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; diff --git a/flint/project/plugins.sbt b/flint/project/plugins.sbt index 93655827af..0fe5dd1ab8 100644 --- a/flint/project/plugins.sbt +++ b/flint/project/plugins.sbt @@ -7,3 +7,4 @@ addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") addSbtPlugin("com.lightbend.sbt" % "sbt-java-formatter" % "0.8.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.0") +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3") From 3ef5953ad43fb56dd9d99312cc475ca1d3611b12 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 19 Jun 2023 17:27:23 -0700 Subject: [PATCH 2/5] Add Flint command builder and IT Signed-off-by: Chen Dai --- flint/build.sbt | 13 +- .../main/antlr4/FlintSparkSqlExtensions.g4 | 84 +------- .../src/main/antlr4/SparkSqlBase.g4 | 153 ++++++++++++++ .../flint/spark/FlintSparkExtensions.scala | 5 +- .../spark/sql/FlintSparkSqlCommand.scala | 31 +++ .../sql/FlintSparkSqlCommandBuilder.scala | 30 +++ .../flint/spark/sql/FlintSparkSqlParser.scala | 189 ++++++++++++++++++ .../spark/sql/FlintSparkSqlParserSuite.scala | 14 ++ .../flint/spark/FlintSparkSqlSuite.scala | 62 ++++++ 9 files changed, 493 insertions(+), 88 deletions(-) create mode 100644 flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 create mode 100644 flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala create mode 100644 flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala create mode 100644 flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala diff --git a/flint/build.sbt b/flint/build.sbt index 51e308bf00..3b497d2d44 100644 --- a/flint/build.sbt +++ b/flint/build.sbt @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import Dependencies._ +import com.simplytyped.Antlr4Plugin.autoImport.antlr4Version lazy val scala212 = "2.12.14" lazy val sparkVersion = "3.3.1" @@ -73,13 +74,11 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) "com.github.sbt" % "junit-interface" % "0.13.3" % "test"), libraryDependencies ++= deps(sparkVersion), // ANTLR settings - antlr4Version in Antlr4 := "4.7", - // antlr4PackageName in Antlr4 := Some("org.opensearch.flint.spark.sql"), - antlr4GenListener in Antlr4 := true, - antlr4GenVisitor in Antlr4 := true, - antlr4TreatWarningsAsErrors in Antlr4 := true, - // antlr4Generate in Antlr4 := - // Seq(file("flint-spark-integration/src/main/antlr/FlintSparkSqlExtensions.g4")), + Antlr4 / antlr4Version := "4.7", + Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), + Antlr4 / antlr4GenListener := true, + Antlr4 / antlr4GenVisitor := true, + // antlr4TreatWarningsAsErrors in Antlr4 := true, // Assembly settings assemblyPackageScala / assembleArtifact := false, assembly / assemblyOption ~= { diff --git a/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index 61bd1a0927..383b9d63ba 100644 --- a/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -1,25 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This file contains code from the Apache Spark project (original license above). - * It contains modifications, which are licensed as follows: - */ - /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 @@ -27,11 +5,13 @@ grammar FlintSparkSqlExtensions; +import SparkSqlBase; + // Flint SQL Syntax Extension singleStatement - : statement EOF + : statement SEMICOLON* EOF ; statement @@ -43,61 +23,5 @@ skippingIndexStatement ; dropSkippingIndexStatement - : DROP SKIPPING INDEX ON multipartIdentifier - ; - - -// Flint Lexer Extension - -DROP: 'DROP'; -INDEX: 'INDEX'; -ON: 'ON'; -SKIPPING : 'SKIPPING'; - - -// Copy from Spark 3.3.1 SqlBaseParser.g4 and SqlBaseLexer.g4 - -multipartIdentifier - : parts+=errorCapturingIdentifier (DOT parts+=errorCapturingIdentifier)* - ; - -// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` -// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise -// valid expressions such as "a-b" can be recognized as an identifier -errorCapturingIdentifier - : identifier errorCapturingIdentifierExtra - ; - -// extra left-factoring grammar -errorCapturingIdentifierExtra - : (MINUS identifier)+ #errorIdent - | #realIdent - ; - -identifier - : IDENTIFIER #unquotedIdentifier - | quotedIdentifier #quotedIdentifierAlternative - ; - -quotedIdentifier - : BACKQUOTED_IDENTIFIER - ; - -DOT: '.'; -MINUS: '-'; - -IDENTIFIER - : (LETTER | DIGIT | '_')+ - ; - -BACKQUOTED_IDENTIFIER - : '`' ( ~'`' | '``' )* '`' - ; - -fragment DIGIT - : [0-9] - ; - -fragment LETTER - : [A-Z] + : DROP SKIPPING INDEX ON tableName=multipartIdentifier ; diff --git a/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 b/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 new file mode 100644 index 0000000000..adcf33225f --- /dev/null +++ b/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar SparkSqlBase; + +// Copy from Spark 3.3.1 SqlBaseParser.g4 and SqlBaseLexer.g4 + +@members { + /** + * When true, parser should throw ParseExcetion for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } +} + + +multipartIdentifier + : parts+=identifier (DOT parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +nonReserved + : DROP | SKIPPING | INDEX + ; + + +// Flint Lexer Extension + +SKIPPING : 'SKIPPING'; + + +SEMICOLON: ';'; + +DOT: '.'; +DROP: 'DROP'; +INDEX: 'INDEX'; +MINUS: '-'; +ON: 'ON'; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; \ No newline at end of file diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala index 430f933416..55fda3c10e 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex +import org.opensearch.flint.spark.sql.FlintSparkSqlParser import org.apache.spark.sql.SparkSessionExtensions @@ -15,6 +15,9 @@ import org.apache.spark.sql.SparkSessionExtensions class FlintSparkExtensions extends (SparkSessionExtensions => Unit) { override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectParser { (spark, parser) => + new FlintSparkSqlParser(parser) + } extensions.injectOptimizerRule { spark => new FlintSparkOptimizer(spark) } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala new file mode 100644 index 0000000000..fafe681172 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.opensearch.flint.spark.FlintSpark + +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.types.StringType + +/** + * Flint Spark SQL command. + * + * Note that currently Flint SQL layer is thin with all core logic in FlintSpark. May create + * separate command for each Flint SQL statement in future as needed. + * + * @param block + * code block that triggers Flint core API + */ +case class FlintSparkSqlCommand(block: FlintSpark => Seq[Row]) extends LeafRunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = block(new FlintSpark(sparkSession)) +} + +object FlintSparkSqlCommand { + val DEFAULT_OUTPUT = Seq(AttributeReference("Result", StringType, nullable = true)()) +} diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala new file mode 100644 index 0000000000..1e5aeaf79e --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext + +import org.apache.spark.sql.catalyst.plans.logical.Command + +/** + * Flint Spark AST builder that builds Spark command for Flint index statement. + */ +class FlintSparkSqlCommandBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] { + + override def visitDropSkippingIndexStatement( + ctx: DropSkippingIndexStatementContext): Command = { + FlintSparkSqlCommand { flint => + val tableName = ctx.tableName.getText + val indexName = FlintSparkSkippingIndex.getSkippingIndexName(tableName) + flint.deleteIndex(indexName) + Seq.empty + } + } + + override def aggregateResult(aggregate: Command, nextResult: Command): Command = + if (nextResult != null) nextResult else aggregate; +} diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala new file mode 100644 index 0000000000..5390ee229a --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Flint SQL parser that extends Spark SQL parser to parse Flint command first and fall back to + * Spark parser for unrecognized statement. + * + * @param sparkParser + * Spark SQL parser + */ +class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface { + + /** + * Flint command builder. This has to be lazy because Spark.conf in FlintSpark will create + * Parser and thus cause stack overflow + */ + private val flintCmdBuilder = new FlintSparkSqlCommandBuilder() + + override def parsePlan(sqlText: String): LogicalPlan = { + val flintLexer = new FlintSparkSqlExtensionsLexer( + new UpperCaseCharStream(CharStreams.fromString(sqlText))) + flintLexer.removeErrorListeners() + flintLexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(flintLexer) + val flintParser = new FlintSparkSqlExtensionsParser(tokenStream) + flintParser.addParseListener(PostProcessor) + // parser.addParseListener(UnclosedCommentProcessor(command, tokenStream)) + flintParser.removeErrorListeners() + flintParser.addErrorListener(ParseErrorListener) + + try { + val ctx = flintParser.singleStatement() + flintCmdBuilder.visit(ctx) match { + case plan: LogicalPlan => plan + case _ => sparkParser.parsePlan(sqlText) + } + } catch { + case e: ParseException => sparkParser.parsePlan(sqlText) + } + + } + + /* + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => + flintCmdBuilder.visit(flintParser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => sparkParser.parsePlan(sqlText) + } + } + */ + + override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + sparkParser.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + sparkParser.parseFunctionIdentifier(sqlText) + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + sparkParser.parseMultipartIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + sparkParser.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = sparkParser.parseDataType(sqlText) + + override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) + + protected def parse[T](sqlText: String)(toResult: FlintSparkSqlExtensionsParser => T): T = { + val lexer = new FlintSparkSqlExtensionsLexer( + new UpperCaseCharStream(CharStreams.fromString(sqlText))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new FlintSparkSqlExtensionsParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } catch { + case e: ParseException => throw e + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException( + Option(sqlText), + e.message, + position, + position, + e.errorClass, + e.messageParameters) + } + } +} + +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume() + override def getSourceName: String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +case object PostProcessor extends FlintSparkSqlExtensionsBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} diff --git a/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala b/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala new file mode 100644 index 0000000000..917c820581 --- /dev/null +++ b/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.apache.spark.SparkFunSuite + +class FlintSparkSqlParserSuite extends SparkFunSuite { + + test("Skipping index statement should pass") { + } +} diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala new file mode 100644 index 0000000000..aeb51ecad0 --- /dev/null +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import scala.Option.empty + +import org.opensearch.flint.OpenSearchSuite +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT} + +class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite { + + /** Flint Spark high level API for assertion */ + private lazy val flint: FlintSpark = { + setFlintSparkConf(HOST_ENDPOINT, openSearchHost) + setFlintSparkConf(HOST_PORT, openSearchPort) + new FlintSpark(spark) + } + + /** Test table and index name */ + private val testTable = "test" + private val testIndex = getSkippingIndexName(testTable) + + override def beforeAll(): Unit = { + super.beforeAll() + + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + } + + test("drop index test") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitionIndex("year") + .create() + + sql(s"DROP SKIPPING INDEX ON $testTable").show + + flint.describeIndex(testIndex) shouldBe empty + } +} From 04d0faf5f42c6d2366746aa0fae958b648c7c5e6 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 20 Jun 2023 10:19:31 -0700 Subject: [PATCH 3/5] Refactor flint parser logic Signed-off-by: Chen Dai --- flint/build.sbt | 4 - .../src/main/antlr4/SparkSqlBase.g4 | 4 +- ...er.scala => FlintSparkSqlAstBuilder.scala} | 6 +- .../flint/spark/sql/FlintSparkSqlParser.scala | 125 +++++++++++------- .../spark/sql/FlintSparkSqlParserSuite.scala | 14 -- 5 files changed, 86 insertions(+), 67 deletions(-) rename flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/{FlintSparkSqlCommandBuilder.scala => FlintSparkSqlAstBuilder.scala} (83%) delete mode 100644 flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala diff --git a/flint/build.sbt b/flint/build.sbt index 3b497d2d44..de198dfbab 100644 --- a/flint/build.sbt +++ b/flint/build.sbt @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ import Dependencies._ -import com.simplytyped.Antlr4Plugin.autoImport.antlr4Version lazy val scala212 = "2.12.14" lazy val sparkVersion = "3.3.1" @@ -25,8 +24,6 @@ ThisBuild / scalastyleConfig := baseDirectory.value / "scalastyle-config.xml" */ ThisBuild / Test / parallelExecution := false -// enablePlugins(Antlr4Plugin) - // Run as part of compile task. lazy val compileScalastyle = taskKey[Unit]("compileScalastyle") @@ -78,7 +75,6 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), Antlr4 / antlr4GenListener := true, Antlr4 / antlr4GenVisitor := true, - // antlr4TreatWarningsAsErrors in Antlr4 := true, // Assembly settings assemblyPackageScala / assembleArtifact := false, assembly / assemblyOption ~= { diff --git a/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 b/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 index adcf33225f..e0c579e6f6 100644 --- a/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 +++ b/flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 @@ -104,11 +104,13 @@ nonReserved ; -// Flint Lexer Extension +// Flint lexical tokens SKIPPING : 'SKIPPING'; +// Spark lexical tokens + SEMICOLON: ';'; DOT: '.'; diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala similarity index 83% rename from flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala rename to flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index 1e5aeaf79e..a74585bab1 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommandBuilder.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark.sql -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext import org.apache.spark.sql.catalyst.plans.logical.Command @@ -13,13 +13,13 @@ import org.apache.spark.sql.catalyst.plans.logical.Command /** * Flint Spark AST builder that builds Spark command for Flint index statement. */ -class FlintSparkSqlCommandBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] { +class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] { override def visitDropSkippingIndexStatement( ctx: DropSkippingIndexStatementContext): Command = { FlintSparkSqlCommand { flint => val tableName = ctx.tableName.getText - val indexName = FlintSparkSkippingIndex.getSkippingIndexName(tableName) + val indexName = getSkippingIndexName(tableName) flint.deleteIndex(indexName) Seq.empty } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 5390ee229a..42479e6e6b 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -3,6 +3,28 @@ * SPDX-License-Identifier: Apache-2.0 */ +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.opensearch.flint.spark.sql import org.antlr.v4.runtime._ @@ -32,41 +54,15 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface * Flint command builder. This has to be lazy because Spark.conf in FlintSpark will create * Parser and thus cause stack overflow */ - private val flintCmdBuilder = new FlintSparkSqlCommandBuilder() - - override def parsePlan(sqlText: String): LogicalPlan = { - val flintLexer = new FlintSparkSqlExtensionsLexer( - new UpperCaseCharStream(CharStreams.fromString(sqlText))) - flintLexer.removeErrorListeners() - flintLexer.addErrorListener(ParseErrorListener) - - val tokenStream = new CommonTokenStream(flintLexer) - val flintParser = new FlintSparkSqlExtensionsParser(tokenStream) - flintParser.addParseListener(PostProcessor) - // parser.addParseListener(UnclosedCommentProcessor(command, tokenStream)) - flintParser.removeErrorListeners() - flintParser.addErrorListener(ParseErrorListener) + private val flintAstBuilder = new FlintSparkSqlAstBuilder() + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => try { - val ctx = flintParser.singleStatement() - flintCmdBuilder.visit(ctx) match { - case plan: LogicalPlan => plan - case _ => sparkParser.parsePlan(sqlText) - } + flintAstBuilder.visit(flintParser.singleStatement()) } catch { - case e: ParseException => sparkParser.parsePlan(sqlText) - } - - } - - /* - override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => - flintCmdBuilder.visit(flintParser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => sparkParser.parsePlan(sqlText) + case _: ParseException => sparkParser.parsePlan(sqlText) } } - */ override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) @@ -94,7 +90,7 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface val tokenStream = new CommonTokenStream(lexer) val parser = new FlintSparkSqlExtensionsParser(tokenStream) - parser.addParseListener(PostProcessor) + parser.addParseListener(FlintPostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) @@ -114,7 +110,10 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface toResult(parser) } } catch { - case e: ParseException => throw e + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(sqlText) case e: AnalysisException => val position = Origin(e.line, e.startPosition) throw new ParseException( @@ -137,18 +136,7 @@ class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { override def seek(where: Int): Unit = wrapped.seek(where) override def size(): Int = wrapped.size - override def getText(interval: Interval): String = { - // ANTLR 4.7's CodePointCharStream implementations have bugs when - // getText() is called with an empty stream, or intervals where - // the start > end. See - // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix - // that is not yet in a released ANTLR artifact. - if (size() > 0 && (interval.b - interval.a >= 0)) { - wrapped.getText(interval) - } else { - "" - } - } + override def getText(interval: Interval): String = wrapped.getText(interval) override def LA(i: Int): Int = { val la = wrapped.LA(i) @@ -157,7 +145,7 @@ class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { } } -case object PostProcessor extends FlintSparkSqlExtensionsBaseListener { +case object FlintPostProcessor extends FlintSparkSqlExtensionsBaseListener { /** Remove the back ticks from an Identifier. */ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { @@ -187,3 +175,50 @@ case object PostProcessor extends FlintSparkSqlExtensionsBaseListener { parent.addChild(new TerminalNodeImpl(f(newToken))) } } + +/* +/** + * The ParseErrorListener converts parse errors into AnalysisExceptions. + */ +case object FlintParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val (start, stop) = offendingSymbol match { + case token: CommonToken => + val start = Origin(Some(line), Some(token.getCharPositionInLine)) + val length = token.getStopIndex - token.getStartIndex + 1 + val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) + (start, stop) + case _ => + val start = Origin(Some(line), Some(charPositionInLine)) + (start, start) + } + throw new FlintParseException(None, msg, start, stop) + } +} + +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class FlintParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin, + errorClass: Option[String] = None, + messageParameters: Array[String] = Array.empty) + extends AnalysisException( + message, + start.line, + start.startPosition, + None, + None, + errorClass, + messageParameters) +*/ diff --git a/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala b/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala deleted file mode 100644 index 917c820581..0000000000 --- a/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.spark.sql - -import org.apache.spark.SparkFunSuite - -class FlintSparkSqlParserSuite extends SparkFunSuite { - - test("Skipping index statement should pass") { - } -} From 1a67cea9a94b5befb6bf07675bc5cab6a1292014 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 20 Jun 2023 10:52:51 -0700 Subject: [PATCH 4/5] Polish comments for PR review Signed-off-by: Chen Dai --- flint/build.sbt | 2 +- .../spark/sql/FlintSparkSqlAstBuilder.scala | 2 +- .../spark/sql/FlintSparkSqlCommand.scala | 8 +-- .../flint/spark/sql/FlintSparkSqlParser.scala | 59 ++----------------- .../flint/spark/FlintSparkSqlSuite.scala | 4 +- 5 files changed, 11 insertions(+), 64 deletions(-) diff --git a/flint/build.sbt b/flint/build.sbt index de198dfbab..e0f5ffa185 100644 --- a/flint/build.sbt +++ b/flint/build.sbt @@ -71,7 +71,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) "com.github.sbt" % "junit-interface" % "0.13.3" % "test"), libraryDependencies ++= deps(sparkVersion), // ANTLR settings - Antlr4 / antlr4Version := "4.7", + Antlr4 / antlr4Version := "4.8", Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), Antlr4 / antlr4GenListener := true, Antlr4 / antlr4GenVisitor := true, diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index a74585bab1..aef18051ba 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -18,7 +18,7 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command override def visitDropSkippingIndexStatement( ctx: DropSkippingIndexStatementContext): Command = { FlintSparkSqlCommand { flint => - val tableName = ctx.tableName.getText + val tableName = ctx.tableName.getText // TODO: handle schema name val indexName = getSkippingIndexName(tableName) flint.deleteIndex(indexName) Seq.empty diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala index fafe681172..ca39a293c0 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlCommand.scala @@ -8,12 +8,10 @@ package org.opensearch.flint.spark.sql import org.opensearch.flint.spark.FlintSpark import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.types.StringType /** - * Flint Spark SQL command. + * Flint Spark SQL DDL command. * * Note that currently Flint SQL layer is thin with all core logic in FlintSpark. May create * separate command for each Flint SQL statement in future as needed. @@ -25,7 +23,3 @@ case class FlintSparkSqlCommand(block: FlintSpark => Seq[Row]) extends LeafRunna override def run(sparkSession: SparkSession): Seq[Row] = block(new FlintSpark(sparkSession)) } - -object FlintSparkSqlCommand { - val DEFAULT_OUTPUT = Seq(AttributeReference("Result", StringType, nullable = true)()) -} diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 42479e6e6b..0fa146b9de 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -42,24 +42,21 @@ import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, StructType} /** - * Flint SQL parser that extends Spark SQL parser to parse Flint command first and fall back to - * Spark parser for unrecognized statement. + * Flint SQL parser that extends Spark SQL parser with Flint SQL statements. * * @param sparkParser * Spark SQL parser */ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface { - /** - * Flint command builder. This has to be lazy because Spark.conf in FlintSpark will create - * Parser and thus cause stack overflow - */ + /** Flint AST builder. */ private val flintAstBuilder = new FlintSparkSqlAstBuilder() override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => try { flintAstBuilder.visit(flintParser.singleStatement()) } catch { + // Fall back to Spark parse plan logic if flint cannot parse case _: ParseException => sparkParser.parsePlan(sqlText) } } @@ -82,6 +79,9 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) + + // Starting from here is copied and modified from Spark 3.3.1 + protected def parse[T](sqlText: String)(toResult: FlintSparkSqlExtensionsParser => T): T = { val lexer = new FlintSparkSqlExtensionsLexer( new UpperCaseCharStream(CharStreams.fromString(sqlText))) @@ -175,50 +175,3 @@ case object FlintPostProcessor extends FlintSparkSqlExtensionsBaseListener { parent.addChild(new TerminalNodeImpl(f(newToken))) } } - -/* -/** - * The ParseErrorListener converts parse errors into AnalysisExceptions. - */ -case object FlintParseErrorListener extends BaseErrorListener { - override def syntaxError( - recognizer: Recognizer[_, _], - offendingSymbol: scala.Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException): Unit = { - val (start, stop) = offendingSymbol match { - case token: CommonToken => - val start = Origin(Some(line), Some(token.getCharPositionInLine)) - val length = token.getStopIndex - token.getStartIndex + 1 - val stop = Origin(Some(line), Some(token.getCharPositionInLine + length)) - (start, stop) - case _ => - val start = Origin(Some(line), Some(charPositionInLine)) - (start, start) - } - throw new FlintParseException(None, msg, start, stop) - } -} - -/** - * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It - * contains fields and an extended error message that make reporting and diagnosing errors easier. - */ -class FlintParseException( - val command: Option[String], - message: String, - val start: Origin, - val stop: Origin, - errorClass: Option[String] = None, - messageParameters: Array[String] = Array.empty) - extends AnalysisException( - message, - start.line, - start.startPosition, - None, - None, - errorClass, - messageParameters) -*/ diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala index aeb51ecad0..69e140b6a8 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala @@ -48,14 +48,14 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite |""".stripMargin) } - test("drop index test") { + test("drop skipping index") { flint .skippingIndex() .onTable(testTable) .addPartitionIndex("year") .create() - sql(s"DROP SKIPPING INDEX ON $testTable").show + sql(s"DROP SKIPPING INDEX ON $testTable") flint.describeIndex(testIndex) shouldBe empty } From 32d73203a533fd35e5dca38fe46d390df337fccd Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 20 Jun 2023 14:18:41 -0700 Subject: [PATCH 5/5] Fix compile error in IT Signed-off-by: Chen Dai --- .../scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala index 69e140b6a8..fffe1fe295 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSqlSuite.scala @@ -52,7 +52,7 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite flint .skippingIndex() .onTable(testTable) - .addPartitionIndex("year") + .addPartitions("year") .create() sql(s"DROP SKIPPING INDEX ON $testTable")