Skip to content

Commit

Permalink
fix: Assigning nil value to *uuid.UUID field in Updates
Browse files Browse the repository at this point in the history
This PR adds handling for the assignment of a nil value to a *uuid.UUID field (resolved as a reflect.Ptr to a reflect.Array), to ensure that the model object reflects the correct value after Updates() has completed. This PR also adds few supporting test cases for Updates() with a map and uuid.UUID column.
  • Loading branch information
omkar-foss committed Jul 5, 2024
1 parent 4a50b36 commit 8b2a8a9
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 7 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ require (
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.14.0
)

require github.com/google/uuid v1.6.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
Expand Down
4 changes: 3 additions & 1 deletion schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,9 @@ func (field *Field) setupValuerAndSetter() {
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
if field.FieldType.Elem().Kind() == reflect.Array {
field.ReflectValueOf(ctx, value).Set(reflectV)
}
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
Expand Down
6 changes: 3 additions & 3 deletions tests/connpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ func TestConnPoolWrapper(t *testing.T) {
db: nativeDB,
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`,`user_uuid`) VALUES (?,?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`,`user_uuid`) VALUES (?,?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`,`user_uuid`) VALUES (?,?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
},
Expand Down
6 changes: 3 additions & 3 deletions tests/sql_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestDryRun(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})

stmt := dryRunDB.Create(&user).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 9 {
if stmt.SQL.String() == "" || len(stmt.Vars) != 10 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}

Expand Down Expand Up @@ -403,7 +403,7 @@ func TestToSQL(t *testing.T) {
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&User{}).Create(user)
})
assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql)
assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active","user_uuid") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false,NULL) RETURNING "id"`, sql)

// save
user = &User{Name: "foo", Age: 20}
Expand All @@ -412,7 +412,7 @@ func TestToSQL(t *testing.T) {
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&User{}).Save(user)
})
assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql)
assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active","user_uuid") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false,NULL) RETURNING "id"`, sql)

// updates
user = &User{Name: "bar", Age: 22}
Expand Down
36 changes: 36 additions & 0 deletions tests/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
Expand Down Expand Up @@ -183,6 +184,41 @@ func TestUpdates(t *testing.T) {

user3.Age += 100
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")

// Updates() with map and uuid.UUID - Case 1 - Update with UUID value
uuidVal, uuidErr := uuid.NewUUID()
if uuidErr != nil {
t.Errorf("No error should occur while generating UUID, but got %v", uuidErr)
}
tx := DB.Model(&user4)
uuidErr = tx.Updates(map[string]interface{}{"user_uuid": uuidVal}).Error
if uuidErr != nil {
t.Errorf("No error should occur while updating with UUID value, but got %v", uuidErr)
}
// Expecting the model object (user4) to reflect the UUID value assignment.
AssertEqual(t, user4.UserUUID, uuidVal)

// Updates() with map and uuid.UUID - Case 2 - Update with UUID nil pointer
var nilUUIDPtr *uuid.UUID = nil
uuidErr = tx.Updates(map[string]interface{}{"user_uuid": nilUUIDPtr}).Error
if uuidErr != nil {
t.Errorf("No error should occur while updating with nil UUID pointer, but got %v", uuidErr)
}
// Expecting the model object (user4) to reflect the UUID nil pointer assignment.
AssertEqual(t, user4.UserUUID, nilUUIDPtr)

// Updates() with map and uuid.UUID - Case 3 - Update with a non-nil UUID pointer
uuidVal2, uuidErr := uuid.NewUUID()
if uuidErr != nil {
t.Errorf("No error should occur while generating UUID, but got %v", uuidErr)
}
var nonNilUUIDPtr *uuid.UUID = &uuidVal2
uuidErr = tx.Updates(map[string]interface{}{"user_uuid": nonNilUUIDPtr}).Error
if uuidErr != nil {
t.Errorf("No error should occur while updating with non-nil UUID pointer, but got %v", uuidErr)
}
// Expecting the model object (user4) to reflect the non-nil UUID pointer assignment.
AssertEqual(t, user4.UserUUID, nonNilUUIDPtr)
}

func TestUpdateColumn(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions utils/tests/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"time"

"github.com/google/uuid"
"gorm.io/gorm"
)

Expand All @@ -30,6 +31,7 @@ type User struct {
Languages []Language `gorm:"many2many:UserSpeak;"`
Friends []*User `gorm:"many2many:user_friends;"`
Active bool
UserUUID *uuid.UUID
}

type Account struct {
Expand Down

0 comments on commit 8b2a8a9

Please sign in to comment.