Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: check decimal scale and precision before casting #11644

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,17 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b
return res, false, err
}

// checkDecimalSize checks the scale and precision limit when casting to a decimal
func checkDecimalSize(tp *types.FieldType, val string) error {
if tp.Flen > mysql.MaxDecimalWidth {
return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, val, mysql.MaxDecimalWidth)
}
if tp.Decimal > mysql.MaxDecimalScale {
return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, val, mysql.MaxDecimalScale)
}
return nil
}

type builtinCastIntAsDecimalSig struct {
baseBuiltinCastFunc
}
Expand Down Expand Up @@ -499,6 +510,10 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe
}
res = types.NewDecFromUint(uVal)
}
err = checkDecimalSize(b.tp, strconv.FormatInt(val, 10))
if err != nil {
return nil, false, err
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx)
return res, isNull, err
}
Expand Down Expand Up @@ -877,6 +892,10 @@ func (b *builtinCastDecimalAsDecimalSig) evalDecimal(row chunk.Row) (res *types.
if isNull || err != nil {
return res, isNull, err
}
err = checkDecimalSize(b.tp, evalDecimal.String())
if err != nil {
return nil, false, err
}
res = &types.MyDecimal{}
if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && evalDecimal.IsNegative()) {
*res = *evalDecimal
Expand Down Expand Up @@ -1172,10 +1191,20 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M
res = new(types.MyDecimal)
sc := b.ctx.GetSessionVars().StmtCtx
if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && res.IsNegative()) {
err = sc.HandleTruncate(res.FromString([]byte(val)))
parseErr := res.FromString([]byte(val))
err = sc.HandleTruncate(parseErr)
if err != nil {
return res, false, err
}
if terror.ErrorEqual(parseErr, types.ErrBadNumber) {
return res, false, nil
}
}
// If cast a string to decimal with ErrBadNumber, returns zero decimal directly
// so check the decimal limit after parsing decimal from string
err = checkDecimalSize(b.tp, val)
if err != nil {
return nil, false, err
}
res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc)
return res, false, err
Expand Down
61 changes: 61 additions & 0 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,3 +1375,64 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsJSON(c *C) {
c.Assert(ok, IsTrue)
c.Assert(output, Equals, input)
}

func (s *testEvaluatorSuite) TestCastToDecimalError(c *C) {
var sig builtinFunc
ctx, _ := s.ctx, s.ctx.GetSessionVars().StmtCtx
widthCases := []struct {
flen int
decimal int
err error
}{
{mysql.MaxDecimalWidth, mysql.MaxDecimalScale + 1, types.ErrTooBigScale},
{mysql.MaxDecimalWidth + 1, mysql.MaxDecimalScale, types.ErrTooBigPrecision},
{mysql.MaxDecimalWidth, mysql.MaxDecimalScale, nil},
}
castToDecCases := []struct {
before *Column
row chunk.MutRow
}{
// cast int as decimal.
{
&Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0},
chunk.MutRowFromDatums([]types.Datum{types.NewIntDatum(1234)}),
},
// cast string as decimal.
{
&Column{RetType: types.NewFieldType(mysql.TypeString), Index: 0},
chunk.MutRowFromDatums([]types.Datum{types.NewStringDatum("1234")}),
},
// cast decimal as decimal.
{
&Column{RetType: types.NewFieldType(mysql.TypeNewDecimal), Index: 0},
chunk.MutRowFromDatums([]types.Datum{types.NewDecimalDatum(types.NewDecFromStringForTest("1234"))}),
},
}

for _, width := range widthCases {
for i, t := range castToDecCases {
args := []Expression{t.before}
tp := types.NewFieldType(mysql.TypeNewDecimal)
tp.Flen, tp.Decimal = width.flen, width.decimal
tp.Charset = charset.CharsetUTF8
decFunc := newBaseBuiltinCastFunc(newBaseBuiltinFunc(ctx, args), false)
decFunc.tp = tp
switch i {
case 0:
sig = &builtinCastIntAsDecimalSig{decFunc}
case 1:
sig = &builtinCastStringAsDecimalSig{decFunc}
case 2:
sig = &builtinCastDecimalAsDecimalSig{decFunc}
}
_, isNull, err := sig.evalDecimal(t.row.ToRow())
c.Assert(isNull, Equals, false)
if width.err != nil {
c.Assert(err, NotNil)
c.Assert(terror.ErrorEqual(err, width.err), IsTrue)
} else {
c.Assert(err, IsNil)
}
}
}
}