From 7a213c2008c8df428dc7d6858a9c0163a578bcf4 Mon Sep 17 00:00:00 2001 From: scott lewis <33612882+dk-lockdown@users.noreply.github.com> Date: Sat, 18 Jun 2022 11:43:25 +0800 Subject: [PATCH] feat: support multiple columns order by (#158) --- pkg/plan/result.go | 277 +++++++++++++++++++++++--------------- test/shd/sharding_test.go | 24 +++- 2 files changed, 193 insertions(+), 108 deletions(-) diff --git a/pkg/plan/result.go b/pkg/plan/result.go index fdde22a..0041e16 100644 --- a/pkg/plan/result.go +++ b/pkg/plan/result.go @@ -52,11 +52,17 @@ func (r ResultWithErrs) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +type OrderField struct { + asc bool + fieldValueIndex int + value interface{} +} + type OrderByCell struct { - Index int - Val interface{} - Next bool - Row proto.Row + orderField []*OrderField + resultIndex int + next bool + row proto.Row } type OrderByCells []*OrderByCell @@ -64,39 +70,94 @@ type OrderByCells []*OrderByCell func (c OrderByCells) Len() int { return len(c) } func (c OrderByCells) Less(i, j int) bool { - switch val1 := c[i].Val.(type) { + var ( + index = 0 + res int + ) + for index < len(c[i].orderField) { + isAsc := c[i].orderField[index].asc + if isAsc { + res = compare(c[i].orderField[index].value, c[j].orderField[index].value) + } else { + res = compare(c[j].orderField[index].value, c[i].orderField[index].value) + } + if res != 0 { + return res > 0 + } + index++ + } + return res > 0 +} + +func compare(val1, val2 interface{}) int { + if val1 == nil && val2 == nil { + return 0 + } else if val1 == nil { + return -1 + } else if val2 == nil { + return 1 + } + switch v1 := val1.(type) { case int64: - val2 := c[j].Val.(int64) - if val1 < val2 { - return true + v2 := val2.(int64) + if v1 < v2 { + return -1 + } else if v1 == v2 { + return 0 + } + return 1 + case uint64: + v2 := val2.(uint64) + if v1 < v2 { + return -1 + } else if v1 == v2 { + return 0 } + return 1 case float32: - val2 := c[j].Val.(float32) - if val1 < val2 { - return true + v2 := val2.(float32) + if v1 < v2 { + return -1 + } else if v1 == v2 { + return 0 } + return 1 case float64: - val2 := c[j].Val.(float64) - if val1 < val2 { - return true + v2 := val2.(float64) + if v1 < v2 { + return -1 + } else if v1 == v2 { + return 0 } + return 1 case string: - val2 := c[j].Val.(string) - if val1 < val2 { - return true + v2 := val2.(string) + if v1 < v2 { + return -1 + } else if v1 == v2 { + return 0 } + return 1 case []uint8: - val2 := c[j].Val.([]uint8) - if string(val1) < string(val2) { - return true + v2 := val2.([]uint8) + if string(v1) < string(v2) { + return -1 + } else if string(v1) == string(v2) { + return 0 } + return 1 case time.Time: - val2 := c[j].Val.(time.Time) - if val1.Before(val2) { - return true + v2 := val2.(time.Time) + if v1.Before(v2) { + return -1 + } else if v1.Equal(v2) { + return 0 } + return 1 + default: + log.Panicf("unsupported value type, val1: %s, val2: %s", val1, val2) } - return false + return 0 } func (c OrderByCells) Swap(i, j int) { @@ -114,7 +175,7 @@ func mergeResult(ctx context.Context, results []*ResultWithErr, orderBy *ast.Ord return mergeResultWithOrderBy(ctx, results, orderBy) } if limit != nil { - log.Fatal("unsupported limit without order by") + log.Panic("unsupported limit without order by") } return nil, 0 } @@ -155,41 +216,29 @@ func mergeResultWithOutOrderByAndLimit(ctx context.Context, results []*ResultWit func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithErr, orderBy *ast.OrderByClause, limit *Limit) (*mysql.MergeResult, uint16) { var ( - sb strings.Builder - fields []*mysql.Field - orderByField string - orderByIndex int - warning uint16 = 0 - desc bool - offset int64 - count int64 - rowCount int64 - commandType = proto.CommandType(ctx) - rows = make([]proto.Row, 0) - cells = make([]*OrderByCell, len(results)) - endResult = make([]bool, len(results)) + fields []*mysql.Field + orderByFields []*OrderField + warning uint16 = 0 + offset int64 + count int64 + rowCount int64 + commandType = proto.CommandType(ctx) + rows = make([]proto.Row, 0) + cells = make([]*OrderByCell, len(results)) + endResult = make([]bool, len(results)) ) - if len(orderBy.Items) > 0 { - // todo built-in lightweight sql engine sorting - } fields = results[0].Result.(*mysql.Result).Fields - restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) - if err := orderBy.Items[0].Expr.Restore(restoreCtx); err != nil { - log.Fatal(err) - } - orderByField = sb.String() - orderByIndex = getOrderByFieldIndex(orderByField, fields) - desc = orderBy.Items[0].Desc + orderByFields = castOrderByItemsToOrderField(orderBy, fields) offset = limit.Offset count = limit.Count rowCount = 0 for { pop := 0 for i, rlt := range results { - if cells[i] != nil && !cells[i].Next { + if cells[i] != nil && !cells[i].next { pop += 1 } - if (cells[i] == nil || cells[i].Next) && !endResult[i] { + if (cells[i] == nil || cells[i].next) && !endResult[i] { result := rlt.Result.(*mysql.Result) row, err := result.Rows.Next() if err != nil { @@ -201,14 +250,17 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithEr textRow := &mysql.TextRow{Row: row} values, err := textRow.Decode() if err != nil { - log.Fatal(err) + log.Panic(err) + } + for _, of := range orderByFields { + of.value = values[of.fieldValueIndex].Val } - value := values[orderByIndex] + cells[i] = &OrderByCell{ - Index: i, - Val: value.Val, - Next: false, - Row: textRow, + orderField: orderByFields, + resultIndex: i, + next: false, + row: textRow, } } else { binaryRow := &mysql.BinaryRow{Row: row} @@ -216,12 +268,15 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithEr if err != nil { log.Fatal(err) } - value := values[orderByIndex] + for _, of := range orderByFields { + of.value = values[of.fieldValueIndex].Val + } + cells[i] = &OrderByCell{ - Index: i, - Val: value.Val, - Next: false, - Row: binaryRow, + orderField: orderByFields, + resultIndex: i, + next: false, + row: binaryRow, } } } @@ -229,11 +284,11 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithEr if pop == 0 { break } - cell := compareOrderByCells(cells, desc) + cell := compareOrderByCells(cells) rowCount += 1 - cells[cell.Index].Next = true + cells[cell.resultIndex].next = true if rowCount > offset { - rows = append(rows, cell.Row) + rows = append(rows, cell.row) if int64(len(rows)) == count { break } @@ -254,35 +309,26 @@ func mergeResultWithOrderByAndLimit(ctx context.Context, results []*ResultWithEr func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, orderBy *ast.OrderByClause) (*mysql.MergeResult, uint16) { var ( - sb strings.Builder - fields []*mysql.Field - orderByField string - orderByIndex int - warning uint16 = 0 - desc bool - commandType = proto.CommandType(ctx) - rows = make([]proto.Row, 0) - cells = make([]*OrderByCell, len(results)) - endResult = make([]bool, len(results)) + fields []*mysql.Field + orderByFields []*OrderField + warning uint16 = 0 + commandType = proto.CommandType(ctx) + // result rows + rows = make([]proto.Row, 0) + // OrderBy compare + cells = make([]*OrderByCell, len(results)) + // Record whether mysql.Result has been traversed + endResult = make([]bool, len(results)) ) - if len(orderBy.Items) > 0 { - // todo built-in lightweight sql engine sorting - } fields = results[0].Result.(*mysql.Result).Fields - restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) - if err := orderBy.Items[0].Expr.Restore(restoreCtx); err != nil { - log.Fatal(err) - } - orderByField = sb.String() - orderByIndex = getOrderByFieldIndex(orderByField, fields) - desc = orderBy.Items[0].Desc + orderByFields = castOrderByItemsToOrderField(orderBy, fields) for { pop := 0 for i, rlt := range results { - if cells[i] != nil && !cells[i].Next { + if cells[i] != nil && !cells[i].next { pop += 1 } - if (cells[i] == nil || cells[i].Next) && !endResult[i] { + if (cells[i] == nil || cells[i].next) && !endResult[i] { result := rlt.Result.(*mysql.Result) row, err := result.Rows.Next() if err != nil { @@ -294,14 +340,17 @@ func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, order textRow := &mysql.TextRow{Row: row} values, err := textRow.Decode() if err != nil { - log.Fatal(err) + log.Panic(err) + } + for _, of := range orderByFields { + of.value = values[of.fieldValueIndex].Val } - value := values[orderByIndex] + cells[i] = &OrderByCell{ - Index: i, - Val: value.Val, - Next: false, - Row: textRow, + orderField: orderByFields, + resultIndex: i, + next: false, + row: textRow, } } else { binaryRow := &mysql.BinaryRow{Row: row} @@ -309,12 +358,15 @@ func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, order if err != nil { log.Fatal(err) } - value := values[orderByIndex] + for _, of := range orderByFields { + of.value = values[of.fieldValueIndex].Val + } + cells[i] = &OrderByCell{ - Index: i, - Val: value.Val, - Next: false, - Row: binaryRow, + orderField: orderByFields, + resultIndex: i, + next: false, + row: binaryRow, } } } @@ -322,9 +374,9 @@ func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, order if pop == 0 { break } - cell := compareOrderByCells(cells, desc) - rows = append(rows, cell.Row) - cells[cell.Index].Next = true + cell := compareOrderByCells(cells) + rows = append(rows, cell.row) + cells[cell.resultIndex].next = true } for _, rlt := range results { @@ -339,20 +391,35 @@ func mergeResultWithOrderBy(ctx context.Context, results []*ResultWithErr, order return result, warning } -func compareOrderByCells(cells []*OrderByCell, desc bool) *OrderByCell { +func compareOrderByCells(cells []*OrderByCell) *OrderByCell { cellSlice := make([]*OrderByCell, 0) for _, cell := range cells { - if !cell.Next { + if !cell.next { cellSlice = append(cellSlice, cell) } } sort.Sort(OrderByCells(cellSlice)) - if desc { - return cellSlice[len(cellSlice)-1] - } return cellSlice[0] } +func castOrderByItemsToOrderField(orderBy *ast.OrderByClause, fields []*mysql.Field) []*OrderField { + var ( + sb strings.Builder + result []*OrderField + ) + for _, item := range orderBy.Items { + sb.Reset() + restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + if err := item.Expr.Restore(restoreCtx); err != nil { + log.Fatal(err) + } + orderByField := sb.String() + orderByIndex := getOrderByFieldIndex(orderByField, fields) + result = append(result, &OrderField{asc: !item.Desc, fieldValueIndex: orderByIndex}) + } + return result +} + func getOrderByFieldIndex(orderByField string, fields []*mysql.Field) int { for i, field := range fields { if strings.EqualFold(orderByField, field.Name) { diff --git a/test/shd/sharding_test.go b/test/shd/sharding_test.go index db6150c..c508b97 100644 --- a/test/shd/sharding_test.go +++ b/test/shd/sharding_test.go @@ -28,7 +28,8 @@ const ( driverName = "mysql" dataSourceName = "dksl:123456@tcp(127.0.0.1:13306)/drug?timeout=10s&readTimeout=10s&writeTimeout=10s&parseTime=true&loc=Local&charset=utf8mb4,utf8" selectDrugResource = "select id, drug_res_type_id, base_type, sale_price from drug_resource where id between ? and ?" - selectDrugResourceOrderByIDDesc = "select id, drug_res_type_id, base_type, sale_price from drug_resource where id between ? and ? order by id desc" + selectDrugResourceOrderBy1 = "select id, drug_res_type_id, base_type, sale_price from drug_resource where id between ? and ? order by id desc" + selectDrugResourceOrderBy2 = "select id, drug_res_type_id, manufacturer_id, sale_price from drug_resource where id between ? and ? order by manufacturer_id desc, id asc" selectDrugResourceOrderByIDDescLimit = "select id, drug_res_type_id, base_type, sale_price from drug_resource where id between ? and ? order by id desc limit ?, ?" selectDrugResourceOrderByIDDescLimit2 = "select id, drug_res_type_id, base_type, sale_price from drug_resource where id between ? and ? order by id desc limit ?" @@ -79,7 +80,7 @@ func (suite *_ShardingSuite) TestSelect() { } func (suite *_ShardingSuite) TestSelectOrderBy() { - rows, err := suite.db.Query(selectDrugResourceOrderByIDDesc, 200, 210) + rows, err := suite.db.Query(selectDrugResourceOrderBy1, 200, 210) if suite.NoErrorf(err, "select row error: %v", err) { var ( id int64 @@ -95,6 +96,23 @@ func (suite *_ShardingSuite) TestSelectOrderBy() { } } +func (suite *_ShardingSuite) TestSelectOrderBy2() { + rows, err := suite.db.Query(selectDrugResourceOrderBy2, 200, 250) + if suite.NoErrorf(err, "select row error: %v", err) { + var ( + id int64 + drugResTypeId string + manufacturerId string + salePrice float32 + ) + for rows.Next() { + err := rows.Scan(&id, &drugResTypeId, &manufacturerId, &salePrice) + suite.NoError(err) + suite.T().Logf("id: %d, drug resource type id: %s, manufacturer id: %s, sale price: %v", id, drugResTypeId, manufacturerId, salePrice) + } + } +} + func (suite *_ShardingSuite) TestSelectOrderByAndLimit() { rows, err := suite.db.Query(selectDrugResourceOrderByIDDescLimit, 200, 300, 10, 20) if suite.NoErrorf(err, "select row error: %v", err) { @@ -152,7 +170,7 @@ func (suite *_ShardingSuite) TestUpdateDrugResource() { suite.Assert().Nil(err) suite.Assert().Equal(int64(11), affectedRows) - rows, err := suite.db.Query(selectDrugResourceOrderByIDDesc, 200, 210) + rows, err := suite.db.Query(selectDrugResourceOrderBy1, 200, 210) if suite.NoErrorf(err, "select row error: %v", err) { var ( id int64