Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle NULL values in UniqueIdentifier.Scan() #163

Merged
merged 11 commits into from
Feb 21, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host
* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas.
* Supports query notifications
* Supports Kerberos Authentication
* Supports handling the `uniqueidentifier` data type with the `UniqueIdentifier` and `NullUniqueIdentifier` go types
* Pluggable Dialer implementations through `msdsn.ProtocolParsers` and `msdsn.ProtocolDialers`
* A `namedpipe` package to support connections using named pipes (np:) on Windows
* A `sharedmemory` package to support connections using shared memory (lpc:) on Windows
Expand Down
2 changes: 2 additions & 0 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func TestBulkcopy(t *testing.T) {
{"test_intf32", float32(1234.56), 1234},
{"test_geom", geom, string(geom)},
{"test_uniqueidentifier", uid, string(uid)},
{"test_nulluniqueidentifier", nil, nil},
// {"test_smallmoney", 1234.56, nil},
// {"test_money", 1234.56, nil},
{"test_decimal_18_0", 1234.0001, "1234"},
Expand Down Expand Up @@ -270,6 +271,7 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_geog] [geography] NULL,
[text_xml] [xml] NULL,
[test_uniqueidentifier] [uniqueidentifier] NULL,
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
[test_decimal_18_0] [decimal](18, 0) NULL,
[test_decimal_18_2] [decimal](18, 2) NULL,
[test_decimal_9_2] [decimal](9, 2) NULL,
Expand Down
65 changes: 65 additions & 0 deletions uniqueidentifier_null.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package mssql

import (
"database/sql/driver"
)

type NullUniqueIdentifier struct {
UUID UniqueIdentifier
Valid bool // Valid is true if UUID is not NULL
}

func (n *NullUniqueIdentifier) Scan(v interface{}) error {
if v == nil {
*n = NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
return nil
}
u := n.UUID
err := u.Scan(v)
*n = NullUniqueIdentifier{
UUID: u,
Valid: true,
}
return err
}

func (n NullUniqueIdentifier) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.UUID.Value()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turn n.UUID.Value()

Won't Value/String/Marshal etc need to handle !Valid?
Pls add tests for these. If there are existing functional tests that cover String() and Value() implementations of other types let's add a case to those too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn´t entirely sure what !Valid for String should be 34e68bc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this type need to implement String() ?
I don't see the core nullable types implementing it
https://cs.opensource.google/go/go/+/master:src/database/sql/sql.go?q=nullint64&ss=go%2Fgo

}

func (n NullUniqueIdentifier) String() string {
if !n.Valid {
return "NULL"
}
return n.UUID.String()
}

func (n NullUniqueIdentifier) MarshalText() (text []byte, err error) {
if !n.Valid {
return []byte("null"), nil
}
return n.UUID.MarshalText()
}

func (n *NullUniqueIdentifier) UnmarshalJSON(b []byte) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b []byte

i think json would use a literal "null" string

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

45eb78e should handle that case

u := n.UUID
if string(b) == "null" {
*n = NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
return nil
}
err := u.UnmarshalJSON(b)
*n = NullUniqueIdentifier{
UUID: u,
Valid: true,
}
return err
}
215 changes: 215 additions & 0 deletions uniqueidentifier_null_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package mssql

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"testing"
)

func TestNullableUniqueIdentifierScanNull(t *testing.T) {
t.Parallel()
nullUUID := NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}

sut := NullUniqueIdentifier{
UUID: [16]byte{0x1},
Valid: true,
}
scanErr := sut.Scan(nil) // NULL in the DB
if scanErr != nil {
t.Fatal("NullUniqueIdentifier should not error out on Scan(nil)")
}
if sut != nullUUID {
t.Errorf("bytes not swapped correctly: got %q; want %q", sut, nullUUID)
}
}

func TestNullableUniqueIdentifierScanBytes(t *testing.T) {
t.Parallel()
dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

var sut NullUniqueIdentifier
scanErr := sut.Scan(dbUUID[:])
if scanErr != nil {
t.Fatal(scanErr)
}
if sut != uuid {
t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid)
}
}

func TestNullableUniqueIdentifierScanString(t *testing.T) {
t.Parallel()
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

var sut NullUniqueIdentifier
scanErr := sut.Scan(uuid.String())
if scanErr != nil {
t.Fatal(scanErr)
}
if sut != uuid {
t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid)
}
}

func TestNullableUniqueIdentifierScanUnexpectedType(t *testing.T) {
t.Parallel()
var sut NullUniqueIdentifier
scanErr := sut.Scan(int(1))
if scanErr == nil {
t.Fatal(scanErr)
}
}

func TestNullableUniqueIdentifierValue(t *testing.T) {
t.Parallel()
dbUUID := [16]byte{0x67, 0x45, 0x23, 0x01, 0xAB, 0x89, 0xEF, 0xCD, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}

uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

sut := uuid
v, valueErr := sut.Value()
if valueErr != nil {
t.Fatal(valueErr)
}

b, ok := v.([]byte)
if !ok {
t.Fatalf("(%T) is not []byte", v)
}

if !bytes.Equal(b, dbUUID[:]) {
t.Errorf("got %q; want %q", b, dbUUID)
}
}

func TestNullableUniqueIdentifierValueNull(t *testing.T) {
t.Parallel()
uuid := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}

sut := uuid
v, valueErr := sut.Value()
if valueErr != nil {
t.Errorf("unexpected error for invalid uuid: %s", valueErr)
}

if v != nil {
t.Errorf("expected non-nil value for invalid uuid: %s", v)
}
}

func TestNullableUniqueIdentifierString(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
expected := "01234567-89AB-CDEF-0123-456789ABCDEF"
if actual := sut.String(); actual != expected {
t.Errorf("sut.String() = %s; want %s", sut, expected)
}
}

func TestNullableUniqueIdentifierStringNull(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}
expected := "NULL"
if actual := sut.String(); actual != expected {
t.Errorf("sut.String() = %s; want %s", sut, expected)
}
}

func TestNullableUniqueIdentifierMarshalText(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
expected := []byte{48, 49, 50, 51, 52, 53, 54, 55, 45, 56, 57, 65, 66, 45, 67, 68, 69, 70, 45, 48, 49, 50, 51, 45, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69, 70}
text, marshalErr := sut.MarshalText()
if marshalErr != nil {
t.Errorf("unexpected error while marshalling: %s", marshalErr)
}
if actual := text; !reflect.DeepEqual(actual, expected) {
t.Errorf("sut.MarshalText() = %v; want %v", actual, expected)
}
}

func TestNullableUniqueIdentifierMarshalTextNull(t *testing.T) {
t.Parallel()
sut := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: false,
}
expected := []byte("null")
text, marshalErr := sut.MarshalText()
if marshalErr != nil {
t.Errorf("unexpected error while marshalling: %s", marshalErr)
}
if actual := text; !reflect.DeepEqual(actual, expected) {
t.Errorf("sut.MarshalText() = %v; want %v", actual, expected)
}
}

func TestNullableUniqueIdentifierUnmarshalJSON(t *testing.T) {
t.Parallel()
input := []byte("01234567-89AB-CDEF-0123-456789ABCDEF")
var u NullUniqueIdentifier

err := u.UnmarshalJSON(input)
if err != nil {
t.Fatal(err)
}
expected := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}
if u != expected {
t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected)
}
}

func TestNullableUniqueIdentifierUnmarshalJSONNull(t *testing.T) {
t.Parallel()
u := NullUniqueIdentifier{
UUID: [16]byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
Valid: true,
}

err := u.UnmarshalJSON([]byte("null"))
if err != nil {
t.Fatal(err)
}
expected := NullUniqueIdentifier{
UUID: [16]byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Valid: false,
}
if u != expected {
t.Errorf("u.UnmarshalJSON() = %v; want %v", u, expected)
}
}

var _ fmt.Stringer = NullUniqueIdentifier{}
var _ sql.Scanner = &NullUniqueIdentifier{}
var _ driver.Valuer = NullUniqueIdentifier{}
Loading