Skip to content

Commit

Permalink
Fix bug of copier
Browse files Browse the repository at this point in the history
  • Loading branch information
glassonion1 committed Aug 20, 2024
1 parent d0d6ae7 commit 1e2cd6f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 20 deletions.
60 changes: 40 additions & 20 deletions copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ func DeepCopyWithCustomSetter(

for i := 0; i < src.NumField(); i++ {
field := src.Type().Field(i)
srcFieldType, ok := src.Type().FieldByName(field.Name)
srcFieldValue := src.FieldByName(field.Name)
if !ok {
continue
}

dstFieldName := field.Name
if tag, ok := field.Tag.Lookup(tagCopier); ok {
Expand All @@ -64,7 +60,7 @@ func DeepCopyWithCustomSetter(
dstFieldName = tag
}

dstFieldType, ok := dst.Type().FieldByName(dstFieldName)
_, ok := dst.Type().FieldByName(dstFieldName)
dstFieldValue := dst.FieldByName(dstFieldName)
if !ok {
continue
Expand All @@ -76,14 +72,35 @@ func DeepCopyWithCustomSetter(
}

// string, int, float
if srcFieldType.Type.ConvertibleTo(dstFieldType.Type) {
dstFieldValue.Set(srcFieldValue.Convert(dstFieldType.Type))
if srcFieldValue.Type().ConvertibleTo(dstFieldValue.Type()) {
dstFieldValue.Set(srcFieldValue.Convert(dstFieldValue.Type()))
continue
}

// *string, *int, *float to string, int, float
if srcFieldValue.Type().Kind() == reflect.Ptr {
if srcFieldValue.IsNil() {
continue
}
if srcFieldValue.Type().Elem().ConvertibleTo(dstFieldValue.Type()) {
dstFieldValue.Set(srcFieldValue.Elem().Convert(dstFieldValue.Type()))
continue
}
}

// string, int, float to *string, *int, *float
if dstFieldValue.Type().Kind() == reflect.Ptr {
if srcFieldValue.Type().ConvertibleTo(dstFieldValue.Type().Elem()) {
rv := reflect.New(dstFieldValue.Type().Elem())
rv.Elem().Set(srcFieldValue.Convert(dstFieldValue.Type().Elem()))
dstFieldValue.Set(rv)
continue
}
}

isSet, err := customSetter(srcFieldValue, dstFieldValue)
if err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("%s: %v", field.Name, err)
}
if isSet {
continue
Expand All @@ -92,7 +109,7 @@ func DeepCopyWithCustomSetter(
// set the time.Time field
isSet, err = setTimeField(srcFieldValue, dstFieldValue)
if err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("%s: %v", field.Name, err)
}
if isSet {
continue
Expand All @@ -104,31 +121,33 @@ func DeepCopyWithCustomSetter(
if !field.Anonymous {
dv, vFunc := instantiate(dstFieldValue)
if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("%s: %v", field.Name, err)
}
dstFieldValue.Set(vFunc())
continue
}
dstFieldValue.SetInt(srcFieldValue.Int())
case reflect.Ptr:

if srcFieldValue.IsNil() {
continue
}
// copy to indirect
indirect := reflect.Indirect(srcFieldValue)
if indirect.Type().AssignableTo(dstFieldType.Type) && dstFieldType.Type.Kind() != reflect.Ptr {
if indirect.Type().AssignableTo(dstFieldValue.Type()) && dstFieldValue.Type().Kind() != reflect.Ptr {
dstFieldValue.Set(indirect)
continue
}

dv, vFunc := instantiate(dstFieldValue)
if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("%s: %v", field.Name, err)
}
dstFieldValue.Set(vFunc())
continue
case reflect.Slice:
if err := copySlice(srcFieldValue, dstFieldValue); err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("%s: %v", field.Name, err)
}
continue
}
Expand Down Expand Up @@ -162,8 +181,15 @@ func copySlice(src, dst reflect.Value) error {

for i := 0; i < src.Len(); i++ {
d := dst.Index(i)
dv, vFunc := instantiate(d)

// Other than pointer and struct
if src.Index(i).Type().ConvertibleTo(d.Type()) {
d.Set(src.Index(i).Convert(d.Type()))
continue
}

// pointer or struct
dv, vFunc := instantiate(d)
if err := DeepCopy(src.Index(i).Interface(), dv.Interface()); err != nil {
return err
}
Expand All @@ -180,9 +206,6 @@ func setTimeField(src, dst reflect.Value) (bool, error) {
case int64:
dst.Set(reflect.ValueOf(t.Unix()))
return true, nil
case *time.Time:
dst.Set(reflect.ValueOf(&t))
return true, nil
}

case *time.Time:
Expand All @@ -194,9 +217,6 @@ func setTimeField(src, dst reflect.Value) (bool, error) {
case int64:
dst.Set(reflect.ValueOf(t.Unix()))
return true, nil
case time.Time:
dst.Set(reflect.ValueOf(*t))
return true, nil
}

case int64:
Expand Down
98 changes: 98 additions & 0 deletions copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,101 @@ func TestDeepCopy_slice2(t *testing.T) {
})
}
}

func TestDeepCopy_ptr(t *testing.T) {
type Foo string
type Bar string
type FooNum int
type BarNum int

type Example1 struct {
Tests []Foo
Test2s []*Foo
Test *Foo
Test2 Foo
Nums []FooNum
NumPtrs []*FooNum
}
type Example2 struct {
Tests []Bar
Test2s []Bar
Test Bar
Test2 *Bar
Nums []BarNum
NumPtrs []*BarNum
}

type args struct {
src Example1
dest *Example2
}

tests := []struct {
name string
in args
want *Example2
err error
}{
{
name: "slice value",
in: args{
src: Example1{
Tests: []Foo{"test1", "test2"},
//Test2s: []*Foo{xgo.ToPtr(Foo("test1")), xgo.ToPtr(Foo("test2"))},
Test: xgo.ToPtr(Foo("test3")),
Test2: "test4",
Nums: []FooNum{1, 2, 3, 4},
NumPtrs: []*FooNum{xgo.ToPtr(FooNum(1)), xgo.ToPtr(FooNum(2))},
},
dest: &Example2{},
},
want: &Example2{
Tests: []Bar{"test1", "test2"},
//Test2s: []Bar{"test1", "test2"},
Test: "test3",
Test2: xgo.ToPtr(Bar("test4")),
Nums: []BarNum{1, 2, 3, 4},
NumPtrs: []*BarNum{xgo.ToPtr(BarNum(1)), xgo.ToPtr(BarNum(2))},
},
err: nil,
},
{
name: "nil or zero value",
in: args{
src: Example1{
Tests: []Foo{"test1", "test2"},
Nums: []FooNum{1, 2, 3, 4},
},
dest: &Example2{},
},
want: &Example2{
Tests: []Bar{"test1", "test2"},
Test: "",
Test2: xgo.ToPtr(Bar("")),
Nums: []BarNum{1, 2, 3, 4},
},
err: nil,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := xgo.DeepCopy(tt.in.src, tt.in.dest)
got := tt.in.dest
if tt.err == nil && err != nil {
t.Errorf("testing %s: should not be error for %#v but: %v", tt.name, tt.in, err)
}
if tt.err != nil && err == nil {
t.Errorf("testing %s: should be error for %#v but not:", tt.name, tt.in)
}
if tt.err != nil && err != tt.err {
t.Errorf("testing %s: should be error of %v but got: %v", tt.name, tt.err, err)
}
if ok := reflect.DeepEqual(tt.want, got); !ok {
t.Errorf("testing %s mismatch (-want +got):\n%v\n%v", tt.name, tt.want, got)
}
})
}
}

0 comments on commit 1e2cd6f

Please sign in to comment.