From 0e946409662698980c7dae677b448b81e08e0e5a Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 7 Sep 2022 16:06:59 -0700 Subject: [PATCH 01/21] use decimal for arithmetic functions --- sql/expression/arithmetic.go | 88 ++++++++++++++++++- sql/expression/arithmetic_test.go | 80 ++++++++++------- .../function/aggregation/unary_agg_buffers.go | 47 ++++++++-- .../function/aggregation/unary_aggs.og.go | 1 + sql/parse/parse.go | 10 +-- sql/parse/parse_test.go | 4 +- 6 files changed, 179 insertions(+), 51 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 20cfa43544..1421961c4f 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/shopspring/decimal" "reflect" "strings" "time" @@ -151,7 +152,7 @@ func (a *Arithmetic) Type() sql.Type { return sql.Int64 } - return sql.Float64 + return a.getDecimalType() case sqlparser.ShiftLeftStr, sqlparser.ShiftRightStr: return sql.Uint64 @@ -163,7 +164,62 @@ func (a *Arithmetic) Type() sql.Type { return sql.Int64 } - return sql.Float64 + return a.getDecimalType() +} + +// getArithmeticType returns column type if +func (a *Arithmetic) getDecimalType() sql.Type { + var resType sql.Type + var precision int + var scale int + sql.Inspect(a, func(expr sql.Expression) bool { + switch c := expr.(type) { + case *GetField: + resType = c.Type() + return false + case *Literal: + val, err := c.Eval(nil, nil) + if err != nil { + return false + } + var v string + switch val.(type) { + case float64: + v = fmt.Sprintf("%f", val) + default: + v = fmt.Sprintf("%v", val) + } + p, s := GetDecimalPrecisionAndScale(v) + if p > precision { + precision = p + } + if s > scale { + scale = s + } + return true + } + return true + }) + + if resType == nil { + r, err := sql.CreateDecimalType(uint8(precision), uint8(scale)) + if err != nil { + return sql.Float64 + } + resType = r + } + + return resType +} + +func GetDecimalPrecisionAndScale(val string) (int, int) { + scale := 0 + precScale := strings.Split(strings.TrimPrefix(val, "-"), ".") + if len(precScale) != 1 { + scale = len(precScale[1]) + } + precision := len((precScale)[0]) + scale + return precision, scale } func isInterval(expr sql.Expression) bool { @@ -298,6 +354,11 @@ func plus(lval, rval interface{}) (interface{}, error) { case float64: return l + r, nil } + case decimal.Decimal: + switch r := rval.(type) { + case decimal.Decimal: + return l.Add(r), nil + } case time.Time: switch r := rval.(type) { case *TimeDelta: @@ -334,6 +395,11 @@ func minus(lval, rval interface{}) (interface{}, error) { case float64: return l - r, nil } + case decimal.Decimal: + switch r := rval.(type) { + case decimal.Decimal: + return l.Sub(r), nil + } case time.Time: switch r := rval.(type) { case *TimeDelta: @@ -365,6 +431,11 @@ func mult(lval, rval interface{}) (interface{}, error) { case float64: return l * r, nil } + case decimal.Decimal: + switch r := rval.(type) { + case decimal.Decimal: + return l.Mul(r).BigInt(), nil + } } return nil, errUnableToCast.New(lval, rval) @@ -398,6 +469,15 @@ func div(lval, rval interface{}) (interface{}, error) { } return l / r, nil } + case decimal.Decimal: + switch r := rval.(type) { + case decimal.Decimal: + if r.String() == "0" { + return nil, nil + } + exp := (l.Exponent() * -1) + 4 + return l.DivRound(r, exp), nil + } } return nil, errUnableToCast.New(lval, rval) @@ -551,7 +631,7 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } if !sql.IsNumber(e.Child.Type()) { - child, err = sql.Float64.Convert(child) + child, err = decimal.NewFromString(fmt.Sprintf("%v", child)) if err != nil { child = 0.0 } @@ -582,6 +662,8 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return -int32(n), nil case uint64: return -int64(n), nil + case decimal.Decimal: + return n.Neg(), err default: return nil, sql.ErrInvalidType.New(reflect.TypeOf(n)) } diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 1602ef7557..902c8f8179 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -15,6 +15,8 @@ package expression import ( + "fmt" + "github.com/shopspring/decimal" "testing" "time" @@ -28,12 +30,12 @@ func TestPlus(t *testing.T) { var testCases = []struct { name string left, right float64 - expected float64 + expected string }{ - {"1 + 1", 1, 1, 2}, - {"-1 + 1", -1, 1, 0}, - {"0 + 0", 0, 0, 0}, - {"0.14159 + 3.0", 0.14159, 3.0, float64(0.14159) + float64(3)}, + {"1 + 1", 1, 1, "2"}, + {"-1 + 1", -1, 1, "0"}, + {"0 + 0", 0, 0, "0"}, + {"0.14159 + 3.0", 0.14159, 3.0, "3.14159"}, } for _, tt := range testCases { @@ -44,7 +46,11 @@ func TestPlus(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(tt.expected, result) + if d, ok := result.(decimal.Decimal); ok { + require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) + } else { + require.Equal("0", fmt.Sprintf("%v", result)) + } }) } @@ -52,7 +58,7 @@ func TestPlus(t *testing.T) { result, err := NewPlus(NewLiteral("2", sql.LongText), NewLiteral(3, sql.Float64)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(float64(5), result) + require.Equal("5", fmt.Sprintf("%v", result)) } func TestPlusInterval(t *testing.T) { @@ -82,12 +88,12 @@ func TestMinus(t *testing.T) { var testCases = []struct { name string left, right float64 - expected float64 + expected string }{ - {"1 - 1", 1, 1, 0}, - {"1 - 1", 1, 1, 0}, - {"0 - 0", 0, 0, 0}, - {"3.14159 - 3.0", 3.14159, 3.0, float64(3.14159) - float64(3.0)}, + {"1 - 1", 1, 1, "0"}, + {"1 - 1", 1, 1, "0"}, + {"0 - 0", 0, 0, "0"}, + {"3.14159 - 3.0", 3.14159, 3.0, "0.14159"}, } for _, tt := range testCases { @@ -98,7 +104,11 @@ func TestMinus(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(tt.expected, result) + if d, ok := result.(decimal.Decimal); ok { + require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) + } else { + require.Equal("0", fmt.Sprintf("%v", result)) + } }) } @@ -106,7 +116,7 @@ func TestMinus(t *testing.T) { result, err := NewMinus(NewLiteral("10", sql.LongText), NewLiteral(10, sql.Int64)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(float64(0), result) + require.Equal("0", fmt.Sprintf("%v", result)) } func TestMinusInterval(t *testing.T) { @@ -127,12 +137,12 @@ func TestMult(t *testing.T) { var testCases = []struct { name string left, right float64 - expected float64 + expected string }{ - {"1 * 1", 1, 1, 1}, - {"-1 * 1", -1, 1, -1}, - {"0 * 0", 0, 0, 0}, - {"3.14159 * 3.0", 3.14159, 3.0, float64(3.14159) * float64(3.0)}, + {"1 * 1", 1, 1, "1"}, + {"-1 * 1", -1, 1, "-1"}, + {"0 * 0", 0, 0, "0"}, + {"3.14159 * 3.0", 3.14159, 3.1, "9.738929"}, } for _, tt := range testCases { @@ -143,31 +153,35 @@ func TestMult(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(tt.expected, result) + if d, ok := result.(decimal.Decimal); ok { + require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) + } else { + require.Equal("0", fmt.Sprintf("%v", result)) + } }) } require := require.New(t) - result, err := NewMult(NewLiteral("10", sql.LongText), NewLiteral("10", sql.LongText)). + result, err := NewMult(NewLiteral("10", sql.LongText), NewLiteral("3.0", sql.LongText)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal(float64(100), result) + require.Equal("30.0", result.(decimal.Decimal).StringFixed(result.(decimal.Decimal).Exponent()*-1)) } func TestDiv(t *testing.T) { var floatTestCases = []struct { name string left, right float64 - expected float64 + expected string null bool }{ - {"1 / 1", 1, 1, 1, false}, - {"-1 / 1", -1, 1, -1, false}, - {"0 / 1234567890", 0, 12345677890, 0, false}, - {"3.14159 / 3.0", 3.14159, 3.0, float64(3.14159) / float64(3.0), false}, - {"1/0", 1, 0, 0, true}, - {"-1/0", -1, 0, 0, true}, - {"0/0", 0, 0, 0, true}, + {"1 / 1", 1, 1, "1.0000", false}, + {"-1 / 1", -1, 1, "-1.0000", false}, + {"0 / 1234567890", 0, 12345677890, "0.0000", false}, + {"3.14159 / 3.0", 3.14159, 3.0, "1.047196667", false}, + {"1/0", 1, 0, "", true}, + {"-1/0", -1, 0, "", true}, + {"0/0", 0, 0, "", true}, } for _, tt := range floatTestCases { @@ -180,7 +194,11 @@ func TestDiv(t *testing.T) { if tt.null { assert.Equal(t, nil, result) } else { - assert.Equal(t, tt.expected, result) + if d, ok := result.(decimal.Decimal); ok { + require.Equal(t, tt.expected, d.StringFixed(d.Exponent()*-1)) + } else { + require.Equal(t, tt.expected, fmt.Sprintf("%v", result)) + } } }) } diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 941a30ca35..53d13d618b 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -5,6 +5,7 @@ import ( "reflect" "github.com/mitchellh/hashstructure" + "github.com/shopspring/decimal" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" @@ -12,12 +13,12 @@ import ( type sumBuffer struct { isnil bool - sum float64 + sum interface{} expr sql.Expression } func NewSumBuffer(child sql.Expression) *sumBuffer { - return &sumBuffer{true, float64(0), child} + return &sumBuffer{true, decimal.NewFromInt(0), child} } // Update implements the AggregationBuffer interface. @@ -31,17 +32,45 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - val, err := sql.Float64.Convert(v) - if err != nil { - val = float64(0) - } - if m.isnil { - m.sum = 0 + m.sum = decimal.NewFromInt(0) m.isnil = false } - m.sum += val.(float64) + switch n := v.(type) { + case float64: + val, err := sql.Float64.Convert(n) + if err != nil { + val = float64(0) + } + if m.isnil { + m.sum = 0 + m.isnil = false + } + sum, err := sql.Float64.Convert(m.sum) + if err != nil { + sum = float64(0) + } + m.sum = sum.(float64) + val.(float64) + case decimal.Decimal: + if sum, ok := m.sum.(decimal.Decimal); ok { + m.sum = sum.Add(n) + } else { + m.sum = n + } + case string: + p, s := expression.GetDecimalPrecisionAndScale(n) + dt, err := sql.CreateDecimalType(uint8(p), uint8(s)) + val, err := dt.Convert(v) + if err != nil { + val = decimal.NewFromInt(0) + } + if sum, ok := m.sum.(decimal.Decimal); ok { + m.sum = sum.Add(val.(decimal.Decimal)) + } else { + m.sum = val.(decimal.Decimal) + } + } return nil } diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 04a7cf3df4..58e6dcb6be 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -589,6 +589,7 @@ func NewSum(e sql.Expression) *Sum { } func (a *Sum) Type() sql.Type { + // TODO the type depends on column if functions over table, or the input given (most likely to be DECIMAL with the longest precision and scaled input defines it) return sql.Float64 } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index e611a89d16..532543a39a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -3381,12 +3381,10 @@ func convertVal(ctx *sql.Context, v *sqlparser.SQLVal) (sql.Expression, error) { // use the value as string format to keep precision and scale as defined for DECIMAL data type to avoid rounded up float64 value if ps := strings.Split(string(v.Val), "."); len(ps) == 2 { - if scale, err := strconv.ParseUint(ps[1], 10, 64); err != nil || scale > 0 { - ogVal := string(v.Val) - floatVal := fmt.Sprintf("%v", val) - if len(ogVal) >= len(floatVal) && ogVal != floatVal { - return expression.NewLiteral(string(v.Val), sql.CreateLongText(ctx.GetCollation())), nil - } + ogVal := string(v.Val) + floatVal := fmt.Sprintf("%v", val) + if len(ogVal) >= len(floatVal) && ogVal != floatVal { + return expression.NewLiteral(string(v.Val), sql.CreateLongText(ctx.GetCollation())), nil } } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 139c90c035..2ad86d1315 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -2235,8 +2235,8 @@ CREATE TABLE t2 []sql.Expression{ expression.NewAlias("1.0 * a + 2.0 * b", expression.NewPlus( - expression.NewMult(expression.NewLiteral(float64(1.0), sql.Float64), expression.NewUnresolvedColumn("a")), - expression.NewMult(expression.NewLiteral(float64(2.0), sql.Float64), expression.NewUnresolvedColumn("b")), + expression.NewMult(expression.NewLiteral("1.0", sql.LongText), expression.NewUnresolvedColumn("a")), + expression.NewMult(expression.NewLiteral("2.0", sql.LongText), expression.NewUnresolvedColumn("b")), ), ), }, From dbd4aedd0f90fec3a0649e8c6e49f2d82f26cb4d Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 7 Sep 2022 23:09:16 +0000 Subject: [PATCH 02/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/arithmetic.go | 2 +- sql/expression/arithmetic_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 1421961c4f..7d0d88d6a4 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -16,12 +16,12 @@ package expression import ( "fmt" - "github.com/shopspring/decimal" "reflect" "strings" "time" "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/shopspring/decimal" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 902c8f8179..8f8382ece5 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -16,10 +16,10 @@ package expression import ( "fmt" - "github.com/shopspring/decimal" "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 52e9ad1368571314e67374a5a60d2e471afc450b Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 13:17:42 -0700 Subject: [PATCH 03/21] use decimal type only if literal is given or column type is decimal --- enginetest/evaluation.go | 10 + enginetest/queries/queries.go | 4 +- enginetest/queries/variable_queries.go | 3 +- sql/decimal.go | 14 ++ sql/expression/arithmetic.go | 204 ++++++++++++++++-- .../function/aggregation/unary_agg_buffers.go | 47 ++-- sql/expression/function/ceil_round_floor.go | 7 + sql/type.go | 4 +- 8 files changed, 248 insertions(+), 45 deletions(-) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 3852fad4f4..0c7dee3353 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -16,6 +16,7 @@ package enginetest import ( "fmt" + "github.com/shopspring/decimal" "strings" "testing" "time" @@ -433,6 +434,15 @@ func checkResults( } } } + if strings.HasPrefix(upperQuery, "SELECT ") || strings.HasPrefix(upperQuery, "WITH ") { + for _, widenedRow := range widenedRows { + for i, val := range widenedRow { + if d, ok := val.(decimal.Decimal); ok { + widenedRow[i] = d.StringFixed(d.Exponent() * -1) + } + } + } + } // .Equal gives better error messages than .ElementsMatch, so use it when possible if orderBy || len(expected) <= 1 { diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index e4b5c88530..bdf9b7dd0c 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -1426,9 +1426,9 @@ var QueryTests = []QueryTest{ }, }, { - Query: "with recursive t (n) as (select sum(1) from dual union all select (2) from dual) select sum(n) from t;", + Query: "with recursive t (n) as (select sum(1) from dual union all select (2.00) from dual) select sum(n) from t;", Expected: []sql.Row{ - {float64(3)}, + {"3.00"}, }, }, { diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index 87c51219b6..5c35b77ee3 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -58,9 +58,8 @@ var VariableQueries = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "select @myvar, @@autocommit, @myvar2, @myvar3", - // TODO: unclear why the last var is getting a float type, should be an int Expected: []sql.Row{ - {1, 1, 0, 0.0}, + {1, 1, 0, 0}, }, }, }, diff --git a/sql/decimal.go b/sql/decimal.go index 651652a107..dbafe00791 100644 --- a/sql/decimal.go +++ b/sql/decimal.go @@ -15,9 +15,11 @@ package sql import ( + "encoding/hex" "fmt" "math/big" "reflect" + "strconv" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -200,6 +202,12 @@ func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, e res = decimal.NewFromFloat32(value) case float64: res = decimal.NewFromFloat(value) + case []uint8: + val, err := strconv.ParseUint(hex.EncodeToString(value), 16, 64) + if err != nil { + return decimal.NullDecimal{}, err + } + return t.ConvertToNullDecimal(val) case string: var err error res, err = decimal.NewFromString(value) @@ -228,6 +236,8 @@ func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, e return decimal.NullDecimal{}, nil } res = value.Decimal + case JSONDocument: + return t.ConvertToNullDecimal(value.Val) default: return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v) } @@ -242,6 +252,10 @@ func (t decimalType) BoundsCheck(v decimal.Decimal) (decimal.Decimal, error) { } // TODO add shortcut for common case // ex: certain num of bits fast tracks OK + r := v.StringFixed(v.Exponent() * -1) + l := t.exclusiveUpperBound.StringFixed(t.exclusiveUpperBound.Exponent() * -1) + if r == l { + } if !v.Abs().LessThan(t.exclusiveUpperBound) { return decimal.Decimal{}, ErrConvertToDecimalLimit.New() } diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 7d0d88d6a4..9678a1465c 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -170,13 +170,26 @@ func (a *Arithmetic) Type() sql.Type { // getArithmeticType returns column type if func (a *Arithmetic) getDecimalType() sql.Type { var resType sql.Type - var precision int - var scale int + var precision uint8 + var scale uint8 sql.Inspect(a, func(expr sql.Expression) bool { switch c := expr.(type) { - case *GetField: + case *SystemVar: resType = c.Type() return false + case *GetField: + if sql.IsDecimal(resType) { + resType = c.Type() + dt, _ := resType.(sql.DecimalType) + if dt.Precision() > (precision) { + precision = dt.Precision() + } + if dt.Scale() > scale { + scale = dt.Precision() + } + } else { + resType = sql.Float64 + } case *Literal: val, err := c.Eval(nil, nil) if err != nil { @@ -196,30 +209,30 @@ func (a *Arithmetic) getDecimalType() sql.Type { if s > scale { scale = s } - return true } return true }) - if resType == nil { - r, err := sql.CreateDecimalType(uint8(precision), uint8(scale)) - if err != nil { - return sql.Float64 + if sql.IsDecimal(resType) { + r, err := sql.CreateDecimalType(precision, scale) + if err == nil { + resType = r } - resType = r + } else if resType == nil { + return sql.Float64 } return resType } -func GetDecimalPrecisionAndScale(val string) (int, int) { +func GetDecimalPrecisionAndScale(val string) (uint8, uint8) { scale := 0 precScale := strings.Split(strings.TrimPrefix(val, "-"), ".") if len(precScale) != 1 { scale = len(precScale[1]) } precision := len((precScale)[0]) + scale - return precision, scale + return uint8(precision), uint8(scale) } func isInterval(expr sql.Expression) bool { @@ -337,18 +350,51 @@ func (a *Arithmetic) convertLeftRight(left interface{}, right interface{}) (inte func plus(lval, rval interface{}) (interface{}, error) { switch l := lval.(type) { + case uint8: + switch r := rval.(type) { + case uint8: + return l + r, nil + } + case int8: + switch r := rval.(type) { + case int8: + return l + r, nil + } + case uint16: + switch r := rval.(type) { + case uint16: + return l + r, nil + } + case int16: + switch r := rval.(type) { + case int16: + return l + r, nil + } + case uint32: + switch r := rval.(type) { + case uint32: + return l + r, nil + } + case int32: + switch r := rval.(type) { + case int32: + return l + r, nil + } case uint64: switch r := rval.(type) { case uint64: return l + r, nil } - case int64: switch r := rval.(type) { case int64: return l + r, nil } - + case float32: + switch r := rval.(type) { + case float32: + return l + r, nil + } case float64: switch r := rval.(type) { case float64: @@ -378,18 +424,51 @@ func plus(lval, rval interface{}) (interface{}, error) { func minus(lval, rval interface{}) (interface{}, error) { switch l := lval.(type) { + case uint8: + switch r := rval.(type) { + case uint8: + return l - r, nil + } + case int8: + switch r := rval.(type) { + case int8: + return l - r, nil + } + case uint16: + switch r := rval.(type) { + case uint16: + return l - r, nil + } + case int16: + switch r := rval.(type) { + case int16: + return l - r, nil + } + case uint32: + switch r := rval.(type) { + case uint32: + return l - r, nil + } + case int32: + switch r := rval.(type) { + case int32: + return l - r, nil + } case uint64: switch r := rval.(type) { case uint64: return l - r, nil } - case int64: switch r := rval.(type) { case int64: return l - r, nil } - + case float32: + switch r := rval.(type) { + case float32: + return l - r, nil + } case float64: switch r := rval.(type) { case float64: @@ -414,18 +493,51 @@ func minus(lval, rval interface{}) (interface{}, error) { func mult(lval, rval interface{}) (interface{}, error) { switch l := lval.(type) { + case uint8: + switch r := rval.(type) { + case uint8: + return l * r, nil + } + case int8: + switch r := rval.(type) { + case int8: + return l * r, nil + } + case uint16: + switch r := rval.(type) { + case uint16: + return l * r, nil + } + case int16: + switch r := rval.(type) { + case int16: + return l * r, nil + } + case uint32: + switch r := rval.(type) { + case uint32: + return l * r, nil + } + case int32: + switch r := rval.(type) { + case int32: + return l * r, nil + } case uint64: switch r := rval.(type) { case uint64: return l * r, nil } - case int64: switch r := rval.(type) { case int64: return l * r, nil } - + case float32: + switch r := rval.(type) { + case float32: + return l * r, nil + } case float64: switch r := rval.(type) { case float64: @@ -443,6 +555,54 @@ func mult(lval, rval interface{}) (interface{}, error) { func div(lval, rval interface{}) (interface{}, error) { switch l := lval.(type) { + case uint8: + switch r := rval.(type) { + case uint8: + if r == 0 { + return nil, nil + } + return l / r, nil + } + case int8: + switch r := rval.(type) { + case int8: + if r == 0 { + return nil, nil + } + return l / r, nil + } + case uint16: + switch r := rval.(type) { + case uint16: + if r == 0 { + return nil, nil + } + return l / r, nil + } + case int16: + switch r := rval.(type) { + case int16: + if r == 0 { + return nil, nil + } + return l / r, nil + } + case uint32: + switch r := rval.(type) { + case uint32: + if r == 0 { + return nil, nil + } + return l / r, nil + } + case int32: + switch r := rval.(type) { + case int32: + if r == 0 { + return nil, nil + } + return l / r, nil + } case uint64: switch r := rval.(type) { case uint64: @@ -451,7 +611,6 @@ func div(lval, rval interface{}) (interface{}, error) { } return l / r, nil } - case int64: switch r := rval.(type) { case int64: @@ -460,7 +619,14 @@ func div(lval, rval interface{}) (interface{}, error) { } return l / r, nil } - + case float32: + switch r := rval.(type) { + case float32: + if r == 0 { + return nil, nil + } + return l / r, nil + } case float64: switch r := rval.(type) { case float64: diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 53d13d618b..d38aad6f60 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -13,7 +13,7 @@ import ( type sumBuffer struct { isnil bool - sum interface{} + sum interface{} // sum is either decimal.Decimal or float64 expr sql.Expression } @@ -32,31 +32,16 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - if m.isnil { - m.sum = decimal.NewFromInt(0) - m.isnil = false - } - switch n := v.(type) { - case float64: - val, err := sql.Float64.Convert(n) - if err != nil { - val = float64(0) - } + case decimal.Decimal: if m.isnil { - m.sum = 0 + m.sum = decimal.NewFromInt(0) m.isnil = false } - sum, err := sql.Float64.Convert(m.sum) - if err != nil { - sum = float64(0) - } - m.sum = sum.(float64) + val.(float64) - case decimal.Decimal: if sum, ok := m.sum.(decimal.Decimal); ok { m.sum = sum.Add(n) } else { - m.sum = n + m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(n) } case string: p, s := expression.GetDecimalPrecisionAndScale(n) @@ -65,11 +50,33 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { if err != nil { val = decimal.NewFromInt(0) } + if m.isnil { + m.sum = decimal.NewFromInt(0) + m.isnil = false + } if sum, ok := m.sum.(decimal.Decimal); ok { + r := sum.StringFixed(sum.Exponent() * -1) + i := val.(decimal.Decimal).StringFixed(val.(decimal.Decimal).Exponent() * -1) + if r == i { + } m.sum = sum.Add(val.(decimal.Decimal)) } else { - m.sum = val.(decimal.Decimal) + m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(val.(decimal.Decimal)) + } + default: + val, err := sql.Float64.Convert(n) + if err != nil { + val = float64(0) + } + if m.isnil { + m.sum = 0 + m.isnil = false + } + sum, err := sql.Float64.Convert(m.sum) + if err != nil { + sum = float64(0) } + m.sum = sum.(float64) + val.(float64) } return nil diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index d2eaad2c4d..6617fb0659 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -17,6 +17,7 @@ package function import ( "encoding/hex" "fmt" + "github.com/shopspring/decimal" "math" "reflect" "strconv" @@ -98,6 +99,8 @@ func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return math.Ceil(num), nil case float32: return float32(math.Ceil(float64(num))), nil + case decimal.Decimal: + return num.Ceil(), nil default: return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) } @@ -176,6 +179,8 @@ func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return math.Floor(num), nil case float32: return float32(math.Floor(float64(num))), nil + case decimal.Decimal: + return num.Floor(), nil default: return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) } @@ -313,6 +318,8 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // 586 / 100 // 5.86 switch xNum := xVal.(type) { + case decimal.Decimal: + return xNum.Round(int32(dVal)), nil case float64: return math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal), nil case float32: diff --git a/sql/type.go b/sql/type.go index 46dec649df..4206de16e4 100644 --- a/sql/type.go +++ b/sql/type.go @@ -621,7 +621,7 @@ func IsNumber(t Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t Type) bool { - return t == Int8 || t == Int16 || t == Int32 || t == Int64 + return t == Int8 || t == Int16 || t == Int24 || t == Int32 || t == Int64 } // IsText checks if t is a CHAR, VARCHAR, TEXT, BINARY, VARBINARY, or BLOB (including TEXT and BLOB variants). @@ -681,7 +681,7 @@ func IsTuple(t Type) bool { // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t Type) bool { - return t == Uint8 || t == Uint16 || t == Uint32 || t == Uint64 + return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64 } // NumColumns returns the number of columns in a type. This is one for all From 504de7c2d117bae0c1e33a2d54be0b306fcfde2c Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 20:18:56 +0000 Subject: [PATCH 04/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/evaluation.go | 2 +- sql/expression/function/ceil_round_floor.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 0c7dee3353..ea4fcf5a5f 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -16,11 +16,11 @@ package enginetest import ( "fmt" - "github.com/shopspring/decimal" "strings" "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 6617fb0659..710d1e3ce0 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -17,11 +17,12 @@ package function import ( "encoding/hex" "fmt" - "github.com/shopspring/decimal" "math" "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" ) From caf1ae40eb95c26497aec36f4721cf9ab3910bca Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 14:31:11 -0700 Subject: [PATCH 05/21] rm decimal type changes --- sql/decimal.go | 19 +--- sql/decimal_test.go | 96 +++++++++---------- sql/expression/arithmetic_test.go | 92 ++++++++---------- .../function/aggregation/unary_aggs.og.go | 1 - sql/expression/function/ceil_round_floor.go | 8 -- 5 files changed, 90 insertions(+), 126 deletions(-) diff --git a/sql/decimal.go b/sql/decimal.go index dbafe00791..a1b343df57 100644 --- a/sql/decimal.go +++ b/sql/decimal.go @@ -15,16 +15,13 @@ package sql import ( - "encoding/hex" "fmt" - "math/big" - "reflect" - "strconv" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" + "math/big" + "reflect" ) const ( @@ -202,12 +199,6 @@ func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, e res = decimal.NewFromFloat32(value) case float64: res = decimal.NewFromFloat(value) - case []uint8: - val, err := strconv.ParseUint(hex.EncodeToString(value), 16, 64) - if err != nil { - return decimal.NullDecimal{}, err - } - return t.ConvertToNullDecimal(val) case string: var err error res, err = decimal.NewFromString(value) @@ -236,8 +227,6 @@ func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, e return decimal.NullDecimal{}, nil } res = value.Decimal - case JSONDocument: - return t.ConvertToNullDecimal(value.Val) default: return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v) } @@ -252,10 +241,6 @@ func (t decimalType) BoundsCheck(v decimal.Decimal) (decimal.Decimal, error) { } // TODO add shortcut for common case // ex: certain num of bits fast tracks OK - r := v.StringFixed(v.Exponent() * -1) - l := t.exclusiveUpperBound.StringFixed(t.exclusiveUpperBound.Exponent() * -1) - if r == l { - } if !v.Abs().LessThan(t.exclusiveUpperBound) { return decimal.Decimal{}, ErrConvertToDecimalLimit.New() } diff --git a/sql/decimal_test.go b/sql/decimal_test.go index 8e0ed764d9..26ccd13b65 100644 --- a/sql/decimal_test.go +++ b/sql/decimal_test.go @@ -213,54 +213,54 @@ func TestDecimalConvert(t *testing.T) { expectedVal interface{} expectedErr bool }{ - {1, 0, nil, nil, false}, - {1, 0, byte(0), "0", false}, - {1, 0, int8(3), "3", false}, - {1, 0, "-3.7e0", "-4", false}, - {1, 0, uint(4), "4", false}, - {1, 0, int16(9), "9", false}, - {1, 0, "0.00000000000000000003e20", "3", false}, - {1, 0, float64(-9.4), "-9", false}, - {1, 0, float32(9.5), "", true}, - {1, 0, int32(-10), "", true}, - - {1, 1, 0, "0.0", false}, - {1, 1, .01, "0.0", false}, - {1, 1, .1, "0.1", false}, - {1, 1, ".22", "0.2", false}, - {1, 1, .55, "0.6", false}, - {1, 1, "-.7863294659345624", "-0.8", false}, - {1, 1, "2634193746329327479.32030573792e-19", "0.3", false}, - {1, 1, 1, "", true}, - {1, 1, new(big.Rat).SetInt64(2), "", true}, - - {5, 0, 0, "0", false}, - {5, 0, 5000.2, "5000", false}, - {5, 0, "7742", "7742", false}, - {5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, - {5, 0, 99999, "99999", false}, - {5, 0, "0xf8e1", "63713", false}, - {5, 0, "0b1001110101100110", "40294", false}, - {5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, - {5, 0, 673927, "", true}, - - {10, 5, 0, "0.00000", false}, - {10, 5, "99999.999994", "99999.99999", false}, - {10, 5, "5.5729136e3", "5572.91360", false}, - {10, 5, "600e-2", "6.00000", false}, - {10, 5, new(big.Rat).SetFrac64(-22, 7), "-3.14286", false}, - {10, 5, 100000, "", true}, - {10, 5, "-99999.999995", "", true}, - - {65, 0, "99999999999999999999999999999999999999999999999999999999999999999", - "99999999999999999999999999999999999999999999999999999999999999999", false}, - {65, 0, "99999999999999999999999999999999999999999999999999999999999999999.1", - "99999999999999999999999999999999999999999999999999999999999999999", false}, - {65, 0, "99999999999999999999999999999999999999999999999999999999999999999.99", "", true}, - - {65, 12, "16976349273982359874209023948672021737840592720387475.2719128737543572927374503832837350563300243035038234972093785", - "16976349273982359874209023948672021737840592720387475.271912873754", false}, - {65, 12, "99999999999999999999999999999999999999999999999999999.9999999999999", "", true}, + //{1, 0, nil, nil, false}, + //{1, 0, byte(0), "0", false}, + //{1, 0, int8(3), "3", false}, + //{1, 0, "-3.7e0", "-4", false}, + //{1, 0, uint(4), "4", false}, + //{1, 0, int16(9), "9", false}, + //{1, 0, "0.00000000000000000003e20", "3", false}, + //{1, 0, float64(-9.4), "-9", false}, + //{1, 0, float32(9.5), "", true}, + //{1, 0, int32(-10), "", true}, + // + //{1, 1, 0, "0.0", false}, + //{1, 1, .01, "0.0", false}, + //{1, 1, .1, "0.1", false}, + //{1, 1, ".22", "0.2", false}, + //{1, 1, .55, "0.6", false}, + //{1, 1, "-.7863294659345624", "-0.8", false}, + //{1, 1, "2634193746329327479.32030573792e-19", "0.3", false}, + //{1, 1, 1, "", true}, + //{1, 1, new(big.Rat).SetInt64(2), "", true}, + // + //{5, 0, 0, "0", false}, + //{5, 0, 5000.2, "5000", false}, + //{5, 0, "7742", "7742", false}, + //{5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, + //{5, 0, 99999, "99999", false}, + //{5, 0, "0xf8e1", "63713", false}, + //{5, 0, "0b1001110101100110", "40294", false}, + //{5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, + //{5, 0, 673927, "", true}, + // + //{10, 5, 0, "0.00000", false}, + //{10, 5, "99999.999994", "99999.99999", false}, + //{10, 5, "5.5729136e3", "5572.91360", false}, + //{10, 5, "600e-2", "6.00000", false}, + //{10, 5, new(big.Rat).SetFrac64(-22, 7), "-3.14286", false}, + //{10, 5, 100000, "", true}, + //{10, 5, "-99999.999995", "", true}, + // + //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999", + // "99999999999999999999999999999999999999999999999999999999999999999", false}, + //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999.1", + // "99999999999999999999999999999999999999999999999999999999999999999", false}, + //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999.99", "", true}, + // + //{65, 12, "16976349273982359874209023948672021737840592720387475.2719128737543572927374503832837350563300243035038234972093785", + // "16976349273982359874209023948672021737840592720387475.271912873754", false}, + //{65, 12, "99999999999999999999999999999999999999999999999999999.9999999999999", "", true}, {20, 10, []byte{32}, nil, true}, {20, 10, time.Date(2019, 12, 12, 12, 12, 12, 0, time.UTC), nil, true}, diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 8f8382ece5..8c3bebb6d1 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -15,11 +15,10 @@ package expression import ( - "fmt" + "github.com/shopspring/decimal" "testing" "time" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,12 +29,12 @@ func TestPlus(t *testing.T) { var testCases = []struct { name string left, right float64 - expected string + expected float64 }{ - {"1 + 1", 1, 1, "2"}, - {"-1 + 1", -1, 1, "0"}, - {"0 + 0", 0, 0, "0"}, - {"0.14159 + 3.0", 0.14159, 3.0, "3.14159"}, + {"1 + 1", 1, 1, 2}, + {"-1 + 1", -1, 1, 0}, + {"0 + 0", 0, 0, 0}, + {"0.14159 + 3.0", 0.14159, 3.0, float64(0.14159) + float64(3)}, } for _, tt := range testCases { @@ -46,11 +45,7 @@ func TestPlus(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - if d, ok := result.(decimal.Decimal); ok { - require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) - } else { - require.Equal("0", fmt.Sprintf("%v", result)) - } + require.Equal(tt.expected, result) }) } @@ -58,7 +53,7 @@ func TestPlus(t *testing.T) { result, err := NewPlus(NewLiteral("2", sql.LongText), NewLiteral(3, sql.Float64)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal("5", fmt.Sprintf("%v", result)) + require.Equal(float64(5), result) } func TestPlusInterval(t *testing.T) { @@ -88,12 +83,12 @@ func TestMinus(t *testing.T) { var testCases = []struct { name string left, right float64 - expected string + expected float64 }{ - {"1 - 1", 1, 1, "0"}, - {"1 - 1", 1, 1, "0"}, - {"0 - 0", 0, 0, "0"}, - {"3.14159 - 3.0", 3.14159, 3.0, "0.14159"}, + {"1 - 1", 1, 1, 0}, + {"1 - 1", 1, 1, 0}, + {"0 - 0", 0, 0, 0}, + {"3.14159 - 3.0", 3.14159, 3.0, float64(3.14159) - float64(3.0)}, } for _, tt := range testCases { @@ -104,11 +99,7 @@ func TestMinus(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - if d, ok := result.(decimal.Decimal); ok { - require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) - } else { - require.Equal("0", fmt.Sprintf("%v", result)) - } + require.Equal(tt.expected, result) }) } @@ -116,7 +107,7 @@ func TestMinus(t *testing.T) { result, err := NewMinus(NewLiteral("10", sql.LongText), NewLiteral(10, sql.Int64)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal("0", fmt.Sprintf("%v", result)) + require.Equal(float64(0), result) } func TestMinusInterval(t *testing.T) { @@ -137,12 +128,12 @@ func TestMult(t *testing.T) { var testCases = []struct { name string left, right float64 - expected string + expected float64 }{ - {"1 * 1", 1, 1, "1"}, - {"-1 * 1", -1, 1, "-1"}, - {"0 * 0", 0, 0, "0"}, - {"3.14159 * 3.0", 3.14159, 3.1, "9.738929"}, + {"1 * 1", 1, 1, 1}, + {"-1 * 1", -1, 1, -1}, + {"0 * 0", 0, 0, 0}, + {"3.14159 * 3.0", 3.14159, 3.0, float64(3.14159) * float64(3.0)}, } for _, tt := range testCases { @@ -153,35 +144,31 @@ func TestMult(t *testing.T) { NewLiteral(tt.right, sql.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - if d, ok := result.(decimal.Decimal); ok { - require.Equal(tt.expected, d.StringFixed(d.Exponent()*-1)) - } else { - require.Equal("0", fmt.Sprintf("%v", result)) - } + require.Equal(tt.expected, result) }) } require := require.New(t) - result, err := NewMult(NewLiteral("10", sql.LongText), NewLiteral("3.0", sql.LongText)). + result, err := NewMult(NewLiteral("10", sql.LongText), NewLiteral("10", sql.LongText)). Eval(sql.NewEmptyContext(), sql.NewRow()) require.NoError(err) - require.Equal("30.0", result.(decimal.Decimal).StringFixed(result.(decimal.Decimal).Exponent()*-1)) + require.Equal(float64(100), result) } func TestDiv(t *testing.T) { var floatTestCases = []struct { name string left, right float64 - expected string + expected float64 null bool }{ - {"1 / 1", 1, 1, "1.0000", false}, - {"-1 / 1", -1, 1, "-1.0000", false}, - {"0 / 1234567890", 0, 12345677890, "0.0000", false}, - {"3.14159 / 3.0", 3.14159, 3.0, "1.047196667", false}, - {"1/0", 1, 0, "", true}, - {"-1/0", -1, 0, "", true}, - {"0/0", 0, 0, "", true}, + {"1 / 1", 1, 1, 1, false}, + {"-1 / 1", -1, 1, -1, false}, + {"0 / 1234567890", 0, 12345677890, 0, false}, + {"3.14159 / 3.0", 3.14159, 3.0, float64(3.14159) / float64(3.0), false}, + {"1/0", 1, 0, 0, true}, + {"-1/0", -1, 0, 0, true}, + {"0/0", 0, 0, 0, true}, } for _, tt := range floatTestCases { @@ -194,11 +181,7 @@ func TestDiv(t *testing.T) { if tt.null { assert.Equal(t, nil, result) } else { - if d, ok := result.(decimal.Decimal); ok { - require.Equal(t, tt.expected, d.StringFixed(d.Exponent()*-1)) - } else { - require.Equal(t, tt.expected, fmt.Sprintf("%v", result)) - } + assert.Equal(t, tt.expected, result) } }) } @@ -523,8 +506,8 @@ func TestUnaryMinus(t *testing.T) { {"uint64", uint64(1), sql.Uint64, int64(-1)}, {"float32", float32(1), sql.Float32, float32(-1)}, {"float64", float64(1), sql.Float64, float64(-1)}, - {"int text", "1", sql.LongText, float64(-1)}, - {"float text", "1.2", sql.LongText, float64(-1.2)}, + {"int text", "1", sql.LongText, "-1"}, + {"float text", "1.2", sql.LongText, "-1.2"}, {"nil", nil, sql.LongText, nil}, } @@ -533,7 +516,12 @@ func TestUnaryMinus(t *testing.T) { f := NewUnaryMinus(NewLiteral(tt.input, tt.typ)) result, err := f.Eval(sql.NewEmptyContext(), nil) require.NoError(t, err) - require.Equal(t, tt.expected, result) + if dt, ok := result.(decimal.Decimal); ok { + require.Equal(t, tt.expected, dt.StringFixed(dt.Exponent()*-1)) + } else { + require.Equal(t, tt.expected, result) + } + }) } } diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 58e6dcb6be..04a7cf3df4 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -589,7 +589,6 @@ func NewSum(e sql.Expression) *Sum { } func (a *Sum) Type() sql.Type { - // TODO the type depends on column if functions over table, or the input given (most likely to be DECIMAL with the longest precision and scaled input defines it) return sql.Float64 } diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 710d1e3ce0..d2eaad2c4d 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -21,8 +21,6 @@ import ( "reflect" "strconv" - "github.com/shopspring/decimal" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" ) @@ -100,8 +98,6 @@ func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return math.Ceil(num), nil case float32: return float32(math.Ceil(float64(num))), nil - case decimal.Decimal: - return num.Ceil(), nil default: return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) } @@ -180,8 +176,6 @@ func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return math.Floor(num), nil case float32: return float32(math.Floor(float64(num))), nil - case decimal.Decimal: - return num.Floor(), nil default: return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) } @@ -319,8 +313,6 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // 586 / 100 // 5.86 switch xNum := xVal.(type) { - case decimal.Decimal: - return xNum.Round(int32(dVal)), nil case float64: return math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal), nil case float32: From 8a1e9982404f817727508f2120b9c4b54df325ba Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 21:32:25 +0000 Subject: [PATCH 06/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/decimal.go | 5 +++-- sql/expression/arithmetic_test.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/decimal.go b/sql/decimal.go index a1b343df57..651652a107 100644 --- a/sql/decimal.go +++ b/sql/decimal.go @@ -16,12 +16,13 @@ package sql import ( "fmt" + "math/big" + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" - "math/big" - "reflect" ) const ( diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 8c3bebb6d1..cda8cd2447 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -15,10 +15,10 @@ package expression import ( - "github.com/shopspring/decimal" "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 8fe4762a33374edbdc98b387ad338c87609bb791 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 15:05:10 -0700 Subject: [PATCH 07/21] fix tests --- sql/decimal_test.go | 96 +++++++++---------- sql/expression/arithmetic.go | 13 ++- .../function/aggregation/sum_test.go | 31 ++++-- 3 files changed, 80 insertions(+), 60 deletions(-) diff --git a/sql/decimal_test.go b/sql/decimal_test.go index 26ccd13b65..8e0ed764d9 100644 --- a/sql/decimal_test.go +++ b/sql/decimal_test.go @@ -213,54 +213,54 @@ func TestDecimalConvert(t *testing.T) { expectedVal interface{} expectedErr bool }{ - //{1, 0, nil, nil, false}, - //{1, 0, byte(0), "0", false}, - //{1, 0, int8(3), "3", false}, - //{1, 0, "-3.7e0", "-4", false}, - //{1, 0, uint(4), "4", false}, - //{1, 0, int16(9), "9", false}, - //{1, 0, "0.00000000000000000003e20", "3", false}, - //{1, 0, float64(-9.4), "-9", false}, - //{1, 0, float32(9.5), "", true}, - //{1, 0, int32(-10), "", true}, - // - //{1, 1, 0, "0.0", false}, - //{1, 1, .01, "0.0", false}, - //{1, 1, .1, "0.1", false}, - //{1, 1, ".22", "0.2", false}, - //{1, 1, .55, "0.6", false}, - //{1, 1, "-.7863294659345624", "-0.8", false}, - //{1, 1, "2634193746329327479.32030573792e-19", "0.3", false}, - //{1, 1, 1, "", true}, - //{1, 1, new(big.Rat).SetInt64(2), "", true}, - // - //{5, 0, 0, "0", false}, - //{5, 0, 5000.2, "5000", false}, - //{5, 0, "7742", "7742", false}, - //{5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, - //{5, 0, 99999, "99999", false}, - //{5, 0, "0xf8e1", "63713", false}, - //{5, 0, "0b1001110101100110", "40294", false}, - //{5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, - //{5, 0, 673927, "", true}, - // - //{10, 5, 0, "0.00000", false}, - //{10, 5, "99999.999994", "99999.99999", false}, - //{10, 5, "5.5729136e3", "5572.91360", false}, - //{10, 5, "600e-2", "6.00000", false}, - //{10, 5, new(big.Rat).SetFrac64(-22, 7), "-3.14286", false}, - //{10, 5, 100000, "", true}, - //{10, 5, "-99999.999995", "", true}, - // - //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999", - // "99999999999999999999999999999999999999999999999999999999999999999", false}, - //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999.1", - // "99999999999999999999999999999999999999999999999999999999999999999", false}, - //{65, 0, "99999999999999999999999999999999999999999999999999999999999999999.99", "", true}, - // - //{65, 12, "16976349273982359874209023948672021737840592720387475.2719128737543572927374503832837350563300243035038234972093785", - // "16976349273982359874209023948672021737840592720387475.271912873754", false}, - //{65, 12, "99999999999999999999999999999999999999999999999999999.9999999999999", "", true}, + {1, 0, nil, nil, false}, + {1, 0, byte(0), "0", false}, + {1, 0, int8(3), "3", false}, + {1, 0, "-3.7e0", "-4", false}, + {1, 0, uint(4), "4", false}, + {1, 0, int16(9), "9", false}, + {1, 0, "0.00000000000000000003e20", "3", false}, + {1, 0, float64(-9.4), "-9", false}, + {1, 0, float32(9.5), "", true}, + {1, 0, int32(-10), "", true}, + + {1, 1, 0, "0.0", false}, + {1, 1, .01, "0.0", false}, + {1, 1, .1, "0.1", false}, + {1, 1, ".22", "0.2", false}, + {1, 1, .55, "0.6", false}, + {1, 1, "-.7863294659345624", "-0.8", false}, + {1, 1, "2634193746329327479.32030573792e-19", "0.3", false}, + {1, 1, 1, "", true}, + {1, 1, new(big.Rat).SetInt64(2), "", true}, + + {5, 0, 0, "0", false}, + {5, 0, 5000.2, "5000", false}, + {5, 0, "7742", "7742", false}, + {5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, + {5, 0, 99999, "99999", false}, + {5, 0, "0xf8e1", "63713", false}, + {5, 0, "0b1001110101100110", "40294", false}, + {5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, + {5, 0, 673927, "", true}, + + {10, 5, 0, "0.00000", false}, + {10, 5, "99999.999994", "99999.99999", false}, + {10, 5, "5.5729136e3", "5572.91360", false}, + {10, 5, "600e-2", "6.00000", false}, + {10, 5, new(big.Rat).SetFrac64(-22, 7), "-3.14286", false}, + {10, 5, 100000, "", true}, + {10, 5, "-99999.999995", "", true}, + + {65, 0, "99999999999999999999999999999999999999999999999999999999999999999", + "99999999999999999999999999999999999999999999999999999999999999999", false}, + {65, 0, "99999999999999999999999999999999999999999999999999999999999999999.1", + "99999999999999999999999999999999999999999999999999999999999999999", false}, + {65, 0, "99999999999999999999999999999999999999999999999999999999999999999.99", "", true}, + + {65, 12, "16976349273982359874209023948672021737840592720387475.2719128737543572927374503832837350563300243035038234972093785", + "16976349273982359874209023948672021737840592720387475.271912873754", false}, + {65, 12, "99999999999999999999999999999999999999999999999999999.9999999999999", "", true}, {20, 10, []byte{32}, nil, true}, {20, 10, time.Date(2019, 12, 12, 12, 12, 12, 0, time.UTC), nil, true}, diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 9678a1465c..a7e91ccf0a 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -152,7 +152,7 @@ func (a *Arithmetic) Type() sql.Type { return sql.Int64 } - return a.getDecimalType() + return a.getArithmeticTypeFromExpr() case sqlparser.ShiftLeftStr, sqlparser.ShiftRightStr: return sql.Uint64 @@ -164,11 +164,15 @@ func (a *Arithmetic) Type() sql.Type { return sql.Int64 } - return a.getDecimalType() + return a.getArithmeticTypeFromExpr() } -// getArithmeticType returns column type if -func (a *Arithmetic) getDecimalType() sql.Type { +// getArithmeticTypeFromExpr returns a type that left and right values to be converted into. +// If there is system variable, return type should be the type of that system variable. +// For any non-DECIMAL column type, it will use default sql.Float64 type. +// For DECIMAL column type, or any Literal values, the return type will the DECIMAL type with +// the highest precision and scale calculated out of all Literals and DECIMAL column type definition. +func (a *Arithmetic) getArithmeticTypeFromExpr() sql.Type { var resType sql.Type var precision uint8 var scale uint8 @@ -225,6 +229,7 @@ func (a *Arithmetic) getDecimalType() sql.Type { return resType } +// GetDecimalPrecisionAndScale returns precision and scale for given string formatted float/double number. func GetDecimalPrecisionAndScale(val string) (uint8, uint8) { scale := 0 precScale := strings.Split(strings.TrimPrefix(val, "-"), ".") diff --git a/sql/expression/function/aggregation/sum_test.go b/sql/expression/function/aggregation/sum_test.go index 967c8eb08b..f16f3c9ff0 100644 --- a/sql/expression/function/aggregation/sum_test.go +++ b/sql/expression/function/aggregation/sum_test.go @@ -15,6 +15,7 @@ package aggregation import ( + "github.com/shopspring/decimal" "testing" "github.com/stretchr/testify/require" @@ -34,17 +35,17 @@ func TestSum(t *testing.T) { { "string int values", []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, - float64(10), + "10", }, { "string float values", []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, - float64(10.5), + "10.5", }, { "string non-int values", []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, - float64(0), + "0", }, { "float values", @@ -85,7 +86,11 @@ func TestSum(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if dt, ok := result.(decimal.Decimal); ok { + require.Equal(tt.expected, dt.StringFixed(dt.Exponent()*-1)) + } else { + require.Equal(tt.expected, result) + } }) } } @@ -107,17 +112,23 @@ func TestSumWithDistinct(t *testing.T) { { "string int values", []sql.Row{{"1"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, - float64(10), + "10", }, + // TODO : DISTINCT returns incorrect result, it currently returns 11.00 + //{ + // "string int values", + // []sql.Row{{"1.00"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, + // "10", + //}, { "string float values", []sql.Row{{"1.5"}, {"1.5"}, {"1.5"}, {"1.5"}, {"2"}, {"3"}, {"4"}}, - float64(10.5), + "10.5", }, { "string non-int values", []sql.Row{{"a"}, {"b"}, {"b"}, {"c"}, {"c"}, {"d"}}, - float64(0), + "0", }, { "float values", @@ -158,7 +169,11 @@ func TestSumWithDistinct(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - require.Equal(tt.expected, result) + if dt, ok := result.(decimal.Decimal); ok { + require.Equal(tt.expected, dt.StringFixed(dt.Exponent()*-1)) + } else { + require.Equal(tt.expected, result) + } }) } } From 7a358d8ae95f28d88605ec4fa67263ba10a79f75 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 22:06:17 +0000 Subject: [PATCH 08/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/sum_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/function/aggregation/sum_test.go b/sql/expression/function/aggregation/sum_test.go index f16f3c9ff0..60e3133120 100644 --- a/sql/expression/function/aggregation/sum_test.go +++ b/sql/expression/function/aggregation/sum_test.go @@ -15,9 +15,9 @@ package aggregation import ( - "github.com/shopspring/decimal" "testing" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" From 7cbb691c0b8bf80b33c807c28c068d3d121149b7 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 15:14:21 -0700 Subject: [PATCH 09/21] sum type is either decimal or float --- sql/expression/function/aggregation/unary_aggs.og.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 04a7cf3df4..fd10db1fe0 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -589,7 +589,7 @@ func NewSum(e sql.Expression) *Sum { } func (a *Sum) Type() sql.Type { - return sql.Float64 + return a.Child.Type() } func (a *Sum) IsNullable() bool { From 37308fe821589fd79b017257ab2f61da4af40d24 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 15:35:19 -0700 Subject: [PATCH 10/21] fix tests --- enginetest/queries/column_alias_queries.go | 8 +++---- enginetest/queries/queries.go | 28 +++++++++++----------- enginetest/queries/script_queries.go | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/enginetest/queries/column_alias_queries.go b/enginetest/queries/column_alias_queries.go index c47f63664e..bf2e2eb825 100644 --- a/enginetest/queries/column_alias_queries.go +++ b/enginetest/queries/column_alias_queries.go @@ -79,7 +79,7 @@ var ColumnAliasQueries = []QueryTest{ }, { Name: "COL2", - Type: sql.Float64, + Type: sql.Int64, }, }, // TODO: SUM should be integer typed for integers @@ -98,7 +98,7 @@ var ColumnAliasQueries = []QueryTest{ }, { Name: "COL2", - Type: sql.Float64, + Type: sql.Int64, }, }, Expected: []sql.Row{ @@ -116,7 +116,7 @@ var ColumnAliasQueries = []QueryTest{ }, { Name: "coL2", - Type: sql.Float64, + Type: sql.Int64, }, }, Expected: []sql.Row{ @@ -134,7 +134,7 @@ var ColumnAliasQueries = []QueryTest{ }, { Name: "TimeStamp", - Type: sql.Float64, + Type: sql.Int64, }, }, Expected: []sql.Row{ diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index af48cebfca..8d70bda49c 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -490,15 +490,15 @@ var QueryTests = []QueryTest{ { Query: "SELECT pk DIV 2, SUM(c3) + sum(c3) as sum FROM one_pk GROUP BY 1 ORDER BY 1", Expected: []sql.Row{ - {int64(0), float64(28)}, - {int64(1), float64(108)}, + {int64(0), int64(28)}, + {int64(1), int64(108)}, }, }, { Query: "SELECT pk DIV 2, SUM(c3) + min(c3) as sum_and_min FROM one_pk GROUP BY 1 ORDER BY 1", Expected: []sql.Row{ - {int64(0), float64(16)}, - {int64(1), float64(76)}, + {int64(0), int64(16)}, + {int64(1), int64(76)}, }, ExpectedColumns: sql.Schema{ { @@ -507,15 +507,15 @@ var QueryTests = []QueryTest{ }, { Name: "sum_and_min", - Type: sql.Float64, + Type: sql.Int64, }, }, }, { Query: "SELECT pk DIV 2, SUM(`c3`) + min( c3 ) FROM one_pk GROUP BY 1 ORDER BY 1", Expected: []sql.Row{ - {int64(0), float64(16)}, - {int64(1), float64(76)}, + {int64(0), int64(16)}, + {int64(1), int64(76)}, }, ExpectedColumns: sql.Schema{ { @@ -524,7 +524,7 @@ var QueryTests = []QueryTest{ }, { Name: "SUM(`c3`) + min( c3 )", - Type: sql.Float64, + Type: sql.Int64, }, }, }, @@ -3775,9 +3775,9 @@ var QueryTests = []QueryTest{ { Query: "SELECT SUM(i) + 1, i FROM mytable GROUP BY i ORDER BY i", Expected: []sql.Row{ - {float64(2), int64(1)}, - {float64(3), int64(2)}, - {float64(4), int64(3)}, + {int64(2), int64(1)}, + {int64(3), int64(2)}, + {int64(4), int64(3)}, }, }, { @@ -6079,9 +6079,9 @@ var QueryTests = []QueryTest{ { Query: "select sum(x.i) + y.i from mytable as x, mytable as y where x.i = y.i GROUP BY x.i", Expected: []sql.Row{ - {float64(2)}, - {float64(4)}, - {float64(6)}, + {int64(2)}, + {int64(4)}, + {int64(6)}, }, }, { diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 580f8380ac..ae675b8aa5 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -1256,7 +1256,7 @@ var ScriptTests = []ScriptTest{ }, { Query: "SELECT SUM( DISTINCT + col1 ) * - 22 - - ( - COUNT( * ) ) col0 FROM tab1 AS cor0", - Expected: []sql.Row{{float64(-1455)}}, + Expected: []sql.Row{{int64(-1455)}}, }, { Query: "SELECT MIN (DISTINCT col1) from tab1 GROUP BY col0 ORDER BY col0", From 661569839bf5c34ce735c777aaa2b0329a2b86a4 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 15:47:11 -0700 Subject: [PATCH 11/21] fix tests --- sql/analyzer/aggregations_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/analyzer/aggregations_test.go b/sql/analyzer/aggregations_test.go index d6e75eacc1..3e7a4fec1c 100644 --- a/sql/analyzer/aggregations_test.go +++ b/sql/analyzer/aggregations_test.go @@ -62,7 +62,7 @@ func TestFlattenAggregationExprs(t *testing.T) { expected: plan.NewProject( []sql.Expression{ expression.NewArithmetic( - expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewGetField(0, sql.Int64, "SUM(foo.a)", false), expression.NewLiteral(int64(1), sql.Int64), "+", ), @@ -102,7 +102,7 @@ func TestFlattenAggregationExprs(t *testing.T) { []sql.Expression{ expression.NewAlias("x", expression.NewArithmetic( - expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewGetField(0, sql.Int64, "SUM(foo.a)", false), expression.NewLiteral(int64(1), sql.Int64), "+", )), @@ -144,7 +144,7 @@ func TestFlattenAggregationExprs(t *testing.T) { expected: plan.NewProject( []sql.Expression{ expression.NewArithmetic( - expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewGetField(0, sql.Int64, "SUM(foo.a)", false), expression.NewGetField(1, sql.Int64, "COUNT(foo.a)", false), "/", ), @@ -189,7 +189,7 @@ func TestFlattenAggregationExprs(t *testing.T) { expected: plan.NewProject( []sql.Expression{ expression.NewArithmetic( - expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewGetField(0, sql.Int64, "SUM(foo.a)", false), expression.NewGetFieldWithTable(1, sql.Int64, "bar", "a", false), "+", ), From f5fa7bfdf3eff0ffae803f735cfffe9393c210bc Mon Sep 17 00:00:00 2001 From: jennifersp Date: Thu, 8 Sep 2022 15:54:00 -0700 Subject: [PATCH 12/21] add sum(decimal) test --- enginetest/queries/script_queries.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index ae675b8aa5..e8133132db 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -2154,6 +2154,25 @@ var ScriptTests = []ScriptTest{ }, }, }, + { + Name: "sum on DECIMAL type column returns the same type result", + SetUpScript: []string{ + "create table decimal_table (id int, val decimal(18,16));", + "insert into decimal_table values (1,-2.5633000000000384);", + "insert into decimal_table values (2,2.5633000000000370);", + "insert into decimal_table values (3,0.0000000000000004);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT val FROM decimal_table;", + Expected: []sql.Row{{"-2.5633000000000384"}, {"2.5633000000000370"}, {"0.0000000000000004"}}, + }, + { + Query: "SELECT sum(val) FROM decimal_table;", + Expected: []sql.Row{{"-0.0000000000000010"}}, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ From b58348816f4d31c49faa5635dad5005841a44199 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 11:41:32 -0700 Subject: [PATCH 13/21] add avg() --- enginetest/evaluation.go | 4 ++ enginetest/queries/queries.go | 2 +- enginetest/queries/script_queries.go | 25 ++++++- .../function/aggregation/sum_test.go | 1 + .../function/aggregation/unary_agg_buffers.go | 69 ++++++++++++++----- .../function/aggregation/unary_aggs.go | 2 - .../function/aggregation/unary_aggs.og.go | 3 +- 7 files changed, 82 insertions(+), 24 deletions(-) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index ea4fcf5a5f..547d65bd10 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -434,6 +434,10 @@ func checkResults( } } } + + // The result from SELECT or WITH queries can be decimal.Decimal type. + // The exact expected value cannot be defined in enginetests, so convert the result to string format, + // which is the value we get on sql shell. if strings.HasPrefix(upperQuery, "SELECT ") || strings.HasPrefix(upperQuery, "WITH ") { for _, widenedRow := range widenedRows { for i, val := range widenedRow { diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 8d70bda49c..1106caf36c 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -3921,7 +3921,7 @@ var QueryTests = []QueryTest{ { Query: `SELECT AVG(23.222000)`, Expected: []sql.Row{ - {float64(23.222)}, + {"23.2220000000"}, }, }, { diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index e8133132db..5580e94f29 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -2155,7 +2155,7 @@ var ScriptTests = []ScriptTest{ }, }, { - Name: "sum on DECIMAL type column returns the same type result", + Name: "sum() and avg() on DECIMAL type column returns the DECIMAL type result", SetUpScript: []string{ "create table decimal_table (id int, val decimal(18,16));", "insert into decimal_table values (1,-2.5633000000000384);", @@ -2171,6 +2171,29 @@ var ScriptTests = []ScriptTest{ Query: "SELECT sum(val) FROM decimal_table;", Expected: []sql.Row{{"-0.0000000000000010"}}, }, + { + Query: "SELECT avg(val) FROM decimal_table;", + Expected: []sql.Row{{"-0.00000000000000033333"}}, + }, + }, + }, + { + Name: "sum() and avg() on non-DECIMAL type column returns the DOUBLE type result", + SetUpScript: []string{ + "create table decimal_table (id int, val1 double, val2 float);", + "insert into decimal_table values (1,-2.5633000000000384, 2.3);", + "insert into decimal_table values (2,2.5633000000000370, 2.4);", + "insert into decimal_table values (3,0.0000000000000004, 5.3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT sum(id), sum(val1), sum(val2) FROM decimal_table;", + Expected: []sql.Row{{float64(6), -9.322676295501879e-16, 10.000000238418579}}, + }, + { + Query: "SELECT avg(id), avg(val1), avg(val2) FROM decimal_table;", + Expected: []sql.Row{{float64(2), -3.107558765167293e-16, 3.333333412806193}}, + }, }, }, } diff --git a/sql/expression/function/aggregation/sum_test.go b/sql/expression/function/aggregation/sum_test.go index 60e3133120..4832f9b29b 100644 --- a/sql/expression/function/aggregation/sum_test.go +++ b/sql/expression/function/aggregation/sum_test.go @@ -115,6 +115,7 @@ func TestSumWithDistinct(t *testing.T) { "10", }, // TODO : DISTINCT returns incorrect result, it currently returns 11.00 + // https://github.com/dolthub/dolt/issues/4298 //{ // "string int values", // []sql.Row{{"1.00"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index d38aad6f60..9a0ccfe3d7 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -45,7 +45,7 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { } case string: p, s := expression.GetDecimalPrecisionAndScale(n) - dt, err := sql.CreateDecimalType(uint8(p), uint8(s)) + dt, err := sql.CreateDecimalType(p, s) val, err := dt.Convert(v) if err != nil { val = decimal.NewFromInt(0) @@ -55,10 +55,6 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { m.isnil = false } if sum, ok := m.sum.(decimal.Decimal); ok { - r := sum.StringFixed(sum.Exponent() * -1) - i := val.(decimal.Decimal).StringFixed(val.(decimal.Decimal).Exponent() * -1) - if r == i { - } m.sum = sum.Add(val.(decimal.Decimal)) } else { m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(val.(decimal.Decimal)) @@ -136,7 +132,7 @@ func (l *lastBuffer) Dispose() { } type avgBuffer struct { - sum float64 + sum interface{} // sum is either decimal.Decimal or float64 rows int64 expr sql.Expression } @@ -161,12 +157,36 @@ func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } - v, err = sql.Float64.Convert(v) - if err != nil { - v = float64(0) + switch n := v.(type) { + case decimal.Decimal: + if sum, ok := a.sum.(decimal.Decimal); ok { + a.sum = sum.Add(n) + } else { + a.sum = decimal.NewFromFloat(a.sum.(float64)).Add(n) + } + case string: + p, s := expression.GetDecimalPrecisionAndScale(n) + dt, err := sql.CreateDecimalType(p, s) + val, err := dt.Convert(v) + if err != nil { + val = decimal.NewFromInt(0) + } + if sum, ok := a.sum.(decimal.Decimal); ok { + a.sum = sum.Add(val.(decimal.Decimal)) + } else { + a.sum = decimal.NewFromFloat(a.sum.(float64)).Add(val.(decimal.Decimal)) + } + default: + val, err := sql.Float64.Convert(n) + if err != nil { + val = float64(0) + } + sum, err := sql.Float64.Convert(a.sum) + if err != nil { + sum = float64(0) + } + a.sum = sum.(float64) + val.(float64) } - - a.sum += v.(float64) a.rows += 1 return nil @@ -175,15 +195,28 @@ func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { // Eval implements the AggregationBuffer interface. func (a *avgBuffer) Eval(ctx *sql.Context) (interface{}, error) { // This case is triggered when no rows exist. - if a.sum == 0 && a.rows == 0 { - return nil, nil - } + switch s := a.sum.(type) { + case float64: + if a.sum == 0 && a.rows == 0 { + return nil, nil + } - if a.rows == 0 { - return float64(0), nil - } + if a.rows == 0 { + return float64(0), nil + } - return a.sum / float64(a.rows), nil + return s / float64(a.rows), nil + case decimal.Decimal: + if s.IsZero() && a.rows == 0 { + return nil, nil + } + if a.rows == 0 { + return decimal.NewFromInt(0), nil + } + scale := (s.Exponent() * -1) + 4 + return s.DivRound(decimal.NewFromInt(a.rows), scale), nil + } + return nil, nil } // Dispose implements the Disposable interface. diff --git a/sql/expression/function/aggregation/unary_aggs.go b/sql/expression/function/aggregation/unary_aggs.go index b49ecb182c..a23aae473d 100644 --- a/sql/expression/function/aggregation/unary_aggs.go +++ b/sql/expression/function/aggregation/unary_aggs.go @@ -8,7 +8,6 @@ var UnaryAggDefs support.GenDefs = []support.AggDef{ // alphabetically sorted { Name: "Avg", Desc: "returns the average value of expr in all rows.", - RetType: "sql.Float64", Nullable: true, }, { @@ -56,7 +55,6 @@ var UnaryAggDefs support.GenDefs = []support.AggDef{ // alphabetically sorted { Name: "Sum", Desc: "returns the sum of expr in all rows", - RetType: "sql.Float64", Nullable: false, }, } diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index fd10db1fe0..8c20320c84 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -4,7 +4,6 @@ package aggregation import ( "fmt" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" @@ -29,7 +28,7 @@ func NewAvg(e sql.Expression) *Avg { } func (a *Avg) Type() sql.Type { - return sql.Float64 + return a.Child.Type() } func (a *Avg) IsNullable() bool { From 88455b06ea5c93fa83184675ba257e0eb6568b6f Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 18:43:09 +0000 Subject: [PATCH 14/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/function/aggregation/unary_aggs.og.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 8c20320c84..c40b5692fa 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -4,6 +4,7 @@ package aggregation import ( "fmt" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" From 76fd05525d65e3f9cf724e7076dfafd5e80ea71d Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 13:44:20 -0700 Subject: [PATCH 15/21] fix --- enginetest/queries/queries.go | 6 ++++++ sql/expression/convert.go | 3 +++ .../function/aggregation/unary_agg_buffers.go | 14 +------------- sql/parse/parse.go | 11 ++++++++++- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 1106caf36c..e8c6659dab 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -3924,6 +3924,12 @@ var QueryTests = []QueryTest{ {"23.2220000000"}, }, }, + { + Query: `SELECT AVG("23.222000")`, + Expected: []sql.Row{ + {23.222}, + }, + }, { Query: `SELECT DATABASE()`, Expected: []sql.Row{ diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 9876854cf6..f9ccec03a7 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -166,6 +166,9 @@ func convertValue(val interface{}, castTo string, originType sql.Type) (interfac } return b, nil case ConvertToChar, ConvertToNChar: + if sql.IsDecimal(originType) { + return val, nil + } s, err := sql.LongText.Convert(val) if err != nil { return nil, nil diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 9a0ccfe3d7..442c110135 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -164,18 +164,6 @@ func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { } else { a.sum = decimal.NewFromFloat(a.sum.(float64)).Add(n) } - case string: - p, s := expression.GetDecimalPrecisionAndScale(n) - dt, err := sql.CreateDecimalType(p, s) - val, err := dt.Convert(v) - if err != nil { - val = decimal.NewFromInt(0) - } - if sum, ok := a.sum.(decimal.Decimal); ok { - a.sum = sum.Add(val.(decimal.Decimal)) - } else { - a.sum = decimal.NewFromFloat(a.sum.(float64)).Add(val.(decimal.Decimal)) - } default: val, err := sql.Float64.Convert(n) if err != nil { @@ -197,7 +185,7 @@ func (a *avgBuffer) Eval(ctx *sql.Context) (interface{}, error) { // This case is triggered when no rows exist. switch s := a.sum.(type) { case float64: - if a.sum == 0 && a.rows == 0 { + if s == 0 && a.rows == 0 { return nil, nil } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index bb2a283345..3d21e461f0 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -3427,7 +3427,16 @@ func convertVal(ctx *sql.Context, v *sqlparser.SQLVal) (sql.Expression, error) { ogVal := string(v.Val) floatVal := fmt.Sprintf("%v", val) if len(ogVal) >= len(floatVal) && ogVal != floatVal { - return expression.NewLiteral(string(v.Val), sql.CreateLongText(ctx.GetCollation())), nil + p, s := expression.GetDecimalPrecisionAndScale(ogVal) + dt, err := sql.CreateDecimalType(p, s) + if err != nil { + return expression.NewLiteral(string(v.Val), sql.CreateLongText(ctx.GetCollation())), nil + } + dVal, err := dt.Convert(ogVal) + if err != nil { + return expression.NewLiteral(string(v.Val), sql.CreateLongText(ctx.GetCollation())), nil + } + return expression.NewLiteral(dVal, dt), nil } } From 1a8de573dead2aac74283ca4a8f80ff2895a0eb0 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 13:53:13 -0700 Subject: [PATCH 16/21] float to int --- sql/analyzer/resolve_having_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/analyzer/resolve_having_test.go b/sql/analyzer/resolve_having_test.go index 1b76ede420..7e47dcf8a3 100644 --- a/sql/analyzer/resolve_having_test.go +++ b/sql/analyzer/resolve_having_test.go @@ -44,7 +44,7 @@ func TestResolveHaving(t *testing.T) { ), expected: plan.NewHaving( expression.NewGreaterThan( - expression.NewGetField(0, sql.Float64, "x", true), + expression.NewGetField(0, sql.Int64, "x", true), expression.NewLiteral(int64(5), sql.Int64), ), plan.NewGroupBy( @@ -75,7 +75,7 @@ func TestResolveHaving(t *testing.T) { ), expected: plan.NewHaving( expression.NewGreaterThan( - expression.NewGetField(0, sql.Float64, "x", true), + expression.NewGetField(0, sql.Int64, "x", true), expression.NewLiteral(int64(5), sql.Int64), ), plan.NewGroupBy( @@ -106,7 +106,7 @@ func TestResolveHaving(t *testing.T) { ), expected: plan.NewProject( []sql.Expression{ - expression.NewGetField(0, sql.Float64, "x", true), + expression.NewGetField(0, sql.Int64, "x", true), expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), }, plan.NewHaving( From 115aaad703d8aab10d4e1c7034463476cbbd553c Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 14:14:33 -0700 Subject: [PATCH 17/21] fix parse_test --- sql/parse/parse_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index b90a089ce5..858b81a72d 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -16,6 +16,7 @@ package parse import ( "fmt" + "github.com/shopspring/decimal" "math" "sort" "testing" @@ -1706,7 +1707,7 @@ CREATE TABLE t2 &expression.DefaultColumn{}, }}), false, []string{"col1", "col2"}, []sql.Expression{}, false), `INSERT INTO test (decimal_col) VALUES (11981.5923291839784651)`: plan.NewInsertInto(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("test", ""), plan.NewValues([][]sql.Expression{{ - expression.NewLiteral("11981.5923291839784651", sql.LongText), + expression.NewLiteral(decimal.RequireFromString("11981.5923291839784651"), sql.MustCreateDecimalType(21, 16)), }}), false, []string{"decimal_col"}, []sql.Expression{}, false), `INSERT INTO test (decimal_col) VALUES (119815923291839784651.11981592329183978465111981592329183978465144)`: plan.NewInsertInto(sql.UnresolvedDatabase(""), plan.NewUnresolvedTable("test", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("119815923291839784651.11981592329183978465111981592329183978465144", sql.LongText), @@ -2235,8 +2236,8 @@ CREATE TABLE t2 []sql.Expression{ expression.NewAlias("1.0 * a + 2.0 * b", expression.NewPlus( - expression.NewMult(expression.NewLiteral("1.0", sql.LongText), expression.NewUnresolvedColumn("a")), - expression.NewMult(expression.NewLiteral("2.0", sql.LongText), expression.NewUnresolvedColumn("b")), + expression.NewMult(expression.NewLiteral(decimal.RequireFromString("1.0"), sql.MustCreateDecimalType(2, 1)), expression.NewUnresolvedColumn("a")), + expression.NewMult(expression.NewLiteral(decimal.RequireFromString("2.0"), sql.MustCreateDecimalType(2, 1)), expression.NewUnresolvedColumn("b")), ), ), }, From 0a8852c1f2c4245e5ee0ac535fedab1a75611299 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 21:15:33 +0000 Subject: [PATCH 18/21] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/parse/parse_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 858b81a72d..d3d0a97505 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -16,7 +16,6 @@ package parse import ( "fmt" - "github.com/shopspring/decimal" "math" "sort" "testing" @@ -25,6 +24,7 @@ import ( "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/pmezard/go-difflib/difflib" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/src-d/go-errors.v1" From aaff56bbb62abfb3d0f8ac7d6d7d071177e6af07 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 12 Sep 2022 17:04:19 -0700 Subject: [PATCH 19/21] simplify avg update --- enginetest/queries/script_queries.go | 12 ++-- .../function/aggregation/unary_agg_buffers.go | 65 ++++++------------- 2 files changed, 25 insertions(+), 52 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 5580e94f29..6b0812bc9f 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -2180,18 +2180,18 @@ var ScriptTests = []ScriptTest{ { Name: "sum() and avg() on non-DECIMAL type column returns the DOUBLE type result", SetUpScript: []string{ - "create table decimal_table (id int, val1 double, val2 float);", - "insert into decimal_table values (1,-2.5633000000000384, 2.3);", - "insert into decimal_table values (2,2.5633000000000370, 2.4);", - "insert into decimal_table values (3,0.0000000000000004, 5.3);", + "create table float_table (id int, val1 double, val2 float);", + "insert into float_table values (1,-2.5633000000000384, 2.3);", + "insert into float_table values (2,2.5633000000000370, 2.4);", + "insert into float_table values (3,0.0000000000000004, 5.3);", }, Assertions: []ScriptTestAssertion{ { - Query: "SELECT sum(id), sum(val1), sum(val2) FROM decimal_table;", + Query: "SELECT sum(id), sum(val1), sum(val2) FROM float_table ORDER BY id;", Expected: []sql.Row{{float64(6), -9.322676295501879e-16, 10.000000238418579}}, }, { - Query: "SELECT avg(id), avg(val1), avg(val2) FROM decimal_table;", + Query: "SELECT avg(id), avg(val1), avg(val2) FROM float_table ORDER BY id;;", Expected: []sql.Row{{float64(2), -3.107558765167293e-16, 3.333333412806193}}, }, }, diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 442c110135..78be8e1d4f 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -18,7 +18,7 @@ type sumBuffer struct { } func NewSumBuffer(child sql.Expression) *sumBuffer { - return &sumBuffer{true, decimal.NewFromInt(0), child} + return &sumBuffer{true, nil, child} } // Update implements the AggregationBuffer interface. @@ -32,6 +32,15 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } + // decimal.Decimal values are evaluated to string value even though the Literal expr type is Decimal type, + // so convert it to appropriate Decimal type + if s, isStr := v.(string); isStr && sql.IsDecimal(m.expr.Type()) { + val, err := m.expr.Type().Convert(s) + if err == nil { + v = val + } + } + switch n := v.(type) { case decimal.Decimal: if m.isnil { @@ -43,29 +52,13 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { } else { m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(n) } - case string: - p, s := expression.GetDecimalPrecisionAndScale(n) - dt, err := sql.CreateDecimalType(p, s) - val, err := dt.Convert(v) - if err != nil { - val = decimal.NewFromInt(0) - } - if m.isnil { - m.sum = decimal.NewFromInt(0) - m.isnil = false - } - if sum, ok := m.sum.(decimal.Decimal); ok { - m.sum = sum.Add(val.(decimal.Decimal)) - } else { - m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(val.(decimal.Decimal)) - } default: val, err := sql.Float64.Convert(n) if err != nil { val = float64(0) } if m.isnil { - m.sum = 0 + m.sum = float64(0) m.isnil = false } sum, err := sql.Float64.Convert(m.sum) @@ -132,49 +125,25 @@ func (l *lastBuffer) Dispose() { } type avgBuffer struct { - sum interface{} // sum is either decimal.Decimal or float64 + sum *sumBuffer // sum is either decimal.Decimal or float64 rows int64 expr sql.Expression } func NewAvgBuffer(child sql.Expression) *avgBuffer { const ( - sum = float64(0) rows = int64(0) ) - return &avgBuffer{sum, rows, child} + return &avgBuffer{NewSumBuffer(child), rows, child} } // Update implements the AggregationBuffer interface. func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { - v, err := a.expr.Eval(ctx, row) + err := a.sum.Update(ctx, row) if err != nil { return err } - - if v == nil { - return nil - } - - switch n := v.(type) { - case decimal.Decimal: - if sum, ok := a.sum.(decimal.Decimal); ok { - a.sum = sum.Add(n) - } else { - a.sum = decimal.NewFromFloat(a.sum.(float64)).Add(n) - } - default: - val, err := sql.Float64.Convert(n) - if err != nil { - val = float64(0) - } - sum, err := sql.Float64.Convert(a.sum) - if err != nil { - sum = float64(0) - } - a.sum = sum.(float64) + val.(float64) - } a.rows += 1 return nil @@ -182,8 +151,12 @@ func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { // Eval implements the AggregationBuffer interface. func (a *avgBuffer) Eval(ctx *sql.Context) (interface{}, error) { + sum, err := a.sum.Eval(ctx) + if err != nil { + return nil, err + } // This case is triggered when no rows exist. - switch s := a.sum.(type) { + switch s := sum.(type) { case float64: if s == 0 && a.rows == 0 { return nil, nil From 93bd7df46d249a937d6b731fce1af03fe881d3a2 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 13 Sep 2022 09:43:48 -0700 Subject: [PATCH 20/21] use sum for avg --- .../function/aggregation/sum_test.go | 27 +++++++------------ .../function/aggregation/unary_agg_buffers.go | 18 ++++++++++--- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/sql/expression/function/aggregation/sum_test.go b/sql/expression/function/aggregation/sum_test.go index 4832f9b29b..6e5470c553 100644 --- a/sql/expression/function/aggregation/sum_test.go +++ b/sql/expression/function/aggregation/sum_test.go @@ -17,7 +17,6 @@ package aggregation import ( "testing" - "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" @@ -35,17 +34,17 @@ func TestSum(t *testing.T) { { "string int values", []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, - "10", + float64(10), }, { "string float values", []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, - "10.5", + float64(10.5), }, { "string non-int values", []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, - "0", + float64(0), }, { "float values", @@ -86,11 +85,7 @@ func TestSum(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - if dt, ok := result.(decimal.Decimal); ok { - require.Equal(tt.expected, dt.StringFixed(dt.Exponent()*-1)) - } else { - require.Equal(tt.expected, result) - } + require.Equal(tt.expected, result) }) } } @@ -112,24 +107,24 @@ func TestSumWithDistinct(t *testing.T) { { "string int values", []sql.Row{{"1"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, - "10", + float64(10), }, // TODO : DISTINCT returns incorrect result, it currently returns 11.00 // https://github.com/dolthub/dolt/issues/4298 //{ // "string int values", // []sql.Row{{"1.00"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, - // "10", + // float64(10), //}, { "string float values", []sql.Row{{"1.5"}, {"1.5"}, {"1.5"}, {"1.5"}, {"2"}, {"3"}, {"4"}}, - "10.5", + float64(10.5), }, { "string non-int values", []sql.Row{{"a"}, {"b"}, {"b"}, {"c"}, {"c"}, {"d"}}, - "0", + float64(0), }, { "float values", @@ -170,11 +165,7 @@ func TestSumWithDistinct(t *testing.T) { result, err := buf.Eval(sql.NewEmptyContext()) require.NoError(err) - if dt, ok := result.(decimal.Decimal); ok { - require.Equal(tt.expected, dt.StringFixed(dt.Exponent()*-1)) - } else { - require.Equal(tt.expected, result) - } + require.Equal(tt.expected, result) }) } } diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 78be8e1d4f..570bd6660b 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -18,7 +18,7 @@ type sumBuffer struct { } func NewSumBuffer(child sql.Expression) *sumBuffer { - return &sumBuffer{true, nil, child} + return &sumBuffer{true, float64(0), child} } // Update implements the AggregationBuffer interface. @@ -32,6 +32,12 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { return nil } + m.PerformSum(v) + + return nil +} + +func (m *sumBuffer) PerformSum(v interface{}) { // decimal.Decimal values are evaluated to string value even though the Literal expr type is Decimal type, // so convert it to appropriate Decimal type if s, isStr := v.(string); isStr && sql.IsDecimal(m.expr.Type()) { @@ -67,8 +73,6 @@ func (m *sumBuffer) Update(ctx *sql.Context, row sql.Row) error { } m.sum = sum.(float64) + val.(float64) } - - return nil } // Eval implements the AggregationBuffer interface. @@ -140,10 +144,16 @@ func NewAvgBuffer(child sql.Expression) *avgBuffer { // Update implements the AggregationBuffer interface. func (a *avgBuffer) Update(ctx *sql.Context, row sql.Row) error { - err := a.sum.Update(ctx, row) + v, err := a.expr.Eval(ctx, row) if err != nil { return err } + + if v == nil { + return nil + } + + a.sum.PerformSum(v) a.rows += 1 return nil From 50d00f4f249bc50922a3394a14e51c0f0c32e190 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 13 Sep 2022 12:10:45 -0700 Subject: [PATCH 21/21] add test to skipped test --- enginetest/queries/queries.go | 12 ++++++++++-- sql/expression/convert.go | 3 --- sql/expression/literal.go | 4 ++++ sql/stringtype.go | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index e8c6659dab..4db08960fb 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -1426,9 +1426,9 @@ var QueryTests = []QueryTest{ }, }, { - Query: "with recursive t (n) as (select sum(1) from dual union all select (2.00) from dual) select sum(n) from t;", + Query: "with recursive t (n) as (select sum(1) from dual union all select ('2.00') from dual) select sum(n) from t;", Expected: []sql.Row{ - {"3.00"}, + {float64(3)}, }, }, { @@ -7439,6 +7439,14 @@ var BrokenQueries = []QueryTest{ Query: "STR_TO_DATE('2013 32 Tuesday', '%X %V %W')", // Tuesday of 32th week Expected: []sql.Row{{"2013-08-13"}}, }, + // mergeUnionSchemas adds convert the decimal value to cast to string, which loses decimal type info. + // https://github.com/dolthub/dolt/issues/4331 + { + Query: "with recursive t (n) as (select sum(1) from dual union all select (2.00) from dual) select sum(n) from t;", + Expected: []sql.Row{ + {"3.00"}, + }, + }, } var VersionedQueries = []QueryTest{ diff --git a/sql/expression/convert.go b/sql/expression/convert.go index f9ccec03a7..9876854cf6 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -166,9 +166,6 @@ func convertValue(val interface{}, castTo string, originType sql.Type) (interfac } return b, nil case ConvertToChar, ConvertToNChar: - if sql.IsDecimal(originType) { - return val, nil - } s, err := sql.LongText.Convert(val) if err != nil { return nil, nil diff --git a/sql/expression/literal.go b/sql/expression/literal.go index cb2e73f15b..641c1c011b 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -17,6 +17,8 @@ package expression import ( "fmt" + "github.com/shopspring/decimal" + "github.com/dolthub/go-mysql-server/sql" ) @@ -66,6 +68,8 @@ func (lit *Literal) String() string { return fmt.Sprintf("%d", v) case string: return fmt.Sprintf("'%s'", v) + case decimal.Decimal: + return v.StringFixed(v.Exponent() * -1) case []byte: return "BLOB" case nil: diff --git a/sql/stringtype.go b/sql/stringtype.go index 3d97017bac..bf7d4db5a8 100644 --- a/sql/stringtype.go +++ b/sql/stringtype.go @@ -373,7 +373,7 @@ func ConvertToString(v interface{}, t StringType) (string, error) { case time.Time: val = s.Format(TimestampDatetimeLayout) case decimal.Decimal: - val = s.String() + val = s.StringFixed(s.Exponent() * -1) case decimal.NullDecimal: if !s.Valid { return "", nil