Skip to content

Commit

Permalink
Merge pull request #212 from Yamashou/fix-cliv2
Browse files Browse the repository at this point in the history
Fix MarshalJson for struct and map
  • Loading branch information
Yamashou authored Apr 17, 2024
2 parents 2f3fb12 + 37317a1 commit 0ca986d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 120 deletions.
60 changes: 53 additions & 7 deletions clientv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,21 @@ func (c *Client) unmarshal(data []byte, res interface{}) error {
}

func MarshalJSON(v interface{}) ([]byte, error) {
if v == nil {
return []byte("null"), nil // Directly return "null" for nil interface{}
}

val := reflect.ValueOf(v)
if !val.IsValid() || (val.Kind() == reflect.Ptr && val.IsNil()) {
return []byte("null"), nil // Return "null" for nil pointer or invalid reflect value
}

encoderFunc := getTypeEncoder(reflect.TypeOf(v))
return encoderFunc(v)
}

// getTypeEncoder returns an appropriate encoder function for the provided type.
func getTypeEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
func getTypeEncoder(t reflect.Type) func(a any) ([]byte, error) {
if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return gqlMarshalerEncoder
}
Expand Down Expand Up @@ -523,14 +532,51 @@ func prepareFields(t reflect.Type) []fieldInfo {
}

func checkMarshalerFields(t reflect.Type) bool {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
switch t.Kind() {
case reflect.Ptr:
return checkMarshalerFields(t.Elem())

case reflect.Struct:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if isMarshalerType(f.Type) {
return true
}
// Recursively check for nested structs
if checkMarshalerFields(f.Type) {
return true
}
}

case reflect.Map:
// Check both key and value types for Marshaler implementation; usually, value type is what matters
keyType, valueType := t.Key(), t.Elem()
if isMarshalerType(valueType) || isMarshalerType(keyType) {
return true
}
if reflect.PtrTo(f.Type).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
// Recursively check the map value type
if checkMarshalerFields(valueType) {
return true
}

case reflect.Slice, reflect.Array:
// Recursively check the element type
return checkMarshalerFields(t.Elem())
case reflect.Interface, reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return false
default:
return false
}

return false
}

func isMarshalerType(t reflect.Type) bool {
if t.Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return true
}
if reflect.PtrTo(t).Implements(reflect.TypeOf((*graphql.Marshaler)(nil)).Elem()) {
return true
}
return false
}
Expand All @@ -539,7 +585,7 @@ func newStructEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
fields := prepareFields(t)
marshalerFieldExists := checkMarshalerFields(t)

return func(v interface{}) ([]byte, error) {
return func(v any) ([]byte, error) {
// If no field implements the MarshalerGQL interface, use standard JSON marshaling
if !marshalerFieldExists {
return json.Marshal(v)
Expand Down Expand Up @@ -591,7 +637,7 @@ func newMapEncoder(t reflect.Type) func(interface{}) ([]byte, error) {
if err != nil {
return nil, err
}
result[keyStr] = json.RawMessage(encodedValue) // Use json.RawMessage to avoid double encoding
result[keyStr] = encodedValue
}

return json.Marshal(result)
Expand Down
Loading

0 comments on commit 0ca986d

Please sign in to comment.