Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing rpc_metadata as part of creating a bidi #7514

Merged
merged 5 commits into from
Mar 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions api_core/google/api_core/bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ class BidiRpc(object):

initial_request = example_pb2.StreamingRpcRequest(
setting='example')
rpc = BidiRpc(stub.StreamingRpc, initial_request=initial_request)
rpc = BidiRpc(
stub.StreamingRpc,
initial_request=initial_request,
metadata=[('name', 'value')]
)

rpc.open()

Expand All @@ -165,11 +169,14 @@ class BidiRpc(object):
Callable[None, protobuf.Message]]): The initial request to
yield. This is useful if an initial request is needed to start the
stream.
metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
the request.
"""

def __init__(self, start_rpc, initial_request=None):
def __init__(self, start_rpc, initial_request=None, metadata=None):
self._start_rpc = start_rpc
self._initial_request = initial_request
self._rpc_metadata = metadata
self._request_queue = queue.Queue()
self._request_generator = None
self._is_active = False
Expand Down Expand Up @@ -200,7 +207,7 @@ def open(self):
request_generator = _RequestQueueGenerator(
self._request_queue, initial_request=self._initial_request
)
call = self._start_rpc(iter(request_generator))
call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata)
crwilcox marked this conversation as resolved.
Show resolved Hide resolved

request_generator.call = call

Expand Down Expand Up @@ -288,10 +295,14 @@ def should_recover(exc):
initial_request = example_pb2.StreamingRpcRequest(
setting='example')

rpc = ResumeableBidiRpc(
metadata = [('header_name', 'value')]

rpc = ResumableBidiRpc(
stub.StreamingRpc,
should_recover=should_recover,
initial_request=initial_request,
should_recover=should_recover)
metadata=metadata
)

rpc.open()

Expand All @@ -310,10 +321,12 @@ def should_recover(exc):
should_recover (Callable[[Exception], bool]): A function that returns
True if the stream should be recovered. This will be called
whenever an error is encountered on the stream.
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
the request.
"""

def __init__(self, start_rpc, should_recover, initial_request=None):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request)
def __init__(self, start_rpc, should_recover, initial_request=None, metadata=None):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
self._should_recover = should_recover
self._operational_lock = threading.RLock()
self._finalized = False
Expand Down
2 changes: 1 addition & 1 deletion api_core/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def lint_setup_py(session):
def pytype(session):
"""Run type-checking."""
session.install(
".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2", "pytype >= 2018.9.26"
".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2", "pytype >= 2019.3.21"
)
session.run("pytype")

Expand Down
12 changes: 11 additions & 1 deletion api_core/tests/unit/test_bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ def make_rpc():
call = mock.create_autospec(_CallAndFuture, instance=True)
rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)

def rpc_side_effect(request):
def rpc_side_effect(request, metadata=None):
call.is_active.return_value = True
call.request = request
call.metadata = metadata
crwilcox marked this conversation as resolved.
Show resolved Hide resolved
return call

rpc.side_effect = rpc_side_effect
Expand Down Expand Up @@ -172,6 +173,15 @@ def test_done_callbacks(self):

callback.assert_called_once_with(mock.sentinel.future)

def test_metadata(self):
rpc, call = make_rpc()
bidi_rpc = bidi.BidiRpc(rpc, metadata=mock.sentinel.A)
assert bidi_rpc._rpc_metadata == mock.sentinel.A

bidi_rpc.open()
assert bidi_rpc.call == call
assert bidi_rpc.call.metadata == mock.sentinel.A

def test_open(self):
rpc, call = make_rpc()
bidi_rpc = bidi.BidiRpc(rpc)
Expand Down
3 changes: 2 additions & 1 deletion firestore/google/cloud/firestore_v1beta1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ def should_recover(exc): # pragma: NO COVER
ResumableBidiRpc = self.ResumableBidiRpc # FBO unit tests

self._rpc = ResumableBidiRpc(
self._api.transport._stubs["firestore_stub"].Listen,
self._api.transport.listen,
initial_request=initial_request,
should_recover=should_recover,
rpc_metadata=self._firestore._rpc_metadata,
)

self._rpc.add_done_callback(self._on_rpc_done)
Expand Down
3 changes: 2 additions & 1 deletion firestore/tests/unit/v1beta1/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,13 @@ def convert_precondition(precond):


class DummyRpc(object): # pragma: NO COVER
def __init__(self, listen, initial_request, should_recover):
def __init__(self, listen, initial_request, should_recover, rpc_metadata=None):
self.listen = listen
self.initial_request = initial_request
self.should_recover = should_recover
self.closed = False
self.callbacks = []
self._rpc_metadata = rpc_metadata

def add_done_callback(self, callback):
self.callbacks.append(callback)
Expand Down
4 changes: 3 additions & 1 deletion firestore/tests/unit/v1beta1/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def _to_protobuf(self):
class DummyFirestore(object):
_firestore_api = DummyFirestoreClient()
_database_string = "abc://bar/"
_rpc_metadata = None

def document(self, *document_path): # pragma: NO COVER
if len(document_path) == 1:
Expand Down Expand Up @@ -781,12 +782,13 @@ def Thread(self, name, target, kwargs):


class DummyRpc(object):
def __init__(self, listen, initial_request, should_recover):
def __init__(self, listen, initial_request, should_recover, rpc_metadata=None):
self.listen = listen
self.initial_request = initial_request
self.should_recover = should_recover
self.closed = False
self.callbacks = []
self._rpc_metadata = rpc_metadata

def add_done_callback(self, callback):
self.callbacks.append(callback)
Expand Down