Skip to content

Commit

Permalink
Add gRPC aio stub and servicer generation (#489)
Browse files Browse the repository at this point in the history
* Add async usage test

* Fix broken mypy on tests dir

* Generate async-compatible stubs and servicers

* Update grpc-stubs to 1.24.12.1 with aio support

* Exclude generated code from Black check

* Use collections.abc instead of typing

* Fix shellcheck in run_test.sh
  • Loading branch information
RobinMcCorkell authored Mar 13, 2023
1 parent 947a7d7 commit ed77525
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
- name: Run formatters and linters
run: |
pip3 install black isort flake8-pyi flake8-noqa flake8-bugbear
black --check .
black --check --extend-exclude '(_pb2_grpc|_pb2).pyi?$' .
isort --check . --diff
flake8 .
- name: run shellcheck
Expand Down
112 changes: 88 additions & 24 deletions mypy_protobuf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,30 +663,77 @@ def _map_key_value_types(

return ktype, vtype

def _callable_type(self, method: d.MethodDescriptorProto) -> str:
def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
module = "grpc.aio" if is_async else "grpc"
if method.client_streaming:
if method.server_streaming:
return self._import("grpc", "StreamStreamMultiCallable")
return self._import(module, "StreamStreamMultiCallable")
else:
return self._import("grpc", "StreamUnaryMultiCallable")
return self._import(module, "StreamUnaryMultiCallable")
else:
if method.server_streaming:
return self._import("grpc", "UnaryStreamMultiCallable")
return self._import(module, "UnaryStreamMultiCallable")
else:
return self._import("grpc", "UnaryUnaryMultiCallable")
return self._import(module, "UnaryUnaryMultiCallable")

def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
def _input_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.input_type)
if use_stream_iterator and method.client_streaming:
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result

def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.input_type)
if method.client_streaming:
# See write_grpc_async_hacks().
result = f"_MaybeAsyncIterator[{result}]"
return result

def _output_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.output_type)
if use_stream_iterator and method.server_streaming:
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result

def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.output_type)
if method.server_streaming:
# Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
# So both can be used in the covariant function return position.
iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]"
aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]"
result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
else:
# Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
# So both can be used in the covariant function return position.
# Awaitable[Resp] is equivalent to async def.
awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]"
result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
return result

def write_grpc_async_hacks(self) -> None:
wl = self._write_line
# _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
# So both can be used in the contravariant function parameter position.
wl("_T = {}('_T')", self._import("typing", "TypeVar"))
wl("")
wl(
"class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):",
self._import("collections.abc", "AsyncIterator"),
self._import("collections.abc", "Iterator"),
self._import("abc", "ABCMeta"),
)
with self._indent():
wl("...")
wl("")

# _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
# So both can be used in the contravariant function parameter position.
wl(
"class _ServicerContext({}, {}): # type: ignore",
self._import("grpc", "ServicerContext"),
self._import("grpc.aio", "ServicerContext"),
)
with self._indent():
wl("...")
wl("")

def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
Expand All @@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
with self._indent():
wl("self,")
input_name = "request_iterator" if method.client_streaming else "request"
input_type = self._input_type(method)
input_type = self._servicer_input_type(method)
wl(f"{input_name}: {input_type},")
wl("context: {},", self._import("grpc", "ServicerContext"))
wl("context: _ServicerContext,")
wl(
") -> {}:{}",
self._output_type(method),
self._servicer_output_type(method),
" ..." if not self._has_comments(scl) else "",
)
if self._has_comments(scl):
with self._indent():
if not self._write_comments(scl):
wl("...")

def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
Expand All @@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
for i, method in methods:
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]

wl("{}: {}[", method.name, self._callable_type(method))
wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method, False))
wl("{},", self._output_type(method, False))
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("]")
self._write_comments(scl)

Expand All @@ -743,17 +790,31 @@ def write_grpc_services(
scl = scl_prefix + [i]

# The stub client
wl(f"class {service.name}Stub:")
wl(
"class {}Stub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
wl(
"def __init__(self, channel: {}) -> None: ...",
self._import("grpc", "Channel"),
)
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
wl("def __init__(self, channel: {}) -> None: ...", channel)
self.write_grpc_stub_methods(service, scl)
wl("")

# The (fake) async stub client
wl(
"class {}AsyncStub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
self.write_grpc_stub_methods(service, scl, is_async=True)
wl("")

# The service definition interface
wl(
"class {}Servicer(metaclass={}):",
Expand All @@ -765,11 +826,13 @@ def write_grpc_services(
wl("")
self.write_grpc_methods(service, scl)
wl("")
server = self._import("grpc", "Server")
aserver = self._import("grpc.aio", "Server")
wl(
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
service.name,
service.name,
self._import("grpc", "Server"),
f"{self._import('typing', 'Union')}[{server}, {aserver}]",
)
wl("")

Expand Down Expand Up @@ -960,6 +1023,7 @@ def generate_mypy_grpc_stubs(
relax_strict_optional_primitives,
grpc=True,
)
pkg_writer.write_grpc_async_hacks()
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])

assert name == fd.name
Expand Down
42 changes: 21 additions & 21 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ RED="\033[0;31m"
NC='\033[0m'

PY_VER_MYPY_PROTOBUF=${PY_VER_MYPY_PROTOBUF:=3.10.6}
PY_VER_MYPY_PROTOBUF_SHORT=$(echo $PY_VER_MYPY_PROTOBUF | cut -d. -f1-2)
PY_VER_MYPY_PROTOBUF_SHORT=$(echo "$PY_VER_MYPY_PROTOBUF" | cut -d. -f1-2)
PY_VER_MYPY=${PY_VER_MYPY:=3.8.13}
PY_VER_UNIT_TESTS="${PY_VER_UNIT_TESTS:=3.8.13}"

Expand Down Expand Up @@ -45,16 +45,16 @@ MYPY_VENV=venv_$PY_VER_MYPY
(
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER_MYPY
pyenv shell "$PY_VER_MYPY"

if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_VENV ]]; then
python3 --version
python3 -m pip --version
python -m pip install virtualenv
python3 -m virtualenv $MYPY_VENV
$MYPY_VENV/bin/python3 -m pip install -r mypy_requirements.txt
python3 -m virtualenv "$MYPY_VENV"
"$MYPY_VENV"/bin/python3 -m pip install -r mypy_requirements.txt
fi
$MYPY_VENV/bin/mypy --version
"$MYPY_VENV"/bin/mypy --version
)

# Create unit tests venvs
Expand All @@ -63,14 +63,14 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
UNIT_TESTS_VENV=venv_$PY_VER
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER
pyenv shell "$PY_VER"

if [[ -z $SKIP_CLEAN ]] || [[ ! -e $UNIT_TESTS_VENV ]]; then
python -m pip install virtualenv
python -m virtualenv $UNIT_TESTS_VENV
$UNIT_TESTS_VENV/bin/python -m pip install -r test_requirements.txt
python -m virtualenv "$UNIT_TESTS_VENV"
"$UNIT_TESTS_VENV"/bin/python -m pip install -r test_requirements.txt
fi
$UNIT_TESTS_VENV/bin/py.test --version
"$UNIT_TESTS_VENV"/bin/py.test --version
)
done

Expand All @@ -79,19 +79,19 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF
(
eval "$(pyenv init --path)"
eval "$(pyenv init -)"
pyenv shell $PY_VER_MYPY_PROTOBUF
pyenv shell "$PY_VER_MYPY_PROTOBUF"

# Create virtualenv + Install requirements for mypy-protobuf
if [[ -z $SKIP_CLEAN ]] || [[ ! -e $MYPY_PROTOBUF_VENV ]]; then
python -m pip install virtualenv
python -m virtualenv $MYPY_PROTOBUF_VENV
$MYPY_PROTOBUF_VENV/bin/python -m pip install -e .
python -m virtualenv "$MYPY_PROTOBUF_VENV"
"$MYPY_PROTOBUF_VENV"/bin/python -m pip install -e .
fi
)

# Run mypy-protobuf
(
source $MYPY_PROTOBUF_VENV/bin/activate
source "$MYPY_PROTOBUF_VENV"/bin/activate

# Confirm version number
test "$(protoc-gen-mypy -V)" = "mypy-protobuf 3.4.0"
Expand Down Expand Up @@ -138,22 +138,22 @@ MYPY_PROTOBUF_VENV=venv_$PY_VER_MYPY_PROTOBUF

for PY_VER in $PY_VER_UNIT_TESTS; do
UNIT_TESTS_VENV=venv_$PY_VER
PY_VER_MYPY_TARGET=$(echo $PY_VER | cut -d. -f1-2)
PY_VER_MYPY_TARGET=$(echo "$PY_VER" | cut -d. -f1-2)

# Generate GRPC protos for mypy / tests
(
source $UNIT_TESTS_VENV/bin/activate
source "$UNIT_TESTS_VENV"/bin/activate
find proto/testproto/grpc -name "*.proto" -print0 | xargs -0 python -m grpc_tools.protoc "${PROTOC_ARGS[@]}" --grpc_python_out=test/generated
)

# Run mypy on unit tests / generated output
(
source $MYPY_VENV/bin/activate
source "$MYPY_VENV"/bin/activate
export MYPYPATH=$MYPYPATH:test/generated

# Run mypy
MODULES=( "-m" "test" )
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable=$UNIT_TESTS_VENV/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"
MODULES=( -m test.test_generated_mypy -m test.test_grpc_usage -m test.test_grpc_async_usage )
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="$UNIT_TESTS_VENV"/bin/python --python-version="$PY_VER_MYPY_TARGET" "${MODULES[@]}"

# Run stubtest. Stubtest does not work with python impl - only cpp impl
API_IMPL="$(python3 -c "import google.protobuf.internal.api_implementation as a ; print(a.Type())")"
Expand All @@ -173,12 +173,12 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
}

call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
if ! diff "$MYPY_OUTPUT/mypy_output" "test_negative/output.expected.$PY_VER_MYPY_TARGET" || ! diff "$MYPY_OUTPUT/mypy_output.omit_linenos" "test_negative/output.expected.$PY_VER_MYPY_TARGET.omit_linenos"; then
echo -e "${RED}test_negative/output.expected.$PY_VER_MYPY_TARGET didnt match. Copying over for you. Now rerun${NC}"

# Copy over all the mypy results for the developer.
call_mypy $PY_VER "${NEGATIVE_MODULES[@]}"
call_mypy "$PY_VER" "${NEGATIVE_MODULES[@]}"
cp "$MYPY_OUTPUT/mypy_output" test_negative/output.expected.3.8
cp "$MYPY_OUTPUT/mypy_output.omit_linenos" test_negative/output.expected.3.8.omit_linenos
exit 1
Expand All @@ -187,7 +187,7 @@ for PY_VER in $PY_VER_UNIT_TESTS; do

(
# Run unit tests.
source $UNIT_TESTS_VENV/bin/activate
source "$UNIT_TESTS_VENV"/bin/activate
PYTHONPATH=test/generated py.test --ignore=test/generated -v
)
done
4 changes: 4 additions & 0 deletions stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ testproto.readme_enum_pb2._?MyEnum(EnumTypeWrapper)?
testproto.nested.nested_pb2.AnotherNested._?NestedEnum(EnumTypeWrapper)?
testproto.nested.nested_pb2.AnotherNested.NestedMessage._?NestedEnum2(EnumTypeWrapper)?

# Our fake async stubs are not there at runtime (yet)
testproto.grpc.dummy_pb2_grpc.DummyServiceAsyncStub
testproto.grpc.import_pb2_grpc.SimpleServiceAsyncStub

# Part of an "EXPERIMENTAL API" according to comment. Not documented.
testproto.grpc.dummy_pb2_grpc.DummyService
testproto.grpc.import_pb2_grpc.SimpleService
Expand Down
Loading

0 comments on commit ed77525

Please sign in to comment.