From c5eced17d00fcb4f9dc8f5e19b592396733d089b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Thu, 8 Feb 2024 10:10:15 +0800 Subject: [PATCH] expression: remove direct dependencies with `sessionctx.Context` for package `expression` (#51025) close pingcap/tidb#51024 --- pkg/expression/BUILD.bazel | 1 - pkg/expression/aggregation/BUILD.bazel | 14 ++++-- pkg/expression/aggregation/agg_to_pb.go | 3 +- pkg/expression/aggregation/aggregation.go | 3 +- .../aggregation/aggregation_test.go | 3 +- pkg/expression/aggregation/base_func.go | 47 +++++++++---------- pkg/expression/aggregation/descriptor.go | 19 ++++---- pkg/expression/aggregation/explain.go | 4 +- pkg/expression/aggregation/window_func.go | 9 ++-- pkg/expression/bench_test.go | 7 ++- pkg/expression/builtin_cast.go | 3 +- pkg/expression/builtin_cast_bench_test.go | 3 +- pkg/expression/builtin_compare.go | 3 +- pkg/expression/builtin_encryption_test.go | 5 +- .../builtin_regexp_vec_const_test.go | 3 +- pkg/expression/builtin_test.go | 3 +- pkg/expression/builtin_time.go | 3 +- pkg/expression/builtin_time_test.go | 18 ++++--- pkg/expression/builtin_vectorized_test.go | 5 +- pkg/expression/constant.go | 4 +- pkg/expression/constant_propagation.go | 15 +++--- pkg/expression/constant_test.go | 21 ++++----- pkg/expression/distsql_builtin.go | 9 ++-- pkg/expression/evaluator_test.go | 3 +- pkg/expression/expression.go | 9 ++-- pkg/expression/helper.go | 3 +- pkg/expression/typeinfer_test.go | 3 +- pkg/expression/util.go | 38 +++++++-------- pkg/planner/core/exhaust_physical_plans.go | 2 +- pkg/planner/util/byitem.go | 3 +- pkg/sessionctx/variable/session.go | 5 ++ 31 files changed, 129 insertions(+), 142 deletions(-) diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel index 7be85235b372f..9aa27966af16a 100644 --- a/pkg/expression/BUILD.bazel +++ b/pkg/expression/BUILD.bazel @@ -207,7 +207,6 @@ go_test( "//pkg/parser/terror", "//pkg/planner/core", "//pkg/session", - "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/variable", "//pkg/sessiontxn", diff --git a/pkg/expression/aggregation/BUILD.bazel b/pkg/expression/aggregation/BUILD.bazel index 9cb6a3734f5d7..0ee1b5338d395 100644 --- a/pkg/expression/aggregation/BUILD.bazel +++ b/pkg/expression/aggregation/BUILD.bazel @@ -1,5 +1,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +package_group( + name = "aggregation_friend", + packages = [ + "-//pkg/sessionctx/...", + "//...", + ], +) + go_library( name = "aggregation", srcs = [ @@ -21,7 +29,9 @@ go_library( "window_func.go", ], importpath = "github.com/pingcap/tidb/pkg/expression/aggregation", - visibility = ["//visibility:public"], + visibility = [ + ":aggregation_friend", + ], deps = [ "//pkg/expression", "//pkg/kv", @@ -30,7 +40,6 @@ go_library( "//pkg/parser/mysql", "//pkg/parser/terror", "//pkg/planner/util", - "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/variable", "//pkg/types", @@ -64,7 +73,6 @@ go_test( "//pkg/kv", "//pkg/parser/ast", "//pkg/parser/mysql", - "//pkg/sessionctx", "//pkg/sessionctx/stmtctx", "//pkg/sessionctx/variable", "//pkg/testkit/testsetup", diff --git a/pkg/expression/aggregation/agg_to_pb.go b/pkg/expression/aggregation/agg_to_pb.go index 5eab6be36b2fb..fc0a207bc9175 100644 --- a/pkg/expression/aggregation/agg_to_pb.go +++ b/pkg/expression/aggregation/agg_to_pb.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" @@ -184,7 +183,7 @@ func PBAggFuncModeToAggFuncMode(pbMode *tipb.AggFunctionMode) (mode AggFunctionM } // PBExprToAggFuncDesc converts pb to aggregate function. -func PBExprToAggFuncDesc(ctx sessionctx.Context, aggFunc *tipb.Expr, fieldTps []*types.FieldType) (*AggFuncDesc, error) { +func PBExprToAggFuncDesc(ctx expression.BuildContext, aggFunc *tipb.Expr, fieldTps []*types.FieldType) (*AggFuncDesc, error) { var name string switch aggFunc.Tp { case tipb.ExprType_Count: diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index 6db82f2761119..0a0ffd56635f2 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -51,7 +50,7 @@ type Aggregation interface { } // NewDistAggFunc creates new Aggregate function for mock tikv. -func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, *AggFuncDesc, error) { +func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx expression.BuildContext) (Aggregation, *AggFuncDesc, error) { args := make([]expression.Expression, 0, len(expr.Children)) for _, child := range expr.Children { arg, err := expression.PBToExpr(ctx, child, fieldTps) diff --git a/pkg/expression/aggregation/aggregation_test.go b/pkg/expression/aggregation/aggregation_test.go index de7c3ad684a29..2156bf94b5aeb 100644 --- a/pkg/expression/aggregation/aggregation_test.go +++ b/pkg/expression/aggregation/aggregation_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -30,7 +29,7 @@ import ( ) type mockAggFuncSuite struct { - ctx sessionctx.Context + ctx expression.BuildContext rows []chunk.Row nullRow chunk.Row } diff --git a/pkg/expression/aggregation/base_func.go b/pkg/expression/aggregation/base_func.go index c0db910999068..df6aa389f001b 100644 --- a/pkg/expression/aggregation/base_func.go +++ b/pkg/expression/aggregation/base_func.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mathutil" @@ -42,13 +41,13 @@ type baseFuncDesc struct { RetTp *types.FieldType } -func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) { +func newBaseFuncDesc(ctx expression.BuildContext, name string, args []expression.Expression) (baseFuncDesc, error) { b := baseFuncDesc{Name: strings.ToLower(name), Args: args} err := b.TypeInfer(ctx) return b, err } -func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool { +func (a *baseFuncDesc) equal(ctx expression.EvalContext, other *baseFuncDesc) bool { if a.Name != other.Name || len(a.Args) != len(other.Args) { return false } @@ -86,18 +85,18 @@ func (a *baseFuncDesc) String() string { } // TypeInfer infers the arguments and return types of an function. -func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error { +func (a *baseFuncDesc) TypeInfer(ctx expression.BuildContext) error { switch a.Name { case ast.AggFuncCount: - a.typeInfer4Count(ctx) + a.typeInfer4Count() case ast.AggFuncApproxCountDistinct: - a.typeInfer4ApproxCountDistinct(ctx) + a.typeInfer4ApproxCountDistinct() case ast.AggFuncApproxPercentile: return a.typeInfer4ApproxPercentile(ctx) case ast.AggFuncSum: - a.typeInfer4Sum(ctx) + a.typeInfer4Sum() case ast.AggFuncAvg: - a.typeInfer4Avg(ctx) + a.typeInfer4Avg() case ast.AggFuncGroupConcat: a.typeInfer4GroupConcat(ctx) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow, @@ -116,9 +115,9 @@ func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error { case ast.WindowFuncLead, ast.WindowFuncLag: a.typeInfer4LeadLag(ctx) case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp: - a.typeInfer4PopOrSamp(ctx) + a.typeInfer4PopOrSamp() case ast.AggFuncJsonArrayagg: - a.typeInfer4JsonArrayAgg(ctx) + a.typeInfer4JsonArrayAgg() case ast.AggFuncJsonObjectAgg: return a.typeInfer4JsonObjectAgg(ctx) default: @@ -127,7 +126,7 @@ func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error { return nil } -func (a *baseFuncDesc) typeInfer4Count(sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4Count() { a.RetTp = types.NewFieldType(mysql.TypeLonglong) a.RetTp.SetFlen(21) a.RetTp.SetDecimal(0) @@ -136,11 +135,11 @@ func (a *baseFuncDesc) typeInfer4Count(sessionctx.Context) { types.SetBinChsClnFlag(a.RetTp) } -func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx sessionctx.Context) { - a.typeInfer4Count(ctx) +func (a *baseFuncDesc) typeInfer4ApproxCountDistinct() { + a.typeInfer4Count() } -func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error { +func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx expression.EvalContext) error { if len(a.Args) != 2 { return errors.New("APPROX_PERCENTILE should take 2 arguments") } @@ -182,7 +181,7 @@ func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error // typeInfer4Sum should return a "decimal", otherwise it returns a "double". // Because child returns integer or decimal type. -func (a *baseFuncDesc) typeInfer4Sum(sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4Sum() { switch a.Args[0].GetType().GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) @@ -221,7 +220,7 @@ func (a *baseFuncDesc) TypeInfer4FinalCount(finalCountRetType *types.FieldType) // typeInfer4Avg should returns a "decimal", otherwise it returns a "double". // Because child returns integer or decimal type. -func (a *baseFuncDesc) typeInfer4Avg(sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4Avg() { switch a.Args[0].GetType().GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) @@ -247,7 +246,7 @@ func (a *baseFuncDesc) typeInfer4Avg(sessionctx.Context) { types.SetBinChsClnFlag(a.RetTp) } -func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4GroupConcat(ctx expression.BuildContext) { a.RetTp = types.NewFieldType(mysql.TypeVarString) charset, collate := charset.GetDefaultCharsetAndCollate() a.RetTp.SetCharset(charset) @@ -263,7 +262,7 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) { } } -func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4MaxMin(ctx expression.BuildContext) { _, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction) if argIsScalaFunc && a.Args[0].GetType().GetType() == mysql.TypeFloat { // For scalar function, the result of "float32" is set to the "float64" @@ -288,7 +287,7 @@ func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { } } -func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4BitFuncs(ctx expression.BuildContext) { a.RetTp = types.NewFieldType(mysql.TypeLonglong) a.RetTp.SetFlen(21) types.SetBinChsClnFlag(a.RetTp) @@ -296,12 +295,12 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) } -func (a *baseFuncDesc) typeInfer4JsonArrayAgg(sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4JsonArrayAgg() { a.RetTp = types.NewFieldType(mysql.TypeJSON) types.SetBinChsClnFlag(a.RetTp) } -func (a *baseFuncDesc) typeInfer4JsonObjectAgg(ctx sessionctx.Context) error { +func (a *baseFuncDesc) typeInfer4JsonObjectAgg(ctx expression.BuildContext) error { a.RetTp = types.NewFieldType(mysql.TypeJSON) types.SetBinChsClnFlag(a.RetTp) a.Args[0] = expression.WrapWithCastAsString(ctx, a.Args[0]) @@ -333,7 +332,7 @@ func (a *baseFuncDesc) typeInfer4PercentRank() { a.RetTp.SetDecimal(mysql.NotFixedDec) } -func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4LeadLag(ctx expression.BuildContext) { if len(a.Args) < 3 { a.typeInfer4MaxMin(ctx) } else { @@ -343,7 +342,7 @@ func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) { } } -func (a *baseFuncDesc) typeInfer4PopOrSamp(sessionctx.Context) { +func (a *baseFuncDesc) typeInfer4PopOrSamp() { // var_pop/std/var_samp/stddev_samp's return value type is double a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.SetFlen(mysql.MaxRealWidth) @@ -398,7 +397,7 @@ var noNeedCastAggFuncs = map[string]struct{}{ } // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. -func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { +func (a *baseFuncDesc) WrapCastForAggArgs(ctx expression.BuildContext) { if len(a.Args) == 0 { return } diff --git a/pkg/expression/aggregation/descriptor.go b/pkg/expression/aggregation/descriptor.go index 2cd95a8f6f62c..cbb09e4351747 100644 --- a/pkg/expression/aggregation/descriptor.go +++ b/pkg/expression/aggregation/descriptor.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" @@ -48,7 +47,7 @@ type AggFuncDesc struct { // NewAggFuncDesc creates an aggregation function signature descriptor. // this func cannot be called twice as the TypeInfer has changed the type of args in the first time. -func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { +func NewAggFuncDesc(ctx expression.BuildContext, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { b, err := newBaseFuncDesc(ctx, name, args) if err != nil { return nil, err @@ -57,7 +56,7 @@ func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expre } // NewAggFuncDescForWindowFunc creates an aggregation function from window functions, where baseFuncDesc may be ready. -func NewAggFuncDescForWindowFunc(ctx sessionctx.Context, desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) { +func NewAggFuncDescForWindowFunc(ctx expression.BuildContext, desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) { if desc.RetTp == nil { // safety check return NewAggFuncDesc(ctx, desc.Name, desc.Args, hasDistinct) } @@ -91,7 +90,7 @@ func (a *AggFuncDesc) String() string { } // Equal checks whether two aggregation function signatures are equal. -func (a *AggFuncDesc) Equal(ctx sessionctx.Context, other *AggFuncDesc) bool { +func (a *AggFuncDesc) Equal(ctx expression.EvalContext, other *AggFuncDesc) bool { if a.HasDistinct != other.HasDistinct { return false } @@ -200,7 +199,7 @@ func (a *AggFuncDesc) Split(ordinal []int) (partialAggDesc, finalAggDesc *AggFun // +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+ // | 1 | 1 | 95 | 95.0000 | 95 | 95 | 95 | 95 | 95 | NULL | NULL | // +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+ -func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { +func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { switch a.Name { case ast.AggFuncCount: return a.evalNullValueInOuterJoin4Count(ctx, schema) @@ -219,7 +218,7 @@ func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx sessionctx.Context, schema *e } // GetAggFunc gets an evaluator according to the aggregation function signature. -func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation { +func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation { aggFunc := aggFunction{AggFuncDesc: a} switch a.Name { case ast.AggFuncSum: @@ -258,7 +257,7 @@ func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation { } } -func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { +func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { for _, arg := range a.Args { result := expression.EvaluateExprWithNull(ctx, schema, arg) con, ok := result.(*expression.Constant) @@ -269,7 +268,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, sch return types.NewDatum(1), true } -func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { +func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) if !ok || con.Value.IsNull() { @@ -278,7 +277,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx sessionctx.Context, schem return con.Value, true } -func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { +func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) if !ok || con.Value.IsNull() { @@ -287,7 +286,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx sessionctx.Context, sc return con.Value, true } -func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { +func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) if !ok || con.Value.IsNull() { diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go index dd28997d589a6..29f88499e1bd1 100644 --- a/pkg/expression/aggregation/explain.go +++ b/pkg/expression/aggregation/explain.go @@ -18,12 +18,12 @@ import ( "bytes" "fmt" + "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" ) // ExplainAggFunc generates explain information for a aggregation function. -func ExplainAggFunc(ctx sessionctx.Context, agg *AggFuncDesc, normalized bool) string { +func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { var buffer bytes.Buffer fmt.Fprintf(&buffer, "%s(", agg.Name) if agg.HasDistinct { diff --git a/pkg/expression/aggregation/window_func.go b/pkg/expression/aggregation/window_func.go index 897754f991562..5c8177a8fb0dc 100644 --- a/pkg/expression/aggregation/window_func.go +++ b/pkg/expression/aggregation/window_func.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tipb/go-tipb" ) @@ -31,7 +30,7 @@ type WindowFuncDesc struct { } // NewWindowFuncDesc creates a window function signature descriptor. -func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, skipCheckArgs bool) (*WindowFuncDesc, error) { +func NewWindowFuncDesc(ctx expression.BuildContext, name string, args []expression.Expression, skipCheckArgs bool) (*WindowFuncDesc, error) { // if we are in the prepare statement, skip the params check since it's not been initialized. if !skipCheckArgs { switch strings.ToLower(name) { @@ -124,7 +123,7 @@ func (s *WindowFuncDesc) Clone() *WindowFuncDesc { } // WindowFuncToPBExpr converts aggregate function to pb. -func WindowFuncToPBExpr(sctx sessionctx.Context, client kv.Client, desc *WindowFuncDesc) *tipb.Expr { +func WindowFuncToPBExpr(sctx expression.EvalContext, client kv.Client, desc *WindowFuncDesc) *tipb.Expr { pc := expression.NewPBConverter(client, sctx) tp := desc.GetTiPBExpr(true) if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) { @@ -143,9 +142,9 @@ func WindowFuncToPBExpr(sctx sessionctx.Context, client kv.Client, desc *WindowF } // CanPushDownToTiFlash control whether a window function desc can be push down to tiflash. -func (s *WindowFuncDesc) CanPushDownToTiFlash(ctx sessionctx.Context) bool { +func (s *WindowFuncDesc) CanPushDownToTiFlash(ctx expression.EvalContext, client kv.Client) bool { // args - if !expression.CanExprsPushDown(ctx, s.Args, ctx.GetClient(), kv.TiFlash) { + if !expression.CanExprsPushDown(ctx, s.Args, client, kv.TiFlash) { return false } // window functions diff --git a/pkg/expression/bench_test.go b/pkg/expression/bench_test.go index 46296ac200614..81d38de491af1 100644 --- a/pkg/expression/bench_test.go +++ b/pkg/expression/bench_test.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/benchdaily" @@ -45,7 +44,7 @@ import ( ) type benchHelper struct { - ctx sessionctx.Context + ctx BuildContext exprs []Expression inputTypes []*types.FieldType @@ -1264,7 +1263,7 @@ func eType2FieldType(eType types.EvalType) *types.FieldType { } } -func genVecExprBenchCase(ctx sessionctx.Context, funcName string, testCase vecExprBenchCase) (expr Expression, fts []*types.FieldType, input *chunk.Chunk, output *chunk.Chunk) { +func genVecExprBenchCase(ctx BuildContext, funcName string, testCase vecExprBenchCase) (expr Expression, fts []*types.FieldType, input *chunk.Chunk, output *chunk.Chunk) { fts = make([]*types.FieldType, len(testCase.childrenTypes)) for i := range fts { if i < len(testCase.childrenFieldTypes) && testCase.childrenFieldTypes[i] != nil { @@ -1403,7 +1402,7 @@ func benchmarkVectorizedEvalOneVec(b *testing.B, vecExprCases vecExprBenchCases) } } -func genVecBuiltinFuncBenchCase(ctx sessionctx.Context, funcName string, testCase vecExprBenchCase) (baseFunc builtinFunc, fts []*types.FieldType, input *chunk.Chunk, result *chunk.Column) { +func genVecBuiltinFuncBenchCase(ctx BuildContext, funcName string, testCase vecExprBenchCase) (baseFunc builtinFunc, fts []*types.FieldType, input *chunk.Chunk, result *chunk.Column) { childrenNumber := len(testCase.childrenTypes) fts = make([]*types.FieldType, childrenNumber) for i := range fts { diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 6fb6d47f13f4c..a4b0e3f360c92 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -35,7 +35,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" @@ -2058,7 +2057,7 @@ func CanImplicitEvalReal(expr Expression) bool { // BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union // Expression. -func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { +func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) { ctx.SetValue(inUnionCastContext, struct{}{}) defer func() { ctx.SetValue(inUnionCastContext, nil) diff --git a/pkg/expression/builtin_cast_bench_test.go b/pkg/expression/builtin_cast_bench_test.go index cb6084747f8c9..750fb3bf633ed 100644 --- a/pkg/expression/builtin_cast_bench_test.go +++ b/pkg/expression/builtin_cast_bench_test.go @@ -19,13 +19,12 @@ import ( "testing" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" ) -func genCastIntAsInt(ctx sessionctx.Context) (*builtinCastIntAsIntSig, *chunk.Chunk, *chunk.Column) { +func genCastIntAsInt(ctx BuildContext) (*builtinCastIntAsIntSig, *chunk.Chunk, *chunk.Column) { col := &Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0} baseFunc, err := newBaseBuiltinFunc(ctx, "", []Expression{col}, types.NewFieldType(mysql.TypeLonglong)) if err != nil { diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index b34e5f62e9a6a..72f5e74ffb5bf 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -1338,7 +1337,7 @@ func GetAccurateCmpType(lhs, rhs Expression) types.EvalType { } // GetCmpFunction get the compare function according to two arguments. -func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc { +func GetCmpFunction(ctx BuildContext, lhs, rhs Expression) CompareFunc { switch GetAccurateCmpType(lhs, rhs) { case types.ETInt: return CompareInt diff --git a/pkg/expression/builtin_encryption_test.go b/pkg/expression/builtin_encryption_test.go index ec023c02a3e20..2cf772330f472 100644 --- a/pkg/expression/builtin_encryption_test.go +++ b/pkg/expression/builtin_encryption_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -281,7 +280,7 @@ func TestAESDecrypt(t *testing.T) { } } -func testNullInput(t *testing.T, ctx sessionctx.Context, fnName string) { +func testNullInput(t *testing.T, ctx BuildContext, fnName string) { err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, "aes-128-ecb") require.NoError(t, err) fc := funcs[fnName] @@ -300,7 +299,7 @@ func testNullInput(t *testing.T, ctx sessionctx.Context, fnName string) { require.True(t, crypt.IsNull()) } -func testAmbiguousInput(t *testing.T, ctx sessionctx.Context, fnName string) { +func testAmbiguousInput(t *testing.T, ctx BuildContext, fnName string) { fc := funcs[fnName] arg := types.NewStringDatum("str") // test for modes that require init_vector diff --git a/pkg/expression/builtin_regexp_vec_const_test.go b/pkg/expression/builtin_regexp_vec_const_test.go index 4194e0ed47683..62d1636adce23 100644 --- a/pkg/expression/builtin_regexp_vec_const_test.go +++ b/pkg/expression/builtin_regexp_vec_const_test.go @@ -20,14 +20,13 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" "github.com/stretchr/testify/require" ) -func genVecBuiltinRegexpBenchCaseForConstants(ctx sessionctx.Context) (baseFunc builtinFunc, childrenFieldTypes []*types.FieldType, input *chunk.Chunk, output *chunk.Column) { +func genVecBuiltinRegexpBenchCaseForConstants(ctx BuildContext) (baseFunc builtinFunc, childrenFieldTypes []*types.FieldType, input *chunk.Chunk, output *chunk.Column) { const ( numArgs = 2 batchSz = 1024 diff --git a/pkg/expression/builtin_test.go b/pkg/expression/builtin_test.go index 883a4ace1e8d4..d8fdb5f52c2cd 100644 --- a/pkg/expression/builtin_test.go +++ b/pkg/expression/builtin_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/chunk" @@ -260,7 +259,7 @@ func TestBuiltinFuncCache(t *testing.T) { // newFunctionForTest creates a new ScalarFunction using funcName and arguments, // it is different from expression.NewFunction which needs an additional retType argument. -func newFunctionForTest(ctx sessionctx.Context, funcName string, args ...Expression) (Expression, error) { +func newFunctionForTest(ctx BuildContext, funcName string, args ...Expression) (Expression, error) { fc, ok := funcs[funcName] if !ok { return nil, ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", funcName) diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index 4e62dff0f7357..1c5c1da2b2f50 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" @@ -2493,7 +2492,7 @@ func (c *nowFunctionClass) getFunction(ctx BuildContext, args []Expression) (bui } // GetStmtTimestamp directly calls getTimeZone with timezone -func GetStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { +func GetStmtTimestamp(ctx EvalContext) (time.Time, error) { tz := getTimeZone(ctx) tVal, err := getStmtTimestamp(ctx) if err != nil { diff --git a/pkg/expression/builtin_time_test.go b/pkg/expression/builtin_time_test.go index 26726af38283f..b571c505f3c02 100644 --- a/pkg/expression/builtin_time_test.go +++ b/pkg/expression/builtin_time_test.go @@ -28,8 +28,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -821,7 +819,7 @@ func TestTime(t *testing.T) { require.NoError(t, err) } -func resetStmtContext(ctx sessionctx.Context) { +func resetStmtContext(ctx EvalContext) { ctx.GetSessionVars().StmtCtx.ResetStmtCache() } @@ -1186,7 +1184,7 @@ func TestSysDate(t *testing.T) { require.Error(t, err) } -func convertToTimeWithFsp(sc *stmtctx.StatementContext, arg types.Datum, tp byte, fsp int) (d types.Datum, err error) { +func convertToTimeWithFsp(tc types.Context, arg types.Datum, tp byte, fsp int) (d types.Datum, err error) { if fsp > types.MaxFsp { fsp = types.MaxFsp } @@ -1194,7 +1192,7 @@ func convertToTimeWithFsp(sc *stmtctx.StatementContext, arg types.Datum, tp byte f := types.NewFieldType(tp) f.SetDecimal(fsp) - d, err = arg.ConvertTo(sc.TypeCtx(), f) + d, err = arg.ConvertTo(tc, f) if err != nil { d.SetNull() return d, err @@ -1211,12 +1209,12 @@ func convertToTimeWithFsp(sc *stmtctx.StatementContext, arg types.Datum, tp byte return } -func convertToTime(sc *stmtctx.StatementContext, arg types.Datum, tp byte) (d types.Datum, err error) { - return convertToTimeWithFsp(sc, arg, tp, types.MaxFsp) +func convertToTime(tc types.Context, arg types.Datum, tp byte) (d types.Datum, err error) { + return convertToTimeWithFsp(tc, arg, tp, types.MaxFsp) } -func builtinDateFormat(ctx sessionctx.Context, args []types.Datum) (d types.Datum, err error) { - date, err := convertToTime(ctx.GetSessionVars().StmtCtx, args[0], mysql.TypeDatetime) +func builtinDateFormat(tc types.Context, args []types.Datum) (d types.Datum, err error) { + date, err := convertToTime(tc, args[0], mysql.TypeDatetime) if err != nil { return d, err } @@ -1288,7 +1286,7 @@ func TestFromUnixTime(t *testing.T) { require.NoError(t, err) v, err := evalBuiltinFunc(f, ctx, chunk.Row{}) require.NoError(t, err) - result, err := builtinDateFormat(ctx, []types.Datum{types.NewStringDatum(c.expect), format}) + result, err := builtinDateFormat(ctx.GetSessionVars().StmtCtx.TypeCtx(), []types.Datum{types.NewStringDatum(c.expect), format}) require.NoError(t, err) require.Equalf(t, result.GetString(), v.GetString(), "%+v", t) } diff --git a/pkg/expression/builtin_vectorized_test.go b/pkg/expression/builtin_vectorized_test.go index 507877be6e856..9cabe724d3505 100644 --- a/pkg/expression/builtin_vectorized_test.go +++ b/pkg/expression/builtin_vectorized_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" @@ -79,7 +78,7 @@ func (p *mockVecPlusIntBuiltinFunc) vecEvalInt(ctx EvalContext, input *chunk.Chu return nil } -func genMockVecPlusIntBuiltinFunc(ctx sessionctx.Context) (*mockVecPlusIntBuiltinFunc, *chunk.Chunk, *chunk.Column) { +func genMockVecPlusIntBuiltinFunc(ctx BuildContext) (*mockVecPlusIntBuiltinFunc, *chunk.Chunk, *chunk.Column) { tp := types.NewFieldType(mysql.TypeLonglong) col1 := newColumn(0) col1.Index, col1.RetType = 0, tp @@ -430,7 +429,7 @@ func convertETType(eType types.EvalType) (mysqlType byte) { return } -func genMockRowDouble(ctx sessionctx.Context, eType types.EvalType, enableVec bool) (builtinFunc, *chunk.Chunk, *chunk.Column, error) { +func genMockRowDouble(ctx BuildContext, eType types.EvalType, enableVec bool) (builtinFunc, *chunk.Chunk, *chunk.Column, error) { mysqlType := convertETType(eType) tp := types.NewFieldType(mysqlType) col1 := newColumn(1) diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index a2d858d4c2945..18e7e8e2cd97e 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -19,7 +19,7 @@ import ( "unsafe" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" @@ -130,7 +130,7 @@ type Constant struct { // ParamMarker indicates param provided by COM_STMT_EXECUTE. type ParamMarker struct { - ctx sessionctx.Context + ctx variable.SessionVarsProvider order int } diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index 5898b2b76b4a0..fcf0ca3a73886 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -20,7 +20,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -38,7 +37,7 @@ type basePropConstSolver struct { eqList []*Constant // if eqList[i] != nil, it means col_i = eqList[i] unionSet *disjointset.IntSet // unionSet stores the relations like col_i = col_j columns []*Column // columns stores all columns appearing in the conditions - ctx sessionctx.Context + ctx BuildContext } func (s *basePropConstSolver) getColID(col *Column) int { @@ -124,7 +123,7 @@ func validEqualCond(cond Expression) (*Column, *Constant) { // for 'a, b, a < 3', it returns 'true, false, b < 3' // for 'a, b, sin(a) + cos(a) = 5', it returns 'true, false, returns sin(b) + cos(b) = 5' // for 'a, b, cast(a) < rand()', it returns 'false, true, cast(a) < rand()' -func tryToReplaceCond(ctx sessionctx.Context, src *Column, tgt *Column, cond Expression, nullAware bool) (bool, bool, Expression) { +func tryToReplaceCond(ctx BuildContext, src *Column, tgt *Column, cond Expression, nullAware bool) (bool, bool, Expression) { if src.RetType.GetType() != tgt.RetType.GetType() { return false, false, cond } @@ -359,7 +358,7 @@ func (s *propConstSolver) solve(conditions []Expression) []Expression { // PropagateConstant propagate constant values of deterministic predicates in a condition. // This is a constant propagation logic for expression list such as ['a=1', 'a=b'] -func PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { +func PropagateConstant(ctx BuildContext, conditions []Expression) []Expression { return newPropConstSolver().PropagateConstant(ctx, conditions) } @@ -664,7 +663,7 @@ func (s *propOuterJoinConstSolver) solve(joinConds, filterConds []Expression) ([ } // propagateConstantDNF find DNF item from CNF, and propagate constant inside DNF. -func propagateConstantDNF(ctx sessionctx.Context, conds []Expression) []Expression { +func propagateConstantDNF(ctx BuildContext, conds []Expression) []Expression { for i, cond := range conds { if dnf, ok := cond.(*ScalarFunction); ok && dnf.FuncName.L == ast.LogicOr { dnfItems := SplitDNFItems(cond) @@ -683,7 +682,7 @@ func propagateConstantDNF(ctx sessionctx.Context, conds []Expression) []Expressi // Second step is to extract `outerCol = innerCol` from join conditions, and derive new join // conditions based on this column equal condition and `outerCol` related // expressions in join conditions and filter conditions; -func PropConstOverOuterJoin(ctx sessionctx.Context, joinConds, filterConds []Expression, +func PropConstOverOuterJoin(ctx BuildContext, joinConds, filterConds []Expression, outerSchema, innerSchema *Schema, nullSensitive bool) ([]Expression, []Expression) { solver := &propOuterJoinConstSolver{ outerSchema: outerSchema, @@ -697,7 +696,7 @@ func PropConstOverOuterJoin(ctx sessionctx.Context, joinConds, filterConds []Exp // PropagateConstantSolver is a constant propagate solver. type PropagateConstantSolver interface { - PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression + PropagateConstant(ctx BuildContext, conditions []Expression) []Expression } // newPropConstSolver returns a PropagateConstantSolver. @@ -708,7 +707,7 @@ func newPropConstSolver() PropagateConstantSolver { } // PropagateConstant propagate constant values of deterministic predicates in a condition. -func (s *propConstSolver) PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { +func (s *propConstSolver) PropagateConstant(ctx BuildContext, conditions []Expression) []Expression { s.ctx = ctx return s.solve(conditions) } diff --git a/pkg/expression/constant_test.go b/pkg/expression/constant_test.go index 9dc8f2f9aede9..652779fe9bb46 100644 --- a/pkg/expression/constant_test.go +++ b/pkg/expression/constant_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" @@ -59,11 +58,11 @@ func newFunctionWithMockCtx(funcName string, args ...Expression) Expression { return newFunction(mock.NewContext(), funcName, args...) } -func newFunction(ctx sessionctx.Context, funcName string, args ...Expression) Expression { +func newFunction(ctx BuildContext, funcName string, args ...Expression) Expression { return newFunctionWithType(ctx, funcName, types.NewFieldType(mysql.TypeLonglong), args...) } -func newFunctionWithType(ctx sessionctx.Context, funcName string, tp *types.FieldType, args ...Expression) Expression { +func newFunctionWithType(ctx BuildContext, funcName string, tp *types.FieldType, args ...Expression) Expression { return NewFunctionInternal(ctx, funcName, tp, args...) } @@ -195,47 +194,47 @@ func TestConstantPropagation(t *testing.T) { func TestConstantFolding(t *testing.T) { tests := []struct { - condition func(ctx sessionctx.Context) Expression + condition func(ctx BuildContext) Expression result string }{ { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Plus, newLonglong(1), newLonglong(2))) }, result: "lt(Column#0, 3)", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Greatest, newLonglong(1), newLonglong(2))) }, result: "lt(Column#0, 2)", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.EQ, newColumn(0), newFunction(ctx, ast.Rand)) }, result: "eq(cast(Column#0, double BINARY), rand())", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.IsNull, newLonglong(1)) }, result: "0", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.EQ, newColumn(0), newFunction(ctx, ast.UnaryNot, newFunctionWithMockCtx(ast.Plus, newLonglong(1), newLonglong(1)))) }, result: "eq(Column#0, 0)", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Plus, newColumn(1), newFunctionWithMockCtx(ast.Plus, newLonglong(2), newLonglong(1)))) }, result: "lt(Column#0, plus(Column#1, 3))", }, { - condition: func(ctx sessionctx.Context) Expression { + condition: func(ctx BuildContext) Expression { expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull()) ctx.GetSessionVars().StmtCtx.InNullRejectCheck = true return expr diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index 3022b506c51d2..d4028ce56e634 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" @@ -38,7 +37,7 @@ func PbTypeToFieldType(tp *tipb.FieldType) *types.FieldType { return ft } -func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (f builtinFunc, e error) { +func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (f builtinFunc, e error) { fieldTp := PbTypeToFieldType(tp) base, err := newBaseBuiltinFuncWithFieldType(fieldTp, args) if err != nil { @@ -1090,7 +1089,7 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti return f, nil } -func newDistSQLFunctionBySig(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (Expression, error) { +func newDistSQLFunctionBySig(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (Expression, error) { f, err := getSignatureByPB(ctx, sigCode, tp, args) if err != nil { return nil, err @@ -1103,7 +1102,7 @@ func newDistSQLFunctionBySig(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, } // PBToExprs converts pb structures to expressions. -func PBToExprs(ctx sessionctx.Context, pbExprs []*tipb.Expr, fieldTps []*types.FieldType) ([]Expression, error) { +func PBToExprs(ctx BuildContext, pbExprs []*tipb.Expr, fieldTps []*types.FieldType) ([]Expression, error) { exprs := make([]Expression, 0, len(pbExprs)) for _, expr := range pbExprs { e, err := PBToExpr(ctx, expr, fieldTps) @@ -1119,7 +1118,7 @@ func PBToExprs(ctx sessionctx.Context, pbExprs []*tipb.Expr, fieldTps []*types.F } // PBToExpr converts pb structure to expression. -func PBToExpr(ctx sessionctx.Context, expr *tipb.Expr, tps []*types.FieldType) (Expression, error) { +func PBToExpr(ctx BuildContext, expr *tipb.Expr, tps []*types.FieldType) (Expression, error) { sc := ctx.GetSessionVars().StmtCtx switch expr.Tp { case tipb.ExprType_ColumnRef: diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index b883c2bfe4c16..e98bf8daabf50 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -90,7 +89,7 @@ func datumsToConstants(datums []types.Datum) []Expression { return constants } -func primitiveValsToConstants(ctx sessionctx.Context, args []any) []Expression { +func primitiveValsToConstants(ctx BuildContext, args []any) []Expression { cons := datumsToConstants(types.MakeDatums(args...)) char, col := ctx.GetSessionVars().GetCharsetInfo() for i, arg := range args { diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index aaa47d7d9b44e..661c4c414d693 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/generatedexpr" @@ -862,7 +861,7 @@ func SplitDNFItems(onExpr Expression) []Expression { // EvaluateExprWithNull sets columns in schema as null and calculate the final result of the scalar function. // If the Expression is a non-constant value, it means the result is unknown. -func EvaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expression) Expression { +func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Expression { if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) { ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check")) } @@ -873,7 +872,7 @@ func EvaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expressio return evaluateExprWithNull(ctx, schema, expr) } -func evaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expression) Expression { +func evaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Expression { switch x := expr.(type) { case *ScalarFunction: args := make([]Expression, len(x.GetArgs())) @@ -898,7 +897,7 @@ func evaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expressio // If the Expression is a non-constant value, it means the result is unknown. // The returned bool values indicates whether the value is influenced by the Null Constant transformed from schema column // when the value is Null Constant. -func evaluateExprWithNullInNullRejectCheck(ctx sessionctx.Context, schema *Schema, expr Expression) (Expression, bool) { +func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, expr Expression) (Expression, bool) { switch x := expr.(type) { case *ScalarFunction: args := make([]Expression, len(x.GetArgs())) @@ -960,7 +959,7 @@ func evaluateExprWithNullInNullRejectCheck(ctx sessionctx.Context, schema *Schem } // TableInfo2SchemaAndNames converts the TableInfo to the schema and name slice. -func TableInfo2SchemaAndNames(ctx sessionctx.Context, dbName model.CIStr, tbl *model.TableInfo) (*Schema, []*types.FieldName, error) { +func TableInfo2SchemaAndNames(ctx BuildContext, dbName model.CIStr, tbl *model.TableInfo) (*Schema, []*types.FieldName, error) { cols, names, err := ColumnInfos2ColumnsAndNames(ctx, dbName, tbl.Name, tbl.Cols(), tbl) if err != nil { return nil, nil, err diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 271dd439b6a56..0ee35281c75ab 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" driver "github.com/pingcap/tidb/pkg/types/parser_driver" "github.com/pingcap/tidb/pkg/util/logutil" @@ -60,7 +59,7 @@ func IsValidCurrentTimestampExpr(exprNode ast.ExprNode, fieldType *types.FieldTy } // GetTimeCurrentTimestamp is used for generating a timestamp for some special cases: cast null value to timestamp type with not null flag. -func GetTimeCurrentTimestamp(ctx sessionctx.Context, tp byte, fsp int) (d types.Datum, err error) { +func GetTimeCurrentTimestamp(ctx BuildContext, tp byte, fsp int) (d types.Datum, err error) { var t types.Time t, err = getTimeCurrentTimeStamp(ctx, tp, fsp) if err != nil { diff --git a/pkg/expression/typeinfer_test.go b/pkg/expression/typeinfer_test.go index ce77470663757..280743fe2ff0d 100644 --- a/pkg/expression/typeinfer_test.go +++ b/pkg/expression/typeinfer_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/sessiontxn" "github.com/pingcap/tidb/pkg/testkit" @@ -111,7 +110,7 @@ func TestInferType(t *testing.T) { tests = append(tests, s.createTestCase4MiscellaneousFunc()...) tests = append(tests, s.createTestCase4GetVarFunc()...) - sctx := testKit.Session().(sessionctx.Context) + sctx := testKit.Session() require.NoError(t, sctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, mysql.DefaultCharset)) require.NoError(t, sctx.GetSessionVars().SetSystemVar(variable.CollationConnection, mysql.DefaultCollationName)) diff --git a/pkg/expression/util.go b/pkg/expression/util.go index e6982a7ec2bd2..56ea0859de392 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" driver "github.com/pingcap/tidb/pkg/types/parser_driver" "github.com/pingcap/tidb/pkg/util/chunk" @@ -271,7 +271,7 @@ func extractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { } // ExtractConstantEqColumnsOrScalar detects the constant equal relationship from CNF exprs. -func ExtractConstantEqColumnsOrScalar(ctx sessionctx.Context, result []Expression, exprs []Expression) []Expression { +func ExtractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, exprs []Expression) []Expression { // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. for _, expr := range exprs { result = extractConstantEqColumnsOrScalar(ctx, result, expr) @@ -279,7 +279,7 @@ func ExtractConstantEqColumnsOrScalar(ctx sessionctx.Context, result []Expressio return result } -func extractConstantEqColumnsOrScalar(ctx sessionctx.Context, result []Expression, expr Expression) []Expression { +func extractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, expr Expression) []Expression { switch v := expr.(type) { case *ScalarFunction: if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { @@ -410,7 +410,7 @@ func SetExprColumnInOperand(expr Expression) Expression { // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. // TODO: remove this function and only use ColumnSubstituteImpl since this function swallows the error, which seems unsafe. -func ColumnSubstitute(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression) Expression { +func ColumnSubstitute(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) Expression { _, _, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, false) return resExpr } @@ -420,7 +420,7 @@ func ColumnSubstitute(ctx sessionctx.Context, expr Expression, schema *Schema, n // // 1: substitute them all once find col in schema. // 2: nothing in expr can be substituted. -func ColumnSubstituteAll(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { +func ColumnSubstituteAll(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { _, hasFail, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, true) return hasFail, resExpr } @@ -430,7 +430,7 @@ func ColumnSubstituteAll(ctx sessionctx.Context, expr Expression, schema *Schema // @return bool means whether the expr has changed. // @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. // @return Expression, the original expr or the changed expr, it depends on the first @return bool. -func ColumnSubstituteImpl(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { +func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) @@ -586,7 +586,7 @@ Loop: // SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. // If the args of one scalar function are all constant, we will substitute it to constant. -func SubstituteCorCol2Constant(ctx sessionctx.Context, expr Expression) (Expression, error) { +func SubstituteCorCol2Constant(ctx BuildContext, expr Expression) (Expression, error) { switch x := expr.(type) { case *ScalarFunction: allConstant := true @@ -710,7 +710,7 @@ var symmetricOp = map[opcode.Op]opcode.Op{ opcode.NullEQ: opcode.NullEQ, } -func pushNotAcrossArgs(ctx sessionctx.Context, exprs []Expression, not bool) ([]Expression, bool) { +func pushNotAcrossArgs(ctx BuildContext, exprs []Expression, not bool) ([]Expression, bool) { newExprs := make([]Expression, 0, len(exprs)) flag := false for _, expr := range exprs { @@ -752,7 +752,7 @@ func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool { return true } -func unwrapCast(sctx sessionctx.Context, parentF *ScalarFunction, castOffset int) (Expression, bool) { +func unwrapCast(sctx BuildContext, parentF *ScalarFunction, castOffset int) (Expression, bool) { _, collation := parentF.CharsetAndCollation() cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction) if !ok || cast.FuncName.L != ast.Cast { @@ -788,7 +788,7 @@ func unwrapCast(sctx sessionctx.Context, parentF *ScalarFunction, castOffset int // eliminateCastFunction will detect the original arg before and the cast type after, once upon // there is no precision loss between them, current cast wrapper can be eliminated. For string // type, collation is also taken into consideration. (mainly used to build range or point) -func eliminateCastFunction(sctx sessionctx.Context, expr Expression) (_ Expression, changed bool) { +func eliminateCastFunction(sctx BuildContext, expr Expression) (_ Expression, changed bool) { f, ok := expr.(*ScalarFunction) if !ok { return expr, false @@ -871,7 +871,7 @@ func eliminateCastFunction(sctx sessionctx.Context, expr Expression) (_ Expressi // Input `not` indicates whether there's a `NOT` be pushed down. // Output `changed` indicates whether the output expression differs from the // input `expr` because of the pushed-down-not. -func pushNotAcrossExpr(ctx sessionctx.Context, expr Expression, not bool) (_ Expression, changed bool) { +func pushNotAcrossExpr(ctx BuildContext, expr Expression, not bool) (_ Expression, changed bool) { if f, ok := expr.(*ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: @@ -935,7 +935,7 @@ func GetExprInsideIsTruth(expr Expression) Expression { } // PushDownNot pushes the `not` function down to the expression's arguments. -func PushDownNot(ctx sessionctx.Context, expr Expression) Expression { +func PushDownNot(ctx BuildContext, expr Expression) Expression { newExpr, _ := pushNotAcrossExpr(ctx, expr, false) return newExpr } @@ -944,7 +944,7 @@ func PushDownNot(ctx sessionctx.Context, expr Expression) Expression { // 1: deeper cast embedded in other complicated function will not be considered. // 2: cast args should be one for original base column and one for constant. // 3: some collation compatibility and precision loss will be considered when remove this cast func. -func EliminateNoPrecisionLossCast(sctx sessionctx.Context, expr Expression) Expression { +func EliminateNoPrecisionLossCast(sctx BuildContext, expr Expression) Expression { newExpr, _ := eliminateCastFunction(sctx, expr) return newExpr } @@ -998,7 +998,7 @@ func Contains(exprs []Expression, e Expression) bool { // ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. // The original DNF will be replaced by the remained part or just be deleted if remained part is nil. // And the extracted part will be appended to the end of the orignal slice. -func ExtractFiltersFromDNFs(ctx sessionctx.Context, conditions []Expression) []Expression { +func ExtractFiltersFromDNFs(ctx BuildContext, conditions []Expression) []Expression { var allExtracted []Expression for i := len(conditions) - 1; i >= 0; i-- { if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { @@ -1015,7 +1015,7 @@ func ExtractFiltersFromDNFs(ctx sessionctx.Context, conditions []Expression) []E } // extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. -func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]Expression, Expression) { +func extractFiltersFromDNF(ctx BuildContext, dnfFunc *ScalarFunction) ([]Expression, Expression) { dnfItems := FlattenDNFConditions(dnfFunc) codeMap := make(map[string]int) hashcode2Expr := make(map[string]Expression) @@ -1082,7 +1082,7 @@ func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]E // the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. // A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as // `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. -func DeriveRelaxedFiltersFromDNF(ctx sessionctx.Context, expr Expression, schema *Schema) Expression { +func DeriveRelaxedFiltersFromDNF(ctx BuildContext, expr Expression, schema *Schema) Expression { sf, ok := expr.(*ScalarFunction) if !ok || sf.FuncName.L != ast.LogicOr { return nil @@ -1164,7 +1164,7 @@ func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant { } // ParamMarkerExpression generate a getparam function expression. -func ParamMarkerExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) { +func ParamMarkerExpression(ctx variable.SessionVarsProvider, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) { useCache := ctx.GetSessionVars().StmtCtx.UseCache isPointExec := ctx.GetSessionVars().StmtCtx.PointExec tp := types.NewFieldType(mysql.TypeUnspecified) @@ -1221,7 +1221,7 @@ func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { } // PosFromPositionExpr generates a position value from PositionExpr. -func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) { +func PosFromPositionExpr(ctx BuildContext, v *ast.PositionExpr) (int, bool, error) { if v.P == nil { return v.N, false, nil } @@ -1392,7 +1392,7 @@ func RemoveDupExprs(exprs []Expression) []Expression { } // GetUint64FromConstant gets a uint64 from constant expression. -func GetUint64FromConstant(ctx sessionctx.Context, expr Expression) (uint64, bool, bool) { +func GetUint64FromConstant(ctx EvalContext, expr Expression) (uint64, bool, bool) { con, ok := expr.(*Constant) if !ok { logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo(ctx))) diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 3693b00b1eca9..e659e28986291 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -2870,7 +2870,7 @@ func (lw *LogicalWindow) tryToGetMppWindows(prop *property.PhysicalProperty) []P { allSupported := true for _, windowFunc := range lw.WindowFuncDescs { - if !windowFunc.CanPushDownToTiFlash(lw.SCtx()) { + if !windowFunc.CanPushDownToTiFlash(lw.SCtx(), lw.SCtx().GetClient()) { lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( "MPP mode may be blocked because window function `" + windowFunc.Name + "` or its arguments are not supported now.") allSupported = false diff --git a/pkg/planner/util/byitem.go b/pkg/planner/util/byitem.go index 878ea0fb046fb..3b429d4a0a4ac 100644 --- a/pkg/planner/util/byitem.go +++ b/pkg/planner/util/byitem.go @@ -18,7 +18,6 @@ import ( "fmt" "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/size" ) @@ -42,7 +41,7 @@ func (by *ByItems) Clone() *ByItems { } // Equal checks whether two ByItems are equal. -func (by *ByItems) Equal(ctx sessionctx.Context, other *ByItems) bool { +func (by *ByItems) Equal(ctx expression.EvalContext, other *ByItems) bool { return by.Expr.Equal(ctx, other.Expr) && by.Desc == other.Desc } diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index b41383d2e8c4e..9247b15bd7fd7 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -663,6 +663,11 @@ type HookContext interface { GetStore() kv.Storage } +// SessionVarsProvider provides the session variables. +type SessionVarsProvider interface { + GetSessionVars() *SessionVars +} + // SessionVars is to handle user-defined or global variables in the current session. type SessionVars struct { Concurrency