Skip to content

Commit

Permalink
expression: finish to remove SessionVars and other complex objects …
Browse files Browse the repository at this point in the history
…from `EvalContext` (#52015)

close #51477
  • Loading branch information
lcwangchao authored Mar 25, 2024
1 parent 411e945 commit 639fa00
Show file tree
Hide file tree
Showing 33 changed files with 979 additions and 164 deletions.
2 changes: 2 additions & 0 deletions pkg/executor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ go_library(
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/context",
"//pkg/expression/contextimpl",
"//pkg/infoschema",
"//pkg/keyspace",
"//pkg/kv",
Expand Down Expand Up @@ -396,6 +397,7 @@ go_test(
"//pkg/executor/sortexec",
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/contextimpl",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/meta",
Expand Down
6 changes: 5 additions & 1 deletion pkg/executor/cluster_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/contextimpl"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand Down Expand Up @@ -310,7 +311,10 @@ func TestSQLDigestTextRetriever(t *testing.T) {
updateDigest.String(): "",
},
}
err := r.RetrieveLocal(context.Background(), tk.Session().GetExprCtx())

sqlExec, err := contextimpl.NewSQLExecutor(tk.Session())
require.NoError(t, err)
err = r.RetrieveLocal(context.Background(), sqlExec)
require.NoError(t, err)
require.Equal(t, insertNormalized, r.SQLDigestsMap[insertDigest.String()])
require.Equal(t, "", r.SQLDigestsMap[updateDigest.String()])
Expand Down
23 changes: 19 additions & 4 deletions pkg/executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/executor/internal/pdhelper"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/contextimpl"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta/autoid"
Expand Down Expand Up @@ -2718,8 +2719,12 @@ func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co
e.batchRetrieverHelper.batchSize = 1024
}

sqlExec, err := contextimpl.NewSQLExecutor(sctx)
if err != nil {
return nil, err
}

// The current TiDB node's address is needed by the CLUSTER_TIDB_TRX table.
var err error
var instanceAddr string
if e.table.Name.O == infoschema.ClusterTableTiDBTrx {
instanceAddr, err = infoschema.GetInstanceAddr(sctx)
Expand All @@ -2745,7 +2750,7 @@ func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co
}
// Retrieve the SQL texts if necessary.
if sqlRetriever != nil {
err1 := sqlRetriever.RetrieveLocal(ctx, sctx.GetExprCtx())
err1 := sqlRetriever.RetrieveLocal(ctx, sqlExec)
if err1 != nil {
return errors.Trace(err1)
}
Expand Down Expand Up @@ -2866,7 +2871,13 @@ func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx session
sqlRetriever.SQLDigestsMap[digest] = ""
}
}
err := sqlRetriever.RetrieveGlobal(ctx, sctx.GetExprCtx())

sqlExec, err := contextimpl.NewSQLExecutor(sctx)
if err != nil {
return errors.Trace(err)
}

err = sqlRetriever.RetrieveGlobal(ctx, sqlExec)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -3064,7 +3075,11 @@ func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx.
}
// Retrieve the SQL texts if necessary.
if sqlRetriever != nil {
err1 := sqlRetriever.RetrieveGlobal(ctx, sctx.GetExprCtx())
sqlExec, err := contextimpl.NewSQLExecutor(sctx)
if err != nil {
return errors.Trace(err)
}
err1 := sqlRetriever.RetrieveGlobal(ctx, sqlExec)
if err1 != nil {
return errors.Trace(err1)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ go_library(
"//pkg/expression/context",
"//pkg/expression/contextopt",
"//pkg/extension",
"//pkg/infoschema/context",
"//pkg/kv",
"//pkg/parser",
"//pkg/parser/ast",
Expand All @@ -87,7 +88,6 @@ go_library(
"//pkg/parser/opcode",
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/privilege",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/types",
Expand Down
129 changes: 94 additions & 35 deletions pkg/expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression/contextopt"
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/privilege"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/plancodec"
"github.com/pingcap/tidb/pkg/util/printer"
Expand Down Expand Up @@ -906,12 +904,21 @@ func (c *tidbDecodeKeyFunctionClass) getFunction(ctx BuildContext, args []Expres
if err != nil {
return nil, err
}
sig := &builtinTiDBDecodeKeySig{bf}
sig := &builtinTiDBDecodeKeySig{baseBuiltinFunc: bf}
return sig, nil
}

// DecodeKeyFromString is used to decode key by expressions
var DecodeKeyFromString func(types.Context, infoschema.InfoSchemaMetaVersion, string) string

type builtinTiDBDecodeKeySig struct {
baseBuiltinFunc
contextopt.InfoSchemaPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBDecodeKeySig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.InfoSchemaPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBDecodeKeySig) Clone() builtinFunc {
Expand All @@ -926,11 +933,15 @@ func (b *builtinTiDBDecodeKeySig) evalString(ctx EvalContext, row chunk.Row) (st
if isNull || err != nil {
return "", isNull, err
}
decode := func(ctx EvalContext, s string) string { return s }
if fn := ctx.Value(TiDBDecodeKeyFunctionKey); fn != nil {
decode = fn.(func(ctx EvalContext, s string) string)
is, err := b.GetDomainInfoSchema(ctx)
if err != nil {
return "", true, err
}
return decode(ctx, s), false, nil

if fn := DecodeKeyFromString; fn != nil {
s = fn(ctx.TypeCtx(), is, s)
}
return s, false, nil
}

// TiDBDecodeKeyFunctionKeyType is used to identify the decoder function in context.
Expand All @@ -941,9 +952,6 @@ func (k TiDBDecodeKeyFunctionKeyType) String() string {
return "tidb_decode_key"
}

// TiDBDecodeKeyFunctionKey is used to identify the decoder function in context.
const TiDBDecodeKeyFunctionKey TiDBDecodeKeyFunctionKeyType = 0

type tidbDecodeSQLDigestsFunctionClass struct {
baseFunctionClass
}
Expand All @@ -953,8 +961,7 @@ func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx BuildContext, args [
return nil, err
}

pm := privilege.GetPrivilegeManager(ctx)
if pm != nil && !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.ProcessPriv) {
if !ctx.RequestVerification("", "", "", mysql.ProcessPriv) {
return nil, errSpecificAccessDenied.GenWithStackByArgs("PROCESS")
}

Expand All @@ -968,12 +975,20 @@ func (c *tidbDecodeSQLDigestsFunctionClass) getFunction(ctx BuildContext, args [
if err != nil {
return nil, err
}
sig := &builtinTiDBDecodeSQLDigestsSig{bf}
sig := &builtinTiDBDecodeSQLDigestsSig{baseBuiltinFunc: bf}
return sig, nil
}

type builtinTiDBDecodeSQLDigestsSig struct {
baseBuiltinFunc
contextopt.SessionVarsPropReader
contextopt.SQLExecutorPropReader
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinTiDBDecodeSQLDigestsSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.SQLExecutorPropReader.RequiredOptionalEvalProps()
}

func (b *builtinTiDBDecodeSQLDigestsSig) Clone() builtinFunc {
Expand Down Expand Up @@ -1026,15 +1041,26 @@ func (b *builtinTiDBDecodeSQLDigestsSig) evalString(ctx EvalContext, row chunk.R
}
}

vars, err := b.GetSessionVars(ctx)
if err != nil {
return "", true, err
}

// Querying may take some time and it takes a context.Context as argument, which is not available here.
// We simply create a context with a timeout here.
timeout := time.Duration(ctx.GetSessionVars().GetMaxExecutionTime()) * time.Millisecond
timeout := time.Duration(vars.GetMaxExecutionTime()) * time.Millisecond
if timeout == 0 || timeout > 20*time.Second {
timeout = 20 * time.Second
}
goCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = retriever.RetrieveGlobal(goCtx, ctx)

exec, err := b.GetSQLExecutor(ctx)
if err != nil {
return "", true, err
}

err = retriever.RetrieveGlobal(goCtx, exec)
if err != nil {
if errors.Cause(err) == context.DeadlineExceeded || errors.Cause(err) == context.Canceled {
return "", true, errUnknown.GenWithStack("Retrieving cancelled internally with error: %v", err)
Expand Down Expand Up @@ -1189,13 +1215,20 @@ func (c *nextValFunctionClass) getFunction(ctx BuildContext, args []Expression)
if err != nil {
return nil, err
}
sig := &builtinNextValSig{bf}
sig := &builtinNextValSig{baseBuiltinFunc: bf}
bf.tp.SetFlen(10)
return sig, nil
}

type builtinNextValSig struct {
baseBuiltinFunc
contextopt.SequenceOperatorPropReader
contextopt.SessionVarsPropReader
}

func (b *builtinNextValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
}

func (b *builtinNextValSig) Clone() builtinFunc {
Expand All @@ -1214,22 +1247,26 @@ func (b *builtinNextValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool
db = ctx.CurrentDB()
}
// Check the tableName valid.
sequence, err := util.GetSequenceByName(ctx.GetInfoSchema(), model.NewCIStr(db), model.NewCIStr(seq))
sequence, err := b.GetSequenceOperator(ctx, db, seq)
if err != nil {
return 0, false, err
}
// Get session vars
vars, err := b.GetSessionVars(ctx)
if err != nil {
return 0, false, err
}
// Do the privilege check.
checker := privilege.GetPrivilegeManager(ctx)
user := ctx.GetSessionVars().User
if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, db, seq, "", mysql.InsertPriv) {
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq)
}
nextVal, err := sequence.GetSequenceNextVal(ctx, db, seq)
nextVal, err := sequence.GetSequenceNextVal()
if err != nil {
return 0, false, err
}
// update the sequenceState.
ctx.GetSessionVars().SequenceState.UpdateState(sequence.GetSequenceID(), nextVal)
vars.SequenceState.UpdateState(sequence.GetSequenceID(), nextVal)
return nextVal, false, nil
}

Expand All @@ -1245,13 +1282,20 @@ func (c *lastValFunctionClass) getFunction(ctx BuildContext, args []Expression)
if err != nil {
return nil, err
}
sig := &builtinLastValSig{bf}
sig := &builtinLastValSig{baseBuiltinFunc: bf}
bf.tp.SetFlen(10)
return sig, nil
}

type builtinLastValSig struct {
baseBuiltinFunc
contextopt.SequenceOperatorPropReader
contextopt.SessionVarsPropReader
}

func (b *builtinLastValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SessionVarsPropReader.RequiredOptionalEvalProps()
}

func (b *builtinLastValSig) Clone() builtinFunc {
Expand All @@ -1270,17 +1314,21 @@ func (b *builtinLastValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool
db = ctx.CurrentDB()
}
// Check the tableName valid.
sequence, err := util.GetSequenceByName(ctx.GetInfoSchema(), model.NewCIStr(db), model.NewCIStr(seq))
sequence, err := b.GetSequenceOperator(ctx, db, seq)
if err != nil {
return 0, false, err
}
// Get session vars
vars, err := b.GetSessionVars(ctx)
if err != nil {
return 0, false, err
}
// Do the privilege check.
checker := privilege.GetPrivilegeManager(ctx)
user := ctx.GetSessionVars().User
if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, db, seq, "", mysql.SelectPriv) {
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.SelectPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("SELECT", user.AuthUsername, user.AuthHostname, seq)
}
return ctx.GetSessionVars().SequenceState.GetLastValue(sequence.GetSequenceID())
return vars.SequenceState.GetLastValue(sequence.GetSequenceID())
}

type setValFunctionClass struct {
Expand All @@ -1295,13 +1343,20 @@ func (c *setValFunctionClass) getFunction(ctx BuildContext, args []Expression) (
if err != nil {
return nil, err
}
sig := &builtinSetValSig{bf}
sig := &builtinSetValSig{baseBuiltinFunc: bf}
bf.tp.SetFlen(args[1].GetType().GetFlen())
return sig, nil
}

type builtinSetValSig struct {
baseBuiltinFunc
contextopt.SequenceOperatorPropReader
contextopt.SessionVarsPropReader
}

func (b *builtinSetValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SequenceOperatorPropReader.RequiredOptionalEvalProps()
}

func (b *builtinSetValSig) Clone() builtinFunc {
Expand All @@ -1320,21 +1375,25 @@ func (b *builtinSetValSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool,
db = ctx.CurrentDB()
}
// Check the tableName valid.
sequence, err := util.GetSequenceByName(ctx.GetInfoSchema(), model.NewCIStr(db), model.NewCIStr(seq))
sequence, err := b.GetSequenceOperator(ctx, db, seq)
if err != nil {
return 0, false, err
}
// Get session vars
vars, err := b.GetSessionVars(ctx)
if err != nil {
return 0, false, err
}
// Do the privilege check.
checker := privilege.GetPrivilegeManager(ctx)
user := ctx.GetSessionVars().User
if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, db, seq, "", mysql.InsertPriv) {
user := vars.User
if !ctx.RequestVerification(db, seq, "", mysql.InsertPriv) {
return 0, false, errSequenceAccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, seq)
}
setValue, isNull, err := b.args[1].EvalInt(ctx, row)
if isNull || err != nil {
return 0, isNull, err
}
return sequence.SetSequenceVal(ctx, setValue, db, seq)
return sequence.SetSequenceVal(setValue)
}

func getSchemaAndSequence(sequenceName string) (string, string) {
Expand Down
Loading

0 comments on commit 639fa00

Please sign in to comment.