Skip to content

Commit

Permalink
Support for slice copying
Browse files Browse the repository at this point in the history
  • Loading branch information
glassonion1 committed Aug 17, 2024
1 parent 0b18a94 commit 648bf97
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 32 deletions.
96 changes: 64 additions & 32 deletions copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error {
return errors.New("copy to value is unaddressable")
}

if src.Kind() == reflect.Slice {
if err := copySlice(src, dst); err != nil {
return fmt.Errorf("%v", err)
}
return nil
}

// What to do if the deepcopy destination model has a tag
var srcToDstTagMap = map[string]string{}
for i := 0; i < dst.NumField(); i++ {
Expand Down Expand Up @@ -65,7 +72,7 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error {

isSet, err := setTimeField(srcFieldValue, dstFieldValue)
if err != nil {
return err
return fmt.Errorf("%v", err)
}
if isSet {
continue
Expand All @@ -78,51 +85,76 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error {
continue
case reflect.Struct:
if !field.Anonymous {
// struct to struct
s := reflect.New(dstFieldType.Type)
v := func() reflect.Value { return reflect.Indirect(s) }
if dstFieldType.Type.Kind() == reflect.Ptr {
// struct to ptr
s, v = reflect.New(dstFieldType.Type.Elem()), func() reflect.Value { return s }
}

if err := DeepCopy(srcFieldValue.Interface(), s.Interface()); err != nil {
dv, vFunc := instantiate(dstFieldValue)
if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil {
return fmt.Errorf("%v", err)
}
dstFieldValue.Set(v())
dstFieldValue.Set(vFunc())
continue
}
dstFieldValue.SetInt(srcFieldValue.Int())

case reflect.Ptr:
if !srcFieldValue.IsNil() {
// copy to indirect
indirect := reflect.Indirect(srcFieldValue)
if indirect.Type().AssignableTo(dstFieldType.Type) && dstFieldType.Type.Kind() != reflect.Ptr {
dstFieldValue.Set(indirect)
continue
}

// ptr to struct
s := reflect.New(dstFieldType.Type)
v := func() reflect.Value { return reflect.Indirect(s) }
if dstFieldType.Type.Kind() == reflect.Ptr {
// ptr to ptr
s, v = reflect.New(dstFieldType.Type.Elem()), func() reflect.Value { return s }
}

if err := DeepCopy(srcFieldValue.Interface(), s.Interface()); err != nil {
return fmt.Errorf("%v", err)
}
dstFieldValue.Set(v())
if srcFieldValue.IsNil() {
continue
}
// copy to indirect
indirect := reflect.Indirect(srcFieldValue)
if indirect.Type().AssignableTo(dstFieldType.Type) && dstFieldType.Type.Kind() != reflect.Ptr {
dstFieldValue.Set(indirect)
continue
}
dv, vFunc := instantiate(dstFieldValue)
if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil {
return fmt.Errorf("%v", err)
}
dstFieldValue.Set(vFunc())
continue

case reflect.Slice:
if err := copySlice(srcFieldValue, dstFieldValue); err != nil {
return fmt.Errorf("%v", err)
}
continue
}
}

return nil
}

// Instantiates a value that can handle copying in both directions - from a pointer to a struct and from a struct to a pointer.
func instantiate(v reflect.Value) (reflect.Value, func() reflect.Value) {
// ptr
if v.Type().Kind() == reflect.Ptr {
rv := reflect.New(v.Type().Elem())
vFunc := func() reflect.Value { return rv }
return rv, vFunc
}

// struct
rv := reflect.New(v.Type())
vFunc := func() reflect.Value { return reflect.Indirect(rv) }
return rv, vFunc
}

func copySlice(src, dst reflect.Value) error {
if src.IsNil() {
return nil
}
slice := reflect.MakeSlice(reflect.SliceOf(dst.Type().Elem()), src.Len(), src.Cap())
dst.Set(slice)

for i := 0; i < src.Len(); i++ {
d := dst.Index(i)
dv, vFunc := instantiate(d)

if err := DeepCopy(src.Index(i).Interface(), dv.Interface()); err != nil {
return err
}
d.Set(vFunc())
}
return nil
}

func setTimeField(src, dst reflect.Value) (bool, error) {
switch t := src.Interface().(type) {
case time.Time:
Expand Down
200 changes: 200 additions & 0 deletions copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,203 @@ func TestDeepCopy_Private(t *testing.T) {
})
}
}

func TestDeepCopy_Slice(t *testing.T) {
type Model1 struct {
Foo string
Bar int
}
type Model2 struct {
Foo string
Bar int
}
type Example1 struct {
ID string
Name string
State int
Tests []string
StructPtrs []*Model1
Structs []Model1
}
type Example2 struct {
ID string
Name string
State int
Tests []string
StructPtrs []*Model2
Structs []Model2
}

type args struct {
src Example1
dest *Example2
}

tests := []struct {
name string
in args
want *Example2
err error
}{
{
name: "slice value",
in: args{
src: Example1{
ID: "id1",
Name: "hoge1",
State: 100,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model1{{Foo: "foo1", Bar: 100}, {Foo: "foo2", Bar: 200}},
Structs: []Model1{{Foo: "foo1", Bar: 1000}, {Foo: "foo2", Bar: 2000}},
},
dest: &Example2{},
},
want: &Example2{
ID: "id1",
Name: "hoge1",
State: 100,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model2{{Foo: "foo1", Bar: 100}, {Foo: "foo2", Bar: 200}},
Structs: []Model2{{Foo: "foo1", Bar: 1000}, {Foo: "foo2", Bar: 2000}},
},
err: nil,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := xgo.DeepCopy(tt.in.src, tt.in.dest)
got := tt.in.dest
if tt.err == nil && err != nil {
t.Errorf("testing %s: should not be error for %#v but: %v", tt.name, tt.in, err)
}
if tt.err != nil && err == nil {
t.Errorf("testing %s: should be error for %#v but not:", tt.name, tt.in)
}
if tt.err != nil && err != tt.err {
t.Errorf("testing %s: should be error of %v but got: %v", tt.name, tt.err, err)
}
if ok := reflect.DeepEqual(tt.want, got); !ok {
t.Errorf("testing %s mismatch (-want +got):\n%v\n%v", tt.name, tt.want, got)
}
})
}
}

func TestDeepCopy_Slice2(t *testing.T) {
type Model1 struct {
Foo string
Bar int
}
type Model2 struct {
Foo string
Bar int
}
type Example1 struct {
ID string
Name string
State int
Tests []string
StructPtrs []*Model1
Structs []Model1
}
type Example2 struct {
ID string
Name string
State int
Tests []string
StructPtrs []*Model2
Structs []Model2
}

type args struct {
src []*Example1
dest []*Example2
}

tests := []struct {
name string
in args
want []*Example2
err error
}{
{
name: "slice value",
in: args{
src: []*Example1{
{
ID: "id1",
Name: "hoge1",
State: 100,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model1{{Foo: "foo1-1", Bar: 100}, {Foo: "foo1-2", Bar: 200}},
Structs: []Model1{{Foo: "foo1-1", Bar: 1000}, {Foo: "foo1-2", Bar: 2000}},
},
{
ID: "id2",
Name: "hoge2",
State: 200,
Tests: []string{"test1", "test2"},
Structs: []Model1{{Foo: "foo2-1", Bar: 1000}, {Foo: "foo2-2", Bar: 2000}},
},
{
ID: "id3",
Name: "hoge3",
State: 300,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model1{{Foo: "foo3-1", Bar: 100}, {Foo: "foo3-2", Bar: 200}},
},
},
dest: []*Example2{},
},
want: []*Example2{
{
ID: "id1",
Name: "hoge1",
State: 100,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model2{{Foo: "foo1-1", Bar: 100}, {Foo: "foo1-2", Bar: 200}},
Structs: []Model2{{Foo: "foo1-1", Bar: 1000}, {Foo: "foo1-2", Bar: 2000}},
},
{
ID: "id2",
Name: "hoge2",
State: 200,
Tests: []string{"test1", "test2"},
Structs: []Model2{{Foo: "foo2-1", Bar: 1000}, {Foo: "foo2-2", Bar: 2000}},
},
{
ID: "id3",
Name: "hoge3",
State: 300,
Tests: []string{"test1", "test2"},
StructPtrs: []*Model2{{Foo: "foo3-1", Bar: 100}, {Foo: "foo3-2", Bar: 200}},
},
},
err: nil,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := xgo.DeepCopy(tt.in.src, &tt.in.dest)
got := tt.in.dest
if tt.err == nil && err != nil {
t.Errorf("testing %s: should not be error for %#v but: %v", tt.name, tt.in, err)
}
if tt.err != nil && err == nil {
t.Errorf("testing %s: should be error for %#v but not:", tt.name, tt.in)
}
if tt.err != nil && err != tt.err {
t.Errorf("testing %s: should be error of %v but got: %v", tt.name, tt.err, err)
}
if ok := reflect.DeepEqual(tt.want, got); !ok {
t.Errorf("testing %s mismatch (-want +got):\n%v\n%v", tt.name, tt.want, got)
}
})
}
}

0 comments on commit 648bf97

Please sign in to comment.