diff --git a/copier.go b/copier.go index 88ecfca..1da4849 100644 --- a/copier.go +++ b/copier.go @@ -22,6 +22,13 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error { return errors.New("copy to value is unaddressable") } + if src.Kind() == reflect.Slice { + if err := copySlice(src, dst); err != nil { + return fmt.Errorf("%v", err) + } + return nil + } + // What to do if the deepcopy destination model has a tag var srcToDstTagMap = map[string]string{} for i := 0; i < dst.NumField(); i++ { @@ -65,7 +72,7 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error { isSet, err := setTimeField(srcFieldValue, dstFieldValue) if err != nil { - return err + return fmt.Errorf("%v", err) } if isSet { continue @@ -78,51 +85,76 @@ func DeepCopy(srcModel interface{}, dstModel interface{}) error { continue case reflect.Struct: if !field.Anonymous { - // struct to struct - s := reflect.New(dstFieldType.Type) - v := func() reflect.Value { return reflect.Indirect(s) } - if dstFieldType.Type.Kind() == reflect.Ptr { - // struct to ptr - s, v = reflect.New(dstFieldType.Type.Elem()), func() reflect.Value { return s } - } - - if err := DeepCopy(srcFieldValue.Interface(), s.Interface()); err != nil { + dv, vFunc := instantiate(dstFieldValue) + if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil { return fmt.Errorf("%v", err) } - dstFieldValue.Set(v()) + dstFieldValue.Set(vFunc()) continue } dstFieldValue.SetInt(srcFieldValue.Int()) - case reflect.Ptr: - if !srcFieldValue.IsNil() { - // copy to indirect - indirect := reflect.Indirect(srcFieldValue) - if indirect.Type().AssignableTo(dstFieldType.Type) && dstFieldType.Type.Kind() != reflect.Ptr { - dstFieldValue.Set(indirect) - continue - } - - // ptr to struct - s := reflect.New(dstFieldType.Type) - v := func() reflect.Value { return reflect.Indirect(s) } - if dstFieldType.Type.Kind() == reflect.Ptr { - // ptr to ptr - s, v = reflect.New(dstFieldType.Type.Elem()), func() reflect.Value { return s } - } - - if err := DeepCopy(srcFieldValue.Interface(), s.Interface()); err != nil { - return fmt.Errorf("%v", err) - } - dstFieldValue.Set(v()) + if srcFieldValue.IsNil() { continue } + // copy to indirect + indirect := reflect.Indirect(srcFieldValue) + if indirect.Type().AssignableTo(dstFieldType.Type) && dstFieldType.Type.Kind() != reflect.Ptr { + dstFieldValue.Set(indirect) + continue + } + dv, vFunc := instantiate(dstFieldValue) + if err := DeepCopy(srcFieldValue.Interface(), dv.Interface()); err != nil { + return fmt.Errorf("%v", err) + } + dstFieldValue.Set(vFunc()) + continue + + case reflect.Slice: + if err := copySlice(srcFieldValue, dstFieldValue); err != nil { + return fmt.Errorf("%v", err) + } + continue } } return nil } +// Instantiates a value that can handle copying in both directions - from a pointer to a struct and from a struct to a pointer. +func instantiate(v reflect.Value) (reflect.Value, func() reflect.Value) { + // ptr + if v.Type().Kind() == reflect.Ptr { + rv := reflect.New(v.Type().Elem()) + vFunc := func() reflect.Value { return rv } + return rv, vFunc + } + + // struct + rv := reflect.New(v.Type()) + vFunc := func() reflect.Value { return reflect.Indirect(rv) } + return rv, vFunc +} + +func copySlice(src, dst reflect.Value) error { + if src.IsNil() { + return nil + } + slice := reflect.MakeSlice(reflect.SliceOf(dst.Type().Elem()), src.Len(), src.Cap()) + dst.Set(slice) + + for i := 0; i < src.Len(); i++ { + d := dst.Index(i) + dv, vFunc := instantiate(d) + + if err := DeepCopy(src.Index(i).Interface(), dv.Interface()); err != nil { + return err + } + d.Set(vFunc()) + } + return nil +} + func setTimeField(src, dst reflect.Value) (bool, error) { switch t := src.Interface().(type) { case time.Time: diff --git a/copier_test.go b/copier_test.go index 6c764c8..5a24dbc 100644 --- a/copier_test.go +++ b/copier_test.go @@ -627,3 +627,203 @@ func TestDeepCopy_Private(t *testing.T) { }) } } + +func TestDeepCopy_Slice(t *testing.T) { + type Model1 struct { + Foo string + Bar int + } + type Model2 struct { + Foo string + Bar int + } + type Example1 struct { + ID string + Name string + State int + Tests []string + StructPtrs []*Model1 + Structs []Model1 + } + type Example2 struct { + ID string + Name string + State int + Tests []string + StructPtrs []*Model2 + Structs []Model2 + } + + type args struct { + src Example1 + dest *Example2 + } + + tests := []struct { + name string + in args + want *Example2 + err error + }{ + { + name: "slice value", + in: args{ + src: Example1{ + ID: "id1", + Name: "hoge1", + State: 100, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model1{{Foo: "foo1", Bar: 100}, {Foo: "foo2", Bar: 200}}, + Structs: []Model1{{Foo: "foo1", Bar: 1000}, {Foo: "foo2", Bar: 2000}}, + }, + dest: &Example2{}, + }, + want: &Example2{ + ID: "id1", + Name: "hoge1", + State: 100, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model2{{Foo: "foo1", Bar: 100}, {Foo: "foo2", Bar: 200}}, + Structs: []Model2{{Foo: "foo1", Bar: 1000}, {Foo: "foo2", Bar: 2000}}, + }, + err: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := xgo.DeepCopy(tt.in.src, tt.in.dest) + got := tt.in.dest + if tt.err == nil && err != nil { + t.Errorf("testing %s: should not be error for %#v but: %v", tt.name, tt.in, err) + } + if tt.err != nil && err == nil { + t.Errorf("testing %s: should be error for %#v but not:", tt.name, tt.in) + } + if tt.err != nil && err != tt.err { + t.Errorf("testing %s: should be error of %v but got: %v", tt.name, tt.err, err) + } + if ok := reflect.DeepEqual(tt.want, got); !ok { + t.Errorf("testing %s mismatch (-want +got):\n%v\n%v", tt.name, tt.want, got) + } + }) + } +} + +func TestDeepCopy_Slice2(t *testing.T) { + type Model1 struct { + Foo string + Bar int + } + type Model2 struct { + Foo string + Bar int + } + type Example1 struct { + ID string + Name string + State int + Tests []string + StructPtrs []*Model1 + Structs []Model1 + } + type Example2 struct { + ID string + Name string + State int + Tests []string + StructPtrs []*Model2 + Structs []Model2 + } + + type args struct { + src []*Example1 + dest []*Example2 + } + + tests := []struct { + name string + in args + want []*Example2 + err error + }{ + { + name: "slice value", + in: args{ + src: []*Example1{ + { + ID: "id1", + Name: "hoge1", + State: 100, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model1{{Foo: "foo1-1", Bar: 100}, {Foo: "foo1-2", Bar: 200}}, + Structs: []Model1{{Foo: "foo1-1", Bar: 1000}, {Foo: "foo1-2", Bar: 2000}}, + }, + { + ID: "id2", + Name: "hoge2", + State: 200, + Tests: []string{"test1", "test2"}, + Structs: []Model1{{Foo: "foo2-1", Bar: 1000}, {Foo: "foo2-2", Bar: 2000}}, + }, + { + ID: "id3", + Name: "hoge3", + State: 300, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model1{{Foo: "foo3-1", Bar: 100}, {Foo: "foo3-2", Bar: 200}}, + }, + }, + dest: []*Example2{}, + }, + want: []*Example2{ + { + ID: "id1", + Name: "hoge1", + State: 100, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model2{{Foo: "foo1-1", Bar: 100}, {Foo: "foo1-2", Bar: 200}}, + Structs: []Model2{{Foo: "foo1-1", Bar: 1000}, {Foo: "foo1-2", Bar: 2000}}, + }, + { + ID: "id2", + Name: "hoge2", + State: 200, + Tests: []string{"test1", "test2"}, + Structs: []Model2{{Foo: "foo2-1", Bar: 1000}, {Foo: "foo2-2", Bar: 2000}}, + }, + { + ID: "id3", + Name: "hoge3", + State: 300, + Tests: []string{"test1", "test2"}, + StructPtrs: []*Model2{{Foo: "foo3-1", Bar: 100}, {Foo: "foo3-2", Bar: 200}}, + }, + }, + err: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := xgo.DeepCopy(tt.in.src, &tt.in.dest) + got := tt.in.dest + if tt.err == nil && err != nil { + t.Errorf("testing %s: should not be error for %#v but: %v", tt.name, tt.in, err) + } + if tt.err != nil && err == nil { + t.Errorf("testing %s: should be error for %#v but not:", tt.name, tt.in) + } + if tt.err != nil && err != tt.err { + t.Errorf("testing %s: should be error of %v but got: %v", tt.name, tt.err, err) + } + if ok := reflect.DeepEqual(tt.want, got); !ok { + t.Errorf("testing %s mismatch (-want +got):\n%v\n%v", tt.name, tt.want, got) + } + }) + } +}