diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 89aa40b6127..d7d6187fa02 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -78,6 +78,9 @@ type ( } ) +// exprInterfacePath is the path of the sqlparser.Expr interface. +const exprInterfacePath = "vitess.io/vitess/go/vt/sqlparser.Expr" + func (gen *astHelperGen) iface() *types.Interface { return gen._iface } @@ -200,22 +203,15 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) { scopes[pkg.PkgPath] = pkg.Types.Scope() } - pos := strings.LastIndexByte(options.RootInterface, '.') - if pos < 0 { - return nil, fmt.Errorf("unexpected input type: %s", options.RootInterface) - } - - pkgname := options.RootInterface[:pos] - typename := options.RootInterface[pos+1:] - - scope := scopes[pkgname] - if scope == nil { - return nil, fmt.Errorf("no scope found for type '%s'", options.RootInterface) + tt, err := findTypeObject(options.RootInterface, scopes) + if err != nil { + return nil, err } - tt := scope.Lookup(typename) - if tt == nil { - return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname) + exprType, _ := findTypeObject(exprInterfacePath, scopes) + var exprInterface *types.Interface + if exprType != nil { + exprInterface = exprType.Type().(*types.Named).Underlying().(*types.Interface) } nt := tt.Type().(*types.Named) @@ -224,7 +220,7 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) { newEqualsGen(pName, &options.Equals), newCloneGen(pName, &options.Clone), newVisitGen(pName), - newRewriterGen(pName, types.TypeString(nt, noQualifier)), + newRewriterGen(pName, types.TypeString(nt, noQualifier), exprInterface), newCOWGen(pName, nt), ) @@ -236,6 +232,28 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) { return it, nil } +// findTypeObject finds the types.Object for the given interface from the given scopes. +func findTypeObject(interfaceToFind string, scopes map[string]*types.Scope) (types.Object, error) { + pos := strings.LastIndexByte(interfaceToFind, '.') + if pos < 0 { + return nil, fmt.Errorf("unexpected input type: %s", interfaceToFind) + } + + pkgname := interfaceToFind[:pos] + typename := interfaceToFind[pos+1:] + + scope := scopes[pkgname] + if scope == nil { + return nil, fmt.Errorf("no scope found for type '%s'", interfaceToFind) + } + + tt := scope.Lookup(typename) + if tt == nil { + return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname) + } + return tt, nil +} + var _ generatorSPI = (*astHelperGen)(nil) func (gen *astHelperGen) scope() *types.Scope { diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 4804ef8d874..cc8b18a78e9 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -30,18 +30,21 @@ const ( type rewriteGen struct { ifaceName string file *jen.File + // exprInterface is used to store the sqlparser.Expr interface + exprInterface *types.Interface } var _ generator = (*rewriteGen)(nil) -func newRewriterGen(pkgname string, ifaceName string) *rewriteGen { +func newRewriterGen(pkgname string, ifaceName string, exprInterface *types.Interface) *rewriteGen { file := jen.NewFile(pkgname) file.HeaderComment(licenseFileHeader) file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") return &rewriteGen{ - ifaceName: ifaceName, - file: file, + ifaceName: ifaceName, + file: file, + exprInterface: exprInterface, } } @@ -105,7 +108,7 @@ func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generat } fields := r.rewriteAllStructFields(t, strct, spi, true) - stmts := []jen.Code{executePre()} + stmts := []jen.Code{r.executePre(t)} stmts = append(stmts, fields...) stmts = append(stmts, executePost(len(fields) > 0)) stmts = append(stmts, returnTrue()) @@ -130,7 +133,7 @@ func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi ge return nil } */ - stmts = append(stmts, executePre()) + stmts = append(stmts, r.executePre(t)) fields := r.rewriteAllStructFields(t, strct, spi, false) stmts = append(stmts, fields...) stmts = append(stmts, executePost(len(fields) > 0)) @@ -225,9 +228,19 @@ func setupCursor() []jen.Code { jen.Id("a.cur.node = node"), } } -func executePre() jen.Code { +func (r *rewriteGen) executePre(t types.Type) jen.Code { curStmts := setupCursor() - curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue())) + if r.exprInterface != nil && types.Implements(t, r.exprInterface) { + curStmts = append(curStmts, jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"), + jen.If(jen.Id("a.cur.revisit").Block( + jen.Id("a.cur.revisit").Op("=").False(), + jen.Return(jen.Id("a.rewriteExpr(parent, a.cur.node.(Expr), replacer)")), + )), + jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))), + ) + } else { + curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue())) + } return jen.If(jen.Id("a.pre!= nil").Block(curStmts...)) } @@ -251,7 +264,7 @@ func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) return nil } - stmts := []jen.Code{executePre(), executePost(false), returnTrue()} + stmts := []jen.Code{r.executePre(t), executePost(false), returnTrue()} r.rewriteFunc(t, stmts) return nil } diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 0266876e201..ec71a9038e9 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1046,7 +1046,12 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1078,7 +1083,12 @@ func (a *application) rewriteRefOfAnyValue(parent SQLNode, node *AnyValue, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1105,7 +1115,12 @@ func (a *application) rewriteRefOfArgument(parent SQLNode, node *Argument, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1129,7 +1144,12 @@ func (a *application) rewriteRefOfArgumentLessWindowExpr(parent SQLNode, node *A a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1156,7 +1176,12 @@ func (a *application) rewriteRefOfAssignmentExpr(parent SQLNode, node *Assignmen a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1220,7 +1245,12 @@ func (a *application) rewriteRefOfAvg(parent SQLNode, node *Avg, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1271,7 +1301,12 @@ func (a *application) rewriteRefOfBetweenExpr(parent SQLNode, node *BetweenExpr, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1308,7 +1343,12 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1340,7 +1380,12 @@ func (a *application) rewriteRefOfBitAnd(parent SQLNode, node *BitAnd, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1367,7 +1412,12 @@ func (a *application) rewriteRefOfBitOr(parent SQLNode, node *BitOr, replacer re a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1394,7 +1444,12 @@ func (a *application) rewriteRefOfBitXor(parent SQLNode, node *BitXor, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1453,7 +1508,12 @@ func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1494,7 +1554,12 @@ func (a *application) rewriteRefOfCastExpr(parent SQLNode, node *CastExpr, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1563,7 +1628,12 @@ func (a *application) rewriteRefOfCharExpr(parent SQLNode, node *CharExpr, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1617,7 +1687,12 @@ func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1649,7 +1724,12 @@ func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1862,7 +1942,12 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1931,7 +2016,12 @@ func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -1995,7 +2085,12 @@ func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *Convert a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2022,7 +2117,12 @@ func (a *application) rewriteRefOfCount(parent SQLNode, node *Count, replacer re a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2049,7 +2149,12 @@ func (a *application) rewriteRefOfCountStar(parent SQLNode, node *CountStar, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2194,7 +2299,12 @@ func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeF a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2253,7 +2363,12 @@ func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2581,7 +2696,12 @@ func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, r a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2704,7 +2824,12 @@ func (a *application) rewriteRefOfExtractFuncExpr(parent SQLNode, node *ExtractF a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2731,7 +2856,12 @@ func (a *application) rewriteRefOfExtractValueExpr(parent SQLNode, node *Extract a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2763,7 +2893,12 @@ func (a *application) rewriteRefOfExtractedSubquery(parent SQLNode, node *Extrac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -2805,7 +2940,12 @@ func (a *application) rewriteRefOfFirstOrLastValueExpr(parent SQLNode, node *Fir a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3013,7 +3153,12 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3050,7 +3195,12 @@ func (a *application) rewriteRefOfGTIDFuncExpr(parent SQLNode, node *GTIDFuncExp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3092,7 +3242,12 @@ func (a *application) rewriteRefOfGeoHashFromLatLongExpr(parent SQLNode, node *G a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3129,7 +3284,12 @@ func (a *application) rewriteRefOfGeoHashFromPointExpr(parent SQLNode, node *Geo a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3161,7 +3321,12 @@ func (a *application) rewriteRefOfGeoJSONFromGeomExpr(parent SQLNode, node *GeoJ a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3198,7 +3363,12 @@ func (a *application) rewriteRefOfGeomCollPropertyFuncExpr(parent SQLNode, node a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3230,7 +3400,12 @@ func (a *application) rewriteRefOfGeomFormatExpr(parent SQLNode, node *GeomForma a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3262,7 +3437,12 @@ func (a *application) rewriteRefOfGeomFromGeoHashExpr(parent SQLNode, node *Geom a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3294,7 +3474,12 @@ func (a *application) rewriteRefOfGeomFromGeoJSONExpr(parent SQLNode, node *Geom a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3331,7 +3516,12 @@ func (a *application) rewriteRefOfGeomFromTextExpr(parent SQLNode, node *GeomFro a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3368,7 +3558,12 @@ func (a *application) rewriteRefOfGeomFromWKBExpr(parent SQLNode, node *GeomFrom a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3405,7 +3600,12 @@ func (a *application) rewriteRefOfGeomPropertyFuncExpr(parent SQLNode, node *Geo a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3469,7 +3669,12 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3727,7 +3932,12 @@ func (a *application) rewriteRefOfInsertExpr(parent SQLNode, node *InsertExpr, r a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3769,7 +3979,12 @@ func (a *application) rewriteRefOfIntervalDateExpr(parent SQLNode, node *Interva a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3801,7 +4016,12 @@ func (a *application) rewriteRefOfIntervalFuncExpr(parent SQLNode, node *Interva a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3833,7 +4053,12 @@ func (a *application) rewriteRefOfIntroducerExpr(parent SQLNode, node *Introduce a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3860,7 +4085,12 @@ func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3887,7 +4117,12 @@ func (a *application) rewriteRefOfJSONArrayExpr(parent SQLNode, node *JSONArrayE a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3914,7 +4149,12 @@ func (a *application) rewriteRefOfJSONAttributesExpr(parent SQLNode, node *JSONA a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3946,7 +4186,12 @@ func (a *application) rewriteRefOfJSONContainsExpr(parent SQLNode, node *JSONCon a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -3987,7 +4232,12 @@ func (a *application) rewriteRefOfJSONContainsPathExpr(parent SQLNode, node *JSO a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4028,7 +4278,12 @@ func (a *application) rewriteRefOfJSONExtractExpr(parent SQLNode, node *JSONExtr a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4064,7 +4319,12 @@ func (a *application) rewriteRefOfJSONKeysExpr(parent SQLNode, node *JSONKeysExp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4096,7 +4356,12 @@ func (a *application) rewriteRefOfJSONObjectExpr(parent SQLNode, node *JSONObjec a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4159,7 +4424,12 @@ func (a *application) rewriteRefOfJSONOverlapsExpr(parent SQLNode, node *JSONOve a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4191,7 +4461,12 @@ func (a *application) rewriteRefOfJSONPrettyExpr(parent SQLNode, node *JSONPrett a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4218,7 +4493,12 @@ func (a *application) rewriteRefOfJSONQuoteExpr(parent SQLNode, node *JSONQuoteE a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4245,7 +4525,12 @@ func (a *application) rewriteRefOfJSONRemoveExpr(parent SQLNode, node *JSONRemov a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4277,7 +4562,12 @@ func (a *application) rewriteRefOfJSONSchemaValidFuncExpr(parent SQLNode, node * a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4309,7 +4599,12 @@ func (a *application) rewriteRefOfJSONSchemaValidationReportFuncExpr(parent SQLN a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4341,7 +4636,12 @@ func (a *application) rewriteRefOfJSONSearchExpr(parent SQLNode, node *JSONSearc a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4392,7 +4692,12 @@ func (a *application) rewriteRefOfJSONStorageFreeExpr(parent SQLNode, node *JSON a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4419,7 +4724,12 @@ func (a *application) rewriteRefOfJSONStorageSizeExpr(parent SQLNode, node *JSON a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4492,7 +4802,12 @@ func (a *application) rewriteRefOfJSONUnquoteExpr(parent SQLNode, node *JSONUnqu a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4519,7 +4834,12 @@ func (a *application) rewriteRefOfJSONValueExpr(parent SQLNode, node *JSONValueE a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4566,7 +4886,12 @@ func (a *application) rewriteRefOfJSONValueMergeExpr(parent SQLNode, node *JSONV a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4598,7 +4923,12 @@ func (a *application) rewriteRefOfJSONValueModifierExpr(parent SQLNode, node *JS a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4802,7 +5132,12 @@ func (a *application) rewriteRefOfLagLeadExpr(parent SQLNode, node *LagLeadExpr, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4881,7 +5216,12 @@ func (a *application) rewriteRefOfLineStringExpr(parent SQLNode, node *LineStrin a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4908,7 +5248,12 @@ func (a *application) rewriteRefOfLinestrPropertyFuncExpr(parent SQLNode, node * a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4940,7 +5285,12 @@ func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -4988,7 +5338,12 @@ func (a *application) rewriteRefOfLocateExpr(parent SQLNode, node *LocateExpr, r a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5073,7 +5428,12 @@ func (a *application) rewriteRefOfLockingFunc(parent SQLNode, node *LockingFunc, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5105,7 +5465,12 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5141,7 +5506,12 @@ func (a *application) rewriteRefOfMax(parent SQLNode, node *Max, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5168,7 +5538,12 @@ func (a *application) rewriteRefOfMemberOfExpr(parent SQLNode, node *MemberOfExp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5200,7 +5575,12 @@ func (a *application) rewriteRefOfMin(parent SQLNode, node *Min, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5259,7 +5639,12 @@ func (a *application) rewriteRefOfMultiLinestringExpr(parent SQLNode, node *Mult a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5286,7 +5671,12 @@ func (a *application) rewriteRefOfMultiPointExpr(parent SQLNode, node *MultiPoin a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5313,7 +5703,12 @@ func (a *application) rewriteRefOfMultiPolygonExpr(parent SQLNode, node *MultiPo a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5340,7 +5735,12 @@ func (a *application) rewriteRefOfNTHValueExpr(parent SQLNode, node *NTHValueExp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5387,7 +5787,12 @@ func (a *application) rewriteRefOfNamedWindow(parent SQLNode, node *NamedWindow, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5478,7 +5883,12 @@ func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5505,7 +5915,12 @@ func (a *application) rewriteRefOfNtileExpr(parent SQLNode, node *NtileExpr, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5561,7 +5976,12 @@ func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5585,7 +6005,12 @@ func (a *application) rewriteRefOfOffset(parent SQLNode, node *Offset, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -5676,7 +6101,12 @@ func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6194,7 +6624,12 @@ func (a *application) rewriteRefOfPerformanceSchemaFuncExpr(parent SQLNode, node a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6221,7 +6656,12 @@ func (a *application) rewriteRefOfPointExpr(parent SQLNode, node *PointExpr, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6253,7 +6693,12 @@ func (a *application) rewriteRefOfPointPropertyFuncExpr(parent SQLNode, node *Po a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6285,7 +6730,12 @@ func (a *application) rewriteRefOfPolygonExpr(parent SQLNode, node *PolygonExpr, a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6312,7 +6762,12 @@ func (a *application) rewriteRefOfPolygonPropertyFuncExpr(parent SQLNode, node * a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6452,7 +6907,12 @@ func (a *application) rewriteRefOfRegexpInstrExpr(parent SQLNode, node *RegexpIn a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6504,7 +6964,12 @@ func (a *application) rewriteRefOfRegexpLikeExpr(parent SQLNode, node *RegexpLik a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6541,7 +7006,12 @@ func (a *application) rewriteRefOfRegexpReplaceExpr(parent SQLNode, node *Regexp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -6593,7 +7063,12 @@ func (a *application) rewriteRefOfRegexpSubstrExpr(parent SQLNode, node *RegexpS a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7398,7 +7873,12 @@ func (a *application) rewriteRefOfStd(parent SQLNode, node *Std, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7425,7 +7905,12 @@ func (a *application) rewriteRefOfStdDev(parent SQLNode, node *StdDev, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7452,7 +7937,12 @@ func (a *application) rewriteRefOfStdPop(parent SQLNode, node *StdPop, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7479,7 +7969,12 @@ func (a *application) rewriteRefOfStdSamp(parent SQLNode, node *StdSamp, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7686,7 +8181,12 @@ func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7713,7 +8213,12 @@ func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, r a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7750,7 +8255,12 @@ func (a *application) rewriteRefOfSum(parent SQLNode, node *Sum, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -7993,7 +8503,12 @@ func (a *application) rewriteRefOfTimestampDiffExpr(parent SQLNode, node *Timest a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8025,7 +8540,12 @@ func (a *application) rewriteRefOfTrimFuncExpr(parent SQLNode, node *TrimFuncExp a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8084,7 +8604,12 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8313,7 +8838,12 @@ func (a *application) rewriteRefOfUpdateXMLExpr(parent SQLNode, node *UpdateXMLE a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8554,7 +9084,12 @@ func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFun a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8581,7 +9116,12 @@ func (a *application) rewriteRefOfVarPop(parent SQLNode, node *VarPop, replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8608,7 +9148,12 @@ func (a *application) rewriteRefOfVarSamp(parent SQLNode, node *VarSamp, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8635,7 +9180,12 @@ func (a *application) rewriteRefOfVariable(parent SQLNode, node *Variable, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8662,7 +9212,12 @@ func (a *application) rewriteRefOfVariance(parent SQLNode, node *Variance, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8754,7 +9309,12 @@ func (a *application) rewriteRefOfWeightStringFuncExpr(parent SQLNode, node *Wei a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -8987,7 +9547,12 @@ func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -9833,7 +10398,12 @@ func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } @@ -9854,7 +10424,12 @@ func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + a.cur.revisit = false + return a.rewriteExpr(parent, a.cur.node.(Expr), replacer) + } + if kontinue { return true } } diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 05d371bad13..cfcf75fa0f9 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -126,7 +126,7 @@ func (c *Cursor) ReplacerF() func(newNode SQLNode) { // and the new node visited. func (c *Cursor) ReplaceAndRevisit(newNode SQLNode) { switch newNode.(type) { - case SelectExprs: + case SelectExprs, Expr: default: // We need to add support to the generated code for when to look at the revisit flag. At the moment it is only // there for slices of SQLNode implementations diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ca1ebc6d2f4..38363d6efbc 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -123,7 +123,7 @@ func handleOrderBy(r *earlyRewriter, cursor *sqlparser.Cursor, node sqlparser.Or func rewriteOrExpr(cursor *sqlparser.Cursor, node *sqlparser.OrExpr) { newNode := rewriteOrFalse(*node) if newNode != nil { - cursor.Replace(newNode) + cursor.ReplaceAndRevisit(newNode) } } diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 2846bfd9366..b1c8ebe03d2 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -444,3 +444,50 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) { }) } } + +// TestConstantFolding tests that the rewriter is able to do various constant foldings properly. +func TestConstantFolding(t *testing.T) { + ks := &vindexes.Keyspace{ + Name: "main", + Sharded: false, + } + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Keyspace: ks, + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("a"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("b"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("c"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + }{{ + sql: "select 1 from t1 where (a, b) in ::fkc_vals and (2 is null or (1 is null or a in (1)))", + expSQL: "select 1 from t1 where (a, b) in ::fkc_vals and a in (1)", + }, { + sql: "select 1 from t1 where (false or (false or a in (1)))", + expSQL: "select 1 from t1 where a in (1)", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + _, err = Analyze(ast, cDB, schemaInfo) + require.NoError(t, err) + require.Equal(t, tcase.expSQL, sqlparser.String(ast)) + }) + } +}