From dcb9afffbea31431997b9057f5d6ca736852e23c Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 01:49:08 +0530 Subject: [PATCH 1/6] add support for `await graphql_async(schema, query)` --- graphql/__init__.py | 3 +- graphql/backend/base.py | 6 + graphql/backend/cache.py | 5 + graphql/backend/core.py | 57 ++++++-- graphql/backend/decider.py | 5 + graphql/backend/quiver_cloud.py | 5 + graphql/backend/tests/test_decider.py | 5 + graphql/execution/__init__.py | 3 +- graphql/execution/executor.py | 108 +++++++++++++-- graphql/execution/executors/asyncio.py | 10 +- graphql/execution/executors/base.py | 16 +++ graphql/execution/executors/gevent.py | 6 +- graphql/execution/executors/process.py | 6 +- graphql/execution/executors/sync.py | 7 +- graphql/execution/executors/thread.py | 7 +- .../execution/tests/test_executor_asyncio.py | 78 ++++++++++- graphql/graphql.py | 34 +++++ tests/starwars/test_query.py | 128 ++++++++++-------- 18 files changed, 398 insertions(+), 91 deletions(-) create mode 100644 graphql/execution/executors/base.py diff --git a/graphql/__init__.py b/graphql/__init__.py index 902800b3..b941dff0 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -24,7 +24,7 @@ from .pyutils.version import get_version # The primary entry point into fulfilling a GraphQL request. -from .graphql import graphql +from .graphql import graphql, graphql_async # Create and operate on GraphQL type definitions and schema. from .type import ( # no import order @@ -175,6 +175,7 @@ __all__ = ( "__version__", "graphql", + "graphql_async", "GraphQLBoolean", "GraphQLEnumType", "GraphQLEnumValue", diff --git a/graphql/backend/base.py b/graphql/backend/base.py index 6573d3a7..39f2ad07 100644 --- a/graphql/backend/base.py +++ b/graphql/backend/base.py @@ -19,6 +19,12 @@ def document_from_string(self, schema, request_string): "document_from_string method not implemented in {}.".format(self.__class__) ) + @abstractmethod + def document_from_string_async(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) + class GraphQLDocument(object): def __init__(self, schema, document_string, document_ast, execute): diff --git a/graphql/backend/cache.py b/graphql/backend/cache.py index f35a45ca..1694f9a3 100644 --- a/graphql/backend/cache.py +++ b/graphql/backend/cache.py @@ -78,3 +78,8 @@ def document_from_string(self, schema, request_string): ) return self.cache_map[key] + + def document_from_string_async(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) diff --git a/graphql/backend/core.py b/graphql/backend/core.py index 41aea2d6..f59fe730 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -1,7 +1,7 @@ from functools import partial from six import string_types -from ..execution import execute, ExecutionResult +from ..execution import execute, execute_async, ExecutionResult from ..language.base import parse, print_ast from ..language import ast from ..validation import validate @@ -16,6 +16,19 @@ from rx import Observable +def _validate_document_ast( + schema, # type: GraphQLSchema + document_ast, # type: Document + **kwargs # type: Any +): + # type: (...) -> Union[ExecutionResult, None] + do_validation = kwargs.get("validate", True) + if do_validation: + validation_errors = validate(schema, document_ast) + if validation_errors: + return ExecutionResult(errors=validation_errors, invalid=True) + + def execute_and_validate( schema, # type: GraphQLSchema document_ast, # type: Document @@ -23,15 +36,26 @@ def execute_and_validate( **kwargs # type: Any ): # type: (...) -> Union[ExecutionResult, Observable] - do_validation = kwargs.get("validate", True) - if do_validation: - validation_errors = validate(schema, document_ast) - if validation_errors: - return ExecutionResult(errors=validation_errors, invalid=True) + execution_result = _validate_document_ast(schema, document_ast, **kwargs) + if execution_result: + return execution_result return execute(schema, document_ast, *args, **kwargs) +async def execute_and_validate_async( + schema, # type: GraphQLSchema + document_ast, # type: Document + *args, # type: Any + **kwargs # type: Any +): + # type: (...) -> Union[ExecutionResult, Observable] + execution_result = _validate_document_ast(schema, document_ast, **kwargs) + if execution_result: + return execution_result + return await execute_async(schema, document_ast, *args, **kwargs) + + class GraphQLCoreBackend(GraphQLBackend): """GraphQLCoreBackend will return a document using the default graphql executor""" @@ -40,8 +64,9 @@ def __init__(self, executor=None): # type: (Optional[Any]) -> None self.execute_params = {"executor": executor} - def document_from_string(self, schema, document_string): - # type: (GraphQLSchema, Union[Document, str]) -> GraphQLDocument + @staticmethod + def _get_doc_str_and_ast(document_string): + # type: (Union[Document, str] -> (str, ast.Document) if isinstance(document_string, ast.Document): document_ast = document_string document_string = print_ast(document_ast) @@ -50,6 +75,11 @@ def document_from_string(self, schema, document_string): document_string, string_types ), "The query must be a string" document_ast = parse(document_string) + return document_string, document_ast + + def document_from_string(self, schema, document_string): + # type: (GraphQLSchema, Union[Document, str]) -> GraphQLDocument + document_string, document_ast = self._get_doc_str_and_ast(document_string) return GraphQLDocument( schema=schema, document_string=document_string, @@ -58,3 +88,14 @@ def document_from_string(self, schema, document_string): execute_and_validate, schema, document_ast, **self.execute_params ), ) + + def document_from_string_async(self, schema, document_string): + document_string, document_ast = self._get_doc_str_and_ast(document_string) + return GraphQLDocument( + schema=schema, + document_string=document_string, + document_ast=document_ast, + execute=partial( + execute_and_validate_async, schema, document_ast, **self.execute_params + ), + ) diff --git a/graphql/backend/decider.py b/graphql/backend/decider.py index 5fdd39bb..f6940ee7 100644 --- a/graphql/backend/decider.py +++ b/graphql/backend/decider.py @@ -209,3 +209,8 @@ def document_from_string(self, schema, request_string): self.get_worker().queue(self.queue_backend, key, schema, request_string) return self.cache_map[key] + + def document_from_string_async(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) diff --git a/graphql/backend/quiver_cloud.py b/graphql/backend/quiver_cloud.py index 2c6796b6..e2e5452d 100644 --- a/graphql/backend/quiver_cloud.py +++ b/graphql/backend/quiver_cloud.py @@ -103,3 +103,8 @@ def uptodate(): schema, code, uptodate, self.extra_namespace ) return document + + def document_from_string_async(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) diff --git a/graphql/backend/tests/test_decider.py b/graphql/backend/tests/test_decider.py index 32dccaee..51c176e4 100644 --- a/graphql/backend/tests/test_decider.py +++ b/graphql/backend/tests/test_decider.py @@ -34,6 +34,11 @@ def document_from_string(self, *args, **kwargs): raise Exception("Backend failed") return self.name + def document_from_string_async(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) + def wait(self): return self.event.wait() diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index d6c2a7f7..8edf0afe 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -18,13 +18,14 @@ 2) fragment "spreads" e.g. "...c" 3) inline fragment "spreads" e.g. "...on Type { a }" """ -from .executor import execute, subscribe +from .executor import execute, execute_async, subscribe from .base import ExecutionResult, ResolveInfo from .middleware import middlewares, MiddlewareManager __all__ = [ "execute", + "execute_async", "subscribe", "ExecutionResult", "ResolveInfo", diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index e77050e2..7014b59a 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -37,6 +37,7 @@ get_operation_root_type, SubscriberExecutionContext, ) +from .executors.asyncio import AsyncioExecutor from .executors.sync import SyncExecutor from .middleware import MiddlewareManager @@ -56,21 +57,18 @@ def subscribe(*args, **kwargs): ) -def execute( +def prepare_execution_context( schema, # type: GraphQLSchema document_ast, # type: Document - root=None, # type: Any - context=None, # type: Optional[Any] - variables=None, # type: Optional[Any] - operation_name=None, # type: Optional[str] - executor=None, # type: Any - return_promise=False, # type: bool - middleware=None, # type: Optional[Any] - allow_subscriptions=False, # type: bool + root, # type: Any + context, # type: Optional[Any] + variables, # type: Optional[Any] + operation_name, # type: Optional[str] + executor, # type: Any + middleware, # type: Optional[Any] + allow_subscriptions, # type: bool **options # type: Any ): - # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]] - if root is None and "root_value" in options: warnings.warn( "root_value has been deprecated. Please use root=... instead.", @@ -107,10 +105,7 @@ def execute( ' of MiddlewareManager. Received "{}".'.format(middleware) ) - if executor is None: - executor = SyncExecutor() - - exe_context = ExecutionContext( + return ExecutionContext( schema, document_ast, root, @@ -122,15 +117,23 @@ def execute( allow_subscriptions, ) + +def get_promise_executor(exe_context, root): def promise_executor(v): # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] return execute_operation(exe_context, exe_context.operation, root) + return promise_executor + +def get_on_rejected(exe_context): def on_rejected(error): # type: (Exception) -> None exe_context.errors.append(error) return None + return on_rejected + +def get_on_resolve(exe_context): def on_resolve(data): # type: (Union[None, Dict, Observable]) -> Union[ExecutionResult, Observable] if isinstance(data, Observable): @@ -140,7 +143,43 @@ def on_resolve(data): return ExecutionResult(data=data) return ExecutionResult(data=data, errors=exe_context.errors) + return on_resolve + +def execute( + schema, # type: GraphQLSchema + document_ast, # type: Document + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[str] + executor=None, # type: Any + return_promise=False, # type: bool + middleware=None, # type: Optional[Any] + allow_subscriptions=False, # type: bool + **options # type: Any +): + # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]] + + if executor is None: + executor = SyncExecutor() + + exe_context = prepare_execution_context( + schema, + document_ast, + root, + context, + variables, + operation_name, + executor, + middleware, + allow_subscriptions, + **options + ) + + promise_executor = get_promise_executor(exe_context, root) + on_rejected = get_on_rejected(exe_context) + on_resolve = get_on_resolve(exe_context) promise = ( Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) ) @@ -156,6 +195,45 @@ def on_resolve(data): return promise +async def execute_async( + schema, # type: GraphQLSchema + document_ast, # type: Document + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[str] + executor=None, # type: Any + middleware=None, # type: Optional[Any] + allow_subscriptions=False, # type: bool + **options # type: Any +): + # type: (...) -> Union[ExecutionResult] + if executor is None: + executor = AsyncioExecutor() + exe_context = prepare_execution_context( + schema, + document_ast, + root, + context, + variables, + operation_name, + executor, + middleware, + allow_subscriptions, + **options + ) + + promise_executor = get_promise_executor(exe_context, root) + on_rejected = get_on_rejected(exe_context) + on_resolve = get_on_resolve(exe_context) + promise = ( + Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) + ) + + await exe_context.executor.wait_until_finished_async() + return promise.get() + + def execute_operation( exe_context, # type: ExecutionContext operation, # type: OperationDefinition diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py index 7e014030..ffa876f8 100644 --- a/graphql/execution/executors/asyncio.py +++ b/graphql/execution/executors/asyncio.py @@ -4,6 +4,8 @@ from promise import Promise +from .base import BaseExecutor + # Necessary for static type checking if False: # flake8: noqa from asyncio.unix_events import _UnixSelectorEventLoop @@ -44,7 +46,7 @@ def asyncgen_to_observable(asyncgen, loop=None): pass -class AsyncioExecutor(object): +class AsyncioExecutor(BaseExecutor): def __init__(self, loop=None): # type: (Optional[_UnixSelectorEventLoop]) -> None if loop is None: @@ -53,13 +55,17 @@ def __init__(self, loop=None): self.futures = [] # type: List[Future] def wait_until_finished(self): + # type: () -> None + self.loop.run_until_complete(self.wait_until_finished_async()) + + async def wait_until_finished_async(self): # type: () -> None # if there are futures to wait for while self.futures: # wait for the futures to finish futures = self.futures self.futures = [] - self.loop.run_until_complete(wait(futures)) + await wait(futures) def clean(self): self.futures = [] diff --git a/graphql/execution/executors/base.py b/graphql/execution/executors/base.py new file mode 100644 index 00000000..1c604d7f --- /dev/null +++ b/graphql/execution/executors/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + + +class BaseExecutor(ABC): + + @abstractmethod + def wait_until_finished(self): + pass + + @abstractmethod + async def wait_until_finished_async(self): + pass + + @abstractmethod + def clean(self): + pass diff --git a/graphql/execution/executors/gevent.py b/graphql/execution/executors/gevent.py index 4bc5ac4e..0af9973c 100644 --- a/graphql/execution/executors/gevent.py +++ b/graphql/execution/executors/gevent.py @@ -3,10 +3,11 @@ import gevent from promise import Promise +from .base import BaseExecutor from .utils import process -class GeventExecutor(object): +class GeventExecutor(BaseExecutor): def __init__(self): self.jobs = [] @@ -17,6 +18,9 @@ def wait_until_finished(self): self.jobs = [] [j.join() for j in jobs] + async def wait_until_finished_async(self): + raise NotImplementedError + def clean(self): self.jobs = [] diff --git a/graphql/execution/executors/process.py b/graphql/execution/executors/process.py index 948279ae..6481ec6d 100644 --- a/graphql/execution/executors/process.py +++ b/graphql/execution/executors/process.py @@ -2,6 +2,7 @@ from promise import Promise +from .base import BaseExecutor from .utils import process @@ -10,7 +11,7 @@ def queue_process(q): process(promise, fn, args, kwargs) -class ProcessExecutor(object): +class ProcessExecutor(BaseExecutor): def __init__(self): self.processes = [] self.q = Queue() @@ -23,6 +24,9 @@ def wait_until_finished(self): self.q.close() self.q.join_thread() + async def wait_until_finished_async(self): + raise NotImplementedError + def clean(self): self.processes = [] diff --git a/graphql/execution/executors/sync.py b/graphql/execution/executors/sync.py index c45d8a8f..fbe8c970 100644 --- a/graphql/execution/executors/sync.py +++ b/graphql/execution/executors/sync.py @@ -1,13 +1,18 @@ +from .base import BaseExecutor + # Necessary for static type checking if False: # flake8: noqa from typing import Any, Callable -class SyncExecutor(object): +class SyncExecutor(BaseExecutor): def wait_until_finished(self): # type: () -> None pass + async def wait_until_finished_async(self): + raise NotImplementedError + def clean(self): pass diff --git a/graphql/execution/executors/thread.py b/graphql/execution/executors/thread.py index f540a1a0..c57d4078 100644 --- a/graphql/execution/executors/thread.py +++ b/graphql/execution/executors/thread.py @@ -2,6 +2,8 @@ from threading import Thread from promise import Promise + +from .base import BaseExecutor from .utils import process # Necessary for static type checking @@ -9,7 +11,7 @@ from typing import Any, Callable, List -class ThreadExecutor(object): +class ThreadExecutor(BaseExecutor): pool = None @@ -30,6 +32,9 @@ def wait_until_finished(self): for thread in threads: thread.join() + async def wait_until_finished_async(self): + raise NotImplementedError + def clean(self): self.threads = [] diff --git a/graphql/execution/tests/test_executor_asyncio.py b/graphql/execution/tests/test_executor_asyncio.py index 714f59ee..a6000c7c 100644 --- a/graphql/execution/tests/test_executor_asyncio.py +++ b/graphql/execution/tests/test_executor_asyncio.py @@ -9,7 +9,7 @@ asyncio = pytest.importorskip("asyncio") from graphql.error import format_error -from graphql.execution import execute +from graphql.execution import execute, execute_async from graphql.language.parser import parse from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString @@ -49,6 +49,39 @@ def resolver_3(contest, *_): assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} +@pytest.mark.asyncio +async def test_asyncio_executor_exc_async(): + # type: () -> None + def resolver(context, *_): + # type: (Optional[Any], *ResolveInfo) -> str + asyncio.sleep(0.001) + return "hey" + + @asyncio.coroutine + def resolver_2(context, *_): + # type: (Optional[Any], *ResolveInfo) -> str + asyncio.sleep(0.003) + return "hey2" + + def resolver_3(contest, *_): + # type: (Optional[Any], *ResolveInfo) -> str + return "hey3" + + Type = GraphQLObjectType( + "Type", + { + "a": GraphQLField(GraphQLString, resolver=resolver), + "b": GraphQLField(GraphQLString, resolver=resolver_2), + "c": GraphQLField(GraphQLString, resolver=resolver_3), + }, + ) + + ast = parse("{ a b c }") + result = await execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + assert not result.errors + assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} + + def test_asyncio_executor_custom_loop(): # type: () -> None loop = asyncio.get_event_loop() @@ -87,14 +120,14 @@ def test_asyncio_executor_with_error(): # type: () -> None ast = parse("query Example { a, b }") - def resolver(context, *_): + async def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.001) + await asyncio.sleep(0.001) return "hey" - def resolver_2(context, *_): + async def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> NoReturn - asyncio.sleep(0.003) + await asyncio.sleep(0.003) raise Exception("resolver_2 failed!") Type = GraphQLObjectType( @@ -117,6 +150,41 @@ def resolver_2(context, *_): assert result.data == {"a": "hey", "b": None} +@pytest.mark.asyncio +async def test_asyncio_executor_with_error_exc_async(): + # type: () -> None + ast = parse("query Example { a, b }") + + async def resolver(context, *_): + # type: (Optional[Any], *ResolveInfo) -> str + await asyncio.sleep(0.001) + return "hey" + + async def resolver_2(context, *_): + # type: (Optional[Any], *ResolveInfo) -> NoReturn + await asyncio.sleep(0.003) + raise Exception("resolver_2 failed!") + + Type = GraphQLObjectType( + "Type", + { + "a": GraphQLField(GraphQLString, resolver=resolver), + "b": GraphQLField(GraphQLString, resolver=resolver_2), + }, + ) + + result = await execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + formatted_errors = list(map(format_error, result.errors)) + assert formatted_errors == [ + { + "locations": [{"line": 1, "column": 20}], + "path": ["b"], + "message": "resolver_2 failed!", + } + ] + assert result.data == {"a": "hey", "b": None} + + def test_evaluates_mutations_serially(): # type: () -> None assert_evaluate_mutations_serially(executor=AsyncioExecutor()) diff --git a/graphql/graphql.py b/graphql/graphql.py index 89ccf386..ab3a6eac 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -44,6 +44,11 @@ def graphql(*args, **kwargs): return execute_graphql(*args, **kwargs) +async def graphql_async(*args, **kwargs): + # type: (*Any, **Any) -> Union[ExecutionResult, Observable] + return await execute_graphql_async(*args, **kwargs) + + def execute_graphql( schema, # type: GraphQLSchema request_string="", # type: Union[Document, str] @@ -73,6 +78,35 @@ def execute_graphql( return ExecutionResult(errors=[e], invalid=True) +async def execute_graphql_async( + schema, # type: GraphQLSchema + request_string="", # type: Union[Document, str] + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[Any] + middleware=None, # type: Optional[Any] + backend=None, # type: Optional[Any] + **execute_options # type: Any +): + # type: (...) -> Union[ExecutionResult, Observable, Promise[ExecutionResult]] + try: + if backend is None: + backend = get_default_backend() + + document = backend.document_from_string_async(schema, request_string) + return await document.execute( + root=root, + context=context, + operation_name=operation_name, + variables=variables, + middleware=middleware, + **execute_options + ) + except Exception as e: + return ExecutionResult(errors=[e], invalid=True) + + @promisify def execute_graphql_as_promise(*args, **kwargs): return execute_graphql(*args, **kwargs) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 5e3c2cd0..02b4fe28 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,10 +1,41 @@ -from graphql import graphql +import pytest + +from graphql import graphql, graphql_async from graphql.error import format_error from .starwars_schema import StarWarsSchema -def test_hero_name_query(): +@pytest.fixture(params=['sync', 'async']) +def execute_graphql(request): + async def _execute( + schema, + query, + variable_values=None + ): + if request.param == 'sync': + return graphql(schema, query, variable_values=variable_values) + else: + return await graphql_async(schema, query, variable_values=variable_values) + return _execute + + +@pytest.fixture +def execute_and_validate_result(execute_graphql): + async def _execute_and_validate( + schema, + query, + expected, + variable_values=None + ): + result = await execute_graphql(schema, query, variable_values=variable_values) + assert not result.errors + assert result.data == expected + return _execute_and_validate + + +@pytest.mark.asyncio +async def test_hero_name_query(execute_and_validate_result): query = """ query HeroNameQuery { hero { @@ -13,12 +44,11 @@ def test_hero_name_query(): } """ expected = {"hero": {"name": "R2-D2"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_hero_name_and_friends_query(): +@pytest.mark.asyncio +async def test_hero_name_and_friends_query(execute_and_validate_result): query = """ query HeroNameAndFriendsQuery { hero { @@ -41,12 +71,11 @@ def test_hero_name_and_friends_query(): ], } } - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_nested_query(): +@pytest.mark.asyncio +async def test_nested_query(execute_and_validate_result): query = """ query NestedQuery { hero { @@ -97,12 +126,11 @@ def test_nested_query(): ], } } - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_fetch_luke_query(): +@pytest.mark.asyncio +async def test_fetch_luke_query(execute_and_validate_result): query = """ query FetchLukeQuery { human(id: "1000") { @@ -111,12 +139,11 @@ def test_fetch_luke_query(): } """ expected = {"human": {"name": "Luke Skywalker"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_fetch_some_id_query(): +@pytest.mark.asyncio +async def test_fetch_some_id_query(execute_and_validate_result): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -126,12 +153,11 @@ def test_fetch_some_id_query(): """ params = {"someId": "1000"} expected = {"human": {"name": "Luke Skywalker"}} - result = graphql(StarWarsSchema, query, variable_values=params) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) -def test_fetch_some_id_query2(): +@pytest.mark.asyncio +async def test_fetch_some_id_query2(execute_and_validate_result): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -141,12 +167,11 @@ def test_fetch_some_id_query2(): """ params = {"someId": "1002"} expected = {"human": {"name": "Han Solo"}} - result = graphql(StarWarsSchema, query, variable_values=params) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) -def test_invalid_id_query(): +@pytest.mark.asyncio +async def test_invalid_id_query(execute_and_validate_result): query = """ query humanQuery($id: String!) { human(id: $id) { @@ -156,12 +181,11 @@ def test_invalid_id_query(): """ params = {"id": "not a valid id"} expected = {"human": None} - result = graphql(StarWarsSchema, query, variable_values=params) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) -def test_fetch_luke_aliased(): +@pytest.mark.asyncio +async def test_fetch_luke_aliased(execute_and_validate_result): query = """ query FetchLukeAliased { luke: human(id: "1000") { @@ -170,12 +194,11 @@ def test_fetch_luke_aliased(): } """ expected = {"luke": {"name": "Luke Skywalker"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_fetch_luke_and_leia_aliased(): +@pytest.mark.asyncio +async def test_fetch_luke_and_leia_aliased(execute_and_validate_result): query = """ query FetchLukeAndLeiaAliased { luke: human(id: "1000") { @@ -187,12 +210,11 @@ def test_fetch_luke_and_leia_aliased(): } """ expected = {"luke": {"name": "Luke Skywalker"}, "leia": {"name": "Leia Organa"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_duplicate_fields(): +@pytest.mark.asyncio +async def test_duplicate_fields(execute_and_validate_result): query = """ query DuplicateFields { luke: human(id: "1000") { @@ -209,12 +231,11 @@ def test_duplicate_fields(): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_use_fragment(): +@pytest.mark.asyncio +async def test_use_fragment(execute_and_validate_result): query = """ query UseFragment { luke: human(id: "1000") { @@ -233,12 +254,11 @@ def test_use_fragment(): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_check_type_of_r2(): +@pytest.mark.asyncio +async def test_check_type_of_r2(execute_and_validate_result): query = """ query CheckTypeOfR2 { hero { @@ -248,12 +268,11 @@ def test_check_type_of_r2(): } """ expected = {"hero": {"__typename": "Droid", "name": "R2-D2"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_check_type_of_luke(): +@pytest.mark.asyncio +async def test_check_type_of_luke(execute_and_validate_result): query = """ query CheckTypeOfLuke { hero(episode: EMPIRE) { @@ -263,16 +282,15 @@ def test_check_type_of_luke(): } """ expected = {"hero": {"__typename": "Human", "name": "Luke Skywalker"}} - result = graphql(StarWarsSchema, query) - assert not result.errors - assert result.data == expected + await execute_and_validate_result(StarWarsSchema, query, expected) -def test_parse_error(): +@pytest.mark.asyncio +async def test_parse_error(execute_graphql): query = """ qeury """ - result = graphql(StarWarsSchema, query) + result = await execute_graphql(StarWarsSchema, query) assert result.invalid formatted_error = format_error(result.errors[0]) assert formatted_error["locations"] == [{"column": 9, "line": 2}] From da44c2c595995eeeb69ea30fd37cd4a277a09c8d Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 06:13:05 +0530 Subject: [PATCH 2/6] replace await by yield from --- graphql/backend/core.py | 10 ++- graphql/execution/executor.py | 6 +- graphql/execution/executors/asyncio.py | 7 +- graphql/execution/executors/base.py | 2 +- graphql/execution/executors/gevent.py | 2 +- graphql/execution/executors/process.py | 2 +- graphql/execution/executors/sync.py | 2 +- graphql/execution/executors/thread.py | 2 +- .../execution/tests/test_executor_asyncio.py | 47 ++++++---- graphql/graphql.py | 14 ++- tests/starwars/test_query.py | 87 ++++++++++++------- tox.ini | 1 + 12 files changed, 115 insertions(+), 67 deletions(-) diff --git a/graphql/backend/core.py b/graphql/backend/core.py index f59fe730..ca87828a 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -1,3 +1,5 @@ +import asyncio + from functools import partial from six import string_types @@ -43,7 +45,8 @@ def execute_and_validate( return execute(schema, document_ast, *args, **kwargs) -async def execute_and_validate_async( +@asyncio.coroutine +def execute_and_validate_async( schema, # type: GraphQLSchema document_ast, # type: Document *args, # type: Any @@ -53,7 +56,8 @@ async def execute_and_validate_async( execution_result = _validate_document_ast(schema, document_ast, **kwargs) if execution_result: return execution_result - return await execute_async(schema, document_ast, *args, **kwargs) + result = yield from execute_async(schema, document_ast, *args, **kwargs) + return result class GraphQLCoreBackend(GraphQLBackend): @@ -66,7 +70,7 @@ def __init__(self, executor=None): @staticmethod def _get_doc_str_and_ast(document_string): - # type: (Union[Document, str] -> (str, ast.Document) + # type: (Union[Document, str]) -> (str, ast.Document) if isinstance(document_string, ast.Document): document_ast = document_string document_string = print_ast(document_ast) diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 7014b59a..a6290400 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -1,3 +1,4 @@ +import asyncio import collections try: @@ -195,7 +196,8 @@ def execute( return promise -async def execute_async( +@asyncio.coroutine +def execute_async( schema, # type: GraphQLSchema document_ast, # type: Document root=None, # type: Any @@ -230,7 +232,7 @@ async def execute_async( Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) ) - await exe_context.executor.wait_until_finished_async() + yield from exe_context.executor.wait_until_finished_async() return promise.get() diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py index ffa876f8..53986771 100644 --- a/graphql/execution/executors/asyncio.py +++ b/graphql/execution/executors/asyncio.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -from asyncio import Future, get_event_loop, iscoroutine, wait +from asyncio import Future, get_event_loop, iscoroutine, wait, coroutine from promise import Promise @@ -58,14 +58,15 @@ def wait_until_finished(self): # type: () -> None self.loop.run_until_complete(self.wait_until_finished_async()) - async def wait_until_finished_async(self): + @coroutine + def wait_until_finished_async(self): # type: () -> None # if there are futures to wait for while self.futures: # wait for the futures to finish futures = self.futures self.futures = [] - await wait(futures) + yield from wait(futures) def clean(self): self.futures = [] diff --git a/graphql/execution/executors/base.py b/graphql/execution/executors/base.py index 1c604d7f..ef3dd82d 100644 --- a/graphql/execution/executors/base.py +++ b/graphql/execution/executors/base.py @@ -8,7 +8,7 @@ def wait_until_finished(self): pass @abstractmethod - async def wait_until_finished_async(self): + def wait_until_finished_async(self): pass @abstractmethod diff --git a/graphql/execution/executors/gevent.py b/graphql/execution/executors/gevent.py index 0af9973c..cb5ab836 100644 --- a/graphql/execution/executors/gevent.py +++ b/graphql/execution/executors/gevent.py @@ -18,7 +18,7 @@ def wait_until_finished(self): self.jobs = [] [j.join() for j in jobs] - async def wait_until_finished_async(self): + def wait_until_finished_async(self): raise NotImplementedError def clean(self): diff --git a/graphql/execution/executors/process.py b/graphql/execution/executors/process.py index 6481ec6d..f3daaaa8 100644 --- a/graphql/execution/executors/process.py +++ b/graphql/execution/executors/process.py @@ -24,7 +24,7 @@ def wait_until_finished(self): self.q.close() self.q.join_thread() - async def wait_until_finished_async(self): + def wait_until_finished_async(self): raise NotImplementedError def clean(self): diff --git a/graphql/execution/executors/sync.py b/graphql/execution/executors/sync.py index fbe8c970..0cc27067 100644 --- a/graphql/execution/executors/sync.py +++ b/graphql/execution/executors/sync.py @@ -10,7 +10,7 @@ def wait_until_finished(self): # type: () -> None pass - async def wait_until_finished_async(self): + def wait_until_finished_async(self): raise NotImplementedError def clean(self): diff --git a/graphql/execution/executors/thread.py b/graphql/execution/executors/thread.py index c57d4078..c15ccc5f 100644 --- a/graphql/execution/executors/thread.py +++ b/graphql/execution/executors/thread.py @@ -32,7 +32,7 @@ def wait_until_finished(self): for thread in threads: thread.join() - async def wait_until_finished_async(self): + def wait_until_finished_async(self): raise NotImplementedError def clean(self): diff --git a/graphql/execution/tests/test_executor_asyncio.py b/graphql/execution/tests/test_executor_asyncio.py index a6000c7c..c2379300 100644 --- a/graphql/execution/tests/test_executor_asyncio.py +++ b/graphql/execution/tests/test_executor_asyncio.py @@ -19,15 +19,17 @@ def test_asyncio_executor(): # type: () -> None + + @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.001) + yield from asyncio.sleep(0.001) return "hey" @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.003) + yield from asyncio.sleep(0.003) return "hey2" def resolver_3(contest, *_): @@ -50,17 +52,20 @@ def resolver_3(contest, *_): @pytest.mark.asyncio -async def test_asyncio_executor_exc_async(): +@asyncio.coroutine +def test_asyncio_executor_exc_async(): # type: () -> None + + @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.001) + yield from asyncio.sleep(0.001) return "hey" @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.003) + yield from asyncio.sleep(0.003) return "hey2" def resolver_3(contest, *_): @@ -77,7 +82,7 @@ def resolver_3(contest, *_): ) ast = parse("{ a b c }") - result = await execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + result = yield from execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) assert not result.errors assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} @@ -86,15 +91,16 @@ def test_asyncio_executor_custom_loop(): # type: () -> None loop = asyncio.get_event_loop() + @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.001, loop=loop) + yield from asyncio.sleep(0.001, loop=loop) return "hey" @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - asyncio.sleep(0.003, loop=loop) + yield from asyncio.sleep(0.003, loop=loop) return "hey2" def resolver_3(contest, *_): @@ -120,14 +126,16 @@ def test_asyncio_executor_with_error(): # type: () -> None ast = parse("query Example { a, b }") - async def resolver(context, *_): + @asyncio.coroutine + def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - await asyncio.sleep(0.001) + yield from asyncio.sleep(0.001) return "hey" - async def resolver_2(context, *_): + @asyncio.coroutine + def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> NoReturn - await asyncio.sleep(0.003) + yield from asyncio.sleep(0.003) raise Exception("resolver_2 failed!") Type = GraphQLObjectType( @@ -151,18 +159,21 @@ async def resolver_2(context, *_): @pytest.mark.asyncio -async def test_asyncio_executor_with_error_exc_async(): +@asyncio.coroutine +def test_asyncio_executor_with_error_exc_async(): # type: () -> None ast = parse("query Example { a, b }") - async def resolver(context, *_): + @asyncio.coroutine + def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - await asyncio.sleep(0.001) + yield from asyncio.sleep(0.001) return "hey" - async def resolver_2(context, *_): + @asyncio.coroutine + def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> NoReturn - await asyncio.sleep(0.003) + yield from asyncio.sleep(0.003) raise Exception("resolver_2 failed!") Type = GraphQLObjectType( @@ -173,7 +184,7 @@ async def resolver_2(context, *_): }, ) - result = await execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + result = yield from execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) formatted_errors = list(map(format_error, result.errors)) assert formatted_errors == [ { diff --git a/graphql/graphql.py b/graphql/graphql.py index ab3a6eac..0fcc41c3 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -1,3 +1,5 @@ +import asyncio + from .execution import ExecutionResult from .backend import get_default_backend @@ -44,9 +46,11 @@ def graphql(*args, **kwargs): return execute_graphql(*args, **kwargs) -async def graphql_async(*args, **kwargs): +@asyncio.coroutine +def graphql_async(*args, **kwargs): # type: (*Any, **Any) -> Union[ExecutionResult, Observable] - return await execute_graphql_async(*args, **kwargs) + result = yield from execute_graphql_async(*args, **kwargs) + return result def execute_graphql( @@ -78,7 +82,8 @@ def execute_graphql( return ExecutionResult(errors=[e], invalid=True) -async def execute_graphql_async( +@asyncio.coroutine +def execute_graphql_async( schema, # type: GraphQLSchema request_string="", # type: Union[Document, str] root=None, # type: Any @@ -95,7 +100,7 @@ async def execute_graphql_async( backend = get_default_backend() document = backend.document_from_string_async(schema, request_string) - return await document.execute( + result = yield from document.execute( root=root, context=context, operation_name=operation_name, @@ -103,6 +108,7 @@ async def execute_graphql_async( middleware=middleware, **execute_options ) + return result except Exception as e: return ExecutionResult(errors=[e], invalid=True) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 02b4fe28..a1f3e2c0 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,5 +1,10 @@ +# type: ignore +# flake8: noqa + import pytest +asyncio = pytest.importorskip("asyncio") + from graphql import graphql, graphql_async from graphql.error import format_error @@ -8,7 +13,8 @@ @pytest.fixture(params=['sync', 'async']) def execute_graphql(request): - async def _execute( + @asyncio.coroutine + def _execute( schema, query, variable_values=None @@ -16,26 +22,29 @@ async def _execute( if request.param == 'sync': return graphql(schema, query, variable_values=variable_values) else: - return await graphql_async(schema, query, variable_values=variable_values) + result = yield from graphql_async(schema, query, variable_values=variable_values) + return result return _execute @pytest.fixture def execute_and_validate_result(execute_graphql): - async def _execute_and_validate( + @asyncio.coroutine + def _execute_and_validate( schema, query, expected, variable_values=None ): - result = await execute_graphql(schema, query, variable_values=variable_values) + result = yield from execute_graphql(schema, query, variable_values=variable_values) assert not result.errors assert result.data == expected return _execute_and_validate @pytest.mark.asyncio -async def test_hero_name_query(execute_and_validate_result): +@asyncio.coroutine +def test_hero_name_query(execute_and_validate_result): query = """ query HeroNameQuery { hero { @@ -44,11 +53,12 @@ async def test_hero_name_query(execute_and_validate_result): } """ expected = {"hero": {"name": "R2-D2"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_hero_name_and_friends_query(execute_and_validate_result): +@asyncio.coroutine +def test_hero_name_and_friends_query(execute_and_validate_result): query = """ query HeroNameAndFriendsQuery { hero { @@ -71,11 +81,12 @@ async def test_hero_name_and_friends_query(execute_and_validate_result): ], } } - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_nested_query(execute_and_validate_result): +@asyncio.coroutine +def test_nested_query(execute_and_validate_result): query = """ query NestedQuery { hero { @@ -126,11 +137,12 @@ async def test_nested_query(execute_and_validate_result): ], } } - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_fetch_luke_query(execute_and_validate_result): +@asyncio.coroutine +def test_fetch_luke_query(execute_and_validate_result): query = """ query FetchLukeQuery { human(id: "1000") { @@ -139,11 +151,12 @@ async def test_fetch_luke_query(execute_and_validate_result): } """ expected = {"human": {"name": "Luke Skywalker"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_fetch_some_id_query(execute_and_validate_result): +@asyncio.coroutine +def test_fetch_some_id_query(execute_and_validate_result): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -153,11 +166,12 @@ async def test_fetch_some_id_query(execute_and_validate_result): """ params = {"someId": "1000"} expected = {"human": {"name": "Luke Skywalker"}} - await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) @pytest.mark.asyncio -async def test_fetch_some_id_query2(execute_and_validate_result): +@asyncio.coroutine +def test_fetch_some_id_query2(execute_and_validate_result): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -167,11 +181,12 @@ async def test_fetch_some_id_query2(execute_and_validate_result): """ params = {"someId": "1002"} expected = {"human": {"name": "Han Solo"}} - await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) @pytest.mark.asyncio -async def test_invalid_id_query(execute_and_validate_result): +@asyncio.coroutine +def test_invalid_id_query(execute_and_validate_result): query = """ query humanQuery($id: String!) { human(id: $id) { @@ -181,11 +196,12 @@ async def test_invalid_id_query(execute_and_validate_result): """ params = {"id": "not a valid id"} expected = {"human": None} - await execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) @pytest.mark.asyncio -async def test_fetch_luke_aliased(execute_and_validate_result): +@asyncio.coroutine +def test_fetch_luke_aliased(execute_and_validate_result): query = """ query FetchLukeAliased { luke: human(id: "1000") { @@ -194,11 +210,12 @@ async def test_fetch_luke_aliased(execute_and_validate_result): } """ expected = {"luke": {"name": "Luke Skywalker"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_fetch_luke_and_leia_aliased(execute_and_validate_result): +@asyncio.coroutine +def test_fetch_luke_and_leia_aliased(execute_and_validate_result): query = """ query FetchLukeAndLeiaAliased { luke: human(id: "1000") { @@ -210,11 +227,12 @@ async def test_fetch_luke_and_leia_aliased(execute_and_validate_result): } """ expected = {"luke": {"name": "Luke Skywalker"}, "leia": {"name": "Leia Organa"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_duplicate_fields(execute_and_validate_result): +@asyncio.coroutine +def test_duplicate_fields(execute_and_validate_result): query = """ query DuplicateFields { luke: human(id: "1000") { @@ -231,11 +249,12 @@ async def test_duplicate_fields(execute_and_validate_result): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_use_fragment(execute_and_validate_result): +@asyncio.coroutine +def test_use_fragment(execute_and_validate_result): query = """ query UseFragment { luke: human(id: "1000") { @@ -254,11 +273,12 @@ async def test_use_fragment(execute_and_validate_result): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_check_type_of_r2(execute_and_validate_result): +@asyncio.coroutine +def test_check_type_of_r2(execute_and_validate_result): query = """ query CheckTypeOfR2 { hero { @@ -268,11 +288,12 @@ async def test_check_type_of_r2(execute_and_validate_result): } """ expected = {"hero": {"__typename": "Droid", "name": "R2-D2"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_check_type_of_luke(execute_and_validate_result): +@asyncio.coroutine +def test_check_type_of_luke(execute_and_validate_result): query = """ query CheckTypeOfLuke { hero(episode: EMPIRE) { @@ -282,15 +303,16 @@ async def test_check_type_of_luke(execute_and_validate_result): } """ expected = {"hero": {"__typename": "Human", "name": "Luke Skywalker"}} - await execute_and_validate_result(StarWarsSchema, query, expected) + yield from execute_and_validate_result(StarWarsSchema, query, expected) @pytest.mark.asyncio -async def test_parse_error(execute_graphql): +@asyncio.coroutine +def test_parse_error(execute_graphql): query = """ qeury """ - result = await execute_graphql(StarWarsSchema, query) + result = yield from execute_graphql(StarWarsSchema, query) assert result.invalid formatted_error = format_error(result.errors[0]) assert formatted_error["locations"] == [{"column": 9, "line": 2}] @@ -299,3 +321,4 @@ async def test_parse_error(execute_graphql): in formatted_error["message"] ) assert result.data is None + diff --git a/tox.ini b/tox.ini index 91c49c26..7bc27ddb 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ deps = six>=1.10.0 pytest-mock pytest-benchmark + py{35,36,37}: pytest-asyncio commands = py{27,34,py}: py.test graphql tests {posargs} py{35,36,37}: py.test graphql tests tests_py35 {posargs} From 38ebbd9475bcc4e2a7978a9e20a8bb1581569e36 Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 06:35:12 +0530 Subject: [PATCH 3/6] fixed mypy errors --- graphql/backend/core.py | 8 +++++--- graphql/execution/executor.py | 4 ++-- graphql/execution/executors/asyncio.py | 4 ++-- graphql/graphql.py | 6 +++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/graphql/backend/core.py b/graphql/backend/core.py index ca87828a..1f745e24 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -12,7 +12,7 @@ # Necessary for static type checking if False: # flake8: noqa - from typing import Any, Optional, Union + from typing import Any, Optional, Union, Tuple from ..language.ast import Document from ..type.schema import GraphQLSchema from rx import Observable @@ -29,6 +29,7 @@ def _validate_document_ast( validation_errors = validate(schema, document_ast) if validation_errors: return ExecutionResult(errors=validation_errors, invalid=True) + return None def execute_and_validate( @@ -70,7 +71,7 @@ def __init__(self, executor=None): @staticmethod def _get_doc_str_and_ast(document_string): - # type: (Union[Document, str]) -> (str, ast.Document) + # type: (Union[ast.Document, str]) -> Tuple[str, ast.Document] if isinstance(document_string, ast.Document): document_ast = document_string document_string = print_ast(document_ast) @@ -82,7 +83,7 @@ def _get_doc_str_and_ast(document_string): return document_string, document_ast def document_from_string(self, schema, document_string): - # type: (GraphQLSchema, Union[Document, str]) -> GraphQLDocument + # type: (GraphQLSchema, Union[ast.Document, str]) -> GraphQLDocument document_string, document_ast = self._get_doc_str_and_ast(document_string) return GraphQLDocument( schema=schema, @@ -94,6 +95,7 @@ def document_from_string(self, schema, document_string): ) def document_from_string_async(self, schema, document_string): + # type: (GraphQLSchema, Union[ast.Document, str]) -> GraphQLDocument document_string, document_ast = self._get_doc_str_and_ast(document_string) return GraphQLDocument( schema=schema, diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index a6290400..9bbc7a64 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -44,7 +44,7 @@ # Necessary for static type checking if False: # flake8: noqa - from typing import Any, Optional, Union, Dict, List, Callable + from typing import Any, Optional, Union, Dict, List, Callable, Generator from ..language.ast import Document, OperationDefinition, Field logger = logging.getLogger(__name__) @@ -209,7 +209,7 @@ def execute_async( allow_subscriptions=False, # type: bool **options # type: Any ): - # type: (...) -> Union[ExecutionResult] + # type: (...) -> Generator if executor is None: executor = AsyncioExecutor() exe_context = prepare_execution_context( diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py index 53986771..7f0abbb4 100644 --- a/graphql/execution/executors/asyncio.py +++ b/graphql/execution/executors/asyncio.py @@ -9,7 +9,7 @@ # Necessary for static type checking if False: # flake8: noqa from asyncio.unix_events import _UnixSelectorEventLoop - from typing import Optional, Any, Callable, List + from typing import Optional, Any, Callable, List, Generator try: from asyncio import ensure_future @@ -60,7 +60,7 @@ def wait_until_finished(self): @coroutine def wait_until_finished_async(self): - # type: () -> None + # type: () -> Generator # if there are futures to wait for while self.futures: # wait for the futures to finish diff --git a/graphql/graphql.py b/graphql/graphql.py index 0fcc41c3..fe7c57dd 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -9,7 +9,7 @@ if False: # flake8: noqa from promise import Promise from rx import Observable - from typing import Any, Union, Optional + from typing import Any, Union, Optional, Generator from .language.ast import Document from .type.schema import GraphQLSchema @@ -48,7 +48,7 @@ def graphql(*args, **kwargs): @asyncio.coroutine def graphql_async(*args, **kwargs): - # type: (*Any, **Any) -> Union[ExecutionResult, Observable] + # type: (*Any, **Any) -> Generator result = yield from execute_graphql_async(*args, **kwargs) return result @@ -94,7 +94,7 @@ def execute_graphql_async( backend=None, # type: Optional[Any] **execute_options # type: Any ): - # type: (...) -> Union[ExecutionResult, Observable, Promise[ExecutionResult]] + # type: (...) -> Generator try: if backend is None: backend = get_default_backend() From a5cebfc5590aa83bd29771f497711b4e0668896d Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 06:42:03 +0530 Subject: [PATCH 4/6] fix pre-commit errors --- graphql/backend/core.py | 6 +-- graphql/execution/executor.py | 3 ++ graphql/execution/executors/base.py | 1 - .../execution/tests/test_executor_asyncio.py | 8 +++- tests/starwars/test_query.py | 40 ++++++++++--------- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/graphql/backend/core.py b/graphql/backend/core.py index 1f745e24..4ee2d07f 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -19,9 +19,9 @@ def _validate_document_ast( - schema, # type: GraphQLSchema - document_ast, # type: Document - **kwargs # type: Any + schema, # type: GraphQLSchema + document_ast, # type: Document + **kwargs # type: Any ): # type: (...) -> Union[ExecutionResult, None] do_validation = kwargs.get("validate", True) diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 9bbc7a64..67ae9ce8 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -123,6 +123,7 @@ def get_promise_executor(exe_context, root): def promise_executor(v): # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] return execute_operation(exe_context, exe_context.operation, root) + return promise_executor @@ -131,6 +132,7 @@ def on_rejected(error): # type: (Exception) -> None exe_context.errors.append(error) return None + return on_rejected @@ -144,6 +146,7 @@ def on_resolve(data): return ExecutionResult(data=data) return ExecutionResult(data=data, errors=exe_context.errors) + return on_resolve diff --git a/graphql/execution/executors/base.py b/graphql/execution/executors/base.py index ef3dd82d..dc33f2b5 100644 --- a/graphql/execution/executors/base.py +++ b/graphql/execution/executors/base.py @@ -2,7 +2,6 @@ class BaseExecutor(ABC): - @abstractmethod def wait_until_finished(self): pass diff --git a/graphql/execution/tests/test_executor_asyncio.py b/graphql/execution/tests/test_executor_asyncio.py index c2379300..8dbc6a3e 100644 --- a/graphql/execution/tests/test_executor_asyncio.py +++ b/graphql/execution/tests/test_executor_asyncio.py @@ -82,7 +82,9 @@ def resolver_3(contest, *_): ) ast = parse("{ a b c }") - result = yield from execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + result = yield from execute_async( + GraphQLSchema(Type), ast, executor=AsyncioExecutor() + ) assert not result.errors assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} @@ -184,7 +186,9 @@ def resolver_2(context, *_): }, ) - result = yield from execute_async(GraphQLSchema(Type), ast, executor=AsyncioExecutor()) + result = yield from execute_async( + GraphQLSchema(Type), ast, executor=AsyncioExecutor() + ) formatted_errors = list(map(format_error, result.errors)) assert formatted_errors == [ { diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index a1f3e2c0..2d72b224 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -11,34 +11,31 @@ from .starwars_schema import StarWarsSchema -@pytest.fixture(params=['sync', 'async']) +@pytest.fixture(params=["sync", "async"]) def execute_graphql(request): @asyncio.coroutine - def _execute( - schema, - query, - variable_values=None - ): - if request.param == 'sync': + def _execute(schema, query, variable_values=None): + if request.param == "sync": return graphql(schema, query, variable_values=variable_values) else: - result = yield from graphql_async(schema, query, variable_values=variable_values) + result = yield from graphql_async( + schema, query, variable_values=variable_values + ) return result + return _execute @pytest.fixture def execute_and_validate_result(execute_graphql): @asyncio.coroutine - def _execute_and_validate( - schema, - query, - expected, - variable_values=None - ): - result = yield from execute_graphql(schema, query, variable_values=variable_values) + def _execute_and_validate(schema, query, expected, variable_values=None): + result = yield from execute_graphql( + schema, query, variable_values=variable_values + ) assert not result.errors assert result.data == expected + return _execute_and_validate @@ -166,7 +163,9 @@ def test_fetch_some_id_query(execute_and_validate_result): """ params = {"someId": "1000"} expected = {"human": {"name": "Luke Skywalker"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result( + StarWarsSchema, query, expected, variable_values=params + ) @pytest.mark.asyncio @@ -181,7 +180,9 @@ def test_fetch_some_id_query2(execute_and_validate_result): """ params = {"someId": "1002"} expected = {"human": {"name": "Han Solo"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result( + StarWarsSchema, query, expected, variable_values=params + ) @pytest.mark.asyncio @@ -196,7 +197,9 @@ def test_invalid_id_query(execute_and_validate_result): """ params = {"id": "not a valid id"} expected = {"human": None} - yield from execute_and_validate_result(StarWarsSchema, query, expected, variable_values=params) + yield from execute_and_validate_result( + StarWarsSchema, query, expected, variable_values=params + ) @pytest.mark.asyncio @@ -321,4 +324,3 @@ def test_parse_error(execute_graphql): in formatted_error["message"] ) assert result.data is None - From 31eef8b8c92be1c10af2fb1fca6011b91f9c2817 Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 16:03:32 +0530 Subject: [PATCH 5/6] make code compatible with python2.7 --- graphql/__init__.py | 9 +- graphql/backend/async_util.py | 22 + graphql/backend/core.py | 45 +- graphql/backend/utils.py | 23 + graphql/execution/__init__.py | 9 +- graphql/execution/common.py | 723 +++++++++++++++++ graphql/execution/executor.py | 757 +----------------- graphql/execution/executor_async.py | 64 ++ graphql/execution/executors/base.py | 5 +- graphql/execution/tests/test_benchmark.py | 2 +- graphql/execution/tests/test_executor.py | 2 +- .../execution/tests/test_executor_asyncio.py | 97 +-- graphql/graphql.py | 41 - graphql/graphql_async.py | 51 ++ tests/starwars/test_query.py | 153 ++-- 15 files changed, 987 insertions(+), 1016 deletions(-) create mode 100644 graphql/backend/async_util.py create mode 100644 graphql/backend/utils.py create mode 100644 graphql/execution/common.py create mode 100644 graphql/execution/executor_async.py create mode 100644 graphql/graphql_async.py diff --git a/graphql/__init__.py b/graphql/__init__.py index b941dff0..63ac0da4 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -21,10 +21,17 @@ from graphql import parse from graphql.language.base import parse """ +import sys + from .pyutils.version import get_version # The primary entry point into fulfilling a GraphQL request. -from .graphql import graphql, graphql_async +from .graphql import graphql +if sys.version_info > (3, 3): + from .graphql_async import graphql_async +else: + def graphql_async(*args, **kwargs): + raise ImportError('graphql_async needs python>=3.4') # Create and operate on GraphQL type definitions and schema. from .type import ( # no import order diff --git a/graphql/backend/async_util.py b/graphql/backend/async_util.py new file mode 100644 index 00000000..c30631d1 --- /dev/null +++ b/graphql/backend/async_util.py @@ -0,0 +1,22 @@ +try: + import asyncio +except ImportError: + asyncio = None + +from .utils import validate_document_ast +from ..execution import execute_async + + +@asyncio.coroutine +def execute_and_validate_async( + schema, # type: GraphQLSchema + document_ast, # type: Document + *args, # type: Any + **kwargs # type: Any +): + # type: (...) -> Union[ExecutionResult, Observable] + execution_result = validate_document_ast(schema, document_ast, **kwargs) + if execution_result: + return execution_result + result = yield from execute_async(schema, document_ast, *args, **kwargs) + return result diff --git a/graphql/backend/core.py b/graphql/backend/core.py index 4ee2d07f..6ce4c277 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -1,15 +1,21 @@ -import asyncio - +import sys from functools import partial from six import string_types -from ..execution import execute, execute_async, ExecutionResult +from .utils import validate_document_ast +from ..execution import execute, ExecutionResult from ..language.base import parse, print_ast from ..language import ast -from ..validation import validate + from .base import GraphQLBackend, GraphQLDocument +if sys.version_info > (3, 3): + from .async_util import execute_and_validate_async +else: + def execute_and_validate_async(*args, **kwargs): + raise ImportError('execute_and_validate_async needs python>=3.4') + # Necessary for static type checking if False: # flake8: noqa from typing import Any, Optional, Union, Tuple @@ -18,20 +24,6 @@ from rx import Observable -def _validate_document_ast( - schema, # type: GraphQLSchema - document_ast, # type: Document - **kwargs # type: Any -): - # type: (...) -> Union[ExecutionResult, None] - do_validation = kwargs.get("validate", True) - if do_validation: - validation_errors = validate(schema, document_ast) - if validation_errors: - return ExecutionResult(errors=validation_errors, invalid=True) - return None - - def execute_and_validate( schema, # type: GraphQLSchema document_ast, # type: Document @@ -39,28 +31,13 @@ def execute_and_validate( **kwargs # type: Any ): # type: (...) -> Union[ExecutionResult, Observable] - execution_result = _validate_document_ast(schema, document_ast, **kwargs) + execution_result = validate_document_ast(schema, document_ast, **kwargs) if execution_result: return execution_result return execute(schema, document_ast, *args, **kwargs) -@asyncio.coroutine -def execute_and_validate_async( - schema, # type: GraphQLSchema - document_ast, # type: Document - *args, # type: Any - **kwargs # type: Any -): - # type: (...) -> Union[ExecutionResult, Observable] - execution_result = _validate_document_ast(schema, document_ast, **kwargs) - if execution_result: - return execution_result - result = yield from execute_async(schema, document_ast, *args, **kwargs) - return result - - class GraphQLCoreBackend(GraphQLBackend): """GraphQLCoreBackend will return a document using the default graphql executor""" diff --git a/graphql/backend/utils.py b/graphql/backend/utils.py new file mode 100644 index 00000000..6d8368cd --- /dev/null +++ b/graphql/backend/utils.py @@ -0,0 +1,23 @@ +from ..execution import ExecutionResult +from ..validation import validate + + +# Necessary for static type checking +if False: # flake8: noqa + from typing import Any, Union + from ..language.ast import Document + from ..type.schema import GraphQLSchema + + +def validate_document_ast( + schema, # type: GraphQLSchema + document_ast, # type: Document + **kwargs # type: Any +): + # type: (...) -> Union[ExecutionResult, None] + do_validation = kwargs.get("validate", True) + if do_validation: + validation_errors = validate(schema, document_ast) + if validation_errors: + return ExecutionResult(errors=validation_errors, invalid=True) + return None diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index 8edf0afe..180df2e9 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -18,10 +18,17 @@ 2) fragment "spreads" e.g. "...c" 3) inline fragment "spreads" e.g. "...on Type { a }" """ -from .executor import execute, execute_async, subscribe +import sys + +from .executor import execute, subscribe from .base import ExecutionResult, ResolveInfo from .middleware import middlewares, MiddlewareManager +if sys.version_info > (3, 3): + from .executor_async import execute_async +else: + def execute_async(*args, **kwargs): + raise ImportError('execute_async needs python>=3.4') __all__ = [ "execute", diff --git a/graphql/execution/common.py b/graphql/execution/common.py new file mode 100644 index 00000000..fa4b382a --- /dev/null +++ b/graphql/execution/common.py @@ -0,0 +1,723 @@ +import collections + +try: + from collections.abc import Iterable +except ImportError: # Python < 3.3 + from collections import Iterable +import functools +import logging +import sys +import warnings +from rx import Observable + +from six import string_types +from promise import Promise, promise_for_dict, is_thenable + +from ..error import GraphQLError, GraphQLLocatedError +from ..pyutils.default_ordered_dict import DefaultOrderedDict +from ..pyutils.ordereddict import OrderedDict +from ..utils.undefined import Undefined +from ..type import ( + GraphQLEnumType, + GraphQLInterfaceType, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, + GraphQLUnionType, +) +from .base import ( + ExecutionContext, + ExecutionResult, + ResolveInfo, + collect_fields, + default_resolve_fn, + get_field_def, + get_operation_root_type, + SubscriberExecutionContext, +) +from .middleware import MiddlewareManager + +# Necessary for static type checking +if False: # flake8: noqa + from typing import Any, Optional, Union, Dict, List, Callable, Generator + from ..language.ast import Document, OperationDefinition, Field + +logger = logging.getLogger(__name__) + + +def prepare_execution_context( + schema, # type: GraphQLSchema + document_ast, # type: Document + root, # type: Any + context, # type: Optional[Any] + variables, # type: Optional[Any] + operation_name, # type: Optional[str] + executor, # type: Any + middleware, # type: Optional[Any] + allow_subscriptions, # type: bool + **options # type: Any +): + if root is None and "root_value" in options: + warnings.warn( + "root_value has been deprecated. Please use root=... instead.", + category=DeprecationWarning, + stacklevel=2, + ) + root = options["root_value"] + if context is None and "context_value" in options: + warnings.warn( + "context_value has been deprecated. Please use context=... instead.", + category=DeprecationWarning, + stacklevel=2, + ) + context = options["context_value"] + if variables is None and "variable_values" in options: + warnings.warn( + "variable_values has been deprecated. Please use variables=... instead.", + category=DeprecationWarning, + stacklevel=2, + ) + variables = options["variable_values"] + assert schema, "Must provide schema" + assert isinstance(schema, GraphQLSchema), ( + "Schema must be an instance of GraphQLSchema. Also ensure that there are " + + "not multiple versions of GraphQL installed in your node_modules directory." + ) + + if middleware: + if not isinstance(middleware, MiddlewareManager): + middleware = MiddlewareManager(*middleware) + + assert isinstance(middleware, MiddlewareManager), ( + "middlewares have to be an instance" + ' of MiddlewareManager. Received "{}".'.format(middleware) + ) + + return ExecutionContext( + schema, + document_ast, + root, + context, + variables or {}, + operation_name, + executor, + middleware, + allow_subscriptions, + ) + + +def get_promise_executor(exe_context, root): + def promise_executor(v): + # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] + return execute_operation(exe_context, exe_context.operation, root) + + return promise_executor + + +def get_on_rejected(exe_context): + def on_rejected(error): + # type: (Exception) -> None + exe_context.errors.append(error) + return None + + return on_rejected + + +def get_on_resolve(exe_context): + def on_resolve(data): + # type: (Union[None, Dict, Observable]) -> Union[ExecutionResult, Observable] + if isinstance(data, Observable): + return data + + if not exe_context.errors: + return ExecutionResult(data=data) + + return ExecutionResult(data=data, errors=exe_context.errors) + + return on_resolve + +def execute_operation( + exe_context, # type: ExecutionContext + operation, # type: OperationDefinition + root_value, # type: Any +): + # type: (...) -> Union[Dict, Promise[Dict]] + type = get_operation_root_type(exe_context.schema, operation) + fields = collect_fields( + exe_context, type, operation.selection_set, DefaultOrderedDict(list), set() + ) + + if operation.operation == "mutation": + return execute_fields_serially(exe_context, type, root_value, [], fields) + + if operation.operation == "subscription": + if not exe_context.allow_subscriptions: + raise Exception( + "Subscriptions are not allowed. " + "You will need to either use the subscribe function " + "or pass allow_subscriptions=True" + ) + return subscribe_fields(exe_context, type, root_value, fields) + + return execute_fields(exe_context, type, root_value, fields, [], None) + + +def execute_fields_serially( + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + path, # type: List + fields, # type: DefaultOrderedDict +): + # type: (...) -> Promise + def execute_field_callback(results, response_name): + # type: (Dict, str) -> Union[Dict, Promise[Dict]] + field_asts = fields[response_name] + result = resolve_field( + exe_context, + parent_type, + source_value, + field_asts, + None, + path + [response_name], + ) + if result is Undefined: + return results + + if is_thenable(result): + + def collect_result(resolved_result): + # type: (Dict) -> Dict + results[response_name] = resolved_result + return results + + return result.then(collect_result, None) + + results[response_name] = result + return results + + def execute_field(prev_promise, response_name): + # type: (Promise, str) -> Promise + return prev_promise.then( + lambda results: execute_field_callback(results, response_name) + ) + + return functools.reduce( + execute_field, fields.keys(), Promise.resolve(collections.OrderedDict()) + ) + + +def execute_fields( + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict + path, # type: List[Union[int, str]] + info, # type: Optional[ResolveInfo] +): + # type: (...) -> Union[Dict, Promise[Dict]] + contains_promise = False + + final_results = OrderedDict() + + for response_name, field_asts in fields.items(): + result = resolve_field( + exe_context, + parent_type, + source_value, + field_asts, + info, + path + [response_name], + ) + if result is Undefined: + continue + + final_results[response_name] = result + if is_thenable(result): + contains_promise = True + + if not contains_promise: + return final_results + + return promise_for_dict(final_results) + + +def subscribe_fields( + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source_value, # type: Any + fields, # type: DefaultOrderedDict +): + # type: (...) -> Observable + subscriber_exe_context = SubscriberExecutionContext(exe_context) + + def on_error(error): + subscriber_exe_context.report_error(error) + + def map_result(data): + # type: (Dict[str, Any]) -> ExecutionResult + if subscriber_exe_context.errors: + result = ExecutionResult(data=data, errors=subscriber_exe_context.errors) + else: + result = ExecutionResult(data=data) + subscriber_exe_context.reset() + return result + + observables = [] # type: List[Observable] + + # assert len(fields) == 1, "Can only subscribe one element at a time." + + for response_name, field_asts in fields.items(): + result = subscribe_field( + subscriber_exe_context, + parent_type, + source_value, + field_asts, + [response_name], + ) + if result is Undefined: + continue + + def catch_error(error): + subscriber_exe_context.errors.append(error) + return Observable.just(None) + + # Map observable results + observable = result.catch_exception(catch_error).map( + lambda data: map_result({response_name: data}) + ) + return observable + observables.append(observable) + + return Observable.merge(observables) + + +def resolve_field( + exe_context, # type: ExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + parent_info, # type: Optional[ResolveInfo] + field_path, # type: List[Union[int, str]] +): + # type: (...) -> Any + field_ast = field_asts[0] + field_name = field_ast.name.value + + field_def = get_field_def(exe_context.schema, parent_type, field_name) + if not field_def: + return Undefined + + return_type = field_def.type + resolve_fn = field_def.resolver or default_resolve_fn + + # We wrap the resolve_fn from the middleware + resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) + + # Build a dict of arguments from the field.arguments AST, using the variables scope to + # fulfill any variable references. + args = exe_context.get_argument_values(field_def, field_ast) + + # The resolve function's optional third argument is a context value that + # is provided to every resolve function within an execution. It is commonly + # used to represent an authenticated user, or request-specific caches. + context = exe_context.context_value + + # The resolve function's optional third argument is a collection of + # information about the current execution state. + info = ResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + schema=exe_context.schema, + fragments=exe_context.fragments, + root_value=exe_context.root_value, + operation=exe_context.operation, + variable_values=exe_context.variable_values, + context=context, + path=field_path, + ) + + executor = exe_context.executor + result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) + + return complete_value_catching_error( + exe_context, return_type, field_asts, info, field_path, result + ) + + +def subscribe_field( + exe_context, # type: SubscriberExecutionContext + parent_type, # type: GraphQLObjectType + source, # type: Any + field_asts, # type: List[Field] + path, # type: List[str] +): + # type: (...) -> Observable + field_ast = field_asts[0] + field_name = field_ast.name.value + + field_def = get_field_def(exe_context.schema, parent_type, field_name) + if not field_def: + return Undefined + + return_type = field_def.type + resolve_fn = field_def.resolver or default_resolve_fn + + # We wrap the resolve_fn from the middleware + resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) + + # Build a dict of arguments from the field.arguments AST, using the variables scope to + # fulfill any variable references. + args = exe_context.get_argument_values(field_def, field_ast) + + # The resolve function's optional third argument is a context value that + # is provided to every resolve function within an execution. It is commonly + # used to represent an authenticated user, or request-specific caches. + context = exe_context.context_value + + # The resolve function's optional third argument is a collection of + # information about the current execution state. + info = ResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + schema=exe_context.schema, + fragments=exe_context.fragments, + root_value=exe_context.root_value, + operation=exe_context.operation, + variable_values=exe_context.variable_values, + context=context, + path=path, + ) + + executor = exe_context.executor + result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) + + if isinstance(result, Exception): + raise result + + if not isinstance(result, Observable): + raise GraphQLError( + "Subscription must return Async Iterable or Observable. Received: {}".format( + repr(result) + ) + ) + + return result.map( + functools.partial( + complete_value_catching_error, + exe_context, + return_type, + field_asts, + info, + path, + ) + ) + + +def resolve_or_error( + resolve_fn, # type: Callable + source, # type: Any + info, # type: ResolveInfo + args, # type: Dict + executor, # type: Any +): + # type: (...) -> Any + try: + return executor.execute(resolve_fn, source, info, **args) + except Exception as e: + logger.exception( + "An error occurred while resolving field {}.{}".format( + info.parent_type.name, info.field_name + ) + ) + e.stack = sys.exc_info()[2] # type: ignore + return e + + +def complete_value_catching_error( + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Any + # If the field type is non-nullable, then it is resolved without any + # protection from errors. + if isinstance(return_type, GraphQLNonNull): + return complete_value(exe_context, return_type, field_asts, info, path, result) + + # Otherwise, error protection is applied, logging the error and + # resolving a null value for this field if one is encountered. + try: + completed = complete_value( + exe_context, return_type, field_asts, info, path, result + ) + if is_thenable(completed): + + def handle_error(error): + # type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any] + traceback = completed._traceback # type: ignore + exe_context.report_error(error, traceback) + return None + + return completed.catch(handle_error) + + return completed + except Exception as e: + traceback = sys.exc_info()[2] + exe_context.report_error(e, traceback) + return None + + +def complete_value( + exe_context, # type: ExecutionContext + return_type, # type: Any + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Any + """ + Implements the instructions for completeValue as defined in the + "Field entries" section of the spec. + + If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field + error if that completion returns null, as per the "Nullability" section of the spec. + + If the field type is a List, then this recursively completes the value for the inner type on each item in the + list. + + If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the + `serialize` method of GraphQL type definition. + + If the field is an abstract type, determine the runtime type of the value and then complete based on that type. + + Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all + sub-selections. + """ + # If field type is NonNull, complete for inner type, and throw field error + # if result is null. + if is_thenable(result): + return Promise.resolve(result).then( + lambda resolved: complete_value( + exe_context, return_type, field_asts, info, path, resolved + ), + lambda error: Promise.rejected( + GraphQLLocatedError(field_asts, original_error=error, path=path) + ), + ) + + # print return_type, type(result) + if isinstance(result, Exception): + raise GraphQLLocatedError(field_asts, original_error=result, path=path) + + if isinstance(return_type, GraphQLNonNull): + return complete_nonnull_value( + exe_context, return_type, field_asts, info, path, result + ) + + # If result is null-like, return null. + if result is None: + return None + + # If field type is List, complete each item in the list with the inner type + if isinstance(return_type, GraphQLList): + return complete_list_value( + exe_context, return_type, field_asts, info, path, result + ) + + # If field type is Scalar or Enum, serialize to a valid value, returning + # null if coercion is not possible. + if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): + return complete_leaf_value(return_type, path, result) + + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + return complete_abstract_value( + exe_context, return_type, field_asts, info, path, result + ) + + if isinstance(return_type, GraphQLObjectType): + return complete_object_value( + exe_context, return_type, field_asts, info, path, result + ) + + assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) + + +def complete_list_value( + exe_context, # type: ExecutionContext + return_type, # type: GraphQLList + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> List[Any] + """ + Complete a list value by completing each item in the list with the inner type + """ + assert isinstance(result, Iterable), ( + "User Error: expected iterable, but did not find one " + "for field {}.{}." + ).format(info.parent_type, info.field_name) + + item_type = return_type.of_type + completed_results = [] + contains_promise = False + + index = 0 + for item in result: + completed_item = complete_value_catching_error( + exe_context, item_type, field_asts, info, path + [index], item + ) + if not contains_promise and is_thenable(completed_item): + contains_promise = True + + completed_results.append(completed_item) + index += 1 + + return Promise.all(completed_results) if contains_promise else completed_results + + +def complete_leaf_value( + return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Union[int, str, float, bool] + """ + Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. + """ + assert hasattr(return_type, "serialize"), "Missing serialize method on type" + serialized_result = return_type.serialize(result) + + if serialized_result is None: + raise GraphQLError( + ('Expected a value of type "{}" but ' + "received: {}").format( + return_type, result + ), + path=path, + ) + return serialized_result + + +def complete_abstract_value( + exe_context, # type: ExecutionContext + return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Dict[str, Any] + """ + Complete an value of an abstract type by determining the runtime type of that value, then completing based + on that type. + """ + runtime_type = None # type: Union[str, GraphQLObjectType, None] + + # Field type must be Object, Interface or Union and expect sub-selections. + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + if return_type.resolve_type: + runtime_type = return_type.resolve_type(result, info) + else: + runtime_type = get_default_resolve_type_fn(result, info, return_type) + + if isinstance(runtime_type, string_types): + runtime_type = info.schema.get_type(runtime_type) # type: ignore + + if not isinstance(runtime_type, GraphQLObjectType): + raise GraphQLError( + ( + "Abstract type {} must resolve to an Object type at runtime " + + 'for field {}.{} with value "{}", received "{}".' + ).format( + return_type, info.parent_type, info.field_name, result, runtime_type + ), + field_asts, + ) + + if not exe_context.schema.is_possible_type(return_type, runtime_type): + raise GraphQLError( + u'Runtime Object type "{}" is not a possible type for "{}".'.format( + runtime_type, return_type + ), + field_asts, + ) + + return complete_object_value( + exe_context, runtime_type, field_asts, info, path, result + ) + + +def get_default_resolve_type_fn( + value, # type: Any + info, # type: ResolveInfo + abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] +): + # type: (...) -> Optional[GraphQLObjectType] + possible_types = info.schema.get_possible_types(abstract_type) + for type in possible_types: + if callable(type.is_type_of) and type.is_type_of(value, info): + return type + return None + + +def complete_object_value( + exe_context, # type: ExecutionContext + return_type, # type: GraphQLObjectType + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Dict[str, Any] + """ + Complete an Object value by evaluating all sub-selections. + """ + if return_type.is_type_of and not return_type.is_type_of(result, info): + raise GraphQLError( + u'Expected value of type "{}" but got: {}.'.format( + return_type, type(result).__name__ + ), + field_asts, + ) + + # Collect sub-fields to execute to complete this value. + subfield_asts = exe_context.get_sub_fields(return_type, field_asts) + return execute_fields(exe_context, return_type, result, subfield_asts, path, info) + + +def complete_nonnull_value( + exe_context, # type: ExecutionContext + return_type, # type: GraphQLNonNull + field_asts, # type: List[Field] + info, # type: ResolveInfo + path, # type: List[Union[int, str]] + result, # type: Any +): + # type: (...) -> Any + """ + Complete a NonNull value by completing the inner type + """ + completed = complete_value( + exe_context, return_type.of_type, field_asts, info, path, result + ) + if completed is None: + raise GraphQLError( + "Cannot return null for non-nullable field {}.{}.".format( + info.parent_type, info.field_name + ), + field_asts, + path=path, + ) + + return completed diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 67ae9ce8..960d7075 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -1,51 +1,31 @@ -import asyncio -import collections - try: from collections.abc import Iterable except ImportError: # Python < 3.3 from collections import Iterable -import functools import logging -import sys -import warnings + from rx import Observable -from six import string_types -from promise import Promise, promise_for_dict, is_thenable +from promise import Promise -from ..error import GraphQLError, GraphQLLocatedError -from ..pyutils.default_ordered_dict import DefaultOrderedDict -from ..pyutils.ordereddict import OrderedDict -from ..utils.undefined import Undefined from ..type import ( - GraphQLEnumType, - GraphQLInterfaceType, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLScalarType, GraphQLSchema, - GraphQLUnionType, ) from .base import ( - ExecutionContext, ExecutionResult, - ResolveInfo, - collect_fields, - default_resolve_fn, - get_field_def, - get_operation_root_type, - SubscriberExecutionContext, ) -from .executors.asyncio import AsyncioExecutor from .executors.sync import SyncExecutor -from .middleware import MiddlewareManager +from .common import ( + prepare_execution_context, + get_promise_executor, + get_on_rejected, + get_on_resolve, +) # Necessary for static type checking if False: # flake8: noqa - from typing import Any, Optional, Union, Dict, List, Callable, Generator - from ..language.ast import Document, OperationDefinition, Field + from typing import Any, Optional, Union + from ..language.ast import Document logger = logging.getLogger(__name__) @@ -58,98 +38,6 @@ def subscribe(*args, **kwargs): ) -def prepare_execution_context( - schema, # type: GraphQLSchema - document_ast, # type: Document - root, # type: Any - context, # type: Optional[Any] - variables, # type: Optional[Any] - operation_name, # type: Optional[str] - executor, # type: Any - middleware, # type: Optional[Any] - allow_subscriptions, # type: bool - **options # type: Any -): - if root is None and "root_value" in options: - warnings.warn( - "root_value has been deprecated. Please use root=... instead.", - category=DeprecationWarning, - stacklevel=2, - ) - root = options["root_value"] - if context is None and "context_value" in options: - warnings.warn( - "context_value has been deprecated. Please use context=... instead.", - category=DeprecationWarning, - stacklevel=2, - ) - context = options["context_value"] - if variables is None and "variable_values" in options: - warnings.warn( - "variable_values has been deprecated. Please use variables=... instead.", - category=DeprecationWarning, - stacklevel=2, - ) - variables = options["variable_values"] - assert schema, "Must provide schema" - assert isinstance(schema, GraphQLSchema), ( - "Schema must be an instance of GraphQLSchema. Also ensure that there are " - + "not multiple versions of GraphQL installed in your node_modules directory." - ) - - if middleware: - if not isinstance(middleware, MiddlewareManager): - middleware = MiddlewareManager(*middleware) - - assert isinstance(middleware, MiddlewareManager), ( - "middlewares have to be an instance" - ' of MiddlewareManager. Received "{}".'.format(middleware) - ) - - return ExecutionContext( - schema, - document_ast, - root, - context, - variables or {}, - operation_name, - executor, - middleware, - allow_subscriptions, - ) - - -def get_promise_executor(exe_context, root): - def promise_executor(v): - # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable] - return execute_operation(exe_context, exe_context.operation, root) - - return promise_executor - - -def get_on_rejected(exe_context): - def on_rejected(error): - # type: (Exception) -> None - exe_context.errors.append(error) - return None - - return on_rejected - - -def get_on_resolve(exe_context): - def on_resolve(data): - # type: (Union[None, Dict, Observable]) -> Union[ExecutionResult, Observable] - if isinstance(data, Observable): - return data - - if not exe_context.errors: - return ExecutionResult(data=data) - - return ExecutionResult(data=data, errors=exe_context.errors) - - return on_resolve - - def execute( schema, # type: GraphQLSchema document_ast, # type: Document @@ -197,628 +85,3 @@ def execute( clean() return promise - - -@asyncio.coroutine -def execute_async( - schema, # type: GraphQLSchema - document_ast, # type: Document - root=None, # type: Any - context=None, # type: Optional[Any] - variables=None, # type: Optional[Any] - operation_name=None, # type: Optional[str] - executor=None, # type: Any - middleware=None, # type: Optional[Any] - allow_subscriptions=False, # type: bool - **options # type: Any -): - # type: (...) -> Generator - if executor is None: - executor = AsyncioExecutor() - exe_context = prepare_execution_context( - schema, - document_ast, - root, - context, - variables, - operation_name, - executor, - middleware, - allow_subscriptions, - **options - ) - - promise_executor = get_promise_executor(exe_context, root) - on_rejected = get_on_rejected(exe_context) - on_resolve = get_on_resolve(exe_context) - promise = ( - Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) - ) - - yield from exe_context.executor.wait_until_finished_async() - return promise.get() - - -def execute_operation( - exe_context, # type: ExecutionContext - operation, # type: OperationDefinition - root_value, # type: Any -): - # type: (...) -> Union[Dict, Promise[Dict]] - type = get_operation_root_type(exe_context.schema, operation) - fields = collect_fields( - exe_context, type, operation.selection_set, DefaultOrderedDict(list), set() - ) - - if operation.operation == "mutation": - return execute_fields_serially(exe_context, type, root_value, [], fields) - - if operation.operation == "subscription": - if not exe_context.allow_subscriptions: - raise Exception( - "Subscriptions are not allowed. " - "You will need to either use the subscribe function " - "or pass allow_subscriptions=True" - ) - return subscribe_fields(exe_context, type, root_value, fields) - - return execute_fields(exe_context, type, root_value, fields, [], None) - - -def execute_fields_serially( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - path, # type: List - fields, # type: DefaultOrderedDict -): - # type: (...) -> Promise - def execute_field_callback(results, response_name): - # type: (Dict, str) -> Union[Dict, Promise[Dict]] - field_asts = fields[response_name] - result = resolve_field( - exe_context, - parent_type, - source_value, - field_asts, - None, - path + [response_name], - ) - if result is Undefined: - return results - - if is_thenable(result): - - def collect_result(resolved_result): - # type: (Dict) -> Dict - results[response_name] = resolved_result - return results - - return result.then(collect_result, None) - - results[response_name] = result - return results - - def execute_field(prev_promise, response_name): - # type: (Promise, str) -> Promise - return prev_promise.then( - lambda results: execute_field_callback(results, response_name) - ) - - return functools.reduce( - execute_field, fields.keys(), Promise.resolve(collections.OrderedDict()) - ) - - -def execute_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict - path, # type: List[Union[int, str]] - info, # type: Optional[ResolveInfo] -): - # type: (...) -> Union[Dict, Promise[Dict]] - contains_promise = False - - final_results = OrderedDict() - - for response_name, field_asts in fields.items(): - result = resolve_field( - exe_context, - parent_type, - source_value, - field_asts, - info, - path + [response_name], - ) - if result is Undefined: - continue - - final_results[response_name] = result - if is_thenable(result): - contains_promise = True - - if not contains_promise: - return final_results - - return promise_for_dict(final_results) - - -def subscribe_fields( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source_value, # type: Any - fields, # type: DefaultOrderedDict -): - # type: (...) -> Observable - subscriber_exe_context = SubscriberExecutionContext(exe_context) - - def on_error(error): - subscriber_exe_context.report_error(error) - - def map_result(data): - # type: (Dict[str, Any]) -> ExecutionResult - if subscriber_exe_context.errors: - result = ExecutionResult(data=data, errors=subscriber_exe_context.errors) - else: - result = ExecutionResult(data=data) - subscriber_exe_context.reset() - return result - - observables = [] # type: List[Observable] - - # assert len(fields) == 1, "Can only subscribe one element at a time." - - for response_name, field_asts in fields.items(): - result = subscribe_field( - subscriber_exe_context, - parent_type, - source_value, - field_asts, - [response_name], - ) - if result is Undefined: - continue - - def catch_error(error): - subscriber_exe_context.errors.append(error) - return Observable.just(None) - - # Map observable results - observable = result.catch_exception(catch_error).map( - lambda data: map_result({response_name: data}) - ) - return observable - observables.append(observable) - - return Observable.merge(observables) - - -def resolve_field( - exe_context, # type: ExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - parent_info, # type: Optional[ResolveInfo] - field_path, # type: List[Union[int, str]] -): - # type: (...) -> Any - field_ast = field_asts[0] - field_name = field_ast.name.value - - field_def = get_field_def(exe_context.schema, parent_type, field_name) - if not field_def: - return Undefined - - return_type = field_def.type - resolve_fn = field_def.resolver or default_resolve_fn - - # We wrap the resolve_fn from the middleware - resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) - - # Build a dict of arguments from the field.arguments AST, using the variables scope to - # fulfill any variable references. - args = exe_context.get_argument_values(field_def, field_ast) - - # The resolve function's optional third argument is a context value that - # is provided to every resolve function within an execution. It is commonly - # used to represent an authenticated user, or request-specific caches. - context = exe_context.context_value - - # The resolve function's optional third argument is a collection of - # information about the current execution state. - info = ResolveInfo( - field_name, - field_asts, - return_type, - parent_type, - schema=exe_context.schema, - fragments=exe_context.fragments, - root_value=exe_context.root_value, - operation=exe_context.operation, - variable_values=exe_context.variable_values, - context=context, - path=field_path, - ) - - executor = exe_context.executor - result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) - - return complete_value_catching_error( - exe_context, return_type, field_asts, info, field_path, result - ) - - -def subscribe_field( - exe_context, # type: SubscriberExecutionContext - parent_type, # type: GraphQLObjectType - source, # type: Any - field_asts, # type: List[Field] - path, # type: List[str] -): - # type: (...) -> Observable - field_ast = field_asts[0] - field_name = field_ast.name.value - - field_def = get_field_def(exe_context.schema, parent_type, field_name) - if not field_def: - return Undefined - - return_type = field_def.type - resolve_fn = field_def.resolver or default_resolve_fn - - # We wrap the resolve_fn from the middleware - resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) - - # Build a dict of arguments from the field.arguments AST, using the variables scope to - # fulfill any variable references. - args = exe_context.get_argument_values(field_def, field_ast) - - # The resolve function's optional third argument is a context value that - # is provided to every resolve function within an execution. It is commonly - # used to represent an authenticated user, or request-specific caches. - context = exe_context.context_value - - # The resolve function's optional third argument is a collection of - # information about the current execution state. - info = ResolveInfo( - field_name, - field_asts, - return_type, - parent_type, - schema=exe_context.schema, - fragments=exe_context.fragments, - root_value=exe_context.root_value, - operation=exe_context.operation, - variable_values=exe_context.variable_values, - context=context, - path=path, - ) - - executor = exe_context.executor - result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) - - if isinstance(result, Exception): - raise result - - if not isinstance(result, Observable): - raise GraphQLError( - "Subscription must return Async Iterable or Observable. Received: {}".format( - repr(result) - ) - ) - - return result.map( - functools.partial( - complete_value_catching_error, - exe_context, - return_type, - field_asts, - info, - path, - ) - ) - - -def resolve_or_error( - resolve_fn, # type: Callable - source, # type: Any - info, # type: ResolveInfo - args, # type: Dict - executor, # type: Any -): - # type: (...) -> Any - try: - return executor.execute(resolve_fn, source, info, **args) - except Exception as e: - logger.exception( - "An error occurred while resolving field {}.{}".format( - info.parent_type.name, info.field_name - ) - ) - e.stack = sys.exc_info()[2] # type: ignore - return e - - -def complete_value_catching_error( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Any - # If the field type is non-nullable, then it is resolved without any - # protection from errors. - if isinstance(return_type, GraphQLNonNull): - return complete_value(exe_context, return_type, field_asts, info, path, result) - - # Otherwise, error protection is applied, logging the error and - # resolving a null value for this field if one is encountered. - try: - completed = complete_value( - exe_context, return_type, field_asts, info, path, result - ) - if is_thenable(completed): - - def handle_error(error): - # type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any] - traceback = completed._traceback # type: ignore - exe_context.report_error(error, traceback) - return None - - return completed.catch(handle_error) - - return completed - except Exception as e: - traceback = sys.exc_info()[2] - exe_context.report_error(e, traceback) - return None - - -def complete_value( - exe_context, # type: ExecutionContext - return_type, # type: Any - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Any - """ - Implements the instructions for completeValue as defined in the - "Field entries" section of the spec. - - If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field - error if that completion returns null, as per the "Nullability" section of the spec. - - If the field type is a List, then this recursively completes the value for the inner type on each item in the - list. - - If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the - `serialize` method of GraphQL type definition. - - If the field is an abstract type, determine the runtime type of the value and then complete based on that type. - - Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all - sub-selections. - """ - # If field type is NonNull, complete for inner type, and throw field error - # if result is null. - if is_thenable(result): - return Promise.resolve(result).then( - lambda resolved: complete_value( - exe_context, return_type, field_asts, info, path, resolved - ), - lambda error: Promise.rejected( - GraphQLLocatedError(field_asts, original_error=error, path=path) - ), - ) - - # print return_type, type(result) - if isinstance(result, Exception): - raise GraphQLLocatedError(field_asts, original_error=result, path=path) - - if isinstance(return_type, GraphQLNonNull): - return complete_nonnull_value( - exe_context, return_type, field_asts, info, path, result - ) - - # If result is null-like, return null. - if result is None: - return None - - # If field type is List, complete each item in the list with the inner type - if isinstance(return_type, GraphQLList): - return complete_list_value( - exe_context, return_type, field_asts, info, path, result - ) - - # If field type is Scalar or Enum, serialize to a valid value, returning - # null if coercion is not possible. - if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): - return complete_leaf_value(return_type, path, result) - - if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): - return complete_abstract_value( - exe_context, return_type, field_asts, info, path, result - ) - - if isinstance(return_type, GraphQLObjectType): - return complete_object_value( - exe_context, return_type, field_asts, info, path, result - ) - - assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) - - -def complete_list_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLList - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> List[Any] - """ - Complete a list value by completing each item in the list with the inner type - """ - assert isinstance(result, Iterable), ( - "User Error: expected iterable, but did not find one " + "for field {}.{}." - ).format(info.parent_type, info.field_name) - - item_type = return_type.of_type - completed_results = [] - contains_promise = False - - index = 0 - for item in result: - completed_item = complete_value_catching_error( - exe_context, item_type, field_asts, info, path + [index], item - ) - if not contains_promise and is_thenable(completed_item): - contains_promise = True - - completed_results.append(completed_item) - index += 1 - - return Promise.all(completed_results) if contains_promise else completed_results - - -def complete_leaf_value( - return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Union[int, str, float, bool] - """ - Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. - """ - assert hasattr(return_type, "serialize"), "Missing serialize method on type" - serialized_result = return_type.serialize(result) - - if serialized_result is None: - raise GraphQLError( - ('Expected a value of type "{}" but ' + "received: {}").format( - return_type, result - ), - path=path, - ) - return serialized_result - - -def complete_abstract_value( - exe_context, # type: ExecutionContext - return_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Dict[str, Any] - """ - Complete an value of an abstract type by determining the runtime type of that value, then completing based - on that type. - """ - runtime_type = None # type: Union[str, GraphQLObjectType, None] - - # Field type must be Object, Interface or Union and expect sub-selections. - if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): - if return_type.resolve_type: - runtime_type = return_type.resolve_type(result, info) - else: - runtime_type = get_default_resolve_type_fn(result, info, return_type) - - if isinstance(runtime_type, string_types): - runtime_type = info.schema.get_type(runtime_type) # type: ignore - - if not isinstance(runtime_type, GraphQLObjectType): - raise GraphQLError( - ( - "Abstract type {} must resolve to an Object type at runtime " - + 'for field {}.{} with value "{}", received "{}".' - ).format( - return_type, info.parent_type, info.field_name, result, runtime_type - ), - field_asts, - ) - - if not exe_context.schema.is_possible_type(return_type, runtime_type): - raise GraphQLError( - u'Runtime Object type "{}" is not a possible type for "{}".'.format( - runtime_type, return_type - ), - field_asts, - ) - - return complete_object_value( - exe_context, runtime_type, field_asts, info, path, result - ) - - -def get_default_resolve_type_fn( - value, # type: Any - info, # type: ResolveInfo - abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] -): - # type: (...) -> Optional[GraphQLObjectType] - possible_types = info.schema.get_possible_types(abstract_type) - for type in possible_types: - if callable(type.is_type_of) and type.is_type_of(value, info): - return type - return None - - -def complete_object_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLObjectType - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Dict[str, Any] - """ - Complete an Object value by evaluating all sub-selections. - """ - if return_type.is_type_of and not return_type.is_type_of(result, info): - raise GraphQLError( - u'Expected value of type "{}" but got: {}.'.format( - return_type, type(result).__name__ - ), - field_asts, - ) - - # Collect sub-fields to execute to complete this value. - subfield_asts = exe_context.get_sub_fields(return_type, field_asts) - return execute_fields(exe_context, return_type, result, subfield_asts, path, info) - - -def complete_nonnull_value( - exe_context, # type: ExecutionContext - return_type, # type: GraphQLNonNull - field_asts, # type: List[Field] - info, # type: ResolveInfo - path, # type: List[Union[int, str]] - result, # type: Any -): - # type: (...) -> Any - """ - Complete a NonNull value by completing the inner type - """ - completed = complete_value( - exe_context, return_type.of_type, field_asts, info, path, result - ) - if completed is None: - raise GraphQLError( - "Cannot return null for non-nullable field {}.{}.".format( - info.parent_type, info.field_name - ), - field_asts, - path=path, - ) - - return completed diff --git a/graphql/execution/executor_async.py b/graphql/execution/executor_async.py new file mode 100644 index 00000000..970785be --- /dev/null +++ b/graphql/execution/executor_async.py @@ -0,0 +1,64 @@ +import asyncio + +try: + from collections.abc import Iterable +except ImportError: # Python < 3.3 + from collections import Iterable + +from promise import Promise + +from ..type import GraphQLSchema + +from .executors.asyncio import AsyncioExecutor +from .common import ( + prepare_execution_context, + get_promise_executor, + get_on_rejected, + get_on_resolve, +) + + +# Necessary for static type checking +if False: # flake8: noqa + from typing import Any, Optional, Generator + from ..language.ast import Document + + +@asyncio.coroutine +def execute_async( + schema, # type: GraphQLSchema + document_ast, # type: Document + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[str] + executor=None, # type: Any + middleware=None, # type: Optional[Any] + allow_subscriptions=False, # type: bool + **options # type: Any +): + # type: (...) -> Generator + if executor is None: + executor = AsyncioExecutor() + exe_context = prepare_execution_context( + schema, + document_ast, + root, + context, + variables, + operation_name, + executor, + middleware, + allow_subscriptions, + **options + ) + + promise_executor = get_promise_executor(exe_context, root) + on_rejected = get_on_rejected(exe_context) + on_resolve = get_on_resolve(exe_context) + promise = ( + Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) + ) + + yield from exe_context.executor.wait_until_finished_async() + return promise.get() diff --git a/graphql/execution/executors/base.py b/graphql/execution/executors/base.py index dc33f2b5..5af0bb41 100644 --- a/graphql/execution/executors/base.py +++ b/graphql/execution/executors/base.py @@ -1,7 +1,8 @@ -from abc import ABC, abstractmethod +from abc import ABCMeta, abstractmethod +import six -class BaseExecutor(ABC): +class BaseExecutor(six.with_metaclass(ABCMeta)): @abstractmethod def wait_until_finished(self): pass diff --git a/graphql/execution/tests/test_benchmark.py b/graphql/execution/tests/test_benchmark.py index 04b51089..4c50d75d 100644 --- a/graphql/execution/tests/test_benchmark.py +++ b/graphql/execution/tests/test_benchmark.py @@ -59,7 +59,7 @@ def b(): def test_big_list_of_ints_serialize(benchmark): - from ..executor import complete_leaf_value + from ..common import complete_leaf_value @benchmark def serialize(): diff --git a/graphql/execution/tests/test_executor.py b/graphql/execution/tests/test_executor.py index b4298e8a..0c74965a 100644 --- a/graphql/execution/tests/test_executor.py +++ b/graphql/execution/tests/test_executor.py @@ -562,7 +562,7 @@ def test_fails_to_execute_a_query_containing_a_type_definition(): def test_exceptions_are_reraised_if_specified(mocker): # type: (MockFixture) -> None - logger = mocker.patch("graphql.execution.executor.logger") + logger = mocker.patch("graphql.execution.common.logger") query = parse( """ diff --git a/graphql/execution/tests/test_executor_asyncio.py b/graphql/execution/tests/test_executor_asyncio.py index 8dbc6a3e..714f59ee 100644 --- a/graphql/execution/tests/test_executor_asyncio.py +++ b/graphql/execution/tests/test_executor_asyncio.py @@ -9,7 +9,7 @@ asyncio = pytest.importorskip("asyncio") from graphql.error import format_error -from graphql.execution import execute, execute_async +from graphql.execution import execute from graphql.language.parser import parse from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString @@ -19,17 +19,15 @@ def test_asyncio_executor(): # type: () -> None - - @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.001) + asyncio.sleep(0.001) return "hey" @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.003) + asyncio.sleep(0.003) return "hey2" def resolver_3(contest, *_): @@ -51,58 +49,19 @@ def resolver_3(contest, *_): assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} -@pytest.mark.asyncio -@asyncio.coroutine -def test_asyncio_executor_exc_async(): - # type: () -> None - - @asyncio.coroutine - def resolver(context, *_): - # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.001) - return "hey" - - @asyncio.coroutine - def resolver_2(context, *_): - # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.003) - return "hey2" - - def resolver_3(contest, *_): - # type: (Optional[Any], *ResolveInfo) -> str - return "hey3" - - Type = GraphQLObjectType( - "Type", - { - "a": GraphQLField(GraphQLString, resolver=resolver), - "b": GraphQLField(GraphQLString, resolver=resolver_2), - "c": GraphQLField(GraphQLString, resolver=resolver_3), - }, - ) - - ast = parse("{ a b c }") - result = yield from execute_async( - GraphQLSchema(Type), ast, executor=AsyncioExecutor() - ) - assert not result.errors - assert result.data == {"a": "hey", "b": "hey2", "c": "hey3"} - - def test_asyncio_executor_custom_loop(): # type: () -> None loop = asyncio.get_event_loop() - @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.001, loop=loop) + asyncio.sleep(0.001, loop=loop) return "hey" @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.003, loop=loop) + asyncio.sleep(0.003, loop=loop) return "hey2" def resolver_3(contest, *_): @@ -128,16 +87,14 @@ def test_asyncio_executor_with_error(): # type: () -> None ast = parse("query Example { a, b }") - @asyncio.coroutine def resolver(context, *_): # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.001) + asyncio.sleep(0.001) return "hey" - @asyncio.coroutine def resolver_2(context, *_): # type: (Optional[Any], *ResolveInfo) -> NoReturn - yield from asyncio.sleep(0.003) + asyncio.sleep(0.003) raise Exception("resolver_2 failed!") Type = GraphQLObjectType( @@ -160,46 +117,6 @@ def resolver_2(context, *_): assert result.data == {"a": "hey", "b": None} -@pytest.mark.asyncio -@asyncio.coroutine -def test_asyncio_executor_with_error_exc_async(): - # type: () -> None - ast = parse("query Example { a, b }") - - @asyncio.coroutine - def resolver(context, *_): - # type: (Optional[Any], *ResolveInfo) -> str - yield from asyncio.sleep(0.001) - return "hey" - - @asyncio.coroutine - def resolver_2(context, *_): - # type: (Optional[Any], *ResolveInfo) -> NoReturn - yield from asyncio.sleep(0.003) - raise Exception("resolver_2 failed!") - - Type = GraphQLObjectType( - "Type", - { - "a": GraphQLField(GraphQLString, resolver=resolver), - "b": GraphQLField(GraphQLString, resolver=resolver_2), - }, - ) - - result = yield from execute_async( - GraphQLSchema(Type), ast, executor=AsyncioExecutor() - ) - formatted_errors = list(map(format_error, result.errors)) - assert formatted_errors == [ - { - "locations": [{"line": 1, "column": 20}], - "path": ["b"], - "message": "resolver_2 failed!", - } - ] - assert result.data == {"a": "hey", "b": None} - - def test_evaluates_mutations_serially(): # type: () -> None assert_evaluate_mutations_serially(executor=AsyncioExecutor()) diff --git a/graphql/graphql.py b/graphql/graphql.py index fe7c57dd..2228802e 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -1,5 +1,3 @@ -import asyncio - from .execution import ExecutionResult from .backend import get_default_backend @@ -46,13 +44,6 @@ def graphql(*args, **kwargs): return execute_graphql(*args, **kwargs) -@asyncio.coroutine -def graphql_async(*args, **kwargs): - # type: (*Any, **Any) -> Generator - result = yield from execute_graphql_async(*args, **kwargs) - return result - - def execute_graphql( schema, # type: GraphQLSchema request_string="", # type: Union[Document, str] @@ -81,38 +72,6 @@ def execute_graphql( except Exception as e: return ExecutionResult(errors=[e], invalid=True) - -@asyncio.coroutine -def execute_graphql_async( - schema, # type: GraphQLSchema - request_string="", # type: Union[Document, str] - root=None, # type: Any - context=None, # type: Optional[Any] - variables=None, # type: Optional[Any] - operation_name=None, # type: Optional[Any] - middleware=None, # type: Optional[Any] - backend=None, # type: Optional[Any] - **execute_options # type: Any -): - # type: (...) -> Generator - try: - if backend is None: - backend = get_default_backend() - - document = backend.document_from_string_async(schema, request_string) - result = yield from document.execute( - root=root, - context=context, - operation_name=operation_name, - variables=variables, - middleware=middleware, - **execute_options - ) - return result - except Exception as e: - return ExecutionResult(errors=[e], invalid=True) - - @promisify def execute_graphql_as_promise(*args, **kwargs): return execute_graphql(*args, **kwargs) diff --git a/graphql/graphql_async.py b/graphql/graphql_async.py new file mode 100644 index 00000000..f42680c4 --- /dev/null +++ b/graphql/graphql_async.py @@ -0,0 +1,51 @@ +try: + import asyncio +except ImportError: + asyncio = None + +from .execution import ExecutionResult +from .backend import get_default_backend + +# Necessary for static type checking +if False: # flake8: noqa + from typing import Any, Union, Optional, Generator + from .language.ast import Document + from .type.schema import GraphQLSchema + + +@asyncio.coroutine +def graphql_async(*args, **kwargs): + # type: (*Any, **Any) -> Generator + result = yield from execute_graphql_async(*args, **kwargs) + return result + + +@asyncio.coroutine +def execute_graphql_async( + schema, # type: GraphQLSchema + request_string="", # type: Union[Document, str] + root=None, # type: Any + context=None, # type: Optional[Any] + variables=None, # type: Optional[Any] + operation_name=None, # type: Optional[Any] + middleware=None, # type: Optional[Any] + backend=None, # type: Optional[Any] + **execute_options # type: Any +): + # type: (...) -> Generator + try: + if backend is None: + backend = get_default_backend() + + document = backend.document_from_string_async(schema, request_string) + result = yield from document.execute( + root=root, + context=context, + operation_name=operation_name, + variables=variables, + middleware=middleware, + **execute_options + ) + return result + except Exception as e: + return ExecutionResult(errors=[e], invalid=True) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 2d72b224..5e3c2cd0 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,47 +1,10 @@ -# type: ignore -# flake8: noqa - -import pytest - -asyncio = pytest.importorskip("asyncio") - -from graphql import graphql, graphql_async +from graphql import graphql from graphql.error import format_error from .starwars_schema import StarWarsSchema -@pytest.fixture(params=["sync", "async"]) -def execute_graphql(request): - @asyncio.coroutine - def _execute(schema, query, variable_values=None): - if request.param == "sync": - return graphql(schema, query, variable_values=variable_values) - else: - result = yield from graphql_async( - schema, query, variable_values=variable_values - ) - return result - - return _execute - - -@pytest.fixture -def execute_and_validate_result(execute_graphql): - @asyncio.coroutine - def _execute_and_validate(schema, query, expected, variable_values=None): - result = yield from execute_graphql( - schema, query, variable_values=variable_values - ) - assert not result.errors - assert result.data == expected - - return _execute_and_validate - - -@pytest.mark.asyncio -@asyncio.coroutine -def test_hero_name_query(execute_and_validate_result): +def test_hero_name_query(): query = """ query HeroNameQuery { hero { @@ -50,12 +13,12 @@ def test_hero_name_query(execute_and_validate_result): } """ expected = {"hero": {"name": "R2-D2"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_hero_name_and_friends_query(execute_and_validate_result): +def test_hero_name_and_friends_query(): query = """ query HeroNameAndFriendsQuery { hero { @@ -78,12 +41,12 @@ def test_hero_name_and_friends_query(execute_and_validate_result): ], } } - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_nested_query(execute_and_validate_result): +def test_nested_query(): query = """ query NestedQuery { hero { @@ -134,12 +97,12 @@ def test_nested_query(execute_and_validate_result): ], } } - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_fetch_luke_query(execute_and_validate_result): +def test_fetch_luke_query(): query = """ query FetchLukeQuery { human(id: "1000") { @@ -148,12 +111,12 @@ def test_fetch_luke_query(execute_and_validate_result): } """ expected = {"human": {"name": "Luke Skywalker"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_fetch_some_id_query(execute_and_validate_result): +def test_fetch_some_id_query(): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -163,14 +126,12 @@ def test_fetch_some_id_query(execute_and_validate_result): """ params = {"someId": "1000"} expected = {"human": {"name": "Luke Skywalker"}} - yield from execute_and_validate_result( - StarWarsSchema, query, expected, variable_values=params - ) + result = graphql(StarWarsSchema, query, variable_values=params) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_fetch_some_id_query2(execute_and_validate_result): +def test_fetch_some_id_query2(): query = """ query FetchSomeIDQuery($someId: String!) { human(id: $someId) { @@ -180,14 +141,12 @@ def test_fetch_some_id_query2(execute_and_validate_result): """ params = {"someId": "1002"} expected = {"human": {"name": "Han Solo"}} - yield from execute_and_validate_result( - StarWarsSchema, query, expected, variable_values=params - ) + result = graphql(StarWarsSchema, query, variable_values=params) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_invalid_id_query(execute_and_validate_result): +def test_invalid_id_query(): query = """ query humanQuery($id: String!) { human(id: $id) { @@ -197,14 +156,12 @@ def test_invalid_id_query(execute_and_validate_result): """ params = {"id": "not a valid id"} expected = {"human": None} - yield from execute_and_validate_result( - StarWarsSchema, query, expected, variable_values=params - ) + result = graphql(StarWarsSchema, query, variable_values=params) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_fetch_luke_aliased(execute_and_validate_result): +def test_fetch_luke_aliased(): query = """ query FetchLukeAliased { luke: human(id: "1000") { @@ -213,12 +170,12 @@ def test_fetch_luke_aliased(execute_and_validate_result): } """ expected = {"luke": {"name": "Luke Skywalker"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_fetch_luke_and_leia_aliased(execute_and_validate_result): +def test_fetch_luke_and_leia_aliased(): query = """ query FetchLukeAndLeiaAliased { luke: human(id: "1000") { @@ -230,12 +187,12 @@ def test_fetch_luke_and_leia_aliased(execute_and_validate_result): } """ expected = {"luke": {"name": "Luke Skywalker"}, "leia": {"name": "Leia Organa"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_duplicate_fields(execute_and_validate_result): +def test_duplicate_fields(): query = """ query DuplicateFields { luke: human(id: "1000") { @@ -252,12 +209,12 @@ def test_duplicate_fields(execute_and_validate_result): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_use_fragment(execute_and_validate_result): +def test_use_fragment(): query = """ query UseFragment { luke: human(id: "1000") { @@ -276,12 +233,12 @@ def test_use_fragment(execute_and_validate_result): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_check_type_of_r2(execute_and_validate_result): +def test_check_type_of_r2(): query = """ query CheckTypeOfR2 { hero { @@ -291,12 +248,12 @@ def test_check_type_of_r2(execute_and_validate_result): } """ expected = {"hero": {"__typename": "Droid", "name": "R2-D2"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_check_type_of_luke(execute_and_validate_result): +def test_check_type_of_luke(): query = """ query CheckTypeOfLuke { hero(episode: EMPIRE) { @@ -306,16 +263,16 @@ def test_check_type_of_luke(execute_and_validate_result): } """ expected = {"hero": {"__typename": "Human", "name": "Luke Skywalker"}} - yield from execute_and_validate_result(StarWarsSchema, query, expected) + result = graphql(StarWarsSchema, query) + assert not result.errors + assert result.data == expected -@pytest.mark.asyncio -@asyncio.coroutine -def test_parse_error(execute_graphql): +def test_parse_error(): query = """ qeury """ - result = yield from execute_graphql(StarWarsSchema, query) + result = graphql(StarWarsSchema, query) assert result.invalid formatted_error = format_error(result.errors[0]) assert formatted_error["locations"] == [{"column": 9, "line": 2}] From 03133bdb7089147e2e655d9f2edc6ea83bb5962d Mon Sep 17 00:00:00 2001 From: voith Date: Mon, 12 Aug 2019 16:42:53 +0530 Subject: [PATCH 6/6] fix linting errors --- graphql/__init__.py | 13 ++++++++----- graphql/backend/async_util.py | 13 +++++++++---- graphql/backend/core.py | 4 +++- graphql/execution/__init__.py | 4 +++- graphql/execution/common.py | 3 ++- graphql/execution/executor.py | 12 ++---------- graphql/execution/executor_async.py | 5 ----- graphql/graphql.py | 3 ++- graphql/graphql_async.py | 5 +---- 9 files changed, 30 insertions(+), 32 deletions(-) diff --git a/graphql/__init__.py b/graphql/__init__.py index 63ac0da4..1cdf0864 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -27,11 +27,6 @@ # The primary entry point into fulfilling a GraphQL request. from .graphql import graphql -if sys.version_info > (3, 3): - from .graphql_async import graphql_async -else: - def graphql_async(*args, **kwargs): - raise ImportError('graphql_async needs python>=3.4') # Create and operate on GraphQL type definitions and schema. from .type import ( # no import order @@ -175,6 +170,14 @@ def graphql_async(*args, **kwargs): set_default_backend, ) +if sys.version_info > (3, 3): + from .graphql_async import graphql_async +else: + + def graphql_async(*args, **kwargs): + raise ImportError("graphql_async needs python>=3.4") + + VERSION = (2, 2, 1, "final", 0) __version__ = get_version(VERSION) diff --git a/graphql/backend/async_util.py b/graphql/backend/async_util.py index c30631d1..d27b43fb 100644 --- a/graphql/backend/async_util.py +++ b/graphql/backend/async_util.py @@ -1,11 +1,16 @@ -try: - import asyncio -except ImportError: - asyncio = None +import asyncio from .utils import validate_document_ast from ..execution import execute_async +# Necessary for static type checking +if False: # flake8: noqa + from typing import Any, Union + from ..execution import ExecutionResult + from ..language.ast import Document + from ..type.schema import GraphQLSchema + from rx import Observable + @asyncio.coroutine def execute_and_validate_async( diff --git a/graphql/backend/core.py b/graphql/backend/core.py index 6ce4c277..1314f80c 100644 --- a/graphql/backend/core.py +++ b/graphql/backend/core.py @@ -13,8 +13,10 @@ if sys.version_info > (3, 3): from .async_util import execute_and_validate_async else: + def execute_and_validate_async(*args, **kwargs): - raise ImportError('execute_and_validate_async needs python>=3.4') + raise ImportError("execute_and_validate_async needs python>=3.4") + # Necessary for static type checking if False: # flake8: noqa diff --git a/graphql/execution/__init__.py b/graphql/execution/__init__.py index 180df2e9..88ef44a5 100644 --- a/graphql/execution/__init__.py +++ b/graphql/execution/__init__.py @@ -27,8 +27,10 @@ if sys.version_info > (3, 3): from .executor_async import execute_async else: + def execute_async(*args, **kwargs): - raise ImportError('execute_async needs python>=3.4') + raise ImportError("execute_async needs python>=3.4") + __all__ = [ "execute", diff --git a/graphql/execution/common.py b/graphql/execution/common.py index fa4b382a..2a7093a8 100644 --- a/graphql/execution/common.py +++ b/graphql/execution/common.py @@ -41,7 +41,7 @@ # Necessary for static type checking if False: # flake8: noqa - from typing import Any, Optional, Union, Dict, List, Callable, Generator + from typing import Any, Optional, Union, Dict, List, Callable from ..language.ast import Document, OperationDefinition, Field logger = logging.getLogger(__name__) @@ -138,6 +138,7 @@ def on_resolve(data): return on_resolve + def execute_operation( exe_context, # type: ExecutionContext operation, # type: OperationDefinition diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 960d7075..ea9da349 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -1,19 +1,11 @@ -try: - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable import logging from rx import Observable from promise import Promise -from ..type import ( - GraphQLSchema, -) -from .base import ( - ExecutionResult, -) +from ..type import GraphQLSchema +from .base import ExecutionResult from .executors.sync import SyncExecutor from .common import ( prepare_execution_context, diff --git a/graphql/execution/executor_async.py b/graphql/execution/executor_async.py index 970785be..a392239d 100644 --- a/graphql/execution/executor_async.py +++ b/graphql/execution/executor_async.py @@ -1,10 +1,5 @@ import asyncio -try: - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable - from promise import Promise from ..type import GraphQLSchema diff --git a/graphql/graphql.py b/graphql/graphql.py index 2228802e..89ccf386 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -7,7 +7,7 @@ if False: # flake8: noqa from promise import Promise from rx import Observable - from typing import Any, Union, Optional, Generator + from typing import Any, Union, Optional from .language.ast import Document from .type.schema import GraphQLSchema @@ -72,6 +72,7 @@ def execute_graphql( except Exception as e: return ExecutionResult(errors=[e], invalid=True) + @promisify def execute_graphql_as_promise(*args, **kwargs): return execute_graphql(*args, **kwargs) diff --git a/graphql/graphql_async.py b/graphql/graphql_async.py index f42680c4..1dff9672 100644 --- a/graphql/graphql_async.py +++ b/graphql/graphql_async.py @@ -1,7 +1,4 @@ -try: - import asyncio -except ImportError: - asyncio = None +import asyncio from .execution import ExecutionResult from .backend import get_default_backend