diff --git a/codegen/directive_build.go b/codegen/directive_build.go index 3282884112f..af77dc441f6 100644 --- a/codegen/directive_build.go +++ b/codegen/directive_build.go @@ -32,7 +32,6 @@ func (cfg *Config) buildDirectives(types NamedTypes) ([]*Directive, error) { if err != nil { return nil, errors.Errorf("default value for directive argument %s(%s) is not valid: %s", dir.Name, arg.Name, err.Error()) } - newArg.StripPtr() } args = append(args, newArg) } diff --git a/codegen/object_build.go b/codegen/object_build.go index ee2b2f1c2b9..328d5adcdbe 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -162,7 +162,6 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I if err != nil { return nil, errors.Errorf("default value for %s.%s is not valid: %s", typ.Name, field.Name, err.Error()) } - newArg.StripPtr() } args = append(args, newArg) } diff --git a/codegen/testserver/element.go b/codegen/testserver/element.go deleted file mode 100644 index 3ddb0ba459e..00000000000 --- a/codegen/testserver/element.go +++ /dev/null @@ -1,31 +0,0 @@ -package testserver - -import ( - "context" - "errors" - "time" -) - -type Element struct { - ID int -} - -type ElementResolver struct{} - -func (r *ElementResolver) Query_path(ctx context.Context) ([]Element, error) { - return []Element{{1}, {2}, {3}, {4}}, nil -} - -func (r *ElementResolver) Element_child(ctx context.Context, obj *Element) (Element, error) { - return Element{obj.ID * 10}, nil -} - -func (r *ElementResolver) Element_error(ctx context.Context, obj *Element, message *string) (bool, error) { - // A silly hack to make the result order stable - time.Sleep(time.Duration(obj.ID) * 10 * time.Millisecond) - - if message != nil { - return true, errors.New(*message) - } - return false, nil -} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index e98a1dc7638..c8e7721c028 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -92,6 +92,7 @@ type ComplexityRoot struct { ErrorBubble func(childComplexity int) int Valid func(childComplexity int) int User func(childComplexity int, id int) int + NullableArg func(childComplexity int, arg *int) int KeywordArgs func(childComplexity int, breakArg string, defaultArg string, funcArg string, interfaceArg string, selectArg string, caseArg string, deferArg string, goArg string, mapArg string, structArg string, chanArg string, elseArg string, gotoArg string, packageArg string, switchArg string, constArg string, fallthroughArg string, ifArg string, rangeArg string, typeArg string, continueArg string, forArg string, importArg string, returnArg string, varArg string) int } @@ -127,6 +128,7 @@ type QueryResolver interface { ErrorBubble(ctx context.Context) (*Error, error) Valid(ctx context.Context) (string, error) User(ctx context.Context, id int) (User, error) + NullableArg(ctx context.Context, arg *int) (*string, error) KeywordArgs(ctx context.Context, breakArg string, defaultArg string, funcArg string, interfaceArg string, selectArg string, caseArg string, deferArg string, goArg string, mapArg string, structArg string, chanArg string, elseArg string, gotoArg string, packageArg string, switchArg string, constArg string, fallthroughArg string, ifArg string, rangeArg string, typeArg string, continueArg string, forArg string, importArg string, returnArg string, varArg string) (bool, error) } type SubscriptionResolver interface { @@ -253,6 +255,26 @@ func field_Query_user_args(rawArgs map[string]interface{}) (map[string]interface } +func field_Query_nullableArg_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { + args := map[string]interface{}{} + var arg0 *int + if tmp, ok := rawArgs["arg"]; ok { + var err error + var ptr1 int + if tmp != nil { + ptr1, err = graphql.UnmarshalInt(tmp) + arg0 = &ptr1 + } + + if err != nil { + return nil, err + } + } + args["arg"] = arg0 + return args, nil + +} + func field_Query_keywordArgs_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { args := map[string]interface{}{} var arg0 string @@ -735,6 +757,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.User(childComplexity, args["id"].(int)), true + case "Query.nullableArg": + if e.complexity.Query.NullableArg == nil { + break + } + + args, err := field_Query_nullableArg_args(rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.NullableArg(childComplexity, args["arg"].(*int)), true + case "Query.keywordArgs": if e.complexity.Query.KeywordArgs == nil { break @@ -964,7 +998,8 @@ func (ec *executionContext) _EmbeddedPointer_ID(ctx context.Context, field graph Field: field, } ctx = graphql.WithResolverContext(ctx, rctx) - resTmp := ec.FieldMiddleware(ctx, obj, func(ctx context.Context) (interface{}, error) { + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children return obj.ID, nil }) if resTmp == nil { @@ -983,7 +1018,8 @@ func (ec *executionContext) _EmbeddedPointer_Title(ctx context.Context, field gr Field: field, } ctx = graphql.WithResolverContext(ctx, rctx) - resTmp := ec.FieldMiddleware(ctx, obj, func(ctx context.Context) (interface{}, error) { + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children return obj.Title, nil }) if resTmp == nil { @@ -1496,6 +1532,12 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } wg.Done() }(i, field) + case "nullableArg": + wg.Add(1) + go func(i int, field graphql.CollectedField) { + out.Values[i] = ec._Query_nullableArg(ctx, field) + wg.Done() + }(i, field) case "keywordArgs": wg.Add(1) go func(i int, field graphql.CollectedField) { @@ -1912,6 +1954,36 @@ func (ec *executionContext) _Query_user(ctx context.Context, field graphql.Colle return ec._User(ctx, field.Selections, &res) } +// nolint: vetshadow +func (ec *executionContext) _Query_nullableArg(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + rawArgs := field.ArgumentMap(ec.Variables) + args, err := field_Query_nullableArg_args(rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + rctx := &graphql.ResolverContext{ + Object: "Query", + Args: args, + Field: field, + } + ctx = graphql.WithResolverContext(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().NullableArg(rctx, args["arg"].(*int)) + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + rctx.Result = res + + if res == nil { + return graphql.Null + } + return graphql.MarshalString(*res) +} + // nolint: vetshadow func (ec *executionContext) _Query_keywordArgs(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { rawArgs := field.ArgumentMap(ec.Variables) @@ -3870,6 +3942,7 @@ var parsedSchema = gqlparser.MustLoadSchema( errorBubble: Error valid: String! user(id: Int!): User! + nullableArg(arg: Int = 123): String } type Subscription { diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index 0d6e0c674ba..6d133636641 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -175,6 +175,15 @@ func TestGeneratedServer(t *testing.T) { sub.Close() }) }) + + t.Run("null args", func(t *testing.T) { + var resp struct { + NullableArg *string + } + err := c.Post(`query { nullableArg(arg: null) }`, &resp) + require.Nil(t, err) + require.Equal(t, "Ok", *resp.NullableArg) + }) } func TestResponseExtension(t *testing.T) { @@ -227,6 +236,11 @@ func (r *testQueryResolver) User(ctx context.Context, id int) (User, error) { return User{ID: 1}, nil } +func (r *testQueryResolver) NullableArg(ctx context.Context, arg *int) (*string, error) { + s := "Ok" + return &s, nil +} + func (r *testResolver) Subscription() SubscriptionResolver { return &testSubscriptionResolver{r} } diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 4d4ca217fcd..4406ed374a5 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -65,6 +65,9 @@ func (r *queryResolver) Valid(ctx context.Context) (string, error) { func (r *queryResolver) User(ctx context.Context, id int) (User, error) { panic("not implemented") } +func (r *queryResolver) NullableArg(ctx context.Context, arg *int) (*string, error) { + panic("not implemented") +} func (r *queryResolver) KeywordArgs(ctx context.Context, breakArg string, defaultArg string, funcArg string, interfaceArg string, selectArg string, caseArg string, deferArg string, goArg string, mapArg string, structArg string, chanArg string, elseArg string, gotoArg string, packageArg string, switchArg string, constArg string, fallthroughArg string, ifArg string, rangeArg string, typeArg string, continueArg string, forArg string, importArg string, returnArg string, varArg string) (bool, error) { panic("not implemented") } diff --git a/codegen/testserver/schema.graphql b/codegen/testserver/schema.graphql index 207bd971c5a..02a69212482 100644 --- a/codegen/testserver/schema.graphql +++ b/codegen/testserver/schema.graphql @@ -10,6 +10,7 @@ type Query { errorBubble: Error valid: String! user(id: Int!): User! + nullableArg(arg: Int = 123): String } type Subscription { diff --git a/example/scalars/generated.go b/example/scalars/generated.go index 6937f665716..58c21d09621 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -48,7 +48,7 @@ type ComplexityRoot struct { Query struct { User func(childComplexity int, id external.ObjectID) int - Search func(childComplexity int, input model.SearchArgs) int + Search func(childComplexity int, input *model.SearchArgs) int } User struct { @@ -65,7 +65,7 @@ type ComplexityRoot struct { type QueryResolver interface { User(ctx context.Context, id external.ObjectID) (*model.User, error) - Search(ctx context.Context, input model.SearchArgs) ([]model.User, error) + Search(ctx context.Context, input *model.SearchArgs) ([]model.User, error) } type UserResolver interface { PrimitiveResolver(ctx context.Context, obj *model.User) (string, error) @@ -89,10 +89,15 @@ func field_Query_user_args(rawArgs map[string]interface{}) (map[string]interface func field_Query_search_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { args := map[string]interface{}{} - var arg0 model.SearchArgs + var arg0 *model.SearchArgs if tmp, ok := rawArgs["input"]; ok { var err error - arg0, err = UnmarshalSearchArgs(tmp) + var ptr1 model.SearchArgs + if tmp != nil { + ptr1, err = UnmarshalSearchArgs(tmp) + arg0 = &ptr1 + } + if err != nil { return nil, err } @@ -196,7 +201,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Search(childComplexity, args["input"].(model.SearchArgs)), true + return e.complexity.Query.Search(childComplexity, args["input"].(*model.SearchArgs)), true case "User.id": if e.complexity.User.Id == nil { @@ -462,7 +467,7 @@ func (ec *executionContext) _Query_search(ctx context.Context, field graphql.Col ctx = graphql.WithResolverContext(ctx, rctx) resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Search(rctx, args["input"].(model.SearchArgs)) + return ec.resolvers.Query().Search(rctx, args["input"].(*model.SearchArgs)) }) if resTmp == nil { if !ec.HasError(rctx) { diff --git a/example/scalars/resolvers.go b/example/scalars/resolvers.go index 752a2c8ab4e..af7c62122e8 100644 --- a/example/scalars/resolvers.go +++ b/example/scalars/resolvers.go @@ -34,7 +34,7 @@ func (r *queryResolver) User(ctx context.Context, id external.ObjectID) (*model. }, nil } -func (r *queryResolver) Search(ctx context.Context, input model.SearchArgs) ([]model.User, error) { +func (r *queryResolver) Search(ctx context.Context, input *model.SearchArgs) ([]model.User, error) { location := model.Point{X: 1, Y: 2} if input.Location != nil { location = *input.Location diff --git a/example/starwars/generated.go b/example/starwars/generated.go index c2fb5d57b52..fdde36622ff 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -87,7 +87,7 @@ type ComplexityRoot struct { } Query struct { - Hero func(childComplexity int, episode Episode) int + Hero func(childComplexity int, episode *Episode) int Reviews func(childComplexity int, episode Episode, since *time.Time) int Search func(childComplexity int, text string) int Character func(childComplexity int, id string) int @@ -105,7 +105,7 @@ type ComplexityRoot struct { Starship struct { Id func(childComplexity int) int Name func(childComplexity int) int - Length func(childComplexity int, unit LengthUnit) int + Length func(childComplexity int, unit *LengthUnit) int History func(childComplexity int) int } } @@ -128,7 +128,7 @@ type MutationResolver interface { CreateReview(ctx context.Context, episode Episode, review Review) (*Review, error) } type QueryResolver interface { - Hero(ctx context.Context, episode Episode) (Character, error) + Hero(ctx context.Context, episode *Episode) (Character, error) Reviews(ctx context.Context, episode Episode, since *time.Time) ([]Review, error) Search(ctx context.Context, text string) ([]SearchResult, error) Character(ctx context.Context, id string) (Character, error) @@ -137,7 +137,7 @@ type QueryResolver interface { Starship(ctx context.Context, id string) (*Starship, error) } type StarshipResolver interface { - Length(ctx context.Context, obj *Starship, unit LengthUnit) (float64, error) + Length(ctx context.Context, obj *Starship, unit *LengthUnit) (float64, error) } func field_Droid_friendsConnection_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { @@ -249,10 +249,15 @@ func field_Mutation_createReview_args(rawArgs map[string]interface{}) (map[strin func field_Query_hero_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { args := map[string]interface{}{} - var arg0 Episode + var arg0 *Episode if tmp, ok := rawArgs["episode"]; ok { var err error - err = (&arg0).UnmarshalGQL(tmp) + var ptr1 Episode + if tmp != nil { + err = (&ptr1).UnmarshalGQL(tmp) + arg0 = &ptr1 + } + if err != nil { return nil, err } @@ -383,10 +388,15 @@ func field_Query___type_args(rawArgs map[string]interface{}) (map[string]interfa func field_Starship_length_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { args := map[string]interface{}{} - var arg0 LengthUnit + var arg0 *LengthUnit if tmp, ok := rawArgs["unit"]; ok { var err error - err = (&arg0).UnmarshalGQL(tmp) + var ptr1 LengthUnit + if tmp != nil { + err = (&ptr1).UnmarshalGQL(tmp) + arg0 = &ptr1 + } + if err != nil { return nil, err } @@ -637,7 +647,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Hero(childComplexity, args["episode"].(Episode)), true + return e.complexity.Query.Hero(childComplexity, args["episode"].(*Episode)), true case "Query.reviews": if e.complexity.Query.Reviews == nil { @@ -756,7 +766,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Starship.Length(childComplexity, args["unit"].(LengthUnit)), true + return e.complexity.Starship.Length(childComplexity, args["unit"].(*LengthUnit)), true case "Starship.history": if e.complexity.Starship.History == nil { @@ -1928,7 +1938,7 @@ func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.Colle ctx = graphql.WithResolverContext(ctx, rctx) resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Hero(rctx, args["episode"].(Episode)) + return ec.resolvers.Query().Hero(rctx, args["episode"].(*Episode)) }) if resTmp == nil { return graphql.Null @@ -2452,7 +2462,7 @@ func (ec *executionContext) _Starship_length(ctx context.Context, field graphql. ctx = graphql.WithResolverContext(ctx, rctx) resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Starship().Length(rctx, obj, args["unit"].(LengthUnit)) + return ec.resolvers.Starship().Length(rctx, obj, args["unit"].(*LengthUnit)) }) if resTmp == nil { if !ec.HasError(rctx) { diff --git a/example/starwars/resolvers.go b/example/starwars/resolvers.go index 50a1625f197..55c4c3e019c 100644 --- a/example/starwars/resolvers.go +++ b/example/starwars/resolvers.go @@ -119,8 +119,8 @@ func (r *mutationResolver) CreateReview(ctx context.Context, episode Episode, re type queryResolver struct{ *Resolver } -func (r *queryResolver) Hero(ctx context.Context, episode Episode) (Character, error) { - if episode == EpisodeEmpire { +func (r *queryResolver) Hero(ctx context.Context, episode *Episode) (Character, error) { + if *episode == EpisodeEmpire { return r.humans["1000"], nil } return r.droid["2001"], nil @@ -193,8 +193,8 @@ func (r *queryResolver) Starship(ctx context.Context, id string) (*Starship, err type starshipResolver struct{ *Resolver } -func (r *starshipResolver) Length(ctx context.Context, obj *Starship, unit LengthUnit) (float64, error) { - switch unit { +func (r *starshipResolver) Length(ctx context.Context, obj *Starship, unit *LengthUnit) (float64, error) { + switch *unit { case LengthUnitMeter, "": return obj.Length, nil case LengthUnitFoot: diff --git a/integration/generated.go b/integration/generated.go index c598e8ad47c..f2bd6a787f2 100644 --- a/integration/generated.go +++ b/integration/generated.go @@ -53,7 +53,7 @@ type ComplexityRoot struct { Date func(childComplexity int, filter models.DateFilter) int Viewer func(childComplexity int) int JsonEncoding func(childComplexity int) int - Error func(childComplexity int, typeArg models.ErrorType) int + Error func(childComplexity int, typeArg *models.ErrorType) int } User struct { @@ -76,7 +76,7 @@ type QueryResolver interface { Date(ctx context.Context, filter models.DateFilter) (bool, error) Viewer(ctx context.Context) (*models.Viewer, error) JSONEncoding(ctx context.Context) (string, error) - Error(ctx context.Context, typeArg models.ErrorType) (bool, error) + Error(ctx context.Context, typeArg *models.ErrorType) (bool, error) } type UserResolver interface { Likes(ctx context.Context, obj *remote_api.User) ([]string, error) @@ -99,10 +99,15 @@ func field_Query_date_args(rawArgs map[string]interface{}) (map[string]interface func field_Query_error_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { args := map[string]interface{}{} - var arg0 models.ErrorType + var arg0 *models.ErrorType if tmp, ok := rawArgs["type"]; ok { var err error - err = (&arg0).UnmarshalGQL(tmp) + var ptr1 models.ErrorType + if tmp != nil { + err = (&ptr1).UnmarshalGQL(tmp) + arg0 = &ptr1 + } + if err != nil { return nil, err } @@ -254,7 +259,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Error(childComplexity, args["type"].(models.ErrorType)), true + return e.complexity.Query.Error(childComplexity, args["type"].(*models.ErrorType)), true case "User.name": if e.complexity.User.Name == nil { @@ -659,7 +664,7 @@ func (ec *executionContext) _Query_error(ctx context.Context, field graphql.Coll ctx = graphql.WithResolverContext(ctx, rctx) resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Error(rctx, args["type"].(models.ErrorType)) + return ec.resolvers.Query().Error(rctx, args["type"].(*models.ErrorType)) }) if resTmp == nil { if !ec.HasError(rctx) { diff --git a/integration/resolver.go b/integration/resolver.go index 6aeda9784b2..c0486f7b87a 100644 --- a/integration/resolver.go +++ b/integration/resolver.go @@ -53,8 +53,8 @@ func (r *elementResolver) Child(ctx context.Context, obj *models.Element) (model type queryResolver struct{ *Resolver } -func (r *queryResolver) Error(ctx context.Context, typeArg models.ErrorType) (bool, error) { - if typeArg == models.ErrorTypeCustom { +func (r *queryResolver) Error(ctx context.Context, typeArg *models.ErrorType) (bool, error) { + if *typeArg == models.ErrorTypeCustom { return false, &CustomError{"User message", "Internal Message"} }