Skip to content

Commit

Permalink
expression: add more optional properties for EvalContext (#51725)
Browse files Browse the repository at this point in the history
ref #51477
  • Loading branch information
lcwangchao authored Mar 21, 2024
1 parent 73328e5 commit 21e7939
Show file tree
Hide file tree
Showing 32 changed files with 867 additions and 164 deletions.
4 changes: 2 additions & 2 deletions pkg/executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(ctx AggFuncUpdateContext, pr Pa
}
decimalCount := types.NewDecFromInt(p.count)
finalResult := new(types.MyDecimal)
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, ctx.GetSessionVars().GetDivPrecisionIncrement())
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, ctx.GetDivPrecisionIncrement())
if err != nil {
return err
}
Expand Down Expand Up @@ -285,7 +285,7 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(ctx AggFuncUpdateC
}
decimalCount := types.NewDecFromInt(p.count)
finalResult := new(types.MyDecimal)
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, ctx.GetSessionVars().GetDivPrecisionIncrement())
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, ctx.GetDivPrecisionIncrement())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) {
x := evalCtx.Value.GetMysqlDecimal()
y := types.NewDecFromInt(evalCtx.Count)
to := new(types.MyDecimal)
err := types.DecimalDiv(x, y, to, evalCtx.Ctx.GetSessionVars().GetDivPrecisionIncrement())
err := types.DecimalDiv(x, y, to, evalCtx.Ctx.GetDivPrecisionIncrement())
terror.Log(err)
frac := af.RetTp.GetDecimal()
if frac == -1 {
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(ctx EvalContext, row chu
}

c := &types.MyDecimal{}
err = types.DecimalDiv(a, b, c, ctx.GetSessionVars().GetDivPrecisionIncrement())
err = types.DecimalDiv(a, b, c, ctx.GetDivPrecisionIncrement())
if err == types.ErrDivByZero {
return c, true, handleDivisionByZeroError(ctx)
} else if err == types.ErrTruncated {
Expand Down Expand Up @@ -829,7 +829,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(ctx EvalContext, row chun
}

c := &types.MyDecimal{}
err = types.DecimalDiv(num[0], num[1], c, ctx.GetSessionVars().GetDivPrecisionIncrement())
err = types.DecimalDiv(num[0], num[1], c, ctx.GetDivPrecisionIncrement())
if err == types.ErrDivByZero {
return 0, true, handleDivisionByZeroError(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(ctx EvalContext, inpu
if result.IsNull(i) {
continue
}
err = types.DecimalDiv(&x[i], &y[i], &to, ctx.GetSessionVars().GetDivPrecisionIncrement())
err = types.DecimalDiv(&x[i], &y[i], &to, ctx.GetDivPrecisionIncrement())
if err == types.ErrDivByZero {
if err = handleDivisionByZeroError(ctx); err != nil {
return err
Expand Down Expand Up @@ -596,7 +596,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(ctx EvalContext, input
}

c := &types.MyDecimal{}
err = types.DecimalDiv(&num[0][i], &num[1][i], c, ctx.GetSessionVars().GetDivPrecisionIncrement())
err = types.DecimalDiv(&num[0][i], &num[1][i], c, ctx.GetDivPrecisionIncrement())
if err == types.ErrDivByZero {
if err = handleDivisionByZeroError(ctx); err != nil {
return err
Expand Down
29 changes: 24 additions & 5 deletions pkg/expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextopt"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand Down Expand Up @@ -1021,12 +1023,14 @@ func (c *validatePasswordStrengthFunctionClass) getFunction(ctx BuildContext, ar
return nil, err
}
bf.tp.SetFlen(21)
sig := &builtinValidatePasswordStrengthSig{bf}
sig := &builtinValidatePasswordStrengthSig{baseBuiltinFunc: bf}
return sig, nil
}

type builtinValidatePasswordStrengthSig struct {
baseBuiltinFunc
contextopt.SessionVarsPropReader
contextopt.CurrentUserPropReader
}

func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc {
Expand All @@ -1035,10 +1039,25 @@ func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc {
return newSig
}

// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface.
func (b *builtinValidatePasswordStrengthSig) RequiredOptionalEvalProps() context.OptionalEvalPropKeySet {
return b.SessionVarsPropReader.RequiredOptionalEvalProps() |
b.CurrentUserPropReader.RequiredOptionalEvalProps()
}

// evalInt evals VALIDATE_PASSWORD_STRENGTH(str).
// See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength
func (b *builtinValidatePasswordStrengthSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) {
globalVars := ctx.GetSessionVars().GlobalVarsAccessor
user, err := b.CurrentUser(ctx)
if err != nil {
return 0, true, err
}

vars, err := b.GetSessionVars(ctx)
if err != nil {
return 0, true, err
}
globalVars := vars.GlobalVarsAccessor
str, isNull, err := b.args[0].EvalString(ctx, row)
if err != nil || isNull {
return 0, true, err
Expand All @@ -1050,11 +1069,11 @@ func (b *builtinValidatePasswordStrengthSig) evalInt(ctx EvalContext, row chunk.
} else if !variable.TiDBOptOn(validation) {
return 0, false, nil
}
return b.validateStr(ctx, str, &globalVars)
return b.validateStr(str, user, &globalVars)
}

func (b *builtinValidatePasswordStrengthSig) validateStr(ctx EvalContext, str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) {
if warn, err := pwdValidator.ValidateUserNameInPassword(str, ctx.GetSessionVars()); err != nil {
func (b *builtinValidatePasswordStrengthSig) validateStr(str string, user *auth.UserIdentity, globalVars *variable.GlobalVarAccessor) (int64, bool, error) {
if warn, err := pwdValidator.ValidateUserNameInPassword(str, user, globalVars); err != nil {
return 0, true, err
} else if len(warn) > 0 {
return 0, false, nil
Expand Down
14 changes: 12 additions & 2 deletions pkg/expression/builtin_encryption_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,16 @@ func (b *builtinValidatePasswordStrengthSig) vectorized() bool {
}

func (b *builtinValidatePasswordStrengthSig) vecEvalInt(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
user, err := b.CurrentUser(ctx)
if err != nil {
return err
}

vars, err := b.GetSessionVars(ctx)
if err != nil {
return err
}

n := input.NumRows()
buf, err := b.bufAllocator.get()
if err != nil {
Expand All @@ -888,7 +898,7 @@ func (b *builtinValidatePasswordStrengthSig) vecEvalInt(ctx EvalContext, input *
result.ResizeInt64(n, false)
result.MergeNulls(buf)
i64s := result.Int64s()
globalVars := ctx.GetSessionVars().GlobalVarsAccessor
globalVars := vars.GlobalVarsAccessor
enableValidation := false
validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable)
if err != nil {
Expand All @@ -901,7 +911,7 @@ func (b *builtinValidatePasswordStrengthSig) vecEvalInt(ctx EvalContext, input *
}
if !enableValidation {
i64s[i] = 0
} else if score, isNull, err := b.validateStr(ctx, buf.GetString(i), &globalVars); err != nil {
} else if score, isNull, err := b.validateStr(buf.GetString(i), user, &globalVars); err != nil {
return err
} else if !isNull {
i64s[i] = score
Expand Down
Loading

0 comments on commit 21e7939

Please sign in to comment.