From a45da2af51716147a1eb15957a47cf0ca634e24a Mon Sep 17 00:00:00 2001 From: apstndb <803393+apstndb@users.noreply.github.com> Date: Mon, 27 Mar 2023 13:54:20 +0900 Subject: [PATCH] fix comment handling (#150) * fix comment handling * update whitespaces * use gsqlsep * update Go version to 1.18 * update to use InputStatement.StripComments * remove unused function * reflect review comments --- .github/workflows/run-tests.yaml | 2 +- cli.go | 4 +- cli_test.go | 34 ++- go.mod | 4 +- go.sum | 14 + separator.go | 286 ++------------------ separator_test.go | 436 ++++++++----------------------- statement.go | 92 +++---- statement_test.go | 5 + 9 files changed, 228 insertions(+), 649 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index ec703d9..a9d5399 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -33,7 +33,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: - go-version: '1.17' + go-version: '1.18' - run: go version - run: make setup-emulator env: diff --git a/cli.go b/cli.go index b29f06f..b0c4e36 100644 --- a/cli.go +++ b/cli.go @@ -135,7 +135,7 @@ func (c *Cli) RunInteractive() int { continue } - stmt, err := BuildStatement(input.statement) + stmt, err := BuildStatementWithComments(input.statementWithoutComments, input.statement) if err != nil { c.PrintInteractiveError(err) continue @@ -467,7 +467,7 @@ func buildCommands(input string) ([]*command, error) { var cmds []*command var pendingDdls []string for _, separated := range separateInput(input) { - stmt, err := BuildStatement(separated.statement) + stmt, err := BuildStatementWithComments(separated.statementWithoutComments, separated.statement) if err != nil { return nil, err } diff --git a/cli_test.go b/cli_test.go index 99f5a61..b8b1697 100644 --- a/cli_test.go +++ b/cli_test.go @@ -65,6 +65,20 @@ func TestBuildCommands(t *testing.T) { {&BulkDdlStatement{[]string{"DROP TABLE t1", "DROP TABLE t2"}}, false}, {&SelectStatement{"SELECT 1"}, false}, }}, + { + ` + CREATE TABLE t1(pk INT64 /* NOT NULL*/, col INT64) PRIMARY KEY(pk); + INSERT t1(pk/*, col*/) VALUES(1/*, 2*/); + UPDATE t1 SET col = /* pk + */ col + 1 WHERE TRUE; + DELETE t1 WHERE TRUE /* AND pk = 1 */; + SELECT 0x1/**/A`, + []*command{ + {&BulkDdlStatement{[]string{"CREATE TABLE t1(pk INT64 , col INT64) PRIMARY KEY(pk)"}}, false}, + {&DmlStatement{"INSERT t1(pk/*, col*/) VALUES(1/*, 2*/)"}, false}, + {&DmlStatement{"UPDATE t1 SET col = /* pk + */ col + 1 WHERE TRUE"}, false}, + {&DmlStatement{"DELETE t1 WHERE TRUE /* AND pk = 1 */"}, false}, + {&SelectStatement{"SELECT 0x1/**/A"}, false}, + }}, } for _, test := range tests { @@ -90,32 +104,36 @@ func TestReadInteractiveInput(t *testing.T) { desc: "single line", input: "SELECT 1;\n", want: &inputStatement{ - statement: "SELECT 1", - delim: delimiterHorizontal, + statement: "SELECT 1", + statementWithoutComments: "SELECT 1", + delim: delimiterHorizontal, }, }, { desc: "multi lines", input: "SELECT\n* FROM\n t1\n;\n", want: &inputStatement{ - statement: "SELECT\n* FROM\n t1", - delim: delimiterHorizontal, + statement: "SELECT\n* FROM\n t1", + statementWithoutComments: "SELECT\n* FROM\n t1", + delim: delimiterHorizontal, }, }, { desc: "multi lines with vertical delimiter", input: "SELECT\n* FROM\n t1\\G\n", want: &inputStatement{ - statement: "SELECT\n* FROM\n t1", - delim: delimiterVertical, + statement: "SELECT\n* FROM\n t1", + statementWithoutComments: "SELECT\n* FROM\n t1", + delim: delimiterVertical, }, }, { desc: "multi lines with multiple comments", input: "SELECT\n/* comment */1,\n# comment\n2;\n", want: &inputStatement{ - statement: "SELECT\n1,\n2", - delim: delimiterHorizontal, + statement: "SELECT\n/* comment */1,\n# comment\n2", + statementWithoutComments: "SELECT\n 1,\n 2", + delim: delimiterHorizontal, }, }, { diff --git a/go.mod b/go.mod index d03366f..db1d9d9 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ module github.com/cloudspannerecosystem/spanner-cli -go 1.17 +go 1.18 require ( cloud.google.com/go v0.110.0 cloud.google.com/go/spanner v1.44.0 + github.com/apstndb/gsqlsep v0.0.0-20230324124551-0e8335710080 github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e github.com/google/go-cmp v0.5.9 github.com/jessevdk/go-flags v1.4.0 @@ -35,6 +36,7 @@ require ( github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/mattn/go-runewidth v0.0.8 // indirect go.opencensus.io v0.24.0 // indirect + golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/oauth2 v0.6.0 // indirect golang.org/x/sys v0.6.0 // indirect diff --git a/go.sum b/go.sum index aab348b..cb596bd 100644 --- a/go.sum +++ b/go.sum @@ -529,6 +529,12 @@ github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/v10 v10.0.1/go.mod h1:YvhnlEePVnBS4+0z3fhPfUy7W1Ikj0Ih0vcRo/gZ1M0= github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= +github.com/apstndb/gsqlsep v0.0.0-20230324010854-4f3bbc6e73c7 h1:asAMA/EjcHpdwICAAdqeruTgY+snsS+VFNLbMusf6Pg= +github.com/apstndb/gsqlsep v0.0.0-20230324010854-4f3bbc6e73c7/go.mod h1:NQogaK8AOkyzXEXMqGvc6hV0SKO8cUkcqqXJimyvTlw= +github.com/apstndb/gsqlsep v0.0.0-20230324122652-aa2e9a53e0d0 h1:xwCIoA+db34vipWGQzYFUhUCMHRc/wvhsjfPRolasXQ= +github.com/apstndb/gsqlsep v0.0.0-20230324122652-aa2e9a53e0d0/go.mod h1:NQogaK8AOkyzXEXMqGvc6hV0SKO8cUkcqqXJimyvTlw= +github.com/apstndb/gsqlsep v0.0.0-20230324124551-0e8335710080 h1:L1KrVddvtrBT7T2/408TOPu9xNr2I/OCJUUOugMezNk= +github.com/apstndb/gsqlsep v0.0.0-20230324124551-0e8335710080/go.mod h1:NQogaK8AOkyzXEXMqGvc6hV0SKO8cUkcqqXJimyvTlw= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -806,6 +812,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -821,6 +828,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw= +golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -860,6 +869,7 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -912,6 +922,7 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= @@ -1035,6 +1046,7 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1043,6 +1055,7 @@ golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= @@ -1127,6 +1140,7 @@ golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/separator.go b/separator.go index a7ff484..6246ad2 100644 --- a/separator.go +++ b/separator.go @@ -16,284 +16,28 @@ package main -import ( - "strings" -) - -type delimiter int +import "github.com/apstndb/gsqlsep" const ( - delimiterUndefined delimiter = iota - delimiterHorizontal - delimiterVertical + delimiterUndefined = "" + delimiterHorizontal = ";" + delimiterVertical = `\G` ) type inputStatement struct { - statement string - delim delimiter -} - -func (d delimiter) String() string { - switch d { - case delimiterUndefined: - return "" - case delimiterHorizontal: - return ";" - case delimiterVertical: - return `\G` - } - return "" + statement string + statementWithoutComments string + delim string } func separateInput(input string) []inputStatement { - return newSeparator(input).separate() -} - -type separator struct { - str []rune // remaining input - sb *strings.Builder -} - -func newSeparator(s string) *separator { - return &separator{ - str: []rune(s), - sb: &strings.Builder{}, - } -} - -func (s *separator) consumeRawString() { - // consume 'r' or 'R' - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - - delim := s.consumeStringDelimiter() - s.consumeStringContent(delim, true) -} - -func (s *separator) consumeBytesString() { - // consume 'b' or 'B' - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - - delim := s.consumeStringDelimiter() - s.consumeStringContent(delim, false) -} - -func (s *separator) consumeRawBytesString() { - // consume 'rb', 'Rb', 'rB', or 'RB' - s.sb.WriteRune(s.str[0]) - s.sb.WriteRune(s.str[1]) - s.str = s.str[2:] - - delim := s.consumeStringDelimiter() - s.consumeStringContent(delim, true) -} - -func (s *separator) consumeString() { - delim := s.consumeStringDelimiter() - s.consumeStringContent(delim, false) -} - -func (s *separator) consumeStringContent(delim string, raw bool) { - var i int - for i < len(s.str) { - // check end of string - switch { - // check single-quoted delim - case len(delim) == 1 && string(s.str[i]) == delim: - s.str = s.str[i+1:] - s.sb.WriteString(delim) - return - // check triple-quoted delim - case len(delim) == 3 && len(s.str) >= i+3 && string(s.str[i:i+3]) == delim: - s.str = s.str[i+3:] - s.sb.WriteString(delim) - return - } - - // escape sequence - if s.str[i] == '\\' { - if raw { - // raw string treats escape character as backslash - s.sb.WriteRune('\\') - i++ - continue - } - - // invalid escape sequence - if i+1 >= len(s.str) { - s.sb.WriteRune('\\') - return - } - - s.sb.WriteRune('\\') - s.sb.WriteRune(s.str[i+1]) - i += 2 - continue - } - s.sb.WriteRune(s.str[i]) - i++ - } - s.str = s.str[i:] - return -} - -func (s *separator) consumeStringDelimiter() string { - c := s.str[0] - // check triple-quoted delim - if len(s.str) >= 3 && s.str[1] == c && s.str[2] == c { - delim := strings.Repeat(string(c), 3) - s.sb.WriteString(delim) - s.str = s.str[3:] - return delim - } - s.str = s.str[1:] - s.sb.WriteRune(c) - return string(c) -} - -func (s *separator) skipComments() { - var i int - for i < len(s.str) { - var terminate string - if s.str[i] == '#' { - // single line comment "#" - terminate = "\n" - i++ - } else if i+1 < len(s.str) && s.str[i] == '-' && s.str[i+1] == '-' { - // single line comment "--" - terminate = "\n" - i += 2 - } else if i+1 < len(s.str) && s.str[i] == '/' && s.str[i+1] == '*' { - // multi line comments "/* */" - // NOTE: Nested multiline comments are not supported in Spanner. - // https://cloud.google.com/spanner/docs/lexical#multiline_comments - terminate = "*/" - i += 2 - } - - // no comment found - if terminate == "" { - return - } - - // not terminated, but end of string - if i >= len(s.str) { - s.str = s.str[len(s.str):] - return - } - - for ; i < len(s.str); i++ { - if l := len(terminate); l == 1 { - if string(s.str[i]) == terminate { - s.str = s.str[i+1:] - i = 0 - break - } - } else if l == 2 { - if i+1 < len(s.str) && string(s.str[i:i+2]) == terminate { - s.str = s.str[i+2:] - i = 0 - break - } - } - } - - // not terminated, but end of string - if i >= len(s.str) { - s.str = s.str[len(s.str):] - return - } - } -} - -// separate separates input string into multiple Spanner statements. -// This does not validate syntax of statements. -// -// NOTE: Logic for parsing a statement is mostly taken from spansql. -// https://github.com/googleapis/google-cloud-go/blob/master/spanner/spansql/parser.go -func (s *separator) separate() []inputStatement { - var statements []inputStatement - for len(s.str) > 0 { - s.skipComments() - if len(s.str) == 0 { - break - } - - switch s.str[0] { - // possibly string literal - case '"', '\'', 'r', 'R', 'b', 'B': - // valid string prefix: "b", "B", "r", "R", "br", "bR", "Br", "BR" - // https://cloud.google.com/spanner/docs/lexical#string_and_bytes_literals - raw, bytes, str := false, false, false - for i := 0; i < 3 && i < len(s.str); i++ { - switch { - case !raw && (s.str[i] == 'r' || s.str[i] == 'R'): - raw = true - continue - case !bytes && (s.str[i] == 'b' || s.str[i] == 'B'): - bytes = true - continue - case s.str[i] == '"' || s.str[i] == '\'': - str = true - switch { - case raw && bytes: - s.consumeRawBytesString() - case raw: - s.consumeRawString() - case bytes: - s.consumeBytesString() - default: - s.consumeString() - } - } - break - } - if !str { - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - } - // quoted identifier - case '`': - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - s.consumeStringContent("`", false) - // horizontal delim - case ';': - statements = append(statements, inputStatement{ - statement: strings.TrimSpace(s.sb.String()), - delim: delimiterHorizontal, - }) - s.sb.Reset() - s.str = s.str[1:] - // possibly vertical delim - case '\\': - if len(s.str) >= 2 && s.str[1] == 'G' { - statements = append(statements, inputStatement{ - statement: strings.TrimSpace(s.sb.String()), - delim: delimiterVertical, - }) - s.sb.Reset() - s.str = s.str[2:] - continue - } - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - default: - s.sb.WriteRune(s.str[0]) - s.str = s.str[1:] - } - } - - // flush remained - if s.sb.Len() > 0 { - if str := strings.TrimSpace(s.sb.String()); len(str) > 0 { - statements = append(statements, inputStatement{ - statement: str, - delim: delimiterUndefined, - }) - s.sb.Reset() - } + var result []inputStatement + for _, stmt := range gsqlsep.SeparateInputPreserveComments(input, delimiterVertical) { + result = append(result, inputStatement{ + statement: stmt.Statement, + statementWithoutComments: stmt.StripComments().Statement, + delim: stmt.Terminator, + }) } - return statements + return result } diff --git a/separator_test.go b/separator_test.go index 2345f12..e631f38 100644 --- a/separator_test.go +++ b/separator_test.go @@ -22,271 +22,6 @@ import ( "github.com/google/go-cmp/cmp" ) -func TestSeparatorSkipComments(t *testing.T) { - for _, tt := range []struct { - desc string - str string - wantRemained string - }{ - { - desc: "single line comment (#)", - str: "# SELECT 1;\n", - wantRemained: "", - }, - { - desc: "single line comment (--)", - str: "-- SELECT 1;\n", - wantRemained: "", - }, - { - desc: "multiline comment", - str: "/* SELECT\n1; */", - wantRemained: "", - }, - { - desc: "single line comment (#) and statement", - str: "# SELECT 1;\nSELECT 2;", - wantRemained: "SELECT 2;", - }, - { - desc: "single line comment (--) and statement", - str: "-- SELECT 1;\nSELECT 2;", - wantRemained: "SELECT 2;", - }, - { - desc: "multiline comment and statement", - str: "/* SELECT\n1; */ SELECT 2;", - wantRemained: " SELECT 2;", - }, - { - desc: "single line comment (#) not terminated", - str: "# SELECT 1", - wantRemained: "", - }, - { - desc: "single line comment (--) not terminated", - str: "-- SELECT 1", - wantRemained: "", - }, - { - desc: "multiline comment not terminated", - str: "/* SELECT\n1;", - wantRemained: "", - }, - { - desc: "not comments", - str: "SELECT 1;", - wantRemained: "SELECT 1;", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - s := newSeparator(tt.str) - s.skipComments() - - remained := string(s.str) - if remained != tt.wantRemained { - t.Errorf("consumeComments(%q) remained %q, but want = %q", tt.str, remained, tt.wantRemained) - } - }) - } -} - -func TestSeparatorConsumeString(t *testing.T) { - for _, tt := range []struct { - desc string - str string - want string - wantRemained string - }{ - { - desc: "double quoted string", - str: `"test" WHERE`, - want: `"test"`, - wantRemained: " WHERE", - }, - { - desc: "single quoted string", - str: `'test' WHERE`, - want: `'test'`, - wantRemained: " WHERE", - }, - { - desc: "tripled quoted string", - str: `"""test""" WHERE`, - want: `"""test"""`, - wantRemained: " WHERE", - }, - { - desc: "quoted string with escape sequence", - str: `"te\"st" WHERE`, - want: `"te\"st"`, - wantRemained: " WHERE", - }, - { - desc: "double quoted empty string", - str: `"" WHERE`, - want: `""`, - wantRemained: " WHERE", - }, - { - desc: "tripled quoted string with new line", - str: "'''t\ne\ns\nt''' WHERE", - want: "'''t\ne\ns\nt'''", - wantRemained: " WHERE", - }, - { - desc: "triple quoted empty string", - str: `"""""" WHERE`, - want: `""""""`, - wantRemained: " WHERE", - }, - { - desc: "multi-byte character in string", - str: `"テスト" WHERE`, - want: `"テスト"`, - wantRemained: " WHERE", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - s := newSeparator(tt.str) - s.consumeString() - - got := s.sb.String() - if got != tt.want { - t.Errorf("consumeString(%q) = %q, but want = %q", tt.str, got, tt.want) - } - - remained := string(s.str) - if remained != tt.wantRemained { - t.Errorf("consumeString(%q) remained %q, but want = %q", tt.str, remained, tt.wantRemained) - } - }) - } -} - -func TestSeparatorConsumeRawString(t *testing.T) { - for _, tt := range []struct { - desc string - str string - want string - wantRemained string - }{ - { - desc: "raw string (r)", - str: `r"test" WHERE`, - want: `r"test"`, - wantRemained: " WHERE", - }, - { - desc: "raw string (R)", - str: `R'test' WHERE`, - want: `R'test'`, - wantRemained: " WHERE", - }, - { - desc: "raw string with escape sequence", - str: `r"test\abc" WHERE`, - want: `r"test\abc"`, - wantRemained: " WHERE", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - s := newSeparator(tt.str) - s.consumeRawString() - - got := s.sb.String() - if got != tt.want { - t.Errorf("consumeRawString(%q) = %q, but want = %q", tt.str, got, tt.want) - } - - remained := string(s.str) - if remained != tt.wantRemained { - t.Errorf("consumeRawString(%q) remained %q, but want = %q", tt.str, remained, tt.wantRemained) - } - }) - } -} - -func TestSeparatorConsumeBytesString(t *testing.T) { - for _, tt := range []struct { - desc string - str string - want string - wantRemained string - }{ - { - desc: "bytes string (b)", - str: `b"test" WHERE`, - want: `b"test"`, - wantRemained: " WHERE", - }, - { - desc: "bytes string (B)", - str: `B'test' WHERE`, - want: `B'test'`, - wantRemained: " WHERE", - }, - { - desc: "bytes string with hex escape", - str: `b"\x12\x34\x56" WHERE`, - want: `b"\x12\x34\x56"`, - wantRemained: " WHERE", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - s := newSeparator(tt.str) - s.consumeBytesString() - - got := s.sb.String() - if got != tt.want { - t.Errorf("consumeBytesString(%q) = %q, but want = %q", tt.str, got, tt.want) - } - - remained := string(s.str) - if remained != tt.wantRemained { - t.Errorf("consumeBytesString(%q) remained %q, but want = %q", tt.str, remained, tt.wantRemained) - } - }) - } -} - -func TestSeparatorConsumeRawBytesString(t *testing.T) { - for _, tt := range []struct { - desc string - str string - want string - wantRemained string - }{ - { - desc: "raw bytes string (rb)", - str: `rb"test" WHERE`, - want: `rb"test"`, - wantRemained: " WHERE", - }, - { - desc: "raw bytes string (RB)", - str: `RB"test" WHERE`, - want: `RB"test"`, - wantRemained: " WHERE", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - s := newSeparator(tt.str) - s.consumeRawBytesString() - - got := s.sb.String() - if got != tt.want { - t.Errorf("consumeRawBytesString(%q) = %q, but want = %q", tt.str, got, tt.want) - } - - remained := string(s.str) - if remained != tt.wantRemained { - t.Errorf("consumeRawBytesString(%q) remained %q, but want = %q", tt.str, remained, tt.wantRemained) - } - }) - } -} - func TestSeparateInput(t *testing.T) { for _, tt := range []struct { desc string @@ -298,8 +33,9 @@ func TestSeparateInput(t *testing.T) { input: `SELECT "123";`, want: []inputStatement{ { - statement: `SELECT "123"`, - delim: delimiterHorizontal, + statement: `SELECT "123"`, + statementWithoutComments: `SELECT "123"`, + delim: delimiterHorizontal, }, }, }, @@ -308,12 +44,14 @@ func TestSeparateInput(t *testing.T) { input: `SELECT "123"; SELECT "456";`, want: []inputStatement{ { - statement: `SELECT "123"`, - delim: delimiterHorizontal, + statement: `SELECT "123"`, + statementWithoutComments: `SELECT "123"`, + delim: delimiterHorizontal, }, { - statement: `SELECT "456"`, - delim: delimiterHorizontal, + statement: `SELECT "456"`, + statementWithoutComments: `SELECT "456"`, + delim: delimiterHorizontal, }, }, }, @@ -322,12 +60,14 @@ func TestSeparateInput(t *testing.T) { input: "SELECT `1`, `2`; SELECT `3`, `4`;", want: []inputStatement{ { - statement: "SELECT `1`, `2`", - delim: delimiterHorizontal, + statement: "SELECT `1`, `2`", + statementWithoutComments: "SELECT `1`, `2`", + delim: delimiterHorizontal, }, { - statement: "SELECT `3`, `4`", - delim: delimiterHorizontal, + statement: "SELECT `3`, `4`", + statementWithoutComments: "SELECT `3`, `4`", + delim: delimiterHorizontal, }, }, }, @@ -336,8 +76,9 @@ func TestSeparateInput(t *testing.T) { input: `SELECT "123"\G`, want: []inputStatement{ { - statement: `SELECT "123"`, - delim: delimiterVertical, + statement: `SELECT "123"`, + statementWithoutComments: `SELECT "123"`, + delim: delimiterVertical, }, }, }, @@ -346,16 +87,19 @@ func TestSeparateInput(t *testing.T) { input: `SELECT "123"; SELECT "456"\G SELECT "789";`, want: []inputStatement{ { - statement: `SELECT "123"`, - delim: delimiterHorizontal, + statement: `SELECT "123"`, + statementWithoutComments: `SELECT "123"`, + delim: delimiterHorizontal, }, { - statement: `SELECT "456"`, - delim: delimiterVertical, + statement: `SELECT "456"`, + statementWithoutComments: `SELECT "456"`, + delim: delimiterVertical, }, { - statement: `SELECT "789"`, - delim: delimiterHorizontal, + statement: `SELECT "789"`, + statementWithoutComments: `SELECT "789"`, + delim: delimiterHorizontal, }, }, }, @@ -364,12 +108,14 @@ func TestSeparateInput(t *testing.T) { input: `SELECT * FROM t1 WHERE id = "123" AND "456"; DELETE FROM t2 WHERE true;`, want: []inputStatement{ { - statement: `SELECT * FROM t1 WHERE id = "123" AND "456"`, - delim: delimiterHorizontal, + statement: `SELECT * FROM t1 WHERE id = "123" AND "456"`, + statementWithoutComments: `SELECT * FROM t1 WHERE id = "123" AND "456"`, + delim: delimiterHorizontal, }, { - statement: `DELETE FROM t2 WHERE true`, - delim: delimiterHorizontal, + statement: `DELETE FROM t2 WHERE true`, + statementWithoutComments: `DELETE FROM t2 WHERE true`, + delim: delimiterHorizontal, }, }, }, @@ -378,8 +124,9 @@ func TestSeparateInput(t *testing.T) { input: `SELECT 1; ;`, want: []inputStatement{ { - statement: `SELECT 1`, - delim: delimiterHorizontal, + statement: `SELECT 1`, + statementWithoutComments: `SELECT 1`, + delim: delimiterHorizontal, }, { statement: ``, @@ -392,12 +139,14 @@ func TestSeparateInput(t *testing.T) { input: "SELECT 1;\n SELECT 2\\G\n", want: []inputStatement{ { - statement: `SELECT 1`, - delim: delimiterHorizontal, + statement: `SELECT 1`, + statementWithoutComments: `SELECT 1`, + delim: delimiterHorizontal, }, { - statement: `SELECT 2`, - delim: delimiterVertical, + statement: `SELECT 2`, + statementWithoutComments: `SELECT 2`, + delim: delimiterVertical, }, }, }, @@ -406,12 +155,14 @@ func TestSeparateInput(t *testing.T) { input: `SELECT "1;2;3"; SELECT 'TL;DR';`, want: []inputStatement{ { - statement: `SELECT "1;2;3"`, - delim: delimiterHorizontal, + statement: `SELECT "1;2;3"`, + statementWithoutComments: `SELECT "1;2;3"`, + delim: delimiterHorizontal, }, { - statement: `SELECT 'TL;DR'`, - delim: delimiterHorizontal, + statement: `SELECT 'TL;DR'`, + statementWithoutComments: `SELECT 'TL;DR'`, + delim: delimiterHorizontal, }, }, }, @@ -420,12 +171,14 @@ func TestSeparateInput(t *testing.T) { input: `SELECT r"1\G2\G3"\G SELECT r'4\G5\G6'\G`, want: []inputStatement{ { - statement: `SELECT r"1\G2\G3"`, - delim: delimiterVertical, + statement: `SELECT r"1\G2\G3"`, + statementWithoutComments: `SELECT r"1\G2\G3"`, + delim: delimiterVertical, }, { - statement: `SELECT r'4\G5\G6'`, - delim: delimiterVertical, + statement: `SELECT r'4\G5\G6'`, + statementWithoutComments: `SELECT r'4\G5\G6'`, + delim: delimiterVertical, }, }, }, @@ -434,12 +187,14 @@ func TestSeparateInput(t *testing.T) { input: "SELECT `1;2`; SELECT `3;4`;", want: []inputStatement{ { - statement: "SELECT `1;2`", - delim: delimiterHorizontal, + statement: "SELECT `1;2`", + statementWithoutComments: "SELECT `1;2`", + delim: delimiterHorizontal, }, { - statement: "SELECT `3;4`", - delim: delimiterHorizontal, + statement: "SELECT `3;4`", + statementWithoutComments: "SELECT `3;4`", + delim: delimiterHorizontal, }, }, }, @@ -448,12 +203,14 @@ func TestSeparateInput(t *testing.T) { input: "SELECT '123'\n; SELECT '456'\n\\G", want: []inputStatement{ { - statement: `SELECT '123'`, - delim: delimiterHorizontal, + statement: `SELECT '123'`, + statementWithoutComments: `SELECT '123'`, + delim: delimiterHorizontal, }, { - statement: `SELECT '456'`, - delim: delimiterVertical, + statement: `SELECT '456'`, + statementWithoutComments: `SELECT '456'`, + delim: delimiterVertical, }, }, }, @@ -462,41 +219,54 @@ func TestSeparateInput(t *testing.T) { input: "CREATE t1 (\nId INT64 NOT NULL\n) PRIMARY KEY (Id);", want: []inputStatement{ { - statement: "CREATE t1 (\nId INT64 NOT NULL\n) PRIMARY KEY (Id)", - delim: delimiterHorizontal, + statement: "CREATE t1 (\nId INT64 NOT NULL\n) PRIMARY KEY (Id)", + statementWithoutComments: "CREATE t1 (\nId INT64 NOT NULL\n) PRIMARY KEY (Id)", + delim: delimiterHorizontal, }, }, }, + { desc: `statement with multiple comments`, input: "# comment;\nSELECT /* comment */ 1; --comment\nSELECT 2;/* comment */", want: []inputStatement{ { - statement: "SELECT 1", - delim: delimiterHorizontal, + statement: "# comment;\nSELECT /* comment */ 1", + statementWithoutComments: "SELECT 1", + delim: delimiterHorizontal, }, { - statement: "SELECT 2", - delim: delimiterHorizontal, + statement: "--comment\nSELECT 2", + statementWithoutComments: "SELECT 2", + delim: delimiterHorizontal, + }, + { + statement: "/* comment */", }, }, }, { desc: `only comments`, input: "# comment;\n/* comment */--comment\n/* comment */", - want: nil, + want: []inputStatement{ + { + statement: "# comment;\n/* comment */--comment\n/* comment */", + }, + }, }, { desc: `second query ends in the middle of string`, input: `SELECT "123"; SELECT "45`, want: []inputStatement{ { - statement: `SELECT "123"`, - delim: delimiterHorizontal, + statement: `SELECT "123"`, + statementWithoutComments: `SELECT "123"`, + delim: delimiterHorizontal, }, { - statement: `SELECT "45`, - delim: delimiterUndefined, + statement: `SELECT "45`, + statementWithoutComments: `SELECT "45`, + delim: delimiterUndefined, }, }, }, @@ -505,8 +275,30 @@ func TestSeparateInput(t *testing.T) { input: `a"""""""""'''''''''b`, want: []inputStatement{ { - statement: `a"""""""""'''''''''b`, - delim: delimiterUndefined, + statement: `a"""""""""'''''''''b`, + statementWithoutComments: `a"""""""""'''''''''b`, + delim: delimiterUndefined, + }, + }, + }, + { + desc: `statement with multiple comments`, + input: "SELECT 0x1/* comment */A; SELECT 0x2--\nB; SELECT 0x3#\nC", + want: []inputStatement{ + { + statement: "SELECT 0x1/* comment */A", + statementWithoutComments: "SELECT 0x1 A", + delim: delimiterHorizontal, + }, + { + statement: "SELECT 0x2--\nB", + statementWithoutComments: "SELECT 0x2 B", + delim: delimiterHorizontal, + }, + { + statement: "SELECT 0x3#\nC", + statementWithoutComments: "SELECT 0x3 C", + delim: delimiterUndefined, }, }, }, diff --git a/statement.go b/statement.go index 76fdbfc..e3924e2 100644 --- a/statement.go +++ b/statement.go @@ -127,43 +127,47 @@ var ( ) func BuildStatement(input string) (Statement, error) { + return BuildStatementWithComments(input, input) +} + +func BuildStatementWithComments(stripped, raw string) (Statement, error) { switch { - case exitRe.MatchString(input): + case exitRe.MatchString(stripped): return &ExitStatement{}, nil - case useRe.MatchString(input): - matched := useRe.FindStringSubmatch(input) + case useRe.MatchString(stripped): + matched := useRe.FindStringSubmatch(stripped) return &UseStatement{Database: unquoteIdentifier(matched[1]), Role: unquoteIdentifier(matched[2])}, nil - case selectRe.MatchString(input): - return &SelectStatement{Query: input}, nil - case createDatabaseRe.MatchString(input): - return &CreateDatabaseStatement{CreateStatement: input}, nil - case createRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case dropDatabaseRe.MatchString(input): - matched := dropDatabaseRe.FindStringSubmatch(input) + case selectRe.MatchString(stripped): + return &SelectStatement{Query: raw}, nil + case createDatabaseRe.MatchString(stripped): + return &CreateDatabaseStatement{CreateStatement: stripped}, nil + case createRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case dropDatabaseRe.MatchString(stripped): + matched := dropDatabaseRe.FindStringSubmatch(stripped) return &DropDatabaseStatement{DatabaseId: unquoteIdentifier(matched[1])}, nil - case dropRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case alterRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case grantRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case revokeRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case truncateTableRe.MatchString(input): - matched := truncateTableRe.FindStringSubmatch(input) + case dropRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case alterRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case grantRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case revokeRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case truncateTableRe.MatchString(stripped): + matched := truncateTableRe.FindStringSubmatch(stripped) return &TruncateTableStatement{Table: unquoteIdentifier(matched[1])}, nil - case analyzeRe.MatchString(input): - return &DdlStatement{Ddl: input}, nil - case showDatabasesRe.MatchString(input): + case analyzeRe.MatchString(stripped): + return &DdlStatement{Ddl: stripped}, nil + case showDatabasesRe.MatchString(stripped): return &ShowDatabasesStatement{}, nil - case showCreateTableRe.MatchString(input): - matched := showCreateTableRe.FindStringSubmatch(input) + case showCreateTableRe.MatchString(stripped): + matched := showCreateTableRe.FindStringSubmatch(stripped) return &ShowCreateTableStatement{Table: unquoteIdentifier(matched[1])}, nil - case showTablesRe.MatchString(input): + case showTablesRe.MatchString(stripped): return &ShowTablesStatement{}, nil - case explainRe.MatchString(input): - matched := explainRe.FindStringSubmatch(input) + case explainRe.MatchString(stripped): + matched := explainRe.FindStringSubmatch(stripped) isAnalyze := matched[1] != "" isDML := dmlRe.MatchString(matched[2]) switch { @@ -176,26 +180,26 @@ func BuildStatement(input string) (Statement, error) { default: return &ExplainStatement{Explain: matched[2]}, nil } - case showColumnsRe.MatchString(input): - matched := showColumnsRe.FindStringSubmatch(input) + case showColumnsRe.MatchString(stripped): + matched := showColumnsRe.FindStringSubmatch(stripped) return &ShowColumnsStatement{Table: unquoteIdentifier(matched[1])}, nil - case showIndexRe.MatchString(input): - matched := showIndexRe.FindStringSubmatch(input) + case showIndexRe.MatchString(stripped): + matched := showIndexRe.FindStringSubmatch(stripped) return &ShowIndexStatement{Table: unquoteIdentifier(matched[1])}, nil - case dmlRe.MatchString(input): - return &DmlStatement{Dml: input}, nil - case pdmlRe.MatchString(input): - matched := pdmlRe.FindStringSubmatch(input) + case dmlRe.MatchString(stripped): + return &DmlStatement{Dml: raw}, nil + case pdmlRe.MatchString(stripped): + matched := pdmlRe.FindStringSubmatch(stripped) return &PartitionedDmlStatement{Dml: matched[1]}, nil - case beginRwRe.MatchString(input): - return newBeginRwStatement(input) - case beginRoRe.MatchString(input): - return newBeginRoStatement(input) - case commitRe.MatchString(input): + case beginRwRe.MatchString(stripped): + return newBeginRwStatement(stripped) + case beginRoRe.MatchString(stripped): + return newBeginRoStatement(stripped) + case commitRe.MatchString(stripped): return &CommitStatement{}, nil - case rollbackRe.MatchString(input): + case rollbackRe.MatchString(stripped): return &RollbackStatement{}, nil - case closeRe.MatchString(input): + case closeRe.MatchString(stripped): return &CloseStatement{}, nil } diff --git a/statement_test.go b/statement_test.go index 09f0b81..7a1f374 100644 --- a/statement_test.go +++ b/statement_test.go @@ -50,6 +50,11 @@ func TestBuildStatement(t *testing.T) { input: "SELECT\n*\nFROM t1", want: &SelectStatement{Query: "SELECT\n*\nFROM t1"}, }, + { + desc: "SELECT statement with comment", + input: "SELECT 0x1/**/A", + want: &SelectStatement{Query: "SELECT 0x1/**/A"}, + }, { desc: "WITH statement", input: "WITH sub AS (SELECT 1) SELECT * FROM sub",