diff --git a/spanner/value.go b/spanner/value.go index 25f2f5802e80..402f4cdfeb43 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -313,7 +313,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return err } *p = x - case *NullString: + case *NullString, **string: if p == nil { return errNilDst(p) } @@ -321,16 +321,26 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = NullString{} + switch sp := ptr.(type) { + case *NullString: + *sp = NullString{} + case **string: + *sp = nil + } break } x, err := getStringValue(v) if err != nil { return err } - p.Valid = true - p.StringVal = x - case *[]NullString: + switch sp := ptr.(type) { + case *NullString: + sp.Valid = true + sp.StringVal = x + case **string: + *sp = &x + } + case *[]NullString, *[]*string: if p == nil { return errNilDst(p) } @@ -338,18 +348,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullString: + *sp = nil + case *[]*string: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullStringArray(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullString: + y, err := decodeNullStringArray(x) + if err != nil { + return err + } + *sp = y + case *[]*string: + y, err := decodeStringPointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]string: if p == nil { return errNilDst(p) @@ -429,7 +453,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errBadEncoding(v, err) } *p = y - case *NullInt64: + case *NullInt64, **int64: if p == nil { return errNilDst(p) } @@ -437,7 +461,12 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = NullInt64{} + switch sp := ptr.(type) { + case *NullInt64: + *sp = NullInt64{} + case **int64: + *sp = nil + } break } x, err := getStringValue(v) @@ -448,9 +477,14 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { if err != nil { return errBadEncoding(v, err) } - p.Valid = true - p.Int64 = y - case *[]NullInt64: + switch sp := ptr.(type) { + case *NullInt64: + sp.Valid = true + sp.Int64 = y + case **int64: + *sp = &y + } + case *[]NullInt64, *[]*int64: if p == nil { return errNilDst(p) } @@ -458,18 +492,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullInt64: + *sp = nil + case *[]*int64: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullInt64Array(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullInt64: + y, err := decodeNullInt64Array(x) + if err != nil { + return err + } + *sp = y + case *[]*int64: + y, err := decodeInt64PointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]int64: if p == nil { return errNilDst(p) @@ -505,7 +553,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return err } *p = x - case *NullBool: + case *NullBool, **bool: if p == nil { return errNilDst(p) } @@ -513,16 +561,26 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = NullBool{} + switch sp := ptr.(type) { + case *NullBool: + *sp = NullBool{} + case **bool: + *sp = nil + } break } x, err := getBoolValue(v) if err != nil { return err } - p.Valid = true - p.Bool = x - case *[]NullBool: + switch sp := ptr.(type) { + case *NullBool: + sp.Valid = true + sp.Bool = x + case **bool: + *sp = &x + } + case *[]NullBool, *[]*bool: if p == nil { return errNilDst(p) } @@ -530,18 +588,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullBool: + *sp = nil + case *[]*bool: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullBoolArray(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullBool: + y, err := decodeNullBoolArray(x) + if err != nil { + return err + } + *sp = y + case *[]*bool: + y, err := decodeBoolPointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]bool: if p == nil { return errNilDst(p) @@ -577,7 +649,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return err } *p = x - case *NullFloat64: + case *NullFloat64, **float64: if p == nil { return errNilDst(p) } @@ -585,16 +657,26 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = NullFloat64{} + switch sp := ptr.(type) { + case *NullFloat64: + *sp = NullFloat64{} + case **float64: + *sp = nil + } break } x, err := getFloat64Value(v) if err != nil { return err } - p.Valid = true - p.Float64 = x - case *[]NullFloat64: + switch sp := ptr.(type) { + case *NullFloat64: + sp.Valid = true + sp.Float64 = x + case **float64: + *sp = &x + } + case *[]NullFloat64, *[]*float64: if p == nil { return errNilDst(p) } @@ -602,18 +684,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullFloat64: + *sp = nil + case *[]*float64: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullFloat64Array(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullFloat64: + y, err := decodeNullFloat64Array(x) + if err != nil { + return err + } + *sp = y + case *[]*float64: + y, err := decodeFloat64PointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]float64: if p == nil { return errNilDst(p) @@ -649,7 +745,18 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { if err != nil { return err } - case *[]NullTime: + case **time.Time: + var nt NullTime + if isNull { + *p = nil + break + } + err := parseNullTime(v, &nt, code, isNull) + if err != nil { + return err + } + *p = &nt.Time + case *[]NullTime, *[]*time.Time: if p == nil { return errNilDst(p) } @@ -657,18 +764,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullTime: + *sp = nil + case *[]*time.Time: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullTimeArray(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullTime: + y, err := decodeNullTimeArray(x) + if err != nil { + return err + } + *sp = y + case *[]*time.Time: + y, err := decodeTimePointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]time.Time: if p == nil { return errNilDst(p) @@ -708,7 +829,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errBadEncoding(v, err) } *p = y - case *NullDate: + case *NullDate, **civil.Date: if p == nil { return errNilDst(p) } @@ -716,7 +837,12 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = NullDate{} + switch sp := ptr.(type) { + case *NullDate: + *sp = NullDate{} + case **civil.Date: + *sp = nil + } break } x, err := getStringValue(v) @@ -727,9 +853,14 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { if err != nil { return errBadEncoding(v, err) } - p.Valid = true - p.Date = y - case *[]NullDate: + switch sp := ptr.(type) { + case *NullDate: + sp.Valid = true + sp.Date = y + case **civil.Date: + *sp = &y + } + case *[]NullDate, *[]*civil.Date: if p == nil { return errNilDst(p) } @@ -737,18 +868,32 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errTypeMismatch(code, acode, ptr) } if isNull { - *p = nil + switch sp := ptr.(type) { + case *[]NullDate: + *sp = nil + case *[]*civil.Date: + *sp = nil + } break } x, err := getListValue(v) if err != nil { return err } - y, err := decodeNullDateArray(x) - if err != nil { - return err + switch sp := ptr.(type) { + case *[]NullDate: + y, err := decodeNullDateArray(x) + if err != nil { + return err + } + *sp = y + case *[]*civil.Date: + y, err := decodeDatePointerArray(x) + if err != nil { + return err + } + *sp = y } - *p = y case *[]civil.Date: if p == nil { return errNilDst(p) @@ -1372,6 +1517,20 @@ func decodeNullStringArray(pb *proto3.ListValue) ([]NullString, error) { return a, nil } +// decodeStringPointerArray decodes proto3.ListValue pb into a *string slice. +func decodeStringPointerArray(pb *proto3.ListValue) ([]*string, error) { + if pb == nil { + return nil, errNilListValue("STRING") + } + a := make([]*string, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, stringType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "STRING", err) + } + } + return a, nil +} + // decodeStringArray decodes proto3.ListValue pb into a string slice. func decodeStringArray(pb *proto3.ListValue) ([]string, error) { if pb == nil { @@ -1401,6 +1560,20 @@ func decodeNullInt64Array(pb *proto3.ListValue) ([]NullInt64, error) { return a, nil } +// decodeInt64PointerArray decodes proto3.ListValue pb into a *int64 slice. +func decodeInt64PointerArray(pb *proto3.ListValue) ([]*int64, error) { + if pb == nil { + return nil, errNilListValue("INT64") + } + a := make([]*int64, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, intType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "INT64", err) + } + } + return a, nil +} + // decodeInt64Array decodes proto3.ListValue pb into a int64 slice. func decodeInt64Array(pb *proto3.ListValue) ([]int64, error) { if pb == nil { @@ -1429,6 +1602,20 @@ func decodeNullBoolArray(pb *proto3.ListValue) ([]NullBool, error) { return a, nil } +// decodeBoolPointerArray decodes proto3.ListValue pb into a *bool slice. +func decodeBoolPointerArray(pb *proto3.ListValue) ([]*bool, error) { + if pb == nil { + return nil, errNilListValue("BOOL") + } + a := make([]*bool, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, boolType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "BOOL", err) + } + } + return a, nil +} + // decodeBoolArray decodes proto3.ListValue pb into a bool slice. func decodeBoolArray(pb *proto3.ListValue) ([]bool, error) { if pb == nil { @@ -1457,6 +1644,20 @@ func decodeNullFloat64Array(pb *proto3.ListValue) ([]NullFloat64, error) { return a, nil } +// decodeFloat64PointerArray decodes proto3.ListValue pb into a NullFloat64 slice. +func decodeFloat64PointerArray(pb *proto3.ListValue) ([]*float64, error) { + if pb == nil { + return nil, errNilListValue("FLOAT64") + } + a := make([]*float64, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, floatType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "FLOAT64", err) + } + } + return a, nil +} + // decodeFloat64Array decodes proto3.ListValue pb into a float64 slice. func decodeFloat64Array(pb *proto3.ListValue) ([]float64, error) { if pb == nil { @@ -1499,6 +1700,20 @@ func decodeNullTimeArray(pb *proto3.ListValue) ([]NullTime, error) { return a, nil } +// decodeTimePointerArray decodes proto3.ListValue pb into a NullTime slice. +func decodeTimePointerArray(pb *proto3.ListValue) ([]*time.Time, error) { + if pb == nil { + return nil, errNilListValue("TIMESTAMP") + } + a := make([]*time.Time, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, timeType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "TIMESTAMP", err) + } + } + return a, nil +} + // decodeTimeArray decodes proto3.ListValue pb into a time.Time slice. func decodeTimeArray(pb *proto3.ListValue) ([]time.Time, error) { if pb == nil { @@ -1527,6 +1742,20 @@ func decodeNullDateArray(pb *proto3.ListValue) ([]NullDate, error) { return a, nil } +// decodeDatePointerArray decodes proto3.ListValue pb into a *civil.Date slice. +func decodeDatePointerArray(pb *proto3.ListValue) ([]*civil.Date, error) { + if pb == nil { + return nil, errNilListValue("DATE") + } + a := make([]*civil.Date, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, dateType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "DATE", err) + } + } + return a, nil +} + // decodeDateArray decodes proto3.ListValue pb into a civil.Date slice. func decodeDateArray(pb *proto3.ListValue) ([]civil.Date, error) { if pb == nil { @@ -1743,6 +1972,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(stringType()) + case *string: + if v != nil { + return encodeValue(*v) + } + pt = stringType() + case []*string: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(stringType()) case []byte: if v != nil { pb.Kind = stringKind(base64.StdEncoding.EncodeToString(v)) @@ -1791,6 +2033,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(intType()) + case *int64: + if v != nil { + return encodeValue(*v) + } + pt = intType() + case []*int64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(intType()) case bool: pb.Kind = &proto3.Value_BoolValue{BoolValue: v} pt = boolType() @@ -1815,6 +2070,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(boolType()) + case *bool: + if v != nil { + return encodeValue(*v) + } + pt = boolType() + case []*bool: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(boolType()) case float64: pb.Kind = &proto3.Value_NumberValue{NumberValue: v} pt = floatType() @@ -1839,6 +2107,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(floatType()) + case *float64: + if v != nil { + return encodeValue(*v) + } + pt = floatType() + case []*float64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(floatType()) case time.Time: if v == commitTimestamp { pb.Kind = stringKind(commitTimestampPlaceholderString) @@ -1867,6 +2148,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(timeType()) + case *time.Time: + if v != nil { + return encodeValue(*v) + } + pt = timeType() + case []*time.Time: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(timeType()) case civil.Date: pb.Kind = stringKind(v.String()) pt = dateType() @@ -1891,6 +2185,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(dateType()) + case *civil.Date: + if v != nil { + return encodeValue(*v) + } + pt = dateType() + case []*civil.Date: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(dateType()) case GenericColumnValue: // Deep clone to ensure subsequent changes to v before // transmission don't affect our encoded value. diff --git a/spanner/value_test.go b/spanner/value_test.go index 955e682cec43..1feebc6cc907 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -74,6 +74,19 @@ func TestEncodeValue(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate + sValue := "abc" + var sNilPtr *string + iValue := int64(7) + var iNilPtr *int64 + bValue := true + var bNilPtr *bool + fValue := 3.14 + var fNilPtr *float64 + tValue := t1 + var tNilPtr *time.Time + dValue := d1 + var dNilPtr *civil.Date + var ( tString = stringType() tInt = intType() @@ -93,9 +106,12 @@ func TestEncodeValue(t *testing.T) { {"abc", stringProto("abc"), tString, "string"}, {NullString{"abc", true}, stringProto("abc"), tString, "NullString with value"}, {NullString{"abc", false}, nullProto(), tString, "NullString with null"}, + {&sValue, stringProto("abc"), tString, "*string with value"}, + {sNilPtr, nullProto(), tString, "*string with null"}, {[]string(nil), nullProto(), listType(tString), "null []string"}, {[]string{"abc", "bcd"}, listProto(stringProto("abc"), stringProto("bcd")), listType(tString), "[]string"}, {[]NullString{{"abcd", true}, {"xyz", false}}, listProto(stringProto("abcd"), nullProto()), listType(tString), "[]NullString"}, + {[]*string{&sValue, sNilPtr}, listProto(stringProto("abc"), nullProto()), listType(tString), "[]*string"}, // BYTES / BYTES ARRAY {[]byte("foo"), bytesProto([]byte("foo")), tBytes, "[]byte with value"}, {[]byte(nil), nullProto(), tBytes, "null []byte"}, @@ -110,35 +126,50 @@ func TestEncodeValue(t *testing.T) { {[]int64{33, 129}, listProto(intProto(33), intProto(129)), listType(tInt), "[]int64"}, {NullInt64{11, true}, intProto(11), tInt, "NullInt64 with value"}, {NullInt64{11, false}, nullProto(), tInt, "NullInt64 with null"}, + {&iValue, intProto(7), tInt, "*int64 with value"}, + {iNilPtr, nullProto(), tInt, "*int64 with null"}, {[]NullInt64{{35, true}, {131, false}}, listProto(intProto(35), nullProto()), listType(tInt), "[]NullInt64"}, + {[]*int64{&iValue, iNilPtr}, listProto(intProto(7), nullProto()), listType(tInt), "[]*int64"}, // BOOL / BOOL ARRAY {true, boolProto(true), tBool, "bool"}, {NullBool{true, true}, boolProto(true), tBool, "NullBool with value"}, {NullBool{true, false}, nullProto(), tBool, "NullBool with null"}, + {&bValue, boolProto(true), tBool, "*bool with value"}, + {bNilPtr, nullProto(), tBool, "*bool with null"}, {[]bool{true, false}, listProto(boolProto(true), boolProto(false)), listType(tBool), "[]bool"}, {[]NullBool{{true, true}, {true, false}}, listProto(boolProto(true), nullProto()), listType(tBool), "[]NullBool"}, + {[]*bool{&bValue, bNilPtr}, listProto(boolProto(true), nullProto()), listType(tBool), "[]*bool"}, // FLOAT64 / FLOAT64 ARRAY {3.14, floatProto(3.14), tFloat, "float"}, {NullFloat64{3.1415, true}, floatProto(3.1415), tFloat, "NullFloat64 with value"}, {NullFloat64{math.Inf(1), true}, floatProto(math.Inf(1)), tFloat, "NullFloat64 with infinity"}, {NullFloat64{3.14159, false}, nullProto(), tFloat, "NullFloat64 with null"}, + {&fValue, floatProto(3.14), tFloat, "*float64 with value"}, + {fNilPtr, nullProto(), tFloat, "*float64 with null"}, {[]float64(nil), nullProto(), listType(tFloat), "null []float64"}, {[]float64{3.141, 0.618, math.Inf(-1)}, listProto(floatProto(3.141), floatProto(0.618), floatProto(math.Inf(-1))), listType(tFloat), "[]float64"}, {[]NullFloat64{{3.141, true}, {0.618, false}}, listProto(floatProto(3.141), nullProto()), listType(tFloat), "[]NullFloat64"}, + {[]*float64{&fValue, fNilPtr}, listProto(floatProto(3.14), nullProto()), listType(tFloat), "[]NullFloat64"}, // TIMESTAMP / TIMESTAMP ARRAY {t1, timeProto(t1), tTime, "time"}, {NullTime{t1, true}, timeProto(t1), tTime, "NullTime with value"}, {NullTime{t1, false}, nullProto(), tTime, "NullTime with null"}, + {&tValue, timeProto(t1), tTime, "*time.Time with value"}, + {tNilPtr, nullProto(), tTime, "*time.Time with null"}, {[]time.Time(nil), nullProto(), listType(tTime), "null []time"}, {[]time.Time{t1, t2, t3, t4}, listProto(timeProto(t1), timeProto(t2), timeProto(t3), timeProto(t4)), listType(tTime), "[]time"}, {[]NullTime{{t1, true}, {t1, false}}, listProto(timeProto(t1), nullProto()), listType(tTime), "[]NullTime"}, + {[]*time.Time{&tValue, tNilPtr}, listProto(timeProto(t1), nullProto()), listType(tTime), "[]*time.Time"}, // DATE / DATE ARRAY {d1, dateProto(d1), tDate, "date"}, {NullDate{d1, true}, dateProto(d1), tDate, "NullDate with value"}, {NullDate{civil.Date{}, false}, nullProto(), tDate, "NullDate with null"}, + {&dValue, dateProto(d1), tDate, "*civil.Date with value"}, + {dNilPtr, nullProto(), tDate, "*civil.Date with null"}, {[]civil.Date(nil), nullProto(), listType(tDate), "null []date"}, {[]civil.Date{d1, d2}, listProto(dateProto(d1), dateProto(d2)), listType(tDate), "[]date"}, {[]NullDate{{d1, true}, {civil.Date{}, false}}, listProto(dateProto(d1), nullProto()), listType(tDate), "[]NullDate"}, + {[]*civil.Date{&dValue, dNilPtr}, listProto(dateProto(d1), nullProto()), listType(tDate), "[]*civil.Date"}, // GenericColumnValue {GenericColumnValue{tString, stringProto("abc")}, stringProto("abc"), tString, "GenericColumnValue with value"}, {GenericColumnValue{tString, nullProto()}, nullProto(), tString, "GenericColumnValue with null"}, @@ -603,6 +634,13 @@ func TestEncodeStructValueBasicFields(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate + sValue := "abc" + iValue := int64(300) + bValue := false + fValue := 3.45 + tValue := t1 + dValue := d1 + StructTypeProto := structType( mkField("Stringf", stringType()), mkField("Intf", intType()), @@ -634,6 +672,48 @@ func TestEncodeStructValueBasicFields(t *testing.T) { dateProto(d1)), StructTypeProto, }, + { + "Pointers to basic types.", + struct { + Stringf *string + Intf *int64 + Boolf *bool + Floatf *float64 + Bytef []byte + Timef *time.Time + Datef *civil.Date + }{&sValue, &iValue, &bValue, &fValue, []byte("foo"), &tValue, &dValue}, + listProto( + stringProto("abc"), + intProto(300), + boolProto(false), + floatProto(3.45), + bytesProto([]byte("foo")), + timeProto(t1), + dateProto(d1)), + StructTypeProto, + }, + { + "Pointers to basic types with null values.", + struct { + Stringf *string + Intf *int64 + Boolf *bool + Floatf *float64 + Bytef []byte + Timef *time.Time + Datef *civil.Date + }{nil, nil, nil, nil, nil, nil, nil}, + listProto( + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto()), + StructTypeProto, + }, { "Basic custom types.", struct { @@ -734,6 +814,19 @@ func TestEncodeStructValueArrayFields(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate + sValue := "def" + var sNilPtr *string + iValue := int64(68) + var iNilPtr *int64 + bValue := true + var bNilPtr *bool + fValue := 3.14 + var fNilPtr *float64 + tValue := t1 + var tNilPtr *time.Time + dValue := d1 + var dNilPtr *civil.Date + StructTypeProto := structType( mkField("Stringf", listType(stringType())), mkField("Intf", listType(intType())), @@ -841,6 +934,38 @@ func TestEncodeStructValueArrayFields(t *testing.T) { listProto(nullProto(), dateProto(d2))), StructTypeProto, }, + { + "Arrays of pointers to basic types with nullable elements.", + struct { + Stringf []*string + Intf []*int64 + Int64f []*int64 + Boolf []*bool + Floatf []*float64 + Bytef [][]byte + Timef []*time.Time + Datef []*civil.Date + }{ + []*string{sNilPtr, &sValue}, + []*int64{iNilPtr, &iValue}, + []*int64{iNilPtr, &iValue}, + []*bool{bNilPtr, &bValue}, + []*float64{fNilPtr, &fValue}, + [][]byte{[]byte("foo"), nil}, + []*time.Time{tNilPtr, &tValue}, + []*civil.Date{dNilPtr, &dValue}, + }, + listProto( + listProto(nullProto(), stringProto("def")), + listProto(nullProto(), intProto(68)), + listProto(nullProto(), intProto(68)), + listProto(nullProto(), boolProto(true)), + listProto(nullProto(), floatProto(3.14)), + listProto(bytesProto([]byte("foo")), nullProto()), + listProto(nullProto(), timeProto(t1)), + listProto(nullProto(), dateProto(d1))), + StructTypeProto, + }, { "Arrays of basic custom types with nullable elements.", struct { @@ -959,6 +1084,32 @@ func TestDecodeValue(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate + // Pointer values. + sValue := "abc" + var sNilPtr *string + s2Value := "bcd" + + iValue := int64(15) + var iNilPtr *int64 + i1Value := int64(91) + i2Value := int64(87) + + bValue := true + var bNilPtr *bool + b2Value := false + + fValue := 3.14 + var fNilPtr *float64 + f2Value := 6.626 + + tValue := t1 + var tNilPtr *time.Time + t2Value := t2 + + dValue := d1 + var dNilPtr *civil.Date + d2Value := d2 + for _, test := range []struct { desc string proto *proto3.Value @@ -969,6 +1120,8 @@ func TestDecodeValue(t *testing.T) { // STRING {desc: "decode STRING to string", proto: stringProto("abc"), protoType: stringType(), want: "abc"}, {desc: "decode NULL to string", proto: nullProto(), protoType: stringType(), want: "abc", wantErr: true}, + {desc: "decode STRING to *string", proto: stringProto("abc"), protoType: stringType(), want: &sValue}, + {desc: "decode NULL to *string", proto: nullProto(), protoType: stringType(), want: sNilPtr}, {desc: "decode STRING to NullString", proto: stringProto("abc"), protoType: stringType(), want: NullString{"abc", true}}, {desc: "decode NULL to NullString", proto: nullProto(), protoType: stringType(), want: NullString{}}, // STRING ARRAY with []NullString @@ -976,6 +1129,9 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to []NullString", proto: nullProto(), protoType: listType(stringType()), want: []NullString(nil)}, // STRING ARRAY with []string {desc: "decode ARRAY to []string", proto: listProto(stringProto("abc"), stringProto("bcd")), protoType: listType(stringType()), want: []string{"abc", "bcd"}}, + // STRING ARRAY with []*string + {desc: "decode ARRAY to []*string", proto: listProto(stringProto("abc"), nullProto(), stringProto("bcd")), protoType: listType(stringType()), want: []*string{&sValue, sNilPtr, &s2Value}}, + {desc: "decode NULL to []*string", proto: nullProto(), protoType: listType(stringType()), want: []*string(nil)}, // BYTES {desc: "decode BYTES to []byte", proto: bytesProto([]byte("ab")), protoType: bytesType(), want: []byte("ab")}, {desc: "decode NULL to []byte", proto: nullProto(), protoType: bytesType(), want: []byte(nil)}, @@ -985,6 +1141,8 @@ func TestDecodeValue(t *testing.T) { //INT64 {desc: "decode INT64 to int64", proto: intProto(15), protoType: intType(), want: int64(15)}, {desc: "decode NULL to int64", proto: nullProto(), protoType: intType(), want: int64(0), wantErr: true}, + {desc: "decode INT64 to *int64", proto: intProto(15), protoType: intType(), want: &iValue}, + {desc: "decode NULL to *int64", proto: nullProto(), protoType: intType(), want: iNilPtr}, {desc: "decode INT64 to NullInt64", proto: intProto(15), protoType: intType(), want: NullInt64{15, true}}, {desc: "decode NULL to NullInt64", proto: nullProto(), protoType: intType(), want: NullInt64{}}, // INT64 ARRAY with []NullInt64 @@ -992,9 +1150,14 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to []NullInt64", proto: nullProto(), protoType: listType(intType()), want: []NullInt64(nil)}, // INT64 ARRAY with []int64 {desc: "decode ARRAY to []int64", proto: listProto(intProto(91), intProto(87)), protoType: listType(intType()), want: []int64{91, 87}}, + // INT64 ARRAY with []*int64 + {desc: "decode ARRAY to []*int64", proto: listProto(intProto(91), nullProto(), intProto(87)), protoType: listType(intType()), want: []*int64{&i1Value, nil, &i2Value}}, + {desc: "decode NULL to []*int64", proto: nullProto(), protoType: listType(intType()), want: []*int64(nil)}, // BOOL {desc: "decode BOOL to bool", proto: boolProto(true), protoType: boolType(), want: true}, {desc: "decode NULL to bool", proto: nullProto(), protoType: boolType(), want: true, wantErr: true}, + {desc: "decode BOOL to *bool", proto: boolProto(true), protoType: boolType(), want: &bValue}, + {desc: "decode NULL to *bool", proto: nullProto(), protoType: boolType(), want: bNilPtr}, {desc: "decode BOOL to NullBool", proto: boolProto(true), protoType: boolType(), want: NullBool{true, true}}, {desc: "decode BOOL to NullBool", proto: nullProto(), protoType: boolType(), want: NullBool{}}, // BOOL ARRAY with []NullBool @@ -1002,9 +1165,14 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to []NullBool", proto: nullProto(), protoType: listType(boolType()), want: []NullBool(nil)}, // BOOL ARRAY with []bool {desc: "decode ARRAY to []bool", proto: listProto(boolProto(true), boolProto(false)), protoType: listType(boolType()), want: []bool{true, false}}, + // BOOL ARRAY with []*bool + {desc: "decode ARRAY to []*bool", proto: listProto(boolProto(true), nullProto(), boolProto(false)), protoType: listType(boolType()), want: []*bool{&bValue, bNilPtr, &b2Value}}, + {desc: "decode NULL to []*bool", proto: nullProto(), protoType: listType(boolType()), want: []*bool(nil)}, // FLOAT64 {desc: "decode FLOAT64 to float64", proto: floatProto(3.14), protoType: floatType(), want: 3.14}, {desc: "decode NULL to float64", proto: nullProto(), protoType: floatType(), want: 0.00, wantErr: true}, + {desc: "decode FLOAT64 to *float64", proto: floatProto(3.14), protoType: floatType(), want: &fValue}, + {desc: "decode NULL to *float64", proto: nullProto(), protoType: floatType(), want: fNilPtr}, {desc: "decode FLOAT64 to NullFloat64", proto: floatProto(3.14), protoType: floatType(), want: NullFloat64{3.14, true}}, {desc: "decode NULL to NullFloat64", proto: nullProto(), protoType: floatType(), want: NullFloat64{}}, // FLOAT64 ARRAY with []NullFloat64 @@ -1012,25 +1180,38 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to []NullFloat64", proto: nullProto(), protoType: listType(floatType()), want: []NullFloat64(nil)}, // FLOAT64 ARRAY with []float64 {desc: "decode ARRAY to []float64", proto: listProto(floatProto(math.Inf(1)), floatProto(math.Inf(-1)), floatProto(3.1)), protoType: listType(floatType()), want: []float64{math.Inf(1), math.Inf(-1), 3.1}}, + // FLOAT64 ARRAY with []NullFloat64 + {desc: "decode ARRAY to []*float64", proto: listProto(floatProto(fValue), nullProto(), floatProto(f2Value)), protoType: listType(floatType()), want: []*float64{&fValue, nil, &f2Value}}, + {desc: "decode NULL to []*float64", proto: nullProto(), protoType: listType(floatType()), want: []*float64(nil)}, // TIMESTAMP {desc: "decode TIMESTAMP to time.Time", proto: timeProto(t1), protoType: timeType(), want: t1}, {desc: "decode TIMESTAMP to NullTime", proto: timeProto(t1), protoType: timeType(), want: NullTime{t1, true}}, {desc: "decode NULL to NullTime", proto: nullProto(), protoType: timeType(), want: NullTime{}}, + {desc: "decode TIMESTAMP to *time.Time", proto: timeProto(t1), protoType: timeType(), want: &tValue}, + {desc: "decode NULL to *time.Time", proto: nullProto(), protoType: timeType(), want: tNilPtr}, {desc: "decode INT64 to time.Time", proto: intProto(7), protoType: timeType(), want: time.Time{}, wantErr: true}, // TIMESTAMP ARRAY with []NullTime {desc: "decode ARRAY to []NullTime", proto: listProto(timeProto(t1), timeProto(t2), timeProto(t3), nullProto()), protoType: listType(timeType()), want: []NullTime{{t1, true}, {t2, true}, {t3, true}, {}}}, {desc: "decode NULL to []NullTime", proto: nullProto(), protoType: listType(timeType()), want: []NullTime(nil)}, // TIMESTAMP ARRAY with []time.Time {desc: "decode ARRAY to []time.Time", proto: listProto(timeProto(t1), timeProto(t2), timeProto(t3)), protoType: listType(timeType()), want: []time.Time{t1, t2, t3}}, + // TIMESTAMP ARRAY with []*time.Time + {desc: "decode ARRAY to []*time.Time", proto: listProto(timeProto(t1), nullProto(), timeProto(t2)), protoType: listType(timeType()), want: []*time.Time{&tValue, nil, &t2Value}}, + {desc: "decode NULL to []*time.Time", proto: nullProto(), protoType: listType(timeType()), want: []*time.Time(nil)}, // DATE {desc: "decode DATE to civil.Date", proto: dateProto(d1), protoType: dateType(), want: d1}, {desc: "decode DATE to NullDate", proto: dateProto(d1), protoType: dateType(), want: NullDate{d1, true}}, {desc: "decode NULL to NullDate", proto: nullProto(), protoType: dateType(), want: NullDate{}}, + {desc: "decode DATE to *civil.Date", proto: dateProto(d1), protoType: dateType(), want: &dValue}, + {desc: "decode NULL to *civil.Date", proto: nullProto(), protoType: dateType(), want: dNilPtr}, // DATE ARRAY with []NullDate {desc: "decode ARRAY to []NullDate", proto: listProto(dateProto(d1), dateProto(d2), nullProto()), protoType: listType(dateType()), want: []NullDate{{d1, true}, {d2, true}, {}}}, {desc: "decode NULL to []NullDate", proto: nullProto(), protoType: listType(dateType()), want: []NullDate(nil)}, // DATE ARRAY with []civil.Date {desc: "decode ARRAY to []civil.Date", proto: listProto(dateProto(d1), dateProto(d2)), protoType: listType(dateType()), want: []civil.Date{d1, d2}}, + // DATE ARRAY with []NullDate + {desc: "decode ARRAY to []*civil.Date", proto: listProto(dateProto(d1), nullProto(), dateProto(d2)), protoType: listType(dateType()), want: []*civil.Date{&dValue, nil, &d2Value}}, + {desc: "decode NULL to []*civil.Date", proto: nullProto(), protoType: listType(dateType()), want: []*civil.Date(nil)}, // STRUCT ARRAY // STRUCT schema is equal to the following Go struct: // type s struct { @@ -1555,6 +1736,121 @@ func TestDecodeStruct(t *testing.T) { } } +func TestDecodeStructWithPointers(t *testing.T) { + stype := &sppb.StructType{Fields: []*sppb.StructType_Field{ + {Name: "Str", Type: stringType()}, + {Name: "Int", Type: intType()}, + {Name: "Bool", Type: boolType()}, + {Name: "Float", Type: floatType()}, + {Name: "Time", Type: timeType()}, + {Name: "Date", Type: dateType()}, + {Name: "StrArray", Type: listType(stringType())}, + {Name: "IntArray", Type: listType(intType())}, + {Name: "BoolArray", Type: listType(boolType())}, + {Name: "FloatArray", Type: listType(floatType())}, + {Name: "TimeArray", Type: listType(timeType())}, + {Name: "DateArray", Type: listType(dateType())}, + }} + lv := []*proto3.ListValue{ + listValueProto( + stringProto("id"), + intProto(15), + boolProto(true), + floatProto(3.14), + timeProto(t1), + dateProto(d1), + listProto(stringProto("id1"), nullProto(), stringProto("id2")), + listProto(intProto(16), nullProto(), intProto(17)), + listProto(boolProto(true), nullProto(), boolProto(false)), + listProto(floatProto(3.14), nullProto(), floatProto(6.626)), + listProto(timeProto(t1), nullProto(), timeProto(t2)), + listProto(dateProto(d1), nullProto(), dateProto(d2)), + ), + listValueProto( + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + nullProto(), + ), + } + + type S1 struct { + Str *string + Int *int64 + Bool *bool + Float *float64 + Time *time.Time + Date *civil.Date + StrArray []*string + IntArray []*int64 + BoolArray []*bool + FloatArray []*float64 + TimeArray []*time.Time + DateArray []*civil.Date + } + var s1 S1 + sValue := "id" + iValue := int64(15) + bValue := true + fValue := 3.14 + tValue := t1 + dValue := d1 + sArrayValue1 := "id1" + sArrayValue2 := "id2" + sArrayValue := []*string{&sArrayValue1, nil, &sArrayValue2} + iArrayValue1 := int64(16) + iArrayValue2 := int64(17) + iArrayValue := []*int64{&iArrayValue1, nil, &iArrayValue2} + bArrayValue1 := true + bArrayValue2 := false + bArrayValue := []*bool{&bArrayValue1, nil, &bArrayValue2} + f1Value := 3.14 + f2Value := 6.626 + fArrayValue := []*float64{&f1Value, nil, &f2Value} + t1Value := t1 + t2Value := t2 + tArrayValue := []*time.Time{&t1Value, nil, &t2Value} + d1Value := d1 + d2Value := d2 + dArrayValue := []*civil.Date{&d1Value, nil, &d2Value} + + for i, test := range []struct { + desc string + ptr *S1 + want *S1 + fail bool + }{ + { + desc: "decode values to S1", + ptr: &s1, + want: &S1{Str: &sValue, Int: &iValue, Bool: &bValue, Float: &fValue, Time: &tValue, Date: &dValue, StrArray: sArrayValue, IntArray: iArrayValue, BoolArray: bArrayValue, FloatArray: fArrayValue, TimeArray: tArrayValue, DateArray: dArrayValue}, + }, + { + desc: "decode nulls to S1", + ptr: &s1, + want: &S1{Str: nil, Int: nil, Bool: nil, Float: nil, Time: nil, Date: nil, StrArray: nil, IntArray: nil, BoolArray: nil, FloatArray: nil, TimeArray: nil, DateArray: nil}, + }, + } { + err := decodeStruct(stype, lv[i], test.ptr) + if (err != nil) != test.fail { + t.Errorf("%s: got error %v, wanted fail: %v", test.desc, err, test.fail) + } + if err == nil { + if !testutil.Equal(test.ptr, test.want) { + t.Errorf("%s: got %+v, want %+v", test.desc, test.ptr, test.want) + } + } + } +} + func TestEncodeStructValueDynamicStructs(t *testing.T) { dynStructType := reflect.StructOf([]reflect.StructField{ {Name: "A", Type: reflect.TypeOf(0), Tag: `spanner:"a"`},