From b0c32a79eed47ba1399fa02e8c5ef9b89df0b777 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 24 Jul 2023 16:58:27 +0200 Subject: [PATCH 1/8] json: Add better parsing and weights logic Signed-off-by: Dirkjan Bussink --- go/mysql/json/helpers.go | 26 +++++- go/mysql/json/parser.go | 16 ++-- go/mysql/json/weights.go | 153 ++++++++++++++++++++++++++++++++++ go/mysql/json/weights_test.go | 28 +++++++ 4 files changed, 216 insertions(+), 7 deletions(-) create mode 100644 go/mysql/json/weights.go create mode 100644 go/mysql/json/weights_test.go diff --git a/go/mysql/json/helpers.go b/go/mysql/json/helpers.go index bc9995b48cb..aa4b6fff688 100644 --- a/go/mysql/json/helpers.go +++ b/go/mysql/json/helpers.go @@ -18,6 +18,8 @@ package json import ( "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vthash" ) @@ -25,7 +27,7 @@ const hashPrefixJSON = 0xCCBB func (v *Value) Hash(h *vthash.Hasher) { h.Write16(hashPrefixJSON) - _, _ = h.Write(v.ToRawBytes()) + _, _ = h.Write(v.WeightString(nil)) } func (v *Value) ToRawBytes() []byte { @@ -81,6 +83,28 @@ func NewOpaqueValue(raw string) *Value { return &Value{s: raw, t: TypeOpaque} } +func NewFromSQL(v sqltypes.Value) (*Value, error) { + switch { + case v.Type() == sqltypes.TypeJSON: + var p Parser + return p.ParseBytes(v.Raw()) + case v.IsSigned(): + return NewNumber(v.RawStr(), NumberTypeSigned), nil + case v.IsUnsigned(): + return NewNumber(v.RawStr(), NumberTypeUnsigned), nil + case v.IsDecimal(): + return NewNumber(v.RawStr(), NumberTypeDecimal), nil + case v.IsFloat(): + return NewNumber(v.RawStr(), NumberTypeFloat), nil + case v.IsText(): + return NewString(v.RawStr()), nil + case v.IsBinary(): + return NewBlob(v.RawStr()), nil + default: + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v) + } +} + func (v *Value) Depth() int { max := func(a, b int) int { if a > b { diff --git a/go/mysql/json/parser.go b/go/mysql/json/parser.go index b660884508d..c9d71e010a8 100644 --- a/go/mysql/json/parser.go +++ b/go/mysql/json/parser.go @@ -704,6 +704,15 @@ func (v *Value) MarshalTime() string { return "" } +func (v *Value) marshalFloat(dst []byte) []byte { + f, _ := v.Float64() + buf := format.FormatFloat(f) + if bytes.IndexByte(buf, '.') == -1 && bytes.IndexByte(buf, 'e') == -1 { + buf = append(buf, '.', '0') + } + return append(dst, buf...) +} + // MarshalTo appends marshaled v to dst and returns the result. func (v *Value) MarshalTo(dst []byte) []byte { switch v.t { @@ -744,12 +753,7 @@ func (v *Value) MarshalTo(dst []byte) []byte { return dst case TypeNumber: if v.NumberType() == NumberTypeFloat { - f, _ := v.Float64() - buf := format.FormatFloat(f) - if bytes.IndexByte(buf, '.') == -1 && bytes.IndexByte(buf, 'e') == -1 { - buf = append(buf, '.', '0') - } - return append(dst, buf...) + return v.marshalFloat(dst) } return append(dst, v.s...) case TypeBoolean: diff --git a/go/mysql/json/weights.go b/go/mysql/json/weights.go new file mode 100644 index 00000000000..6b65465b239 --- /dev/null +++ b/go/mysql/json/weights.go @@ -0,0 +1,153 @@ +package json + +import ( + "encoding/binary" + "strings" + + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/mysql/fastparse" +) + +const ( + JSON_KEY_NULL = '\x00' + JSON_KEY_NUMBER_NEG = '\x01' + JSON_KEY_NUMBER_ZERO = '\x02' + JSON_KEY_NUMBER_POS = '\x03' + JSON_KEY_STRING = '\x04' + JSON_KEY_OBJECT = '\x05' + JSON_KEY_ARRAY = '\x06' + JSON_KEY_FALSE = '\x07' + JSON_KEY_TRUE = '\x08' + JSON_KEY_DATE = '\x09' + JSON_KEY_TIME = '\x0A' + JSON_KEY_DATETIME = '\x0B' + JSON_KEY_OPAQUE = '\x0C' +) + +// numericWeightString generates a fixed-width weight string for any JSON +// number. It requires the `num` representation to be normalized, otherwise +// the resulting string will not sort. +func (v *Value) numericWeightString(dst []byte, num string) []byte { + const MaxPadLength = 30 + + var ( + exponent string + exp int64 + significant string + negative bool + original = len(dst) + ) + + if num[0] == '-' { + negative = true + num = num[1:] + } + + if i := strings.IndexByte(num, 'e'); i >= 0 { + exponent = num[i+1:] + num = num[:i] + } + + significant = num + for len(significant) > 0 { + if significant[0] >= '1' && significant[0] <= '9' { + break + } + significant = significant[1:] + } + if len(significant) == 0 { + return append(dst, JSON_KEY_NUMBER_ZERO) + } + + if len(exponent) > 0 { + exp, _ = fastparse.ParseInt64(exponent, 10) + } else { + dec := strings.IndexByte(num, '.') + ofs := len(num) - len(significant) + if dec < 0 { + exp = int64(len(significant) - 1) + } else if ofs < dec { + exp = int64(dec - ofs - 1) + } else { + exp = int64(dec - ofs) + } + } + + if negative { + dst = append(dst, JSON_KEY_NUMBER_NEG) + dst = binary.BigEndian.AppendUint16(dst, uint16(-exp)^(1<<15)) + + for _, ch := range []byte(significant) { + if ch >= '0' && ch <= '9' { + dst = append(dst, '9'-ch+'0') + } + } + for len(dst)-original < MaxPadLength { + dst = append(dst, '9') + } + } else { + dst = append(dst, JSON_KEY_NUMBER_POS) + dst = binary.BigEndian.AppendUint16(dst, uint16(exp)^(1<<15)) + + for _, ch := range []byte(significant) { + if ch >= '0' && ch <= '9' { + dst = append(dst, ch) + } + } + for len(dst)-original < MaxPadLength { + dst = append(dst, '0') + } + } + + return dst +} + +func (v *Value) WeightString(dst []byte) []byte { + switch v.Type() { + case TypeNull: + dst = append(dst, JSON_KEY_NULL) + case TypeNumber: + if v.NumberType() == NumberTypeFloat { + f := v.marshalFloat(nil) + dst = v.numericWeightString(dst, hack.String(f)) + } else { + dst = v.numericWeightString(dst, v.s) + } + case TypeString: + dst = append(dst, JSON_KEY_STRING) + dst = append(dst, v.s...) + case TypeObject: + // MySQL compat: we follow the same behavior as MySQL does for weight strings in JSON, + // where Objects and Arrays are only sorted by their length and not by the values + // of their contents. + // Note that in MySQL, generating the weight string of a JSON Object or Array will actually + // print a warning in the logs! We're not printing anything. + dst = append(dst, JSON_KEY_OBJECT) + dst = binary.BigEndian.AppendUint32(dst, uint32(v.o.Len())) + case TypeArray: + dst = append(dst, JSON_KEY_ARRAY) + dst = binary.BigEndian.AppendUint32(dst, uint32(len(v.a))) + case TypeBoolean: + switch v { + case ValueTrue: + dst = append(dst, JSON_KEY_TRUE) + case ValueFalse: + dst = append(dst, JSON_KEY_FALSE) + default: + panic("invalid JSON Boolean") + } + case TypeDate: + dst = append(dst, JSON_KEY_DATE) + dst = append(dst, v.MarshalDate()...) + case TypeDateTime: + dst = append(dst, JSON_KEY_DATETIME) + dst = append(dst, v.MarshalDateTime()...) + case TypeTime: + dst = append(dst, JSON_KEY_TIME) + dst = append(dst, v.MarshalTime()...) + case TypeOpaque, TypeBit, TypeBlob: + dst = append(dst, JSON_KEY_OPAQUE) + dst = append(dst, v.s...) + } + return dst +} diff --git a/go/mysql/json/weights_test.go b/go/mysql/json/weights_test.go new file mode 100644 index 00000000000..9442a64aa06 --- /dev/null +++ b/go/mysql/json/weights_test.go @@ -0,0 +1,28 @@ +package json + +import ( + "bytes" + "testing" + + "vitess.io/vitess/go/mysql/format" +) + +func TestWeightStrings(t *testing.T) { + var cases = []struct { + l, r *Value + }{ + {NewNumber("-2.3742940301417033", NumberTypeFloat), NewNumber("-0.024384053736998118", NumberTypeFloat)}, + {NewNumber("2.3742940301417033", NumberTypeFloat), NewNumber("20.3742940301417033", NumberTypeFloat)}, + {NewNumber(string(format.FormatFloat(1000000000000000.0)), NumberTypeFloat), NewNumber("100000000000000000", NumberTypeDecimal)}, + } + + for _, tc := range cases { + l := tc.l.WeightString(nil) + r := tc.r.WeightString(nil) + + if bytes.Compare(l, r) >= 0 { + t.Errorf("expected %s < %s\nl = %v\n = %v\nr = %v\n = %v", + tc.l.String(), tc.r.String(), l, string(l), r, string(r)) + } + } +} From a736f831165f2bcc32059a2420ea7aefbdca9fcf Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 24 Jul 2023 17:10:19 +0200 Subject: [PATCH 2/8] datetime: Fix parsing integers into datetime A 0 time is still valid. Signed-off-by: Dirkjan Bussink --- go/mysql/datetime/parse.go | 2 +- go/mysql/datetime/parse_test.go | 51 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/go/mysql/datetime/parse.go b/go/mysql/datetime/parse.go index 1d94a9ba8a5..52861127cde 100644 --- a/go/mysql/datetime/parse.go +++ b/go/mysql/datetime/parse.go @@ -321,7 +321,7 @@ func ParseDateTimeInt64(i int64) (dt DateTime, ok bool) { if i == 0 { return dt, true } - if t == 0 || d == 0 { + if d == 0 { return dt, false } dt.Time, ok = ParseTimeInt64(t) diff --git a/go/mysql/datetime/parse_test.go b/go/mysql/datetime/parse_test.go index 6b5b489d167..7f2397621e1 100644 --- a/go/mysql/datetime/parse_test.go +++ b/go/mysql/datetime/parse_test.go @@ -17,6 +17,7 @@ limitations under the License. package datetime import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -290,3 +291,53 @@ func TestParseDateTime(t *testing.T) { }) } } + +func TestParseDateTimeInt64(t *testing.T) { + type datetime struct { + year int + month int + day int + hour int + minute int + second int + nanosecond int + } + tests := []struct { + input int64 + output datetime + l int + err bool + }{ + {input: 1, output: datetime{}, err: true}, + {input: 20221012000000, output: datetime{2022, 10, 12, 0, 0, 0, 0}}, + {input: 20221012112233, output: datetime{2022, 10, 12, 11, 22, 33, 0}}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%d", test.input), func(t *testing.T) { + got, ok := ParseDateTimeInt64(test.input) + if test.err { + if !got.IsZero() { + assert.Equal(t, test.output.year, got.Date.Year()) + assert.Equal(t, test.output.month, got.Date.Month()) + assert.Equal(t, test.output.day, got.Date.Day()) + assert.Equal(t, test.output.hour, got.Time.Hour()) + assert.Equal(t, test.output.minute, got.Time.Minute()) + assert.Equal(t, test.output.second, got.Time.Second()) + assert.Equal(t, test.output.nanosecond, got.Time.Nanosecond()) + } + assert.Falsef(t, ok, "did not fail to parse %s", test.input) + return + } + + require.True(t, ok) + assert.Equal(t, test.output.year, got.Date.Year()) + assert.Equal(t, test.output.month, got.Date.Month()) + assert.Equal(t, test.output.day, got.Date.Day()) + assert.Equal(t, test.output.hour, got.Time.Hour()) + assert.Equal(t, test.output.minute, got.Time.Minute()) + assert.Equal(t, test.output.second, got.Time.Second()) + assert.Equal(t, test.output.nanosecond, got.Time.Nanosecond()) + }) + } +} From 9f6c61321c44945c0047e5cf827d195beb49bf2e Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 11:49:21 +0200 Subject: [PATCH 3/8] Fix parsing datetime A valid month value is at least 1, not at least 0. Signed-off-by: Dirkjan Bussink --- go/mysql/datetime/parse_test.go | 1 + go/mysql/datetime/timeparts.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/go/mysql/datetime/parse_test.go b/go/mysql/datetime/parse_test.go index 7f2397621e1..6ed342edfb3 100644 --- a/go/mysql/datetime/parse_test.go +++ b/go/mysql/datetime/parse_test.go @@ -236,6 +236,7 @@ func TestParseDateTime(t *testing.T) { {input: "20221012111213.123456", output: datetime{2022, 10, 12, 11, 12, 13, 123456000}, l: 6}, {input: "221012111213.123456", output: datetime{2022, 10, 12, 11, 12, 13, 123456000}, l: 6}, {input: "2022101211121321321312", output: datetime{2022, 10, 12, 11, 12, 13, 0}, err: true}, + {input: "3284004416225113510", output: datetime{}, err: true}, {input: "2012-12-31 11:30:45", output: datetime{2012, 12, 31, 11, 30, 45, 0}}, {input: "2012^12^31 11+30+45", output: datetime{2012, 12, 31, 11, 30, 45, 0}}, {input: "2012/12/31 11*30*45", output: datetime{2012, 12, 31, 11, 30, 45, 0}}, diff --git a/go/mysql/datetime/timeparts.go b/go/mysql/datetime/timeparts.go index 2c25acc9653..a774099a93a 100644 --- a/go/mysql/datetime/timeparts.go +++ b/go/mysql/datetime/timeparts.go @@ -48,7 +48,7 @@ func (tp *timeparts) toDateTime(prec int) (DateTime, int, bool) { if tp.yday > 0 { return DateTime{}, 0, false } else { - if tp.month < 0 { + if tp.month < 1 { tp.month = int(time.January) } if tp.day < 0 { From 9aab43f1580f82cc31d30054620da15b30c0efc9 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 11:51:05 +0200 Subject: [PATCH 4/8] sqltypes: Use faster integer parsing logic Signed-off-by: Dirkjan Bussink --- go/sqltypes/bind_variables_test.go | 14 +- go/sqltypes/value.go | 22 ++- go/sqltypes/value_test.go | 10 +- go/vt/vtgate/engine/merge_sort_test.go | 4 +- .../vtgate/evalengine/api_arithmetic_test.go | 140 +++++++++++------- 5 files changed, 112 insertions(+), 78 deletions(-) diff --git a/go/sqltypes/bind_variables_test.go b/go/sqltypes/bind_variables_test.go index 40925d228a1..3e83a74a331 100644 --- a/go/sqltypes/bind_variables_test.go +++ b/go/sqltypes/bind_variables_test.go @@ -329,7 +329,7 @@ func TestValidateBindVarables(t *testing.T) { Value: []byte("a"), }, }, - err: `v: strconv.ParseInt: parsing "a": invalid syntax`, + err: `v: cannot parse int64 from "a"`, }, { in: map[string]*querypb.BindVariable{ "v": { @@ -340,7 +340,7 @@ func TestValidateBindVarables(t *testing.T) { }}, }, }, - err: `v: strconv.ParseInt: parsing "a": invalid syntax`, + err: `v: cannot parse int64 from "a"`, }} for _, tcase := range tcases { err := ValidateBindVariables(tcase.in) @@ -500,31 +500,31 @@ func TestValidateBindVariable(t *testing.T) { Type: querypb.Type_INT64, Value: []byte(InvalidNeg), }, - err: "out of range", + err: `cannot parse int64 from "-9223372036854775809": overflow`, }, { in: &querypb.BindVariable{ Type: querypb.Type_INT64, Value: []byte(InvalidPos), }, - err: "out of range", + err: `cannot parse int64 from "18446744073709551616": overflow`, }, { in: &querypb.BindVariable{ Type: querypb.Type_UINT64, Value: []byte("-1"), }, - err: "invalid syntax", + err: `cannot parse uint64 from "-1"`, }, { in: &querypb.BindVariable{ Type: querypb.Type_UINT64, Value: []byte(InvalidPos), }, - err: "out of range", + err: `cannot parse uint64 from "18446744073709551616": overflow`, }, { in: &querypb.BindVariable{ Type: querypb.Type_FLOAT64, Value: []byte("a"), }, - err: "invalid syntax", + err: `unparsed tail left after parsing float64 from "a"`, }, { in: &querypb.BindVariable{ Type: querypb.Type_EXPRESSION, diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index d4a017798f0..2d77c3b72ea 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -29,7 +29,8 @@ import ( "vitess.io/vitess/go/bytes2" "vitess.io/vitess/go/hack" - + "vitess.io/vitess/go/mysql/decimal" + "vitess.io/vitess/go/mysql/fastparse" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -74,17 +75,22 @@ type ( func NewValue(typ querypb.Type, val []byte) (v Value, err error) { switch { case IsSigned(typ): - if _, err := strconv.ParseInt(string(val), 10, 64); err != nil { + if _, err := fastparse.ParseInt64(hack.String(val), 10); err != nil { return NULL, err } return MakeTrusted(typ, val), nil case IsUnsigned(typ): - if _, err := strconv.ParseUint(string(val), 10, 64); err != nil { + if _, err := fastparse.ParseUint64(hack.String(val), 10); err != nil { + return NULL, err + } + return MakeTrusted(typ, val), nil + case IsFloat(typ): + if _, err := fastparse.ParseFloat64(hack.String(val)); err != nil { return NULL, err } return MakeTrusted(typ, val), nil - case IsFloat(typ) || typ == Decimal: - if _, err := strconv.ParseFloat(string(val), 64); err != nil { + case IsDecimal(typ): + if _, err := decimal.NewFromMySQL(val); err != nil { return NULL, err } return MakeTrusted(typ, val), nil @@ -286,7 +292,7 @@ func (v Value) ToInt64() (int64, error) { return 0, ErrIncompatibleTypeCast } - return strconv.ParseInt(v.RawStr(), 10, 64) + return fastparse.ParseInt64(v.RawStr(), 10) } func (v Value) ToInt32() (int32, error) { @@ -313,7 +319,7 @@ func (v Value) ToFloat64() (float64, error) { return 0, ErrIncompatibleTypeCast } - return strconv.ParseFloat(v.RawStr(), 64) + return fastparse.ParseFloat64(v.RawStr()) } // ToUint16 returns the value as MySQL would return it as a uint16. @@ -332,7 +338,7 @@ func (v Value) ToUint64() (uint64, error) { return 0, ErrIncompatibleTypeCast } - return strconv.ParseUint(v.RawStr(), 10, 64) + return fastparse.ParseUint64(v.RawStr(), 10) } func (v Value) ToUint32() (uint32, error) { diff --git a/go/sqltypes/value_test.go b/go/sqltypes/value_test.go index 82aea752480..86c751f3d0d 100644 --- a/go/sqltypes/value_test.go +++ b/go/sqltypes/value_test.go @@ -165,23 +165,23 @@ func TestNewValue(t *testing.T) { }, { inType: Int64, inVal: InvalidNeg, - outErr: "out of range", + outErr: `cannot parse int64 from "-9223372036854775809": overflow`, }, { inType: Int64, inVal: InvalidPos, - outErr: "out of range", + outErr: `cannot parse int64 from "18446744073709551616": overflow`, }, { inType: Uint64, inVal: "-1", - outErr: "invalid syntax", + outErr: `cannot parse uint64 from "-1"`, }, { inType: Uint64, inVal: InvalidPos, - outErr: "out of range", + outErr: `cannot parse uint64 from "18446744073709551616": overflow`, }, { inType: Float64, inVal: "a", - outErr: "invalid syntax", + outErr: `unparsed tail left after parsing float64 from "a"`, }, { inType: Expression, inVal: "a", diff --git a/go/vt/vtgate/engine/merge_sort_test.go b/go/vt/vtgate/engine/merge_sort_test.go index 74c8d320d2b..e8823e9e6d5 100644 --- a/go/vt/vtgate/engine/merge_sort_test.go +++ b/go/vt/vtgate/engine/merge_sort_test.go @@ -370,7 +370,7 @@ func TestMergeSortDataFailures(t *testing.T) { }} err := testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil }) - want := `strconv.ParseInt: parsing "2.1": invalid syntax` + want := `unparsed tail left after parsing int64 from "2.1": ".1"` require.EqualError(t, err, want) // Create a new VCursor because the previous MergeSort will still @@ -386,7 +386,7 @@ func TestMergeSortDataFailures(t *testing.T) { ), }} err = testMergeSort(shardResults, orderBy, func(qr *sqltypes.Result) error { return nil }) - want = `strconv.ParseInt: parsing "1.1": invalid syntax` + want = `unparsed tail left after parsing int64 from "1.1": ".1"` require.EqualError(t, err, want) } diff --git a/go/vt/vtgate/evalengine/api_arithmetic_test.go b/go/vt/vtgate/evalengine/api_arithmetic_test.go index 0a0abc84a30..7d5a6e00c71 100644 --- a/go/vt/vtgate/evalengine/api_arithmetic_test.go +++ b/go/vt/vtgate/evalengine/api_arithmetic_test.go @@ -117,12 +117,12 @@ func TestArithmetics(t *testing.T) { // testing for error for parsing float value to uint64 v1: TestValue(sqltypes.Uint64, "1.2"), v2: NewInt64(2), - err: "strconv.ParseUint: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing uint64 from \"1.2\": \".2\"", }, { // testing for error for parsing float value to uint64 v1: NewUint64(2), v2: TestValue(sqltypes.Uint64, "1.2"), - err: "strconv.ParseUint: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing uint64 from \"1.2\": \".2\"", }, { // uint64 - uint64 v1: NewUint64(8), @@ -253,11 +253,11 @@ func TestArithmetics(t *testing.T) { }, { v1: TestValue(sqltypes.Int64, "1.2"), v2: NewInt64(2), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { v1: NewInt64(2), v2: TestValue(sqltypes.Int64, "1.2"), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { // testing for uint64 overflow with max uint64 + int value v1: NewUint64(maxUint64), @@ -320,12 +320,12 @@ func TestArithmetics(t *testing.T) { // testing for error in types v1: TestValue(sqltypes.Int64, "1.2"), v2: NewInt64(2), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { // testing for error in types v1: NewInt64(2), v2: TestValue(sqltypes.Int64, "1.2"), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { // testing for uint/int v1: NewUint64(4), @@ -384,12 +384,12 @@ func TestArithmetics(t *testing.T) { // testing for error in types v1: TestValue(sqltypes.Int64, "1.2"), v2: NewInt64(2), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { // testing for error in types v1: NewInt64(2), v2: TestValue(sqltypes.Int64, "1.2"), - err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", }, { // testing for uint*int v1: NewUint64(4), @@ -479,12 +479,12 @@ func TestNullSafeAdd(t *testing.T) { // Make sure underlying error is returned for LHS. v1: TestValue(sqltypes.Int64, "1.2"), v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), }, { // Make sure underlying error is returned for RHS. v1: NewInt64(2), v2: TestValue(sqltypes.Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), }, { // Make sure underlying error is returned while adding. v1: NewInt64(-1), @@ -540,7 +540,7 @@ func TestCast(t *testing.T) { }, { typ: sqltypes.Int24, v: TestValue(sqltypes.VarChar, "bad int"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseInt: parsing "bad int": invalid syntax`), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `cannot parse int64 from "bad int"`), }, { typ: sqltypes.Uint64, v: TestValue(sqltypes.Uint32, "32"), @@ -552,7 +552,7 @@ func TestCast(t *testing.T) { }, { typ: sqltypes.Uint24, v: TestValue(sqltypes.Int64, "-1"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseUint: parsing "-1": invalid syntax`), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `cannot parse uint64 from "-1"`), }, { typ: sqltypes.Float64, v: TestValue(sqltypes.Int64, "64"), @@ -572,7 +572,7 @@ func TestCast(t *testing.T) { }, { typ: sqltypes.Float64, v: TestValue(sqltypes.VarChar, "bad float"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseFloat: parsing "bad float": invalid syntax`), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `unparsed tail left after parsing float64 from "bad float": "bad float"`), }, { typ: sqltypes.VarChar, v: TestValue(sqltypes.Int64, "64"), @@ -701,7 +701,7 @@ func TestToFloat64(t *testing.T) { out: 1.2, }, { v: TestValue(sqltypes.Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), }} for _, tcase := range tcases { t.Run(tcase.v.String(), func(t *testing.T) { @@ -847,11 +847,11 @@ func TestNewIntegralNumeric(t *testing.T) { }, { // Only valid Int64 allowed if type is Int64. v: TestValue(sqltypes.Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), }, { // Only valid Uint64 allowed if type is Uint64. v: TestValue(sqltypes.Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing uint64 from \"1.2\": \".2\""), }, { v: TestValue(sqltypes.VarChar, "abcd"), err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), @@ -1147,17 +1147,24 @@ func TestMin(t *testing.T) { err: vterrors.New(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"), }} for _, tcase := range tcases { - v, err := Min(tcase.v1, tcase.v2, collations.Unknown) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + v, err := Min(tcase.v1, tcase.v2, collations.Unknown) + if tcase.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + return + } - if !reflect.DeepEqual(v, tcase.min) { - t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) - } + if !reflect.DeepEqual(v, tcase.min) { + t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) + } + }) } } @@ -1205,17 +1212,24 @@ func TestMinCollate(t *testing.T) { }, } for _, tcase := range tcases { - got, err := Min(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + got, err := Min(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) + if tcase.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + if !vterrors.Equals(err, tcase.err) { + t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + return + } - if got.ToString() == tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } + if got.ToString() == tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + }) } } @@ -1254,17 +1268,24 @@ func TestMax(t *testing.T) { err: vterrors.New(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"), }} for _, tcase := range tcases { - v, err := Max(tcase.v1, tcase.v2, collations.Unknown) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + v, err := Max(tcase.v1, tcase.v2, collations.Unknown) + if tcase.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + return + } - if !reflect.DeepEqual(v, tcase.max) { - t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) - } + if !reflect.DeepEqual(v, tcase.max) { + t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) + } + }) } } @@ -1312,17 +1333,24 @@ func TestMaxCollate(t *testing.T) { }, } for _, tcase := range tcases { - got, err := Max(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + got, err := Max(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) + if tcase.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + if !vterrors.Equals(err, tcase.err) { + t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + return + } - if got.ToString() != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } + if got.ToString() != tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + }) } } From 07a777fd7dd92a0953bcff36ce7eac9cbe592320 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 12:07:49 +0200 Subject: [PATCH 5/8] vindexes: Fix collation passing into vindex comparisons Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/engine/hash_join.go | 2 +- go/vt/vtgate/vindexes/consistent_lookup.go | 4 +--- go/vt/vtgate/vindexes/consistent_lookup_test.go | 10 ++++++++-- go/vt/vtgate/vindexes/lookup_test.go | 5 +++++ go/vt/vtgate/vindexes/vindex.go | 2 ++ 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index 1fb889c8fd4..89f552f98c1 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -98,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma for _, currentLHSRow := range lftRows { lhsVal := currentLHSRow[hj.LHSKey] // hash codes can give false positives, so we need to check with a real comparison as well - cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, collations.Unknown) + cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, hj.Collation) if err != nil { return nil, err } diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index 553742b13cc..4836b0d6502 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -23,7 +23,6 @@ import ( "fmt" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -413,8 +412,7 @@ func (lu *clCommon) Delete(ctx context.Context, vcursor VCursor, rowsColValues [ func (lu *clCommon) Update(ctx context.Context, vcursor VCursor, oldValues []sqltypes.Value, ksid []byte, newValues []sqltypes.Value) error { equal := true for i := range oldValues { - // TODO(king-11) make collation aware - result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], collations.Unknown) + result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i], vcursor.ConnCollation()) // errors from NullsafeCompare can be ignored. if they are real problems, we'll see them in the Create/Update if err != nil || result != 0 { equal = false diff --git a/go/vt/vtgate/vindexes/consistent_lookup_test.go b/go/vt/vtgate/vindexes/consistent_lookup_test.go index 17adbcf748f..59035776edd 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup_test.go +++ b/go/vt/vtgate/vindexes/consistent_lookup_test.go @@ -29,6 +29,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" @@ -451,7 +453,7 @@ func TestConsistentLookupNoUpdate(t *testing.T) { vc.verifyLog(t, []string{}) } -func TestConsistentLookupUpdateBecauseUncomparableTypes(t *testing.T) { +func TestConsistentLookupUpdateBecauseComparableTypes(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} @@ -475,7 +477,7 @@ func TestConsistentLookupUpdateBecauseUncomparableTypes(t *testing.T) { err = lookup.(Lookup).Update(context.Background(), vc, []sqltypes.Value{literal, literal}, []byte("test"), []sqltypes.Value{literal, literal}) require.NoError(t, err) - require.NotEmpty(t, vc.log) + vc.verifyLog(t, []string{}) vc.log = nil }) } @@ -524,6 +526,10 @@ func (vc *loggingVCursor) InTransactionAndIsDML() bool { return false } +func (vc *loggingVCursor) ConnCollation() collations.ID { + return collations.Default() +} + type bv struct { Name string Bv string diff --git a/go/vt/vtgate/vindexes/lookup_test.go b/go/vt/vtgate/vindexes/lookup_test.go index 1051a394787..a59fcbf1da9 100644 --- a/go/vt/vtgate/vindexes/lookup_test.go +++ b/go/vt/vtgate/vindexes/lookup_test.go @@ -22,6 +22,7 @@ import ( "strings" "testing" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/assert" @@ -113,6 +114,10 @@ func (vc *vcursor) execute(query string, bindvars map[string]*querypb.BindVariab panic("unexpected") } +func (vc *vcursor) ConnCollation() collations.ID { + return collations.Default() +} + func lookupCreateVindexTestCase( testName string, vindexParams map[string]string, diff --git a/go/vt/vtgate/vindexes/vindex.go b/go/vt/vtgate/vindexes/vindex.go index 141e1e61efe..a5295681248 100644 --- a/go/vt/vtgate/vindexes/vindex.go +++ b/go/vt/vtgate/vindexes/vindex.go @@ -21,6 +21,7 @@ import ( "fmt" "sort" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" @@ -42,6 +43,7 @@ type ( ExecuteKeyspaceID(ctx context.Context, keyspace string, ksid []byte, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError, autocommit bool) (*sqltypes.Result, error) InTransactionAndIsDML() bool LookupRowLockShardSession() vtgatepb.CommitOrder + ConnCollation() collations.ID } // Vindex defines the interface required to register a vindex. From 7a366ee3bd0998f771816f60a92f23e8c1fc2b17 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 12:28:23 +0200 Subject: [PATCH 6/8] evalengine: Add broader support for type comparisons This implements additional type support for the Nullsafe* family of functions. It implements a fast path for common types with equal coercion and then falls back to the generic evalengine logic for all other cases. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/api_compare.go | 144 +++++++++++------- go/vt/vtgate/evalengine/api_compare_test.go | 27 ++-- go/vt/vtgate/evalengine/api_hash.go | 54 +++++-- go/vt/vtgate/evalengine/api_hash_test.go | 159 +++++++++++++++----- go/vt/vtgate/evalengine/api_types.go | 38 ----- go/vt/vtgate/evalengine/compare.go | 2 +- go/vt/vtgate/evalengine/eval.go | 136 +++++++++++++++-- 7 files changed, 392 insertions(+), 168 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index 9ecc03a3c6f..7289af093f2 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -85,23 +85,6 @@ func minmax(v1, v2 sqltypes.Value, min bool, collation collations.ID) (sqltypes. return v2, nil } -// isByteComparable returns true if the type is binary or date/time. -func isByteComparable(typ sqltypes.Type, collationID collations.ID) bool { - if sqltypes.IsBinary(typ) { - return true - } - if sqltypes.IsText(typ) { - return collationID == collations.CollationBinaryID - } - switch typ { - case sqltypes.Timestamp, sqltypes.Date, sqltypes.Time, sqltypes.Datetime, sqltypes.Enum, - sqltypes.Set, sqltypes.TypeJSON, sqltypes.Bit, sqltypes.Geometry: - return true - default: - return false - } -} - // NullsafeCompare returns 0 if v1==v2, -1 if v1v2. // NULL is the lowest value. If any value is // numeric, then a numeric comparison is performed after @@ -122,52 +105,99 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err return 1, nil } - if isByteComparable(v1.Type(), collationID) && isByteComparable(v2.Type(), collationID) { - return bytes.Compare(v1.Raw(), v2.Raw()), nil + // We have a fast path here for the case where both values are + // the same type, and it's one of the basic types we can compare + // directly. This is a common case for equality checks. + if v1.Type() == v2.Type() { + switch { + case sqltypes.IsSigned(v1.Type()): + i1, err := v1.ToInt64() + if err != nil { + return 0, err + } + i2, err := v2.ToInt64() + if err != nil { + return 0, err + } + switch { + case i1 < i2: + return -1, nil + case i1 > i2: + return 1, nil + default: + return 0, nil + } + case sqltypes.IsUnsigned(v1.Type()): + u1, err := v1.ToUint64() + if err != nil { + return 0, err + } + u2, err := v2.ToUint64() + if err != nil { + return 0, err + } + switch { + case u1 < u2: + return -1, nil + case u1 > u2: + return 1, nil + default: + return 0, nil + } + case sqltypes.IsBinary(v1.Type()), v1.Type() == sqltypes.Date, + v1.Type() == sqltypes.Datetime, v1.Type() == sqltypes.Timestamp: + // We can't optimize for Time here, since Time is not sortable + // based on the raw bytes. This is because of cases like + // '24:00:00' and '101:00:00' which are both valid times and + // order wrong based on the raw bytes. + return bytes.Compare(v1.Raw(), v2.Raw()), nil + case sqltypes.IsText(v1.Type()): + if collationID == collations.CollationBinaryID { + return bytes.Compare(v1.Raw(), v2.Raw()), nil + } + coll := collationID.Get() + if coll == nil { + return 0, UnsupportedCollationError{ID: collationID} + } + result := coll.Collate(v1.Raw(), v2.Raw(), false) + switch { + case result < 0: + return -1, nil + case result > 0: + return 1, nil + default: + return 0, nil + } + } } - typ, err := CoerceTo(v1.Type(), v2.Type()) // TODO systay we should add a method where this decision is done at plantime + v1eval, err := valueToEval(v1, collations.TypedCollation{ + Collation: collationID, + Coercibility: collations.CoerceImplicit, + Repertoire: collations.RepertoireUnicode, + }) if err != nil { return 0, err } - switch { - case sqltypes.IsText(typ): - collation := collationID.Get() - if collation == nil { - return 0, UnsupportedCollationError{ID: collationID} - } - - v1Bytes, err := v1.ToBytes() - if err != nil { - return 0, err - } - v2Bytes, err := v2.ToBytes() - if err != nil { - return 0, err - } - - switch result := collation.Collate(v1Bytes, v2Bytes, false); { - case result < 0: - return -1, nil - case result > 0: - return 1, nil - default: - return 0, nil - } - - case sqltypes.IsNumber(typ): - v1cast, err := valueToEvalCast(v1, typ) - if err != nil { - return 0, err - } - v2cast, err := valueToEvalCast(v2, typ) - if err != nil { - return 0, err - } - return compareNumeric(v1cast, v2cast) + v2eval, err := valueToEval(v2, collations.TypedCollation{ + Collation: collationID, + Coercibility: collations.CoerceImplicit, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return 0, err + } - default: - return 0, UnsupportedComparisonError{Type1: v1.Type(), Type2: v2.Type()} + out, err := evalCompare(v1eval, v2eval) + if err != nil { + return 0, err + } + if out == 0 { + return 0, nil + } + if out > 0 { + return 1, nil } + return -1, nil } diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index 603f3ebe676..70e83dda6be 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -1214,17 +1214,24 @@ func TestNullsafeCompareCollate(t *testing.T) { }, } for _, tcase := range tcases { - got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + got, err := NullsafeCompare(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) + if tcase.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + } + if !vterrors.Equals(err, tcase.err) { + t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + return + } - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } + if got != tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + }) } } diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 60dce8232d2..5d5aac98457 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -34,7 +34,7 @@ type HashCode = uint64 // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType sqltypes.Type) (HashCode, error) { - e, err := valueToEvalCast(v, coerceType) + e, err := valueToEvalCast(v, coerceType, collation) if err != nil { return 0, err } @@ -93,10 +93,10 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat f = float64(uval) case v.IsFloat() || v.IsDecimal(): f, err = v.ToFloat64() - case v.IsQuoted(): + case v.IsText(), v.IsBinary(): f, _ = fastparse.ParseFloat64(v.RawStr()) default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", v.Type()) + return nullsafeHashcode128Default(hash, v, collation, coerceTo) } if err != nil { return err @@ -107,10 +107,12 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case sqltypes.IsSigned(coerceTo): var i int64 var err error + var neg bool switch { case v.IsSigned(): i, err = v.ToInt64() + neg = i < 0 case v.IsUnsigned(): var uval uint64 uval, err = v.ToUint64() @@ -122,7 +124,8 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat return ErrHashCoercionIsNotExact } i = int64(fval) - case v.IsQuoted(): + neg = i < 0 + case v.IsText(), v.IsBinary(): i, err = fastparse.ParseInt64(v.RawStr(), 10) if err != nil { fval, _ := fastparse.ParseFloat64(v.RawStr()) @@ -131,13 +134,14 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat } i, err = int64(fval), nil } + neg = i < 0 default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", v.Type()) + return nullsafeHashcode128Default(hash, v, collation, coerceTo) } if err != nil { return err } - if i < 0 { + if neg { hash.Write16(hashPrefixIntegralNegative) } else { hash.Write16(hashPrefixIntegralPositive) @@ -147,11 +151,12 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case sqltypes.IsUnsigned(coerceTo): var u uint64 var err error - + var neg bool switch { case v.IsSigned(): var ival int64 ival, err = v.ToInt64() + neg = ival < 0 u = uint64(ival) case v.IsUnsigned(): u, err = v.ToUint64() @@ -161,23 +166,29 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat if fval != math.Trunc(fval) || fval < 0 { return ErrHashCoercionIsNotExact } + neg = fval < 0 u = uint64(fval) - case v.IsQuoted(): + case v.IsText(), v.IsBinary(): u, err = fastparse.ParseUint64(v.RawStr(), 10) if err != nil { fval, _ := fastparse.ParseFloat64(v.RawStr()) if fval != math.Trunc(fval) || fval < 0 { return ErrHashCoercionIsNotExact } + neg = fval < 0 u, err = uint64(fval), nil } default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", v.Type()) + return nullsafeHashcode128Default(hash, v, collation, coerceTo) } if err != nil { return err } - hash.Write16(hashPrefixIntegralPositive) + if neg { + hash.Write16(hashPrefixIntegralNegative) + } else { + hash.Write16(hashPrefixIntegralPositive) + } hash.Write64(u) case sqltypes.IsBinary(coerceTo): @@ -211,13 +222,30 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a decimal: %v", v) + return nullsafeHashcode128Default(hash, v, collation, coerceTo) } hash.Write16(hashPrefixDecimal) dec.Hash(hash) - default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", v.Type()) + return nullsafeHashcode128Default(hash, v, collation, coerceTo) } return nil } + +func nullsafeHashcode128Default(hash *vthash.Hasher, v sqltypes.Value, collation collations.ID, coerceTo sqltypes.Type) error { + // Slow path to handle all other types. This uses the generic + // logic for value casting to ensure we match MySQL here. + e, err := valueToEvalCast(v, coerceTo, collation) + if err != nil { + return err + } + switch e := e.(type) { + case nil: + hash.Write16(hashPrefixNil) + return nil + case hashable: + e.Hash(hash) + return nil + } + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", coerceTo) +} diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 55200eb0619..96b7dbac424 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -25,6 +25,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vthash" "vitess.io/vitess/go/mysql/collations" @@ -37,24 +40,33 @@ func TestHashCodes(t *testing.T) { equal bool err error }{ - {sqltypes.NewInt64(-1), sqltypes.NewUint64(^uint64(0)), true, nil}, - {sqltypes.NewUint64(^uint64(0)), sqltypes.NewInt64(-1), true, nil}, {sqltypes.NewFloat64(-1), sqltypes.NewVarChar("-1"), true, nil}, {sqltypes.NewDecimal("-1"), sqltypes.NewVarChar("-1"), true, nil}, + {sqltypes.NewDate("2000-01-01"), sqltypes.NewInt64(20000101), true, nil}, + {sqltypes.NewDatetime("2000-01-01 11:22:33"), sqltypes.NewInt64(20000101112233), true, nil}, + {sqltypes.NewTime("11:22:33"), sqltypes.NewInt64(112233), true, nil}, + {sqltypes.NewInt64(20000101), sqltypes.NewDate("2000-01-01"), true, nil}, + {sqltypes.NewInt64(20000101112233), sqltypes.NewDatetime("2000-01-01 11:22:33"), true, nil}, + {sqltypes.NewInt64(112233), sqltypes.NewTime("11:22:33"), true, nil}, + {sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"2": "bar", "1": "foo"}`)), true, nil}, + {sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), sqltypes.NewVarChar(`{"2": "bar", "1": "foo"}`), false, nil}, + {sqltypes.NewVarChar(`{"2": "bar", "1": "foo"}`), sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), false, nil}, } for _, tc := range cases { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.CollationUtf8mb4ID) - require.NoError(t, err) - require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) + t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.CollationUtf8mb4ID) + require.NoError(t, err) + require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) - require.NoError(t, err) + h1, err := NullsafeHashcode(tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + require.NoError(t, err) - h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) - require.ErrorIs(t, err, tc.err) + h2, err := NullsafeHashcode(tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + require.ErrorIs(t, err, tc.err) - assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) + assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) + }) } } @@ -70,7 +82,7 @@ func TestHashCodesRandom(t *testing.T) { v1, v2 := randomValues() cmp, err := NullsafeCompare(v1, v2, collation) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) - typ, err := CoerceTo(v1.Type(), v2.Type()) + typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) hash1, err := NullsafeHashcode(v1, collation, typ) @@ -107,32 +119,43 @@ func TestHashCodes128(t *testing.T) { equal bool err error }{ - {sqltypes.NewInt64(-1), sqltypes.NewUint64(^uint64(0)), true, nil}, - {sqltypes.NewUint64(^uint64(0)), sqltypes.NewInt64(-1), true, nil}, + {sqltypes.NewInt64(-1), sqltypes.NewUint64(^uint64(0)), false, nil}, + {sqltypes.NewUint64(^uint64(0)), sqltypes.NewInt64(-1), false, nil}, {sqltypes.NewInt64(-1), sqltypes.NewVarChar("-1"), true, nil}, {sqltypes.NewVarChar("-1"), sqltypes.NewInt64(-1), true, nil}, {sqltypes.NewInt64(23), sqltypes.NewFloat64(23.0), true, nil}, {sqltypes.NewInt64(23), sqltypes.NewFloat64(23.1), false, ErrHashCoercionIsNotExact}, {sqltypes.NewUint64(^uint64(0)), sqltypes.NewFloat64(-1.0), false, ErrHashCoercionIsNotExact}, {sqltypes.NewUint64(42), sqltypes.NewFloat64(42.0), true, nil}, + {sqltypes.NewDate("2000-01-01"), sqltypes.NewInt64(20000101), true, nil}, + {sqltypes.NewDatetime("2000-01-01 11:22:33"), sqltypes.NewInt64(20000101112233), true, nil}, + {sqltypes.NewTime("11:22:33"), sqltypes.NewInt64(112233), true, nil}, + {sqltypes.NewInt64(20000101), sqltypes.NewDate("2000-01-01"), true, nil}, + {sqltypes.NewInt64(20000101112233), sqltypes.NewDatetime("2000-01-01 11:22:33"), true, nil}, + {sqltypes.NewInt64(112233), sqltypes.NewTime("11:22:33"), true, nil}, + {sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"2": "bar", "1": "foo"}`)), true, nil}, + {sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), sqltypes.NewVarChar(`{"2": "bar", "1": "foo"}`), false, nil}, + {sqltypes.NewVarChar(`{"2": "bar", "1": "foo"}`), sqltypes.MakeTrusted(sqltypes.TypeJSON, []byte(`{"1": "foo", "2": "bar"}`)), false, nil}, } for _, tc := range cases { - cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.CollationUtf8mb4ID) - require.NoError(t, err) - require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) - - hasher1 := vthash.New() - err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) - require.NoError(t, err) - - hasher2 := vthash.New() - err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) - require.ErrorIs(t, err, tc.err) - - h1 := hasher1.Sum128() - h2 := hasher2.Sum128() - assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) + t.Run(fmt.Sprintf("%v %s %v", tc.static, equality(tc.equal).Operator(), tc.dynamic), func(t *testing.T) { + cmp, err := NullsafeCompare(tc.static, tc.dynamic, collations.CollationUtf8mb4ID) + require.NoError(t, err) + require.Equalf(t, tc.equal, cmp == 0, "got %v %s %v (expected %s)", tc.static, equality(cmp == 0).Operator(), tc.dynamic, equality(tc.equal)) + + hasher1 := vthash.New() + err = NullsafeHashcode128(&hasher1, tc.static, collations.CollationUtf8mb4ID, tc.static.Type()) + require.NoError(t, err) + + hasher2 := vthash.New() + err = NullsafeHashcode128(&hasher2, tc.dynamic, collations.CollationUtf8mb4ID, tc.static.Type()) + require.ErrorIs(t, err, tc.err) + + h1 := hasher1.Sum128() + h2 := hasher2.Sum128() + assert.Equalf(t, tc.equal, h1 == h2, "HASH(%v) %s HASH(%v) (expected %s)", tc.static, equality(h1 == h2).Operator(), tc.dynamic, equality(tc.equal)) + }) } } @@ -148,7 +171,7 @@ func TestHashCodesRandom128(t *testing.T) { v1, v2 := randomValues() cmp, err := NullsafeCompare(v1, v2, collation) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) - typ, err := CoerceTo(v1.Type(), v2.Type()) + typ, err := coerceTo(v1.Type(), v2.Type()) require.NoError(t, err) hasher1 := vthash.New() @@ -207,6 +230,10 @@ var randomGenerators = []func() sqltypes.Value{ randomVarChar, randomComplexVarChar, randomDecimal, + randomDate, + randomDatetime, + randomTimestamp, + randomTime, } func randomValue() sqltypes.Value { @@ -214,15 +241,73 @@ func randomValue() sqltypes.Value { return randomGenerators[r]() } -func randomNull() sqltypes.Value { return sqltypes.NULL } -func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } -func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } -func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } -func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } -func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } -func randomDecimal() sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", rand.Int63())) } -func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) } +func randTime() time.Time { + min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() + max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() + delta := max - min + + sec := rand.Int63n(delta) + min + return time.Unix(sec, 0) +} + +func randomNull() sqltypes.Value { return sqltypes.NULL } +func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } +func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } +func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } +func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } +func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } +func randomDecimal() sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", rand.Int63())) } +func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) } +func randomDate() sqltypes.Value { return sqltypes.NewDate(randTime().Format(time.DateOnly)) } +func randomDatetime() sqltypes.Value { return sqltypes.NewDatetime(randTime().Format(time.DateTime)) } +func randomTimestamp() sqltypes.Value { return sqltypes.NewTimestamp(randTime().Format(time.DateTime)) } +func randomTime() sqltypes.Value { return sqltypes.NewTime(randTime().Format(time.TimeOnly)) } func randomComplexVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf(" \t %f apa", float64(rand.Intn(1000))*1.10)) } + +// coerceTo takes two input types, and decides how they should be coerced before compared +func coerceTo(v1, v2 sqltypes.Type) (sqltypes.Type, error) { + if v1 == v2 { + return v1, nil + } + if sqltypes.IsNull(v1) || sqltypes.IsNull(v2) { + return sqltypes.Null, nil + } + if (sqltypes.IsText(v1) || sqltypes.IsBinary(v1)) && (sqltypes.IsText(v2) || sqltypes.IsBinary(v2)) { + return sqltypes.VarChar, nil + } + if sqltypes.IsDateOrTime(v1) { + return v1, nil + } + if sqltypes.IsDateOrTime(v2) { + return v2, nil + } + + if sqltypes.IsNumber(v1) || sqltypes.IsNumber(v2) { + switch { + case sqltypes.IsText(v1) || sqltypes.IsBinary(v1) || sqltypes.IsText(v2) || sqltypes.IsBinary(v2): + return sqltypes.Float64, nil + case sqltypes.IsFloat(v2) || v2 == sqltypes.Decimal || sqltypes.IsFloat(v1) || v1 == sqltypes.Decimal: + return sqltypes.Float64, nil + case sqltypes.IsSigned(v1): + switch { + case sqltypes.IsUnsigned(v2): + return sqltypes.Uint64, nil + case sqltypes.IsSigned(v2): + return sqltypes.Int64, nil + default: + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) + } + case sqltypes.IsUnsigned(v1): + switch { + case sqltypes.IsSigned(v2) || sqltypes.IsUnsigned(v2): + return sqltypes.Uint64, nil + default: + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) + } + } + } + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) +} diff --git a/go/vt/vtgate/evalengine/api_types.go b/go/vt/vtgate/evalengine/api_types.go index 734f35b35aa..c0334da5784 100644 --- a/go/vt/vtgate/evalengine/api_types.go +++ b/go/vt/vtgate/evalengine/api_types.go @@ -24,44 +24,6 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -// CoerceTo takes two input types, and decides how they should be coerced before compared -func CoerceTo(v1, v2 sqltypes.Type) (sqltypes.Type, error) { - if v1 == v2 { - return v1, nil - } - if sqltypes.IsNull(v1) || sqltypes.IsNull(v2) { - return sqltypes.Null, nil - } - if (sqltypes.IsText(v1) || sqltypes.IsBinary(v1)) && (sqltypes.IsText(v2) || sqltypes.IsBinary(v2)) { - return sqltypes.VarChar, nil - } - if sqltypes.IsNumber(v1) || sqltypes.IsNumber(v2) { - switch { - case sqltypes.IsText(v1) || sqltypes.IsBinary(v1) || sqltypes.IsText(v2) || sqltypes.IsBinary(v2): - return sqltypes.Float64, nil - case sqltypes.IsFloat(v2) || v2 == sqltypes.Decimal || sqltypes.IsFloat(v1) || v1 == sqltypes.Decimal: - return sqltypes.Float64, nil - case sqltypes.IsSigned(v1): - switch { - case sqltypes.IsUnsigned(v2): - return sqltypes.Uint64, nil - case sqltypes.IsSigned(v2): - return sqltypes.Int64, nil - default: - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) - } - case sqltypes.IsUnsigned(v1): - switch { - case sqltypes.IsSigned(v2) || sqltypes.IsUnsigned(v2): - return sqltypes.Uint64, nil - default: - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) - } - } - } - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) -} - // Cast converts a Value to the target type. func Cast(v sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { if v.Type() == typ || v.IsNull() { diff --git a/go/vt/vtgate/evalengine/compare.go b/go/vt/vtgate/evalengine/compare.go index deee5fdb520..ef0cafb6127 100644 --- a/go/vt/vtgate/evalengine/compare.go +++ b/go/vt/vtgate/evalengine/compare.go @@ -139,7 +139,7 @@ func compareStrings(l, r eval) (int, error) { } collation := col.Collation.Get() if collation == nil { - panic("unknown collation after coercion") + return 0, vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: %d)", col.Collation) } return collation.Collate(l.ToRawBytes(), r.ToRawBytes(), false), nil } diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index d11bba24dde..c02efaac534 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -17,7 +17,6 @@ limitations under the License. package evalengine import ( - "fmt" "strconv" "unicode/utf8" @@ -199,14 +198,18 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) { return evalToInt64(e), nil case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: return evalToInt64(e).toUint64(), nil - case sqltypes.Date, sqltypes.Datetime, sqltypes.Year, sqltypes.TypeJSON, sqltypes.Time, sqltypes.Bit: - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String()) + case sqltypes.Date: + return evalToDate(e), nil + case sqltypes.Datetime, sqltypes.Timestamp: + return evalToDateTime(e, -1), nil + case sqltypes.Time: + return evalToTime(e, -1), nil default: - panic(fmt.Sprintf("BUG: emitted unknown type: %s", typ)) + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String()) } } -func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type) (eval, error) { +func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.ID) (eval, error) { switch { case typ == sqltypes.Null: return nil, nil @@ -226,7 +229,16 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type) (eval, error) { fval, _ := fastparse.ParseFloat64(v.RawStr()) return newEvalFloat(fval), nil default: - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a float: %v", v) + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + f, _ := evalToFloat(e) + return f, nil } case sqltypes.IsDecimal(typ): @@ -248,7 +260,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type) (eval, error) { fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a decimal: %v", v) + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + return evalToDecimal(e, 0, 0), nil } return &evalDecimal{dec: dec, length: -dec.Exponent()}, nil @@ -260,8 +280,19 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type) (eval, error) { case v.IsUnsigned(): uval, err := v.ToUint64() return newEvalInt64(int64(uval)), err + case v.IsText() || v.IsBinary(): + i, err := fastparse.ParseInt64(v.RawStr(), 10) + return newEvalInt64(i), err default: - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a signed int: %v", v) + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + return evalToInt64(e), nil } case sqltypes.IsUnsigned(typ): @@ -272,18 +303,99 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type) (eval, error) { case v.IsUnsigned(): uval, err := v.ToUint64() return newEvalUint64(uval), err + case v.IsText() || v.IsBinary(): + u, err := fastparse.ParseUint64(v.RawStr(), 10) + return newEvalUint64(u), err default: - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a unsigned int: %v", v) + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + i := evalToInt64(e) + return newEvalUint64(uint64(i.i)), nil } case sqltypes.IsText(typ) || sqltypes.IsBinary(typ): switch { case v.IsText() || v.IsBinary(): - // TODO: collation - return newEvalRaw(v.Type(), v.Raw(), collationBinary), nil + return newEvalRaw(v.Type(), v.Raw(), collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceImplicit, + Repertoire: collations.RepertoireUnicode, + }), nil + case sqltypes.IsText(typ): + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + return evalToVarchar(e, collation, true) default: - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a text: %v", v) + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + return evalToBinary(e), nil + } + + case typ == sqltypes.TypeJSON: + return json.NewFromSQL(v) + case typ == sqltypes.Date: + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + // Separate return here to avoid nil wrapped in interface type + d := evalToDate(e) + if d == nil { + return nil, nil + } + return d, nil + case typ == sqltypes.Datetime || typ == sqltypes.Timestamp: + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + // Separate return here to avoid nil wrapped in interface type + dt := evalToDateTime(e, -1) + if dt == nil { + return nil, nil + } + return dt, nil + case typ == sqltypes.Time: + e, err := valueToEval(v, collations.TypedCollation{ + Collation: collation, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + }) + if err != nil { + return nil, err + } + // Separate return here to avoid nil wrapped in interface type + t := evalToTime(e, -1) + if t == nil { + return nil, nil } + return t, nil } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) } From 0e8661b7684e9c342b03b340e8dc927765ab7766 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 13:53:31 +0200 Subject: [PATCH 7/8] evalengine: Use available function for collation Signed-off-by: Dirkjan Bussink --- go/mysql/json/weights.go | 16 +++++++ go/mysql/json/weights_test.go | 16 +++++++ go/vt/vtgate/evalengine/api_literal.go | 30 ++++--------- go/vt/vtgate/evalengine/collation.go | 27 ++++++++++++ go/vt/vtgate/evalengine/eval.go | 60 +++++--------------------- go/vt/vtgate/evalengine/expr_bvar.go | 12 +----- go/vt/vtgate/evalengine/translate.go | 8 ---- 7 files changed, 80 insertions(+), 89 deletions(-) create mode 100644 go/vt/vtgate/evalengine/collation.go diff --git a/go/mysql/json/weights.go b/go/mysql/json/weights.go index 6b65465b239..262fe96e9cf 100644 --- a/go/mysql/json/weights.go +++ b/go/mysql/json/weights.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package json import ( diff --git a/go/mysql/json/weights_test.go b/go/mysql/json/weights_test.go index 9442a64aa06..9bbcd548e50 100644 --- a/go/mysql/json/weights_test.go +++ b/go/mysql/json/weights_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package json import ( diff --git a/go/vt/vtgate/evalengine/api_literal.go b/go/vt/vtgate/evalengine/api_literal.go index 271798a5c7f..00c033d1d2f 100644 --- a/go/vt/vtgate/evalengine/api_literal.go +++ b/go/vt/vtgate/evalengine/api_literal.go @@ -194,39 +194,27 @@ func NewLiteralBinaryFromBit(val []byte) (*Literal, error) { // NewBindVar returns a bind variable func NewBindVar(key string, typ sqltypes.Type, col collations.ID) *BindVariable { return &BindVariable{ - Key: key, - Type: typ, - Collation: collations.TypedCollation{ - Collation: col, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }, + Key: key, + Type: typ, + Collation: defaultCoercionCollation(col), } } // NewBindVarTuple returns a bind variable containing a tuple func NewBindVarTuple(key string, col collations.ID) *BindVariable { return &BindVariable{ - Key: key, - Type: sqltypes.Tuple, - Collation: collations.TypedCollation{ - Collation: col, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }, + Key: key, + Type: sqltypes.Tuple, + Collation: defaultCoercionCollation(col), } } // NewColumn returns a column expression func NewColumn(offset int, typ sqltypes.Type, col collations.ID) *Column { return &Column{ - Offset: offset, - Type: typ, - Collation: collations.TypedCollation{ - Collation: col, - Coercibility: collations.CoerceImplicit, - Repertoire: collations.RepertoireUnicode, - }, + Offset: offset, + Type: typ, + Collation: defaultCoercionCollation(col), } } diff --git a/go/vt/vtgate/evalengine/collation.go b/go/vt/vtgate/evalengine/collation.go new file mode 100644 index 00000000000..9d53a9d8ea9 --- /dev/null +++ b/go/vt/vtgate/evalengine/collation.go @@ -0,0 +1,27 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import "vitess.io/vitess/go/mysql/collations" + +func defaultCoercionCollation(id collations.ID) collations.TypedCollation { + return collations.TypedCollation{ + Collation: id, + Coercibility: collations.CoerceCoercible, + Repertoire: collations.RepertoireUnicode, + } +} diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index c02efaac534..bbd5a9809ca 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -229,11 +229,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) return newEvalFloat(fval), nil default: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -260,11 +256,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I fval, _ := fastparse.ParseFloat64(v.RawStr()) dec = decimal.NewFromFloat(fval) default: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -284,11 +276,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I i, err := fastparse.ParseInt64(v.RawStr(), 10) return newEvalInt64(i), err default: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -307,11 +295,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I u, err := fastparse.ParseUint64(v.RawStr(), 10) return newEvalUint64(u), err default: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -322,27 +306,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case sqltypes.IsText(typ) || sqltypes.IsBinary(typ): switch { case v.IsText() || v.IsBinary(): - return newEvalRaw(v.Type(), v.Raw(), collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceImplicit, - Repertoire: collations.RepertoireUnicode, - }), nil + return newEvalRaw(v.Type(), v.Raw(), defaultCoercionCollation(collation)), nil case sqltypes.IsText(typ): - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } return evalToVarchar(e, collation, true) default: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -352,11 +324,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I case typ == sqltypes.TypeJSON: return json.NewFromSQL(v) case typ == sqltypes.Date: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -367,11 +335,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return d, nil case typ == sqltypes.Datetime || typ == sqltypes.Timestamp: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } @@ -382,11 +346,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I } return dt, nil case typ == sqltypes.Time: - e, err := valueToEval(v, collations.TypedCollation{ - Collation: collation, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(v, defaultCoercionCollation(collation)) if err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 387438c4310..9172f8abc3c 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -57,11 +57,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { tuple := make([]eval, 0, len(bvar.Values)) for _, value := range bvar.Values { - e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), collations.TypedCollation{ - Collation: collations.DefaultCollationForType(value.Type), - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.DefaultCollationForType(value.Type))) if err != nil { return nil, err } @@ -77,11 +73,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { if bv.typed() { typ = bv.Type } - return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), collations.TypedCollation{ - Collation: collations.DefaultCollationForType(typ), - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - }) + return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.DefaultCollationForType(typ))) } } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 998827db095..7a63899b318 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -183,14 +183,6 @@ func (ast *astCompiler) translateIsExpr(left sqlparser.Expr, op sqlparser.IsExpr }, nil } -func defaultCoercionCollation(id collations.ID) collations.TypedCollation { - return collations.TypedCollation{ - Collation: id, - Coercibility: collations.CoerceCoercible, - Repertoire: collations.RepertoireUnicode, - } -} - func (ast *astCompiler) translateBindVar(arg *sqlparser.Argument) (Expr, error) { bvar := NewBindVar(arg.Name, arg.Type, ast.cfg.Collation) From f2b110ac1418b4aec41597164006e83a46774b96 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Tue, 25 Jul 2023 14:50:46 +0200 Subject: [PATCH 8/8] evalengine: Add back fallback to binary types Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/api_compare_test.go | 90 +++++++++++++-------- go/vt/vtgate/evalengine/eval.go | 2 + go/vt/vtgate/evalengine/expr_compare.go | 16 ++++ 3 files changed, 73 insertions(+), 35 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_compare_test.go b/go/vt/vtgate/evalengine/api_compare_test.go index 70e83dda6be..e4fb5d38470 100644 --- a/go/vt/vtgate/evalengine/api_compare_test.go +++ b/go/vt/vtgate/evalengine/api_compare_test.go @@ -767,6 +767,18 @@ func TestCompareTime(t *testing.T) { out: &F, op: sqlparser.GreaterThanOp, row: []sqltypes.Value{sqltypes.NewTime("02:46:02"), sqltypes.NewTime("10:42:50")}, }, + { + name: "time is greater than time", + v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + out: &T, op: sqlparser.GreaterThanOp, + row: []sqltypes.Value{sqltypes.NewTime("101:14:35"), sqltypes.NewTime("13:01:38")}, + }, + { + name: "time is not greater than time", + v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), + out: &F, op: sqlparser.GreaterThanOp, + row: []sqltypes.Value{sqltypes.NewTime("24:46:02"), sqltypes.NewTime("101:42:50")}, + }, { name: "time is less than time", v1: NewColumn(0, sqltypes.Time, collations.CollationBinaryID), v2: NewColumn(1, sqltypes.Time, collations.CollationBinaryID), @@ -1100,42 +1112,50 @@ func TestNullsafeCompare(t *testing.T) { v1, v2 sqltypes.Value out int err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: 0, - }, { - // LHS null. - v1: NULL, - v2: NewInt64(1), - out: -1, - }, { - // RHS null. - v1: NewInt64(1), - v2: NULL, - out: 1, - }, { - // LHS Text - v1: TestValue(sqltypes.VarChar, "abcd"), - v2: TestValue(sqltypes.VarChar, "abcd"), - out: 0, - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(sqltypes.Float64, "0.0"), - v2: TestValue(sqltypes.VarChar, " 6736380880502626304.000000 aa"), - out: -1, - }} + }{ + { + v1: NULL, + v2: NULL, + out: 0, + }, + { + v1: NULL, + v2: NewInt64(1), + out: -1, + }, + { + v1: NewInt64(1), + v2: NULL, + out: 1, + }, + { + v1: TestValue(sqltypes.VarChar, "abcd"), + v2: TestValue(sqltypes.VarChar, "abcd"), + out: 0, + }, + { + v1: TestValue(sqltypes.Float64, "0.0"), + v2: TestValue(sqltypes.VarChar, " 6736380880502626304.000000 aa"), + out: -1, + }, + { + v1: TestValue(sqltypes.Enum, "foo"), + v2: TestValue(sqltypes.Enum, "bar"), + out: 1, + }, + } for _, tcase := range tcases { - got, err := NullsafeCompare(tcase.v1, tcase.v2, collation) - if tcase.err != nil { - require.EqualError(t, err, tcase.err.Error()) - continue - } - require.NoError(t, err) - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) - } + t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { + got, err := NullsafeCompare(tcase.v1, tcase.v2, collation) + if tcase.err != nil { + require.EqualError(t, err, tcase.err.Error()) + return + } + require.NoError(t, err) + if got != tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) + } + }) } } diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index bbd5a9809ca..fbc3cbca57d 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -434,6 +434,8 @@ func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eva var p json.Parser j, err := p.ParseBytes(value.Raw()) return j, wrap(err) + case fallbackBinary(tt): + return newEvalRaw(tt, value.Raw(), collation), nil default: return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %q %s", value, value.Type()) } diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index 3aca0cc1151..8e3b4ee322a 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -17,6 +17,8 @@ limitations under the License. package evalengine import ( + "bytes" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -233,6 +235,8 @@ func evalCompare(left, right eval) (comp int, err error) { return compareJSON(left, right) case lt == sqltypes.Tuple || rt == sqltypes.Tuple: return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: evalCompare: tuple comparison should be handled early") + case lt == rt && fallbackBinary(lt): + return bytes.Compare(left.ToRawBytes(), right.ToRawBytes()), nil default: // Quoting MySQL Docs: // @@ -247,6 +251,18 @@ func evalCompare(left, right eval) (comp int, err error) { } } +// fallbackBinary compares two values of the same type using the fallback binary comparison. +// This is for types we don't yet properly support otherwise but do end up being used +// for comparisons, for example when using vdiff. +// TODO: Clean this up as we add more properly supported types and comparisons. +func fallbackBinary(t sqltypes.Type) bool { + switch t { + case sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry: + return true + } + return false +} + func evalCompareTuplesNullSafe(left, right []eval) (bool, error) { if len(left) != len(right) { panic("did not typecheck cardinality")