Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types/json: fix JSON comparison for int and float (#17622) #17717

Merged
merged 2 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4155,6 +4155,9 @@ func (s *testIntegrationSuite) TestFuncJSON(c *C) {
json_length('[1, 2, 3]')
`)
r.Check(testkit.Rows("1 0 0 1 2 3"))

// #16267
tk.MustQuery(`select json_array(922337203685477580) = json_array(922337203685477581);`).Check(testkit.Rows("0"))
}

func (s *testIntegrationSuite) TestColumnInfoModified(c *C) {
Expand Down
82 changes: 62 additions & 20 deletions types/json/binary_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"sort"
"unicode/utf8"
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/util/hack"
Expand Down Expand Up @@ -638,6 +637,41 @@ func compareFloat64PrecisionLoss(x, y float64) int {
return 1
}

func compareInt64(x int64, y int64) int {
if x < y {
return -1
} else if x == y {
return 0
}

return 1
}

func compareUint64(x uint64, y uint64) int {
if x < y {
return -1
} else if x == y {
return 0
}

return 1
}

func compareInt64Uint64(x int64, y uint64) int {
if x < 0 {
return -1
}
return compareUint64(uint64(x), y)
}

func compareFloat64Int64(x float64, y int64) int {
return compareFloat64PrecisionLoss(x, float64(y))
}

func compareFloat64Uint64(x float64, y uint64) int {
return compareFloat64PrecisionLoss(x, float64(y))
}

// CompareBinary compares two binary json objects. Returns -1 if left < right,
// 0 if left == right, else returns 1.
func CompareBinary(left, right BinaryJSON) int {
Expand All @@ -653,10 +687,33 @@ func CompareBinary(left, right BinaryJSON) int {
case TypeCodeLiteral:
// false is less than true.
cmp = int(right.Value[0]) - int(left.Value[0])
case TypeCodeInt64, TypeCodeUint64, TypeCodeFloat64:
leftFloat := i64AsFloat64(left.GetInt64(), left.TypeCode)
rightFloat := i64AsFloat64(right.GetInt64(), right.TypeCode)
cmp = compareFloat64PrecisionLoss(leftFloat, rightFloat)
case TypeCodeInt64:
switch right.TypeCode {
case TypeCodeInt64:
cmp = compareInt64(left.GetInt64(), right.GetInt64())
case TypeCodeUint64:
cmp = compareInt64Uint64(left.GetInt64(), right.GetUint64())
case TypeCodeFloat64:
cmp = -compareFloat64Int64(right.GetFloat64(), left.GetInt64())
}
case TypeCodeUint64:
switch right.TypeCode {
case TypeCodeInt64:
cmp = -compareInt64Uint64(right.GetInt64(), left.GetUint64())
case TypeCodeUint64:
cmp = compareUint64(left.GetUint64(), right.GetUint64())
case TypeCodeFloat64:
cmp = -compareFloat64Uint64(right.GetFloat64(), left.GetUint64())
}
case TypeCodeFloat64:
switch right.TypeCode {
case TypeCodeInt64:
cmp = compareFloat64Int64(left.GetFloat64(), right.GetInt64())
case TypeCodeUint64:
cmp = compareFloat64Uint64(left.GetFloat64(), right.GetUint64())
case TypeCodeFloat64:
cmp = compareFloat64PrecisionLoss(left.GetFloat64(), right.GetFloat64())
}
case TypeCodeString:
cmp = bytes.Compare(left.GetString(), right.GetString())
case TypeCodeArray:
Expand All @@ -682,21 +739,6 @@ func CompareBinary(left, right BinaryJSON) int {
return cmp
}

func i64AsFloat64(i64 int64, typeCode TypeCode) float64 {
switch typeCode {
case TypeCodeLiteral, TypeCodeInt64:
return float64(i64)
case TypeCodeUint64:
u64 := *(*uint64)(unsafe.Pointer(&i64))
return float64(u64)
case TypeCodeFloat64:
return *(*float64)(unsafe.Pointer(&i64))
default:
msg := fmt.Sprintf(unknownTypeCodeErrorMsg, typeCode)
panic(msg)
}
}

// MergeBinary merges multiple BinaryJSON into one according the following rules:
// 1) adjacent arrays are merged to a single array;
// 2) adjacent object are merged to a single object;
Expand Down
55 changes: 43 additions & 12 deletions types/json/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package json

import (
"math"
"testing"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -259,22 +260,52 @@ func (s *testJSONSuite) TestCompareBinary(c *C) {
jObject := mustParseBinaryFromString(c, `{"a": "b"}`)

var tests = []struct {
left BinaryJSON
right BinaryJSON
left BinaryJSON
right BinaryJSON
result int
}{
{jNull, jIntegerSmall},
{jIntegerSmall, jIntegerLarge},
{jIntegerLarge, jStringSmall},
{jStringSmall, jStringLarge},
{jStringLarge, jObject},
{jObject, jArraySmall},
{jArraySmall, jArrayLarge},
{jArrayLarge, jBoolFalse},
{jBoolFalse, jBoolTrue},
{jNull, jIntegerSmall, -1},
{jIntegerSmall, jIntegerLarge, -1},
{jIntegerLarge, jStringSmall, -1},
{jStringSmall, jStringLarge, -1},
{jStringLarge, jObject, -1},
{jObject, jArraySmall, -1},
{jArraySmall, jArrayLarge, -1},
{jArrayLarge, jBoolFalse, -1},
{jBoolFalse, jBoolTrue, -1},
{CreateBinary(int64(922337203685477580)), CreateBinary(int64(922337203685477580)), 0},
{CreateBinary(int64(922337203685477580)), CreateBinary(int64(922337203685477581)), -1},
{CreateBinary(int64(922337203685477581)), CreateBinary(int64(922337203685477580)), 1},

{CreateBinary(int64(-1)), CreateBinary(uint64(18446744073709551615)), -1},
{CreateBinary(int64(922337203685477580)), CreateBinary(uint64(922337203685477581)), -1},
{CreateBinary(int64(2)), CreateBinary(uint64(1)), 1},
{CreateBinary(int64(math.MaxInt64)), CreateBinary(uint64(math.MaxInt64)), 0},

{CreateBinary(uint64(18446744073709551615)), CreateBinary(int64(-1)), 1},
{CreateBinary(uint64(922337203685477581)), CreateBinary(int64(922337203685477580)), 1},
{CreateBinary(uint64(1)), CreateBinary(int64(2)), -1},
{CreateBinary(uint64(math.MaxInt64)), CreateBinary(int64(math.MaxInt64)), 0},

{CreateBinary(float64(9.0)), CreateBinary(int64(9)), 0},
{CreateBinary(float64(8.9)), CreateBinary(int64(9)), -1},
{CreateBinary(float64(9.1)), CreateBinary(int64(9)), 1},

{CreateBinary(float64(9.0)), CreateBinary(uint64(9)), 0},
{CreateBinary(float64(8.9)), CreateBinary(uint64(9)), -1},
{CreateBinary(float64(9.1)), CreateBinary(uint64(9)), 1},

{CreateBinary(int64(9)), CreateBinary(float64(9.0)), 0},
{CreateBinary(int64(9)), CreateBinary(float64(8.9)), 1},
{CreateBinary(int64(9)), CreateBinary(float64(9.1)), -1},

{CreateBinary(uint64(9)), CreateBinary(float64(9.0)), 0},
{CreateBinary(uint64(9)), CreateBinary(float64(8.9)), 1},
{CreateBinary(uint64(9)), CreateBinary(float64(9.1)), -1},
}
for _, tt := range tests {
cmp := CompareBinary(tt.left, tt.right)
c.Assert(cmp < 0, IsTrue)
c.Assert(cmp == tt.result, IsTrue, Commentf("left: %v, right: %v, expect: %v, got: %v", tt.left, tt.right, tt.result, cmp))
}
}

Expand Down