From fd4f92297b49d1c33b6f9e123bc143bd7dff7665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Fri, 14 Jun 2024 16:08:03 +0200 Subject: [PATCH] Handle Nullability for Columns from Outer Tables (#16174) Signed-off-by: Andres Taylor --- .../endtoend/vtgate/queries/misc/misc_test.go | 14 +++ .../endtoend/vtgate/queries/misc/schema.sql | 20 +++- go/vt/vtgate/evalengine/compiler.go | 4 + .../planbuilder/operator_transformers.go | 12 +- .../vtgate/planbuilder/operators/ast_to_op.go | 2 +- .../vtgate/planbuilder/operators/distinct.go | 1 - go/vt/vtgate/planbuilder/operators/filter.go | 2 +- .../vtgate/planbuilder/operators/hash_join.go | 4 +- go/vt/vtgate/planbuilder/operators/insert.go | 4 +- go/vt/vtgate/planbuilder/operators/join.go | 4 +- .../planbuilder/operators/projection.go | 2 +- .../planbuilder/operators/queryprojection.go | 2 +- .../planbuilder/operators/sharded_routing.go | 4 +- .../planbuilder/operators/union_merging.go | 4 +- .../plancontext/planning_context.go | 21 ++++ .../plancontext/planning_context_test.go | 108 ++++++++++++++++++ go/vt/vtgate/semantics/semantic_state.go | 1 + 17 files changed, 184 insertions(+), 25 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/plancontext/planning_context_test.go diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index ed2221eaf7d..6712275592a 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -371,3 +371,17 @@ func TestAlterTableWithView(t *testing.T) { mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`) } + +func TestHandleNullableColumn(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate") + require.NoError(t, + utils.WaitForAuthoritative(t, keyspaceName, "tbl", clusterInstance.VtgateProcess.ReadVSchema)) + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1), (2,2)") + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (0,0,0), (1,1,6)") + // This query tests that we handle nullable columns correctly + // tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL + mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`) +} diff --git a/go/test/endtoend/vtgate/queries/misc/schema.sql b/go/test/endtoend/vtgate/queries/misc/schema.sql index ceac0c07e6d..f87d7c19078 100644 --- a/go/test/endtoend/vtgate/queries/misc/schema.sql +++ b/go/test/endtoend/vtgate/queries/misc/schema.sql @@ -1,5 +1,15 @@ -create table if not exists t1( - id1 bigint, - id2 bigint, - primary key(id1) -) Engine=InnoDB; \ No newline at end of file +create table t1 +( + id1 bigint, + id2 bigint, + primary key (id1) +) Engine=InnoDB; + +create table tbl +( + id bigint, + unq_col bigint, + nonunq_col bigint, + primary key (id), + unique (unq_col) +) Engine = InnoDB; diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 3a9b204596f..387dbe44cc2 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -106,6 +106,10 @@ func (t *Type) Nullable() bool { return true // nullable by default for unknown types } +func (t *Type) SetNullability(n bool) { + t.nullable = n +} + func (t *Type) Valid() bool { return t.init } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index cadfba91772..486cadf2fe8 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -291,7 +291,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega oa.aggregates = append(oa.aggregates, aggrParam) } for _, groupBy := range op.Grouping { - typ, _ := ctx.SemTable.TypeForExpr(groupBy.Inner) + typ, _ := ctx.TypeForExpr(groupBy.Inner) oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ KeyCol: groupBy.ColOffset, WeightStringCol: groupBy.WSOffset, @@ -332,7 +332,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin } for idx, order := range ordering.Order { - typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr) + typ, _ := ctx.TypeForExpr(order.SimplifiedExpr) ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, evalengine.OrderByParams{ Col: ordering.Offset[idx], WeightStringCol: ordering.WOffset[idx], @@ -389,7 +389,7 @@ func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr case *operators.EvalEngine: return e.EExpr, nil case operators.Offset: - typ, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr) + typ, _ := ctx.TypeForExpr(pe.EvalExpr) return evalengine.NewColumn(int(e), typ, pe.EvalExpr), nil default: return nil, vterrors.VT13001("project not planned for: %s", pe.String()) @@ -560,7 +560,7 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route eroute, err := routeToEngineRoute(ctx, op, hints) for _, order := range op.Ordering { - typ, _ := ctx.SemTable.TypeForExpr(order.AST) + typ, _ := ctx.TypeForExpr(order.AST) eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{ Col: order.Offset, WeightStringCol: order.WOffset, @@ -877,11 +877,11 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin) var missingTypes []string - ltyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].LHS) + ltyp, found := ctx.TypeForExpr(op.JoinComparisons[0].LHS) if !found { missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS)) } - rtyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].RHS) + rtyp, found := ctx.TypeForExpr(op.JoinComparisons[0].RHS) if !found { missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS)) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 8a46109e959..55b29a146c7 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -224,7 +224,7 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s case sqlparser.NormalJoinType: return createInnerJoin(ctx, tableExpr, lhs, rhs) case sqlparser.LeftJoinType, sqlparser.RightJoinType: - return createOuterJoin(tableExpr, lhs, rhs) + return createOuterJoin(ctx, tableExpr, lhs, rhs) default: panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString())) } diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index eeddd928f66..e3784dbb904 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -56,7 +56,6 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator { offset := d.Source.AddColumn(ctx, true, false, aeWrap(weightStringFor(e))) wsCol = &offset } - d.Columns = append(d.Columns, engine.CheckCol{ Col: idx, WsCol: wsCol, diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index c2432a40da9..0570d61860d 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -123,7 +123,7 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (Operator, *ApplyResult) func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) Operator { cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index 0ad46bcbc82..f997ed5205d 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -332,7 +332,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } @@ -432,7 +432,7 @@ func (hj *HashJoin) addSingleSidedColumn( rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) cfg := &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), } diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 7c6e242ae9c..75466500fe6 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -506,7 +506,7 @@ func insertRowsPlan(ctx *plancontext.PlanningContext, insOp *Insert, ins *sqlpar colNum, _ := findOrAddColumn(ins, col) for rowNum, row := range rows { innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) @@ -637,7 +637,7 @@ func modifyForAutoinc(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, v } var err error gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 787d7fedfcc..8e685beb4cb 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -83,7 +83,7 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult return newOp, Rewrote("merge querygraphs into a single one") } -func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { +func createOuterJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { if tableExpr.Join == sqlparser.RightJoinType { lhs, rhs = rhs, lhs } @@ -93,6 +93,8 @@ func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Oper } predicate := tableExpr.Condition.On sqlparser.RemoveKeyspaceInCol(predicate) + // mark the RHS as outer tables so we know which columns are nullable + ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs)) return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate} } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index f46cbf21928..38164b71a94 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -586,7 +586,7 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) Operator { // for everything else, we'll turn to the evalengine eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{ - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Collation: ctx.SemTable.Collation, Environment: ctx.VSchema.Environment(), }) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index c9ea589381c..d34422a8d4d 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -107,7 +107,7 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.T } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: - typ, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg()) + typ, _ := ctx.TypeForExpr(aggr.Func.GetArg()) return typ } diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index ef6117b1d8e..6818311c0dd 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -540,7 +540,7 @@ func (tr *ShardedRouting) planCompositeInOpArg( Key: right.String(), Index: idx, } - if typ, found := ctx.SemTable.TypeForExpr(col); found { + if typ, found := ctx.TypeForExpr(col); found { value.Type = typ.Type() value.Collation = typ.Collation() } @@ -654,7 +654,7 @@ func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) eval for _, expr := range ctx.SemTable.GetExprAndEqualities(n) { ee, _ := evalengine.Translate(expr, &evalengine.Config{ Collation: ctx.SemTable.Collation, - ResolveType: ctx.SemTable.TypeForExpr, + ResolveType: ctx.TypeForExpr, Environment: ctx.VSchema.Environment(), }) if ee != nil { diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 1fb7d4fb454..c2fd79cd026 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -202,8 +202,8 @@ func createMergedUnion( continue } deps = deps.Merge(ctx.SemTable.RecursiveDeps(rae.Expr)) - rt, foundR := ctx.SemTable.TypeForExpr(rae.Expr) - lt, foundL := ctx.SemTable.TypeForExpr(lae.Expr) + rt, foundR := ctx.TypeForExpr(rae.Expr) + lt, foundL := ctx.TypeForExpr(lae.Expr) if foundR && foundL { types := []sqltypes.Type{rt.Type(), lt.Type()} t := evalengine.AggregateTypes(types) diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 3c2a1c97434..90a6bdac6f8 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -20,6 +20,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -57,6 +58,10 @@ type PlanningContext struct { // Statement contains the originally parsed statement Statement sqlparser.Statement + + // OuterTables contains the tables that are outer to the current query + // Used to set the nullable flag on the columns + OuterTables semantics.TableSet } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -201,3 +206,19 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t } return modifiedExpr } + +// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table. +func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { + t, found := ctx.SemTable.TypeForExpr(e) + if !found { + return t, found + } + deps := ctx.SemTable.RecursiveDeps(e) + // If the expression is from an outer table, it should be nullable + // There are some exceptions to this, where an expression depending on the outer side + // will never return NULL, but it's better to be conservative here. + if deps.IsOverlapping(ctx.OuterTables) { + t.SetNullability(true) + } + return t, true +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go new file mode 100644 index 00000000000..70faa61737d --- /dev/null +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plancontext + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestOuterTableNullability(t *testing.T) { + // Tests that columns from outer tables are nullable, + // even though the semantic state says that they are not nullable. + // This is because the outer table may not have a matching row. + // All columns are marked as NOT NULL in the schema. + query := "select * from t1 left join t2 on t1.a = t2.a where t1.a+t2.a/abs(t2.boing)" + ctx, columns := prepareContextAndFindColumns(t, query) + + // Check if the columns are correctly marked as nullable. + for _, col := range columns { + colName := "column: " + sqlparser.String(col) + t.Run(colName, func(t *testing.T) { + // Extract the column type from the context and the semantic state. + // The context should mark the column as nullable. + ctxType, found := ctx.TypeForExpr(col) + require.True(t, found, colName) + stType, found := ctx.SemTable.TypeForExpr(col) + require.True(t, found, colName) + ctxNullable := ctxType.Nullable() + stNullable := stType.Nullable() + + switch col.Qualifier.Name.String() { + case "t1": + assert.False(t, ctxNullable, colName) + assert.False(t, stNullable, colName) + case "t2": + assert.True(t, ctxNullable, colName) + + // The semantic state says that the column is not nullable. Don't trust it. + assert.False(t, stNullable, colName) + } + }) + } +} + +func prepareContextAndFindColumns(t *testing.T, query string) (ctx *PlanningContext, columns []*sqlparser.ColName) { + parser := sqlparser.NewTestParser() + ast, err := parser.Parse(query) + require.NoError(t, err) + semTable := semantics.EmptySemTable() + t1 := semantics.SingleTableSet(0) + t2 := semantics.SingleTableSet(1) + stmt := ast.(*sqlparser.Select) + expr := stmt.Where.Expr + + // Instead of using the semantic analysis, we manually set the types for the columns. + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + + switch col.Qualifier.Name.String() { + case "t1": + semTable.Recursive[col] = t1 + case "t2": + semTable.Recursive[col] = t2 + } + + intNotNull := evalengine.NewType(sqltypes.Int64, collations.Unknown) + intNotNull.SetNullability(false) + semTable.ExprTypes[col] = intNotNull + columns = append(columns, col) + return false, nil + }, nil, expr) + + ctx = &PlanningContext{ + SemTable: semTable, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + ReservedArguments: map[sqlparser.Expr]string{}, + Statement: stmt, + OuterTables: t2, // t2 is the outer table. + } + return +} diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 76a51efd160..9a8721108b3 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -614,6 +614,7 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel } // TypeForExpr returns the type of expressions in the query +// Note that PlanningContext has the same method, and you should use that if you have a PlanningContext func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { if typ, found := st.ExprTypes[e]; found { return typ, true