diff --git a/sql/collations.go b/sql/collations.go index 5febdb8836..6cba60df71 100644 --- a/sql/collations.go +++ b/sql/collations.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "io" "unicode/utf8" "github.com/cespare/xxhash" @@ -813,16 +814,15 @@ func (c CollationID) Collation() Collation { return collationArray[c] } -// HashToUint returns a hash of the given decoded string based on the collation. Collations take each rune's weight into -// account, therefore two strings with technically different contents may hash to the same value, as the collation +// WriteWeightString writes the weights of each codepoint in the string into the given io.Writer. +// Two strings with technically different contents may generate the same WeightString to the same value, as the collation // considers them the same string. -func (c CollationID) HashToUint(str string) (uint64, error) { - hash := xxhash.New() +func (c CollationID) WriteWeightString(hash io.Writer, str string) error { if c == Collation_binary { // Binary strings are almost always malformed due to their usage, therefore we treat them differently _, err := hash.Write(encodings.StringToBytes(str)) if err != nil { - return 0, err + return err } } else { getRuneWeight := collationArray[c].Sorter @@ -830,7 +830,7 @@ func (c CollationID) HashToUint(str string) (uint64, error) { // All strings (should) have been decoded at this point, so we can rely on Go's internal string encoding runeFromString, strRead := utf8.DecodeRuneInString(str) if strRead == 0 || strRead == utf8.RuneError { - return 0, ErrCollationMalformedString.New("hashing") + return ErrCollationMalformedString.New("hashing") } runeWeight := getRuneWeight(runeFromString) _, err := hash.Write([]byte{ @@ -840,11 +840,23 @@ func (c CollationID) HashToUint(str string) (uint64, error) { byte(runeWeight >> 24), }) if err != nil { - return 0, err + return err } str = str[strRead:] } } + return nil +} + +// HashToUint returns a hash of the given decoded string based on the collation. Collations take each rune's weight into +// account, therefore two strings with technically different contents may hash to the same value, as the collation +// considers them the same string. +func (c CollationID) HashToUint(str string) (uint64, error) { + hash := xxhash.New() + err := c.WriteWeightString(hash, str) + if err != nil { + return 0, err + } return hash.Sum64(), nil } diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index 91d0ee20b0..7e33130f54 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -420,12 +420,25 @@ func groupingKey( row sql.Row, ) (uint64, error) { hash := xxhash.New() - for _, expr := range exprs { + for i, expr := range exprs { v, err := expr.Eval(ctx, row) if err != nil { return 0, err } - _, err = hash.Write(([]byte)(fmt.Sprintf("%v,", v))) + + if i > 0 { + // separate each expression in the grouping key with a nil byte + if _, err = hash.Write([]byte{0}); err != nil { + return 0, err + } + } + + switch t := expr.Type().(type) { + case sql.StringType: + err = t.Collation().WriteWeightString(hash, v.(string)) + default: + _, err = fmt.Fprintf(hash, "%v", v) + } if err != nil { return 0, err } diff --git a/sql/plan/group_by_test.go b/sql/plan/group_by_test.go index f9b996e39d..b68aee737d 100644 --- a/sql/plan/group_by_test.go +++ b/sql/plan/group_by_test.go @@ -17,6 +17,7 @@ package plan import ( "testing" + "github.com/dolthub/vitess/go/vt/proto/query" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/memory" @@ -157,6 +158,86 @@ func TestGroupByAggregationGrouping(t *testing.T) { require.Equal(expected, rows) } +func TestGroupByCollations(t *testing.T) { + tString := sql.MustCreateString(query.Type_VARCHAR, 255, sql.Collation_utf8mb4_0900_ai_ci) + tEnum := sql.MustCreateEnumType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci) + tSet := sql.MustCreateSetType([]string{"col1_1", "col1_2"}, sql.Collation_utf8mb4_0900_ai_ci) + + var testCases = []struct { + Type sql.Type + Value func(t *testing.T, v string) any + }{ + { + Type: tString, + Value: func(t *testing.T, v string) any { return v }, + }, + { + Type: tEnum, + Value: func(t *testing.T, v string) any { + conv, err := tEnum.Convert(v) + require.NoError(t, err) + return conv + }, + }, + { + Type: tSet, + Value: func(t *testing.T, v string) any { + conv, err := tSet.Convert(v) + require.NoError(t, err) + return conv + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.Type.String(), func(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + childSchema := sql.Schema{ + {Name: "col1", Type: tc.Type}, + {Name: "col2", Type: sql.Int64}, + } + + child := memory.NewTable("test", sql.NewPrimaryKeySchema(childSchema), nil) + + rows := []sql.Row{ + sql.NewRow(tc.Value(t, "col1_1"), int64(1111)), + sql.NewRow(tc.Value(t, "Col1_1"), int64(1111)), + sql.NewRow(tc.Value(t, "col1_2"), int64(4444)), + sql.NewRow(tc.Value(t, "col1_1"), int64(1111)), + sql.NewRow(tc.Value(t, "Col1_2"), int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } + + p := NewGroupBy( + []sql.Expression{ + aggregation.NewSum( + expression.NewGetFieldWithTable(1, sql.Int64, "test", "col2", false), + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, tc.Type, "test", "col1", false), + }, + NewResolvedTable(child, nil, nil), + ) + + rows, err := sql.NodeToRows(ctx, p) + require.NoError(err) + + expected := []sql.Row{ + {float64(3333)}, + {float64(8888)}, + } + + require.Equal(expected, rows) + }) + } +} + func BenchmarkGroupBy(b *testing.B) { table := benchmarkTable(b)