Skip to content

Commit

Permalink
Improve cast logic for nil map params + errors clarification
Browse files Browse the repository at this point in the history
  • Loading branch information
demdxx committed Dec 3, 2023
1 parent 31bc05e commit 6213783
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 36 deletions.
22 changes: 19 additions & 3 deletions cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ import (

// TryTo cast any input type into the target
func TryTo(v, to any, tags ...string) (any, error) {
return TryToContext(context.Background(), v, reflect.TypeOf(to), tags...)
return TryToContext(context.Background(), v, to, tags...)
}

// TryToContext cast any input type into the target
func TryToContext(ctx context.Context, v, to any, tags ...string) (any, error) {
if v == nil {
return nil, ErrInvalidParams
if vl := reflect.ValueOf(v); !vl.IsValid() || vl.IsNil() {
return nil, nil
}
return nil, wrapError(ErrInvalidParams, `TryToContext: "v" is nil`)
}
return TryToTypeContext(ctx, v, reflect.TypeOf(to), tags...)
}
Expand All @@ -52,7 +55,10 @@ func TryToType(v any, t reflect.Type, tags ...string) (any, error) {
// TryToTypeContext cast any input type into the target reflection
func TryToTypeContext(ctx context.Context, v any, t reflect.Type, tags ...string) (any, error) {
if v == nil {
return nil, ErrInvalidParams
if vl := reflect.ValueOf(v); !vl.IsValid() || vl.IsNil() {
return nil, nil
}
return nil, wrapError(ErrInvalidParams, `TryToTypeContext: "v" is nil`)
}
val := reflect.ValueOf(v)
if t == nil { // In case of type is ANY make a copy of data
Expand All @@ -78,6 +84,14 @@ func ReflectTryToType(v reflect.Value, t reflect.Type, recursive bool, tags ...s
// ReflectTryToTypeContext converts reflection value to reflection type or returns error
func ReflectTryToTypeContext(ctx context.Context, srcVal reflect.Value, t reflect.Type, recursive bool, tags ...string) (any, error) {
v := reflectTarget(srcVal)
if !v.IsValid() {
switch t.Kind() {
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice, reflect.Array, reflect.Chan, reflect.Func:
return nil, nil
default:
return nil, wrapError(ErrInvalidParams, "ReflectTryToTypeContext: `srcVal` is invalid")
}
}
if v.Type() == t {
if k := t.Kind(); k != reflect.Struct &&
k != reflect.Map &&
Expand Down Expand Up @@ -187,6 +201,8 @@ func TryCastValueContext[R any, T any](ctx context.Context, v T, recursive bool,
switch nval := val.(type) {
case *R:
return *nval, nil
case nil:
return rVal, nil
default:
return val.(R), nil
}
Expand Down
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ func (w *errorWrapper) Error() string { return w.msg + ": " + w.err.Error() }
func (w *errorWrapper) Unwrap() error { return w.err }

func wrapError(err error, msg string) error {
if err == nil {
return nil
}
return &errorWrapper{err: err, msg: msg}
}

Expand Down
18 changes: 12 additions & 6 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ func TryMapCopy[K comparable, V any](dst map[K]V, src any, recursive bool, tags
// TryMapCopyContext converts source into destination or return error
func TryMapCopyContext[K comparable, V any](ctx context.Context, dst map[K]V, src any, recursive bool, tags ...string) error {
if dst == nil || src == nil {
return ErrInvalidParams
if dst == nil {
return wrapError(ErrInvalidParams, "TryMapCopyContext `destenation` parameter is nil")
}
return wrapError(ErrInvalidParams, "TryMapCopyContext `source` parameter is nil")
}
var (
srcVal = reflectTarget(reflect.ValueOf(src))
Expand All @@ -51,7 +54,7 @@ func TryMapCopyContext[K comparable, V any](ctx context.Context, dst map[K]V, sr
}
}
if err != nil {
return err
return wrapError(err, `"`+Str(k.Interface())+`" map key`)
}
}
case reflect.Struct:
Expand All @@ -71,7 +74,7 @@ func TryMapCopyContext[K comparable, V any](ctx context.Context, dst map[K]V, sr
dst[key], err = TryCastContext[V](ctx, fl, tags...)
}
if err != nil {
return err
return wrapError(err, "`"+name+"` struct key")
}
} // end if !omitempty || !IsEmpty(fl)
}
Expand All @@ -92,7 +95,10 @@ func ToMap(dst, src any, recursive bool, tags ...string) error {
// tag defines the tags name in the structure to map the keys
func ToMapContext(ctx context.Context, dst, src any, recursive bool, tags ...string) error {
if dst == nil || src == nil {
return ErrInvalidParams
if dst == nil {
return wrapError(ErrInvalidParams, "ToMapContext `destenation` parameter is nil")
}
return wrapError(ErrInvalidParams, "ToMapContext `source` parameter is nil")
}

var (
Expand All @@ -117,7 +123,7 @@ func ToMapContext(ctx context.Context, dst, src any, recursive bool, tags ...str
if recursive {
dest[k.Interface()], err = mapDestValue(field.Interface(), destType, recursive, tags...)
if err != nil {
return err
return wrapError(err, Str(k.Interface()))
}
} else {
dest[k.Interface()] = field.Interface()
Expand All @@ -133,7 +139,7 @@ func ToMapContext(ctx context.Context, dst, src any, recursive bool, tags ...str
if recursive {
dest[name], err = mapDestValue(fl, destType, recursive, tags...)
if err != nil {
return err
return wrapError(err, "`"+name+"` value")
}
} else {
dest[name] = fl
Expand Down
7 changes: 7 additions & 0 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ func TestMap(t *testing.T) {

target7 := MapRecursiveContext[string, map[string]any](context.TODO(), struct{ Foo struct{ Bar string } }{Foo: struct{ Bar string }{Bar: "boo"}})
assert.Equal(t, true, reflect.DeepEqual(map[string]map[string]any{"Foo": {"Bar": "boo"}}, target7))

// Nil check for map values
var nilMap = map[string]any{"default": nil, "sub1": []any{nil}, "sub2": map[string]any{"n1": nil, "n2": nil, "n3": nil}}
var target8 = map[string]any{}
err = ToMap(&target8, nilMap, true)
assert.NoError(t, err)
assert.Equal(t, true, reflect.DeepEqual(nilMap, target8))
}

func TestIsMap(t *testing.T) {
Expand Down
15 changes: 11 additions & 4 deletions slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,20 @@ func AnySliceContext[R any](ctx context.Context, src any, tags ...string) []R {
// TryToAnySliceContext converts any input slice into destination type slice
func TryToAnySliceContext(ctx context.Context, dst, src any, tags ...string) error {
if dst == nil || src == nil {
return ErrInvalidParams
if dst == nil {
return wrapError(ErrInvalidParams, "TryToAnySliceContext `destenation` parameter is nil")
}
return wrapError(ErrInvalidParams, "TryToAnySliceContext `source` parameter is nil")
}

dstSlice := reflectTarget(reflect.ValueOf(dst))
if k := dstSlice.Kind(); k != reflect.Slice && k != reflect.Array {
return ErrInvalidParams
return wrapError(ErrInvalidParams, "TryToAnySliceContext `destenation` parameter is not a slice or array")
}

srcSlice := reflectTarget(reflect.ValueOf(src))
if k := srcSlice.Kind(); k != reflect.Slice && k != reflect.Array {
return ErrInvalidParams
return wrapError(ErrInvalidParams, "TryToAnySliceContext `source` parameter is not a slice or array")
}

dstElemType := dstSlice.Type().Elem()
Expand Down Expand Up @@ -216,7 +219,11 @@ func TryToAnySliceContext(ctx context.Context, dst, src any, tags ...string) err
}
}
if v, err := ReflectTryToTypeContext(ctx, srcItem, dstElemType, true, tags...); err == nil {
dstItem.Set(reflect.ValueOf(v))
if v == nil {
dstItem.Set(reflect.Zero(dstElemType))
} else {
dstItem.Set(reflect.ValueOf(v))
}
} else {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func TestToSlice(t *testing.T) {
for _, test := range tests {
res, err := test.cfn(test.src)
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
assert.ErrorContains(t, err, test.err.Error())
} else if assert.NoError(t, err) {
assert.ElementsMatch(t, test.trg, res)
}
Expand Down
53 changes: 31 additions & 22 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ func TryCopyStruct(dst, src any, tags ...string) (err error) {
// TryCopyStructContext convert any input type into the target structure
func TryCopyStructContext(ctx context.Context, dst, src any, tags ...string) (err error) {
if dst == nil || src == nil {
return ErrInvalidParams
if dst == nil {
return wrapError(ErrInvalidParams, "TryCopyStructContext `destenation` parameter is nil")
}
return wrapError(ErrInvalidParams, "TryCopyStructContext `source` parameter is nil")
}

if sintf, ok := dst.(CastSetter); ok {
Expand Down Expand Up @@ -84,33 +87,39 @@ func TryCopyStructContext(ctx context.Context, dst, src any, tags ...string) (er

// Set field value
if v == nil {
if err = setFieldValueReflect(ctx, field, reflect.Zero(field.Type())); err != nil {
break
}
continue
}

switch field.Kind() {
case reflect.Struct:
err = TryCopyStructContext(ctx, field.Addr().Interface(), v, tags...)
default:
var vl any
if vl, err = TryToTypeContext(ctx, v, field.Type(), tags...); err == nil {
val := reflect.ValueOf(vl)
if val.Kind() == reflect.Ptr && val.Kind() != field.Kind() {
val = val.Elem()
}
err = setFieldValueReflect(ctx, field, val)
} else if setter, _ := field.Interface().(CastSetter); setter != nil {
err = setter.CastSet(ctx, v)
} else if field.CanAddr() {
if setter, _ := field.Addr().Interface().(CastSetter); setter != nil {
err = setFieldValueReflect(ctx, field, reflect.Zero(field.Type()))
} else {
switch field.Kind() {
case reflect.Struct:
err = TryCopyStructContext(ctx, field.Addr().Interface(), v, tags...)
default:
var (
vl any
ok = false
)
if setter, _ := field.Interface().(CastSetter); setter != nil {
err = setter.CastSet(ctx, v)
ok = true
} else if field.CanAddr() {
if setter, _ := field.Addr().Interface().(CastSetter); setter != nil {
err = setter.CastSet(ctx, v)
ok = true
}
}
if !ok {
if vl, err = TryToTypeContext(ctx, v, field.Type(), tags...); err == nil {
val := reflect.ValueOf(vl)
if val.Kind() == reflect.Ptr && val.Kind() != field.Kind() {
val = val.Elem()
}
err = setFieldValueReflect(ctx, field, val)
}
}
}
}

if err != nil {
err = wrapError(err, ft.Name)
break
}
}
Expand Down

0 comments on commit 6213783

Please sign in to comment.