Skip to content

Commit

Permalink
expression: remove direct dependencies with sessionctx.Context for …
Browse files Browse the repository at this point in the history
…package `expression` (#51025)

close #51024
  • Loading branch information
lcwangchao authored Feb 8, 2024
1 parent b94a2a8 commit c5eced1
Show file tree
Hide file tree
Showing 31 changed files with 129 additions and 142 deletions.
1 change: 0 additions & 1 deletion pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ go_test(
"//pkg/parser/terror",
"//pkg/planner/core",
"//pkg/session",
"//pkg/sessionctx",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/sessiontxn",
Expand Down
14 changes: 11 additions & 3 deletions pkg/expression/aggregation/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

package_group(
name = "aggregation_friend",
packages = [
"-//pkg/sessionctx/...",
"//...",
],
)

go_library(
name = "aggregation",
srcs = [
Expand All @@ -21,7 +29,9 @@ go_library(
"window_func.go",
],
importpath = "github.com/pingcap/tidb/pkg/expression/aggregation",
visibility = ["//visibility:public"],
visibility = [
":aggregation_friend",
],
deps = [
"//pkg/expression",
"//pkg/kv",
Expand All @@ -30,7 +40,6 @@ go_library(
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/planner/util",
"//pkg/sessionctx",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/types",
Expand Down Expand Up @@ -64,7 +73,6 @@ go_test(
"//pkg/kv",
"//pkg/parser/ast",
"//pkg/parser/mysql",
"//pkg/sessionctx",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/testkit/testsetup",
Expand Down
3 changes: 1 addition & 2 deletions pkg/expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/codec"
Expand Down Expand Up @@ -184,7 +183,7 @@ func PBAggFuncModeToAggFuncMode(pbMode *tipb.AggFunctionMode) (mode AggFunctionM
}

// PBExprToAggFuncDesc converts pb to aggregate function.
func PBExprToAggFuncDesc(ctx sessionctx.Context, aggFunc *tipb.Expr, fieldTps []*types.FieldType) (*AggFuncDesc, error) {
func PBExprToAggFuncDesc(ctx expression.BuildContext, aggFunc *tipb.Expr, fieldTps []*types.FieldType) (*AggFuncDesc, error) {
var name string
switch aggFunc.Tp {
case tipb.ExprType_Count:
Expand Down
3 changes: 1 addition & 2 deletions pkg/expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -51,7 +50,7 @@ type Aggregation interface {
}

// NewDistAggFunc creates new Aggregate function for mock tikv.
func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, *AggFuncDesc, error) {
func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx expression.BuildContext) (Aggregation, *AggFuncDesc, error) {
args := make([]expression.Expression, 0, len(expr.Children))
for _, child := range expr.Children {
arg, err := expression.PBToExpr(ctx, child, fieldTps)
Expand Down
3 changes: 1 addition & 2 deletions pkg/expression/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand All @@ -30,7 +29,7 @@ import (
)

type mockAggFuncSuite struct {
ctx sessionctx.Context
ctx expression.BuildContext
rows []chunk.Row
nullRow chunk.Row
}
Expand Down
47 changes: 23 additions & 24 deletions pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mathutil"
Expand All @@ -42,13 +41,13 @@ type baseFuncDesc struct {
RetTp *types.FieldType
}

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
func newBaseFuncDesc(ctx expression.BuildContext, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.TypeInfer(ctx)
return b, err
}

func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool {
func (a *baseFuncDesc) equal(ctx expression.EvalContext, other *baseFuncDesc) bool {
if a.Name != other.Name || len(a.Args) != len(other.Args) {
return false
}
Expand Down Expand Up @@ -86,18 +85,18 @@ func (a *baseFuncDesc) String() string {
}

// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
func (a *baseFuncDesc) TypeInfer(ctx expression.BuildContext) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
a.typeInfer4Count()
case ast.AggFuncApproxCountDistinct:
a.typeInfer4ApproxCountDistinct(ctx)
a.typeInfer4ApproxCountDistinct()
case ast.AggFuncApproxPercentile:
return a.typeInfer4ApproxPercentile(ctx)
case ast.AggFuncSum:
a.typeInfer4Sum(ctx)
a.typeInfer4Sum()
case ast.AggFuncAvg:
a.typeInfer4Avg(ctx)
a.typeInfer4Avg()
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
Expand All @@ -116,9 +115,9 @@ func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
case ast.WindowFuncLead, ast.WindowFuncLag:
a.typeInfer4LeadLag(ctx)
case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
a.typeInfer4PopOrSamp(ctx)
a.typeInfer4PopOrSamp()
case ast.AggFuncJsonArrayagg:
a.typeInfer4JsonArrayAgg(ctx)
a.typeInfer4JsonArrayAgg()
case ast.AggFuncJsonObjectAgg:
return a.typeInfer4JsonObjectAgg(ctx)
default:
Expand All @@ -127,7 +126,7 @@ func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
return nil
}

func (a *baseFuncDesc) typeInfer4Count(sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4Count() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.SetFlen(21)
a.RetTp.SetDecimal(0)
Expand All @@ -136,11 +135,11 @@ func (a *baseFuncDesc) typeInfer4Count(sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx sessionctx.Context) {
a.typeInfer4Count(ctx)
func (a *baseFuncDesc) typeInfer4ApproxCountDistinct() {
a.typeInfer4Count()
}

func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error {
func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx expression.EvalContext) error {
if len(a.Args) != 2 {
return errors.New("APPROX_PERCENTILE should take 2 arguments")
}
Expand Down Expand Up @@ -182,7 +181,7 @@ func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error

// typeInfer4Sum should return a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Sum(sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4Sum() {
switch a.Args[0].GetType().GetType() {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
Expand Down Expand Up @@ -221,7 +220,7 @@ func (a *baseFuncDesc) TypeInfer4FinalCount(finalCountRetType *types.FieldType)

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4Avg() {
switch a.Args[0].GetType().GetType() {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
Expand All @@ -247,7 +246,7 @@ func (a *baseFuncDesc) typeInfer4Avg(sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4GroupConcat(ctx expression.BuildContext) {
a.RetTp = types.NewFieldType(mysql.TypeVarString)
charset, collate := charset.GetDefaultCharsetAndCollate()
a.RetTp.SetCharset(charset)
Expand All @@ -263,7 +262,7 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {
}
}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4MaxMin(ctx expression.BuildContext) {
_, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction)
if argIsScalaFunc && a.Args[0].GetType().GetType() == mysql.TypeFloat {
// For scalar function, the result of "float32" is set to the "float64"
Expand All @@ -288,20 +287,20 @@ func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
}
}

func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4BitFuncs(ctx expression.BuildContext) {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.SetFlen(21)
types.SetBinChsClnFlag(a.RetTp)
a.RetTp.AddFlag(mysql.UnsignedFlag | mysql.NotNullFlag)
a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4JsonArrayAgg(sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4JsonArrayAgg() {
a.RetTp = types.NewFieldType(mysql.TypeJSON)
types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4JsonObjectAgg(ctx sessionctx.Context) error {
func (a *baseFuncDesc) typeInfer4JsonObjectAgg(ctx expression.BuildContext) error {
a.RetTp = types.NewFieldType(mysql.TypeJSON)
types.SetBinChsClnFlag(a.RetTp)
a.Args[0] = expression.WrapWithCastAsString(ctx, a.Args[0])
Expand Down Expand Up @@ -333,7 +332,7 @@ func (a *baseFuncDesc) typeInfer4PercentRank() {
a.RetTp.SetDecimal(mysql.NotFixedDec)
}

func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4LeadLag(ctx expression.BuildContext) {
if len(a.Args) < 3 {
a.typeInfer4MaxMin(ctx)
} else {
Expand All @@ -343,7 +342,7 @@ func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
}
}

func (a *baseFuncDesc) typeInfer4PopOrSamp(sessionctx.Context) {
func (a *baseFuncDesc) typeInfer4PopOrSamp() {
// var_pop/std/var_samp/stddev_samp's return value type is double
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.SetFlen(mysql.MaxRealWidth)
Expand Down Expand Up @@ -398,7 +397,7 @@ var noNeedCastAggFuncs = map[string]struct{}{
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
func (a *baseFuncDesc) WrapCastForAggArgs(ctx expression.BuildContext) {
if len(a.Args) == 0 {
return
}
Expand Down
19 changes: 9 additions & 10 deletions pkg/expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
Expand All @@ -48,7 +47,7 @@ type AggFuncDesc struct {

// NewAggFuncDesc creates an aggregation function signature descriptor.
// this func cannot be called twice as the TypeInfer has changed the type of args in the first time.
func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) {
func NewAggFuncDesc(ctx expression.BuildContext, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) {
b, err := newBaseFuncDesc(ctx, name, args)
if err != nil {
return nil, err
Expand All @@ -57,7 +56,7 @@ func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expre
}

// NewAggFuncDescForWindowFunc creates an aggregation function from window functions, where baseFuncDesc may be ready.
func NewAggFuncDescForWindowFunc(ctx sessionctx.Context, desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) {
func NewAggFuncDescForWindowFunc(ctx expression.BuildContext, desc *WindowFuncDesc, hasDistinct bool) (*AggFuncDesc, error) {
if desc.RetTp == nil { // safety check
return NewAggFuncDesc(ctx, desc.Name, desc.Args, hasDistinct)
}
Expand Down Expand Up @@ -91,7 +90,7 @@ func (a *AggFuncDesc) String() string {
}

// Equal checks whether two aggregation function signatures are equal.
func (a *AggFuncDesc) Equal(ctx sessionctx.Context, other *AggFuncDesc) bool {
func (a *AggFuncDesc) Equal(ctx expression.EvalContext, other *AggFuncDesc) bool {
if a.HasDistinct != other.HasDistinct {
return false
}
Expand Down Expand Up @@ -200,7 +199,7 @@ func (a *AggFuncDesc) Split(ordinal []int) (partialAggDesc, finalAggDesc *AggFun
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
// | 1 | 1 | 95 | 95.0000 | 95 | 95 | 95 | 95 | 95 | NULL | NULL |
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
switch a.Name {
case ast.AggFuncCount:
return a.evalNullValueInOuterJoin4Count(ctx, schema)
Expand All @@ -219,7 +218,7 @@ func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx sessionctx.Context, schema *e
}

// GetAggFunc gets an evaluator according to the aggregation function signature.
func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation {
func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation {
aggFunc := aggFunction{AggFuncDesc: a}
switch a.Name {
case ast.AggFuncSum:
Expand Down Expand Up @@ -258,7 +257,7 @@ func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation {
}
}

func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
for _, arg := range a.Args {
result := expression.EvaluateExprWithNull(ctx, schema, arg)
con, ok := result.(*expression.Constant)
Expand All @@ -269,7 +268,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, sch
return types.NewDatum(1), true
}

func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand All @@ -278,7 +277,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx sessionctx.Context, schem
return con.Value, true
}

func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand All @@ -287,7 +286,7 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx sessionctx.Context, sc
return con.Value, true
}

func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) {
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
con, ok := result.(*expression.Constant)
if !ok || con.Value.IsNull() {
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (
"bytes"
"fmt"

"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/sessionctx"
)

// ExplainAggFunc generates explain information for a aggregation function.
func ExplainAggFunc(ctx sessionctx.Context, agg *AggFuncDesc, normalized bool) string {
func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string {
var buffer bytes.Buffer
fmt.Fprintf(&buffer, "%s(", agg.Name)
if agg.HasDistinct {
Expand Down
Loading

0 comments on commit c5eced1

Please sign in to comment.