diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 824f20f29e3..07a39020f8b 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -48,6 +48,10 @@ func (code PulloutOpcode) String() string { return pulloutName[code] } +func (code PulloutOpcode) NeedsListArg() bool { + return code == PulloutIn || code == PulloutNotIn +} + // MarshalJSON serializes the PulloutOpcode as a JSON string. // It's used for testing and diagnostics. func (code PulloutOpcode) MarshalJSON() ([]byte, error) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index e0073c6e74b..55fcba6cd3b 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -242,7 +242,7 @@ func (sq *SubQuery) settleFilter(ctx *plancontext.PlanningContext, outer ops.Ope } var arg sqlparser.Expr - if sq.FilterType == opcode.PulloutIn || sq.FilterType == opcode.PulloutNotIn { + if sq.FilterType.NeedsListArg() { arg = sqlparser.NewListArg(sq.ArgName) } else { arg = sqlparser.NewArgument(sq.ArgName) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_builder.go b/go/vt/vtgate/planbuilder/operators/subquery_builder.go index f2b99dd3aae..a0897b5ad4b 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_builder.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_builder.go @@ -368,28 +368,50 @@ type subqueryExtraction struct { cols []string } +func getOpCodeFromParent(parent sqlparser.SQLNode) *opcode.PulloutOpcode { + code := opcode.PulloutValue + switch parent := parent.(type) { + case *sqlparser.ExistsExpr: + return nil + case *sqlparser.ComparisonExpr: + switch parent.Operator { + case sqlparser.InOp: + code = opcode.PulloutIn + case sqlparser.NotInOp: + code = opcode.PulloutNotIn + } + } + return &code +} + func extractSubQueries(ctx *plancontext.PlanningContext, expr sqlparser.Expr, isDML bool) *subqueryExtraction { sqe := &subqueryExtraction{} - replaceWithArg := func(cursor *sqlparser.Cursor, sq *sqlparser.Subquery) { + replaceWithArg := func(cursor *sqlparser.Cursor, sq *sqlparser.Subquery, t opcode.PulloutOpcode) { sqName := ctx.GetReservedArgumentFor(sq) sqe.cols = append(sqe.cols, sqName) if isDML { - cursor.Replace(sqlparser.NewArgument(sqName)) + if t.NeedsListArg() { + cursor.Replace(sqlparser.NewListArg(sqName)) + } else { + cursor.Replace(sqlparser.NewArgument(sqName)) + } } else { cursor.Replace(sqlparser.NewColName(sqName)) } sqe.subq = append(sqe.subq, sq) } + expr = sqlparser.Rewrite(expr, nil, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.Subquery: - if _, isExists := cursor.Parent().(*sqlparser.ExistsExpr); isExists { + t := getOpCodeFromParent(cursor.Parent()) + if t == nil { return true } - replaceWithArg(cursor, node) - sqe.pullOutCode = append(sqe.pullOutCode, opcode.PulloutValue) + replaceWithArg(cursor, node, *t) + sqe.pullOutCode = append(sqe.pullOutCode, *t) case *sqlparser.ExistsExpr: - replaceWithArg(cursor, node.Subquery) + replaceWithArg(cursor, node.Subquery, opcode.PulloutExists) sqe.pullOutCode = append(sqe.pullOutCode, opcode.PulloutExists) } return true diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 80f740f4cd8..3f9b4198ee5 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -388,16 +388,22 @@ func pushProjectionToOuterContainer(ctx *plancontext.PlanningContext, p *Project } func rewriteColNameToArgument(in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { - cols := make(map[string]any) - for _, sq1 := range se { - for _, sq2 := range subqueries { - if sq1.ArgName == sq2.ArgName { - cols[sq1.ArgName] = nil + rewriteIt := func(s string) sqlparser.SQLNode { + for _, sq1 := range se { + if sq1.ArgName != s && sq1.HasValuesName != s { + continue + } + + for _, sq2 := range subqueries { + switch { + case s == sq2.ArgName && sq1.FilterType.NeedsListArg(): + return sqlparser.NewListArg(s) + case s == sq2.ArgName || s == sq2.HasValuesName: + return sqlparser.NewArgument(s) + } } } - } - if len(cols) <= 0 { - return in + return nil } // replace the ColNames with Argument inside the subquery @@ -406,10 +412,10 @@ func rewriteColNameToArgument(in sqlparser.Expr, se SubQueryExpression, subqueri if !ok || !col.Qualifier.IsEmpty() { return true } - if _, ok := cols[col.Name.String()]; !ok { + arg := rewriteIt(col.Name.String()) + if arg == nil { return true } - arg := sqlparser.NewArgument(col.Name.String()) cursor.Replace(arg) return true })