Skip to content

Commit

Permalink
Validate unbreakable cycles of input types.
Browse files Browse the repository at this point in the history
This is based on a spec proposal (see:
graphql/graphql-spec#701) and may change in the
future.
  • Loading branch information
lirsacc committed Nov 1, 2020
1 parent 321e126 commit bc706f3
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 1 deletion.
86 changes: 85 additions & 1 deletion src/py_gql/schema/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
Schema validation utility
"""

import collections
import re
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Set, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Sequence,
Set,
Tuple,
Union,
)

from .._string_utils import quoted_options_list
from ..exc import SchemaError, SchemaValidationError
Expand All @@ -18,6 +29,7 @@
EnumValue,
InputObjectType,
InterfaceType,
ListType,
NamedType,
NonNullType,
ObjectType,
Expand Down Expand Up @@ -167,6 +179,7 @@ def __call__(self) -> None:
self.add_error("%s is not a valid schema type" % type_)

self.validate_directives()
self.validate_cyclic_input_types()

def validate_root_types(self) -> None:
query = self.schema.query_type
Expand Down Expand Up @@ -507,3 +520,74 @@ def validate_input_fields(self, input_object: InputObjectType) -> None:
)

fieldnames.add(field.name)

def validate_cyclic_input_types(self) -> None:
"""
Detect unbroken chains of input types.
Generally input types can refer to themselves as long as it is through a
nullable type or a list, non nullable cycles are not supported.
This is currently (2020-10-31) `in the process of stabilising
<https://github.com/graphql/graphql-spec/pull/701/>`_ and may change in
the future.
"""
# TODO: Add link to spec / RFC in errors when stabilised.
input_types = [
t
for t in self.schema.types.values()
if isinstance(t, InputObjectType)
]

direct_references = collections.defaultdict(set)

# Collect any non breakable reference to any input object type.
for t in input_types:
for f in t.fields:
real_type = f.type

# Non null types are breakable by default, wrapped types are not.
breakable = not isinstance(real_type, (ListType, NonNullType))
while isinstance(real_type, (ListType, NonNullType)):
# List can break the chain.
if isinstance(real_type, ListType):
breakable = True
real_type = real_type.type

if (not breakable) and isinstance(real_type, InputObjectType):
direct_references[t].add(real_type)

chains = [] # type: List[Tuple[str, Dict[str, List[str]]]]

def _search(outer, acc=None, path=None):
acc, path = acc or set(), path or ()

for inner in direct_references[outer]:
if inner.name in path:
break

if (inner.name, path) in acc:
break

acc.add((inner.name, path))
_search(inner, acc, (*path, inner.name))

return acc

all_chains = [
(t.name, _search(t)) for t in list(direct_references.keys())
]

# TODO: This will contain multiple rotated versions of any given cycle.
# This is fine for now, but would be nice to avoid duplicate data.
for typename, chains in all_chains:
for final, path in chains:
if final == typename:
self.add_error(
"Non breakable input chain found: %s"
% quoted_options_list(
[typename, *path, typename],
separator=" > ",
final_separator=" > ",
)
)
81 changes: 81 additions & 0 deletions tests/test_schema/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,84 @@ def default_resolver(root, info, ctx, **args):
schema.default_resolver = default_resolver

validate_schema(schema)


class TestInputTypeCycles:
def _schema(self, input_types):
return Schema(
ObjectType(
"query",
[
Field(
"field",
Int,
args=[Argument(t.name, t) for t in input_types],
)
],
)
)

def test_no_cycle(self):
A = InputObjectType("A", [InputField("f", Int)])
B = InputObjectType("B", [InputField("f", Int)])
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
schema = self._schema([A, B, C])
assert validate_schema(schema)

def test_simple_cycles(self):
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
B = InputObjectType("B", [InputField("c", lambda: NonNullType(C))])
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
schema = self._schema([A, B, C])

with pytest.raises(SchemaError) as exc_info:
validate_schema(schema)

assert set([str(e) for e in exc_info.value.errors]) == set(
[
'Non breakable input chain found: "B" > "C" > "A" > "B"',
'Non breakable input chain found: "A" > "B" > "C" > "A"',
'Non breakable input chain found: "C" > "A" > "B" > "C"',
]
)

def test_multiple_cycles(self):
A = InputObjectType(
"A",
[
InputField("b", lambda: NonNullType(B)),
InputField("c", lambda: NonNullType(C)),
],
)
B = InputObjectType("B", [InputField("a", lambda: NonNullType(A))])
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
schema = self._schema([A, B, C])

with pytest.raises(SchemaError) as exc_info:
validate_schema(schema)

assert set([str(e) for e in exc_info.value.errors]) == set(
[
'Non breakable input chain found: "C" > "A" > "C"',
'Non breakable input chain found: "A" > "C" > "A"',
'Non breakable input chain found: "A" > "B" > "A"',
'Non breakable input chain found: "B" > "A" > "B"',
]
)

def test_simple_breakable_cycle(self):
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
B = InputObjectType("B", [InputField("c", lambda: C)])
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
schema = self._schema([A, B, C])
assert validate_schema(schema)

def test_list_breaks_cycle(self):
A = InputObjectType("A", [InputField("b", lambda: NonNullType(B))])
B = InputObjectType(
"B",
[InputField("c", lambda: NonNullType(ListType(NonNullType(C))))],
)
C = InputObjectType("C", [InputField("a", lambda: NonNullType(A))])
schema = self._schema([A, B, C])
assert validate_schema(schema)

0 comments on commit bc706f3

Please sign in to comment.