Skip to content

Commit

Permalink
Implement OneOf Input Objects via @OneOf directive
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Sep 7, 2024
1 parent 6e6d5be commit b7a18ed
Show file tree
Hide file tree
Showing 23 changed files with 720 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
GraphQLStreamDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
GraphQLOneOfDirective,
# "Enum" of Type Kinds
TypeKind,
# Constant Deprecation Reason
Expand Down Expand Up @@ -504,6 +505,7 @@
"GraphQLStreamDirective",
"GraphQLDeprecatedDirective",
"GraphQLSpecifiedByDirective",
"GraphQLOneOfDirective",
"TypeKind",
"DEFAULT_DEPRECATION_REASON",
"introspection_types",
Expand Down
15 changes: 10 additions & 5 deletions src/graphql/execution/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ def coerce_variable_values(
continue

def on_input_value_error(
path: list[str | int], invalid_value: Any, error: GraphQLError
path: list[str | int],
invalid_value: Any,
error: GraphQLError,
var_name: str = var_name,
var_def_node: VariableDefinitionNode = var_def_node,
) -> None:
invalid_str = inspect(invalid_value)
prefix = f"Variable '${var_name}' got invalid value {invalid_str}" # noqa: B023
prefix = f"Variable '${var_name}' got invalid value {invalid_str}"
if path:
prefix += f" at '{var_name}{print_path_list(path)}'" # noqa: B023
prefix += f" at '{var_name}{print_path_list(path)}'"
on_error(
GraphQLError(
prefix + "; " + error.message,
var_def_node, # noqa: B023
var_def_node,
original_error=error,
)
)
Expand Down Expand Up @@ -193,7 +197,8 @@ def get_argument_values(
)
raise GraphQLError(msg, value_node)
continue # pragma: no cover
is_null = variable_values[variable_name] is None
variable_value = variable_values[variable_name]
is_null = variable_value is None or variable_value is Undefined

if is_null and is_non_null_type(arg_type):
msg = f"Argument '{name}' of non-null type '{arg_type}' must not be null."
Expand Down
2 changes: 2 additions & 0 deletions src/graphql/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
GraphQLStreamDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
GraphQLOneOfDirective,
# Keyword Args
GraphQLDirectiveKwargs,
# Constant Deprecation Reason
Expand Down Expand Up @@ -286,6 +287,7 @@
"GraphQLStreamDirective",
"GraphQLDeprecatedDirective",
"GraphQLSpecifiedByDirective",
"GraphQLOneOfDirective",
"GraphQLDirectiveKwargs",
"DEFAULT_DEPRECATION_REASON",
"is_specified_scalar_type",
Expand Down
5 changes: 5 additions & 0 deletions src/graphql/type/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ class GraphQLInputObjectTypeKwargs(GraphQLNamedTypeKwargs, total=False):

fields: GraphQLInputFieldMap
out_type: GraphQLInputFieldOutType | None
is_one_of: bool


class GraphQLInputObjectType(GraphQLNamedType):
Expand Down Expand Up @@ -1301,6 +1302,7 @@ class GeoPoint(GraphQLInputObjectType):

ast_node: InputObjectTypeDefinitionNode | None
extension_ast_nodes: tuple[InputObjectTypeExtensionNode, ...]
is_one_of: bool

def __init__(
self,
Expand All @@ -1311,6 +1313,7 @@ def __init__(
extensions: dict[str, Any] | None = None,
ast_node: InputObjectTypeDefinitionNode | None = None,
extension_ast_nodes: Collection[InputObjectTypeExtensionNode] | None = None,
is_one_of: bool = False,
) -> None:
super().__init__(
name=name,
Expand All @@ -1322,6 +1325,7 @@ def __init__(
self._fields = fields
if out_type is not None:
self.out_type = out_type # type: ignore
self.is_one_of = is_one_of

@staticmethod
def out_type(value: dict[str, Any]) -> Any:
Expand All @@ -1340,6 +1344,7 @@ def to_kwargs(self) -> GraphQLInputObjectTypeKwargs:
out_type=None
if self.out_type is GraphQLInputObjectType.out_type
else self.out_type,
is_one_of=self.is_one_of,
)

def __copy__(self) -> GraphQLInputObjectType: # pragma: no cover
Expand Down
9 changes: 9 additions & 0 deletions src/graphql/type/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,20 @@ def assert_directive(directive: Any) -> GraphQLDirective:
description="Exposes a URL that specifies the behaviour of this scalar.",
)

# Used to declare an Input Object as a OneOf Input Objects.
GraphQLOneOfDirective = GraphQLDirective(
name="oneOf",
locations=[DirectiveLocation.INPUT_OBJECT],
args={},
description="Indicates an Input Object is a OneOf Input Object.",
)

specified_directives: tuple[GraphQLDirective, ...] = (
GraphQLIncludeDirective,
GraphQLSkipDirective,
GraphQLDeprecatedDirective,
GraphQLSpecifiedByDirective,
GraphQLOneOfDirective,
)
"""A tuple with all directives from the GraphQL specification"""

Expand Down
5 changes: 5 additions & 0 deletions src/graphql/type/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __new__(cls):
resolve=cls.input_fields,
),
"ofType": GraphQLField(_Type, resolve=cls.of_type),
"isOneOf": GraphQLField(GraphQLBoolean, resolve=cls.is_one_of),
}

@staticmethod
Expand Down Expand Up @@ -396,6 +397,10 @@ def input_fields(type_, _info, includeDeprecated=False):
def of_type(type_, _info):
return getattr(type_, "of_type", None)

@staticmethod
def is_one_of(type_, _info):
return type_.is_one_of if is_input_object_type(type_) else None


_Type: GraphQLObjectType = GraphQLObjectType(
name="__Type",
Expand Down
24 changes: 23 additions & 1 deletion src/graphql/type/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SchemaDefinitionNode,
SchemaExtensionNode,
)
from ..pyutils import and_list, inspect
from ..pyutils import Undefined, and_list, inspect
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
from .definition import (
GraphQLEnumType,
Expand Down Expand Up @@ -482,6 +482,28 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
],
)

if input_obj.is_one_of:
self.validate_one_of_input_object_field(input_obj, field_name, field)

def validate_one_of_input_object_field(
self,
type_: GraphQLInputObjectType,
field_name: str,
field: GraphQLInputField,
) -> None:
if is_non_null_type(field.type):
self.report_error(
f"OneOf input field {type_.name}.{field_name} must be nullable.",
field.ast_node and field.ast_node.type,
)

if field.default_value is not Undefined:
self.report_error(
f"OneOf input field {type_.name}.{field_name}"
" cannot have a default value.",
field.ast_node,
)


def get_operation_type_node(
schema: GraphQLSchema, operation: OperationType
Expand Down
24 changes: 24 additions & 0 deletions src/graphql/utilities/coerce_input_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,30 @@ def coerce_input_value(
+ did_you_mean(suggestions)
),
)

if type_.is_one_of:
keys = list(coerced_dict)
if len(keys) != 1:
on_error(
path.as_list() if path else [],
input_value,
GraphQLError(
"Exactly one key must be specified"
f" for OneOf type '{type_.name}'.",
),
)
else:
key = keys[0]
value = coerced_dict[key]
if value is None:
on_error(
(path.as_list() if path else []) + [key],
value,
GraphQLError(
f"Field '{key}' must be non-null.",
),
)

return type_.out_type(coerced_dict)

if is_leaf_type(type_):
Expand Down
9 changes: 9 additions & 0 deletions src/graphql/utilities/extend_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
GraphQLNullableType,
GraphQLObjectType,
GraphQLObjectTypeKwargs,
GraphQLOneOfDirective,
GraphQLOutputType,
GraphQLScalarType,
GraphQLSchema,
Expand Down Expand Up @@ -777,6 +778,7 @@ def build_input_object_type(
fields=partial(self.build_input_field_map, all_nodes),
ast_node=ast_node,
extension_ast_nodes=extension_nodes,
is_one_of=is_one_of(ast_node),
)

def build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType:
Expand Down Expand Up @@ -822,3 +824,10 @@ def get_specified_by_url(

specified_by_url = get_directive_values(GraphQLSpecifiedByDirective, node)
return specified_by_url["url"] if specified_by_url else None


def is_one_of(node: InputObjectTypeDefinitionNode) -> bool:
"""Given an input object node, returns if the node should be OneOf."""
from ..execution import get_directive_values

return get_directive_values(GraphQLOneOfDirective, node) is not None
8 changes: 8 additions & 0 deletions src/graphql/utilities/value_from_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def value_from_ast(
return Undefined
coerced_obj[field.out_name or field_name] = field_value

if type_.is_one_of:
keys = list(coerced_obj)
if len(keys) != 1:
return Undefined

if coerced_obj[keys[0]] is None:
return Undefined

return type_.out_type(coerced_obj)

if is_leaf_type(type_):
Expand Down
72 changes: 70 additions & 2 deletions src/graphql/validation/rules/values_of_correct_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, cast
from typing import Any, Mapping, cast

from ...error import GraphQLError
from ...language import (
Expand All @@ -12,16 +12,20 @@
FloatValueNode,
IntValueNode,
ListValueNode,
NonNullTypeNode,
NullValueNode,
ObjectFieldNode,
ObjectValueNode,
StringValueNode,
ValueNode,
VariableDefinitionNode,
VariableNode,
VisitorAction,
print_ast,
)
from ...pyutils import Undefined, did_you_mean, suggestion_list
from ...type import (
GraphQLInputObjectType,
GraphQLScalarType,
get_named_type,
get_nullable_type,
Expand All @@ -31,7 +35,7 @@
is_non_null_type,
is_required_input_field,
)
from . import ValidationRule
from . import ValidationContext, ValidationRule

__all__ = ["ValuesOfCorrectTypeRule"]

Expand All @@ -45,6 +49,18 @@ class ValuesOfCorrectTypeRule(ValidationRule):
See https://spec.graphql.org/draft/#sec-Values-of-Correct-Type
"""

def __init__(self, context: ValidationContext) -> None:
super().__init__(context)
self.variable_definitions: dict[str, VariableDefinitionNode] = {}

def enter_operation_definition(self, *_args: Any) -> None:
self.variable_definitions.clear()

def enter_variable_definition(
self, definition: VariableDefinitionNode, *_args: Any
) -> None:
self.variable_definitions[definition.variable.name.value] = definition

def enter_list_value(self, node: ListValueNode, *_args: Any) -> VisitorAction:
# Note: TypeInfo will traverse into a list's item type, so look to the parent
# input type to check if it is a list.
Expand Down Expand Up @@ -72,6 +88,10 @@ def enter_object_value(self, node: ObjectValueNode, *_args: Any) -> VisitorActio
node,
)
)
if type_.is_one_of:
validate_one_of_input_object(
self.context, node, type_, field_node_map, self.variable_definitions
)
return None

def enter_object_field(self, node: ObjectFieldNode, *_args: Any) -> None:
Expand Down Expand Up @@ -162,3 +182,51 @@ def is_valid_value_node(self, node: ValueNode) -> None:
)

return


def validate_one_of_input_object(
context: ValidationContext,
node: ObjectValueNode,
type_: GraphQLInputObjectType,
field_node_map: Mapping[str, ObjectFieldNode],
variable_definitions: dict[str, VariableDefinitionNode],
) -> None:
keys = list(field_node_map)
is_not_exactly_one_filed = len(keys) != 1

if is_not_exactly_one_filed:
context.report_error(
GraphQLError(
f"OneOf Input Object '{type_.name}' must specify exactly one key.",
node,
)
)
return

object_field_node = field_node_map.get(keys[0])
value = object_field_node.value if object_field_node else None
is_null_literal = not value or isinstance(value, NullValueNode)

if is_null_literal:
context.report_error(
GraphQLError(
f"Field '{type_.name}.{keys[0]}' must be non-null.",
node,
)
)
return

is_variable = value and isinstance(value, VariableNode)
if is_variable:
variable_name = cast(VariableNode, value).name.value
definition = variable_definitions[variable_name]
is_nullable_variable = not isinstance(definition.type, NonNullTypeNode)

if is_nullable_variable:
context.report_error(
GraphQLError(
f"Variable '{variable_name}' must be non-nullable"
f" to be used for OneOf Input Object '{type_.name}'.",
node,
)
)
Loading

0 comments on commit b7a18ed

Please sign in to comment.