diff --git a/types/convert.go b/types/convert.go index ca59c6620b095..c86b342108704 100644 --- a/types/convert.go +++ b/types/convert.go @@ -369,23 +369,30 @@ func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) return floatStrToIntStr(sc, floatPrefix, str) } -// roundIntStr is to round int string base on the number following dot. +// roundIntStr is to round a **valid int string** base on the number following dot. func roundIntStr(numNextDot byte, intStr string) string { if numNextDot < '5' { return intStr } retStr := []byte(intStr) - for i := len(intStr) - 1; i >= 0; i-- { - if retStr[i] != '9' { - retStr[i]++ + idx := len(intStr) - 1 + for ; idx >= 1; idx-- { + if retStr[idx] != '9' { + retStr[idx]++ break } - if i == 0 { - retStr[i] = '1' + retStr[idx] = '0' + } + if idx == 0 { + if intStr[0] == '9' { + retStr[0] = '1' + retStr = append(retStr, '0') + } else if isDigit(intStr[0]) { + retStr[0]++ + } else { + retStr[1] = '1' retStr = append(retStr, '0') - break } - retStr[i] = '0' } return string(retStr) } @@ -429,6 +436,7 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st } return intStr, nil } + // intCnt and digits contain the prefix `+/-` if validFloat[0] is `+/-` var intCnt int digits := make([]byte, 0, len(validFloat)) if dotIdx == -1 { @@ -451,7 +459,7 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st intCnt += exp if intCnt <= 0 { intStr = "0" - if intCnt == 0 && len(digits) > 0 { + if intCnt == 0 && len(digits) > 0 && isDigit(digits[0]) { intStr = roundIntStr(digits[0], intStr) } return intStr, nil diff --git a/types/convert_test.go b/types/convert_test.go index 01015cfdeb407..a14371b88e7a5 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -666,6 +666,21 @@ func (s *testTypeConvertSuite) TestConvert(c *C) { signedAccept(c, mysql.TypeNewDecimal, dec, "-0.00123") } +func (s *testTypeConvertSuite) TestRoundIntStr(c *C) { + cases := []struct { + a string + b byte + c string + }{ + {"+999", '5', "+1000"}, + {"999", '5', "1000"}, + {"-999", '5', "-1000"}, + } + for _, cc := range cases { + c.Assert(roundIntStr(cc.b, cc.a), Equals, cc.c) + } +} + func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { tests := []struct { origin string @@ -693,15 +708,31 @@ func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { _, err := strconv.ParseFloat(prefix, 64) c.Assert(err, IsNil) } - floatStr, err := floatStrToIntStr(sc, "1e9223372036854775807", "1e9223372036854775807") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "1") - floatStr, err = floatStrToIntStr(sc, "125e342", "125e342.83") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "125") - floatStr, err = floatStrToIntStr(sc, "1e21", "1e21") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "1") + + tests2 := []struct { + origin string + expected string + }{ + {"1e9223372036854775807", "1"}, + {"125e342", "125"}, + {"1e21", "1"}, + {"1e5", "100000"}, + {"-123.45678e5", "-12345678"}, + {"+0.5", "1"}, + {"-0.5", "-1"}, + {".5e0", "1"}, + {"+.5e0", "+1"}, + {"-.5e0", "-1"}, + {".5", "1"}, + {"123.456789e5", "12345679"}, + {"123.456784e5", "12345678"}, + {"+999.9999e2", "+100000"}, + } + for _, t := range tests2 { + str, err := floatStrToIntStr(sc, t.origin, t.origin) + c.Assert(err, IsNil) + c.Assert(str, Equals, t.expected, Commentf("%v, %v", t.origin, t.expected)) + } } // TestConvertTime tests time related conversion.