Skip to content

Commit

Permalink
feat: Add convenience literal APIs (#47)
Browse files Browse the repository at this point in the history
* Introduce literal package

---------

Co-authored-by: Jacques Nadeau <jacques@apache.org>
  • Loading branch information
scgkiran and jacques-n authored Aug 21, 2024
1 parent e77df67 commit 597afdb
Show file tree
Hide file tree
Showing 11 changed files with 1,275 additions and 92 deletions.
35 changes: 23 additions & 12 deletions expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (*ProtoLiteral) isRootRef() {}
func (t *ProtoLiteral) GetType() types.Type { return t.Type }
func (t *ProtoLiteral) String() string {
switch literalType := t.Type.(type) {
case types.PrecisionTimeStampType, types.PrecisionTimeStampTzType:
case *types.PrecisionTimestampType, *types.PrecisionTimestampTzType:
return fmt.Sprintf("%s(%d)", literalType, t.Value)
}
return fmt.Sprintf("%s(%s)", t.Type, t.Value)
Expand Down Expand Up @@ -458,15 +458,15 @@ func (t *ProtoLiteral) ToProtoLiteral() *proto.Expression_Literal {
Scale: literalType.Scale,
},
}
case types.PrecisionTimeStampType:
case *types.PrecisionTimestampType:
v := t.Value.(uint64)
lit.LiteralType = &proto.Expression_Literal_PrecisionTimestamp_{
PrecisionTimestamp: &proto.Expression_Literal_PrecisionTimestamp{
Precision: literalType.GetPrecisionProtoVal(),
Value: int64(v),
},
}
case types.PrecisionTimeStampTzType:
case *types.PrecisionTimestampTzType:
v := t.Value.(uint64)
lit.LiteralType = &proto.Expression_Literal_PrecisionTimestampTz{
PrecisionTimestampTz: &proto.Expression_Literal_PrecisionTimestamp{
Expand Down Expand Up @@ -623,7 +623,8 @@ func NewFixedBinaryLiteral(val types.FixedBinary, nullable bool) *ByteSliceLiter
type allLiteralTypes interface {
PrimitiveLiteralValue | nestedLiteral | MapLiteralValue |
[]byte | types.UUID | types.FixedBinary | *types.IntervalYearToMonth |
*types.IntervalDayToSecond | *types.VarChar | *types.Decimal | *types.UserDefinedLiteral
*types.IntervalDayToSecond | *types.VarChar | *types.Decimal | *types.UserDefinedLiteral |
*types.PrecisionTimestamp | *types.PrecisionTimestampTz
}

func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) {
Expand Down Expand Up @@ -711,6 +712,10 @@ func NewLiteral[T allLiteralTypes](val T, nullable bool) (Literal, error) {
Length: int32(v.Length),
},
}, nil
case *types.PrecisionTimestamp:
return NewPrecisionTimestampLiteral(v.PrecisionTimestamp.Value, types.TimePrecision(v.PrecisionTimestamp.Precision), getNullability(nullable)), nil
case *types.PrecisionTimestampTz:
return NewPrecisionTimestampTzLiteral(v.PrecisionTimestampTz.Value, types.TimePrecision(v.PrecisionTimestampTz.Precision), getNullability(nullable)), nil
}

return nil, substraitgo.ErrNotImplemented
Expand Down Expand Up @@ -955,7 +960,7 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal {
if precTimeStamp.Value < 0 {
return nil
}
return NewPrecisionTimestampLiteral(uint64(precTimeStamp.Value), precision, nullability)
return NewPrecisionTimestampLiteral(precTimeStamp.Value, precision, nullability)
case *proto.Expression_Literal_PrecisionTimestampTz:
precTimeStamp := lit.PrecisionTimestampTz
precision, err := types.ProtoToTimePrecision(precTimeStamp.Precision)
Expand All @@ -965,27 +970,33 @@ func LiteralFromProto(l *proto.Expression_Literal) Literal {
if precTimeStamp.Value < 0 {
return nil
}
return NewPrecisionTimestampTzLiteral(uint64(precTimeStamp.Value), precision, nullability)
return NewPrecisionTimestampTzLiteral(precTimeStamp.Value, precision, nullability)
}
panic("unimplemented literal type")
}

// NewPrecisionTimestampLiteral it takes timestamp value which is in specified precision
// and nullable property (n) and returns a PrecisionTimestamp Literal
func NewPrecisionTimestampLiteral(value uint64, precision types.TimePrecision, n types.Nullability) Literal {
precisionType := types.NewPrecisionTimestampType(precision).WithNullability(n)
func NewPrecisionTimestampLiteral(value int64, precision types.TimePrecision, n types.Nullability) Literal {
return &ProtoLiteral{
Value: value,
Type: precisionType,
Type: &types.PrecisionTimestampType{
Precision: precision,
Nullability: n,
},
}
}

// NewPrecisionTimestampTzLiteral it takes timestamp value which is in specified precision
// and nullable property (n) and returns a PrecisionTimestampTz Literal
func NewPrecisionTimestampTzLiteral(value uint64, precision types.TimePrecision, n types.Nullability) Literal {
precisionType := types.NewPrecisionTimestampTzType(precision).WithNullability(n)
func NewPrecisionTimestampTzLiteral(value int64, precision types.TimePrecision, n types.Nullability) Literal {
return &ProtoLiteral{
Value: value,
Type: precisionType,
Type: &types.PrecisionTimestampTzType{
PrecisionTimestampType: types.PrecisionTimestampType{
Precision: precision,
Nullability: n,
},
},
}
}
4 changes: 2 additions & 2 deletions expr/proto_literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func TestLiteralFromProto(t *testing.T) {
}{
{"TimeStampType",
&proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestamp_{PrecisionTimestamp: &proto.Expression_Literal_PrecisionTimestamp{Precision: 4, Value: 12345678}}, Nullable: true},
&ProtoLiteral{Value: uint64(12345678), Type: types.NewPrecisionTimestampType(types.PrecisionEMinus4Seconds).WithNullability(types.NullabilityNullable)},
&ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampType(types.PrecisionEMinus4Seconds).WithNullability(types.NullabilityNullable)},
},
{"TimeStampTzType",
&proto.Expression_Literal{LiteralType: &proto.Expression_Literal_PrecisionTimestampTz{PrecisionTimestampTz: &proto.Expression_Literal_PrecisionTimestamp{Precision: 9, Value: 12345678}}, Nullable: true},
&ProtoLiteral{Value: uint64(12345678), Type: types.NewPrecisionTimestampTzType(types.PrecisionNanoSeconds).WithNullability(types.NullabilityNullable)},
&ProtoLiteral{Value: int64(12345678), Type: types.NewPrecisionTimestampTzType(types.PrecisionNanoSeconds).WithNullability(types.NullabilityNullable)},
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

module github.com/substrait-io/substrait-go

go 1.20
go 1.21

require (
github.com/alecthomas/participle/v2 v2.0.0
github.com/cockroachdb/apd/v3 v3.2.1
github.com/goccy/go-yaml v1.9.8
github.com/google/go-cmp v0.5.9
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.8.2
golang.org/x/exp v0.0.0-20230206171751-46f607a40771
google.golang.org/protobuf v1.33.0
Expand All @@ -19,8 +21,10 @@ require (
github.com/fatih/color v1.13.0 // indirect
github.com/go-playground/validator/v10 v10.11.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/sys v0.18.0 // indirect
Expand Down
13 changes: 13 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
github.com/alecthomas/participle/v2 v2.0.0/go.mod h1:rAKZdJldHu8084ojcWevWAL8KmEU+AT+Olodb+WoN2Y=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg=
github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand All @@ -23,7 +29,10 @@ github.com/goccy/go-yaml v1.9.8 h1:5gMyLUeU1/6zl+WFfR1hN7D2kf+1/eRGa7DFtToiBvQ=
github.com/goccy/go-yaml v1.9.8/go.mod h1:JubOolP3gh0HpiBc4BLRD4YmjEjHAmIIB2aaXKkTfoE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
Expand All @@ -36,6 +45,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
Expand All @@ -46,6 +57,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
Expand Down
84 changes: 84 additions & 0 deletions literal/decimal_util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package literal

import (
"fmt"
"regexp"
"strings"

"github.com/cockroachdb/apd/v3"
)

var decimalPattern = regexp.MustCompile(`^[+-]?\d{0,38}(\.\d{0,38})?([eE][+-]?\d{0,38})?$`)

// decimalStringToBytes converts a decimal string to a 16-byte byte array.
// 16-byte bytes represents a little-endian 128-bit integer, to be divided by 10^Scale to get the decimal value.
// This function also returns the precision and scale of the decimal value.
// The precision is the total number of digits in the decimal value. The precision is limited to 38 digits.
// The scale is the number of digits to the right of the decimal point. The scale is limited to the precision.
func decimalStringToBytes(decimalStr string) ([16]byte, int32, int32, error) {
var (
result [16]byte
precision int32
scale int32
)

strings.Trim(decimalStr, " ")
if !decimalPattern.MatchString(decimalStr) {
return result, 0, 0, fmt.Errorf("invalid decimal string")
}

// Parse the decimal string using apd
dec, cond, err := apd.NewFromString(decimalStr)
if err != nil || cond.Any() {
return result, 0, 0, fmt.Errorf("invalid decimal string: %v", err)
}

if dec.Exponent > 0 {
precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent
scale = 0
} else {
scale = -dec.Exponent
precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1)
}
if precision > 38 {
return result, precision, scale, fmt.Errorf("number exceeds maximum precision of 38")
}

coefficient := dec.Coeff
if dec.Exponent > 0 {
// multiple coefficient with 10^exponent
multiplier := apd.NewBigInt(1).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(dec.Exponent)), nil)
coefficient.Mul(&dec.Coeff, multiplier)
}
// Convert the coefficient to a byte array
byteArray := coefficient.Bytes()
if len(byteArray) > 16 {
return result, 0, 0, fmt.Errorf("number exceeds 16 bytes")
}
copy(result[16-len(byteArray):], byteArray)

// Handle the sign and two's complement for negative numbers
if dec.Negative {
twosComplement(result[:])
}

// Reverse the byte array to little-endian
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}

return result, precision, scale, nil
}

func twosComplement(bytes []byte) {
for i := range bytes {
bytes[i] = ^bytes[i]
}
carry := byte(1)
for i := len(bytes) - 1; i >= 0; i-- {
bytes[i] += carry
if bytes[i] != 0 {
break
}
}
}
122 changes: 122 additions & 0 deletions literal/decimal_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package literal

import (
"fmt"
"math/big"
"strings"
"testing"

"github.com/cockroachdb/apd/v3"
"github.com/stretchr/testify/assert"
)

func Test_decimalStringToBytes(t *testing.T) {
tests := []struct {
input string
hexWant string
expPrecision int32
expScale int32
expected string
}{
{"12345", "39300000000000000000000000000000", 5, 0, ""},
{"+12345", "39300000000000000000000000000000", 5, 0, "12345"},
{"-12345", "c7cfffffffffffffffffffffffffffff", 5, 0, ""},
{"123.45", "39300000000000000000000000000000", 5, 2, ""},
{"-123.45", "c7cfffffffffffffffffffffffffffff", 5, 2, ""},
{"0.123", "7b000000000000000000000000000000", 4, 3, ""},
{"-0.123", "85ffffffffffffffffffffffffffffff", 4, 3, ""},
{"9223372036854775807", "ffffffffffffff7f0000000000000000", 19, 0, ""}, // Max int64
{"-9223372036854775808", "0000000000000080ffffffffffffffff", 19, 0, ""}, // Min int64
{"99999999999999999999999999999999999999", "ffffffff3f228a097ac4865aa84c3b4b", 38, 0, ""},
{"+99999999999999999999999999999999999999", "ffffffff3f228a097ac4865aa84c3b4b", 38, 0, ""},
{"-99999999999999999999999999999999999999", "01000000c0dd75f6853b79a557b3c4b4", 38, 0, ""},
{"0", "00000000000000000000000000000000", 1, 0, ""},
{"-0", "00000000000000000000000000000000", 1, 0, "0"},
{"0.0", "00000000000000000000000000000000", 2, 1, ""},
{"65535", "ffff0000000000000000000000000000", 5, 0, ""},
{"-65535", "0100ffffffffffffffffffffffffffff", 5, 0, ""},
{"18446744073709551615", "ffffffffffffffff0000000000000000", 20, 0, ""}, // Max uint64
{"-18446744073709551616", "0000000000000000ffffffffffffffff", 20, 0, ""}, // Min int64 - 1
{"12345.6789", "15cd5b07000000000000000000000000", 9, 4, ""},
{"1234567890123456", "c0ba8a3cd56204000000000000000000", 16, 0, ""},
{"1234567890123456.78901234", "f2af966ca0101f9b241a000000000000", 24, 8, ""},
{"1230000000000000", "00e012b1ad5e04000000000000000000", 16, 0, ""},
{"0.0012345678901234", "f22fce733a0b00000000000000000000", 17, 16, ""},
{"-0.0012345678901234", "0ed0318cc5f4ffffffffffffffffffff", 17, 16, ""},
{"123456789012345678901234567890.1234", "f2af967ed05c82de3297ff6fde3c0000", 34, 4, ""},
{"-1234567890.1234567890", "2ef5e0147356ab54ffffffffffffffff", 20, 10, ""},
{"1.23e-5", "7b000000000000000000000000000000", 8, 7, "0.0000123"},
{"1.23e15", "00e012b1ad5e04000000000000000000", 16, 0, "1230000000000000"},
{"1.23e20", "00000c6d51c8f7aa0600000000000000", 21, 0, "123000000000000000000"},
{"1.23e35", "00000000cebde644bc05f0425eb01700", 36, 0, "123000000000000000000000000000000000"},
{"1.23E35", "00000000cebde644bc05f0425eb01700", 36, 0, "123000000000000000000000000000000000"},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
testDecimalStringToBytes(t, tt.input, tt.hexWant, tt.expPrecision, tt.expScale, tt.expected)
})
}

badInputs := []struct{ input string }{
{"12345678901234567890123456789012345678901234"},
{"abc"},
{"12.34.56"},
{"199999999999999999999999999999999999999"},
{"1.23e45"},
{"1.23E300"},
}
for _, tt := range badInputs {
t.Run(tt.input, func(t *testing.T) {
_, _, _, err := decimalStringToBytes(tt.input)
assert.Error(t, err, "decimalStringToBytes(%v) expected error", tt.input)
})
}
}

func testDecimalStringToBytes(t *testing.T, input, hexWant string, expPrecision, expScale int32, expected string) {
got, precision, scale, err := decimalStringToBytes(input)
assert.NoError(t, err)
assert.Len(t, got, 16)
assert.Equal(t, hexToBytes(t, hexWant), got[:])
assert.Equal(t, expPrecision, precision)
assert.Equal(t, expScale, scale)
if err == nil {
// verify that the conversion is correct
decStr := decimalBytesToString(got, scale)
if expected == "" {
expected = strings.TrimPrefix(input, "+")
}
assert.Equal(t, expected, decStr)
}
}

func hexToBytes(t *testing.T, input string) []byte {
bytes := make([]byte, len(input)/2)
for i := 0; i < len(input); i += 2 {
_, err := fmt.Sscanf(input[i:i+2], "%02x", &bytes[i/2])
assert.NoError(t, err)
}
return bytes
}

func decimalBytesToString(decimalBytes [16]byte, scale int32) string {
// Reverse the byte array to big-endian
for i, j := 0, len(decimalBytes)-1; i < j; i, j = i+1, j-1 {
decimalBytes[i], decimalBytes[j] = decimalBytes[j], decimalBytes[i]
}

isNegative := decimalBytes[0]&0x80 != 0
// compute two's complement for negative numbers
if isNegative {
twosComplement(decimalBytes[:])
}

// Convert the byte array to a big.Int
intValue := new(big.Int).SetBytes(decimalBytes[:])
if isNegative {
intValue.Neg(intValue)
}
apdBigInt := apd.NewBigInt(0).SetMathBigInt(intValue)
return apd.NewWithBigInt(apdBigInt, -scale).String()
}
Loading

0 comments on commit 597afdb

Please sign in to comment.