diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 366c48def1b80..b99fd57d09b5e 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -471,6 +471,17 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b return res, false, err } +// checkDecimalSize checks the scale and precision limit when casting to a decimal +func checkDecimalSize(tp *types.FieldType, val string) error { + if tp.Flen > mysql.MaxDecimalWidth { + return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, val, mysql.MaxDecimalWidth) + } + if tp.Decimal > mysql.MaxDecimalScale { + return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, val, mysql.MaxDecimalScale) + } + return nil +} + type builtinCastIntAsDecimalSig struct { baseBuiltinCastFunc } @@ -499,6 +510,10 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe } res = types.NewDecFromUint(uVal) } + err = checkDecimalSize(b.tp, strconv.FormatInt(val, 10)) + if err != nil { + return nil, false, err + } res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx) return res, isNull, err } @@ -877,6 +892,10 @@ func (b *builtinCastDecimalAsDecimalSig) evalDecimal(row chunk.Row) (res *types. if isNull || err != nil { return res, isNull, err } + err = checkDecimalSize(b.tp, evalDecimal.String()) + if err != nil { + return nil, false, err + } res = &types.MyDecimal{} if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && evalDecimal.IsNegative()) { *res = *evalDecimal @@ -1172,10 +1191,20 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M res = new(types.MyDecimal) sc := b.ctx.GetSessionVars().StmtCtx if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && res.IsNegative()) { - err = sc.HandleTruncate(res.FromString([]byte(val))) + parseErr := res.FromString([]byte(val)) + err = sc.HandleTruncate(parseErr) if err != nil { return res, false, err } + if terror.ErrorEqual(parseErr, types.ErrBadNumber) { + return res, false, nil + } + } + // If cast a string to decimal with ErrBadNumber, returns zero decimal directly + // so check the decimal limit after parsing decimal from string + err = checkDecimalSize(b.tp, val) + if err != nil { + return nil, false, err } res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc) return res, false, err diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index c0636e4b34a04..630018bdeb995 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -1375,3 +1375,64 @@ func (s *testEvaluatorSuite) TestWrapWithCastAsJSON(c *C) { c.Assert(ok, IsTrue) c.Assert(output, Equals, input) } + +func (s *testEvaluatorSuite) TestCastToDecimalError(c *C) { + var sig builtinFunc + ctx, _ := s.ctx, s.ctx.GetSessionVars().StmtCtx + widthCases := []struct { + flen int + decimal int + err error + }{ + {mysql.MaxDecimalWidth, mysql.MaxDecimalScale + 1, types.ErrTooBigScale}, + {mysql.MaxDecimalWidth + 1, mysql.MaxDecimalScale, types.ErrTooBigPrecision}, + {mysql.MaxDecimalWidth, mysql.MaxDecimalScale, nil}, + } + castToDecCases := []struct { + before *Column + row chunk.MutRow + }{ + // cast int as decimal. + { + &Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0}, + chunk.MutRowFromDatums([]types.Datum{types.NewIntDatum(1234)}), + }, + // cast string as decimal. + { + &Column{RetType: types.NewFieldType(mysql.TypeString), Index: 0}, + chunk.MutRowFromDatums([]types.Datum{types.NewStringDatum("1234")}), + }, + // cast decimal as decimal. + { + &Column{RetType: types.NewFieldType(mysql.TypeNewDecimal), Index: 0}, + chunk.MutRowFromDatums([]types.Datum{types.NewDecimalDatum(types.NewDecFromStringForTest("1234"))}), + }, + } + + for _, width := range widthCases { + for i, t := range castToDecCases { + args := []Expression{t.before} + tp := types.NewFieldType(mysql.TypeNewDecimal) + tp.Flen, tp.Decimal = width.flen, width.decimal + tp.Charset = charset.CharsetUTF8 + decFunc := newBaseBuiltinCastFunc(newBaseBuiltinFunc(ctx, args), false) + decFunc.tp = tp + switch i { + case 0: + sig = &builtinCastIntAsDecimalSig{decFunc} + case 1: + sig = &builtinCastStringAsDecimalSig{decFunc} + case 2: + sig = &builtinCastDecimalAsDecimalSig{decFunc} + } + _, isNull, err := sig.evalDecimal(t.row.ToRow()) + c.Assert(isNull, Equals, false) + if width.err != nil { + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, width.err), IsTrue) + } else { + c.Assert(err, IsNil) + } + } + } +}