Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion api_core/google/api_core/bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def pending_requests(self):
return self._request_queue.qsize()


def _never_terminate(future_or_error):
"""By default, no errors cause BiDi termination."""
return False


class ResumableBidiRpc(BidiRpc):
"""A :class:`BidiRpc` that can automatically resume the stream on errors.

Expand Down Expand Up @@ -391,6 +396,9 @@ def should_recover(exc):
should_recover (Callable[[Exception], bool]): A function that returns
True if the stream should be recovered. This will be called
whenever an error is encountered on the stream.
should_terminate (Callable[[Exception], bool]): A function that returns
True if the stream should be terminated. This will be called
whenever an error is encountered on the stream.
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
the request.
throttle_reopen (bool): If ``True``, throttling will be applied to
Expand All @@ -401,12 +409,14 @@ def __init__(
self,
start_rpc,
should_recover,
should_terminate=_never_terminate,
initial_request=None,
metadata=None,
throttle_reopen=False,
):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
self._should_recover = should_recover
self._should_terminate = should_terminate
self._operational_lock = threading.RLock()
self._finalized = False
self._finalize_lock = threading.Lock()
Expand All @@ -433,7 +443,9 @@ def _on_call_done(self, future):
# error, not for errors that we can recover from. Note that grpc's
# "future" here is also a grpc.RpcError.
with self._operational_lock:
if not self._should_recover(future):
if self._should_terminate(future):
self._finalize(future)
elif not self._should_recover(future):
self._finalize(future)
else:
_LOGGER.debug("Re-opening stream from gRPC callback.")
Expand Down Expand Up @@ -496,6 +508,12 @@ def _recoverable(self, method, *args, **kwargs):
with self._operational_lock:
_LOGGER.debug("Call to retryable %r caused %s.", method, exc)

if self._should_terminate(exc):
self.close()
_LOGGER.debug("Terminating %r due to %s.", method, exc)
self._finalize(exc)
break

if not self._should_recover(exc):
self.close()
_LOGGER.debug("Not retrying %r due to %s.", method, exc)
Expand Down
110 changes: 104 additions & 6 deletions api_core/tests/unit/test_bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,33 +370,111 @@ def cancel(self):


class TestResumableBidiRpc(object):
def test_initial_state(self):
callback = mock.Mock()
callback.return_value = True
bidi_rpc = bidi.ResumableBidiRpc(None, callback)
def test_ctor_defaults(self):
start_rpc = mock.Mock()
should_recover = mock.Mock()
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)

assert bidi_rpc.is_active is False
assert bidi_rpc._finalized is False
assert bidi_rpc._start_rpc is start_rpc
assert bidi_rpc._should_recover is should_recover
assert bidi_rpc._should_terminate is bidi._never_terminate
assert bidi_rpc._initial_request is None
assert bidi_rpc._rpc_metadata is None
assert bidi_rpc._reopen_throttle is None

def test_ctor_explicit(self):
start_rpc = mock.Mock()
should_recover = mock.Mock()
should_terminate = mock.Mock()
initial_request = mock.Mock()
metadata = {"x-foo": "bar"}
bidi_rpc = bidi.ResumableBidiRpc(
start_rpc,
should_recover,
should_terminate=should_terminate,
initial_request=initial_request,
metadata=metadata,
throttle_reopen=True,
)

assert bidi_rpc.is_active is False
assert bidi_rpc._finalized is False
assert bidi_rpc._should_recover is should_recover
assert bidi_rpc._should_terminate is should_terminate
assert bidi_rpc._initial_request is initial_request
assert bidi_rpc._rpc_metadata == metadata
assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)

def test_done_callbacks_terminate(self):
cancellation = mock.Mock()
start_rpc = mock.Mock()
should_recover = mock.Mock(spec=["__call__"], return_value=True)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(
start_rpc, should_recover, should_terminate=should_terminate
)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(cancellation)

should_terminate.assert_called_once_with(cancellation)
should_recover.assert_not_called()
callback.assert_called_once_with(cancellation)
assert not bidi_rpc.is_active

def test_done_callbacks_recoverable(self):
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
should_recover = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(mock.sentinel.future)

callback.assert_not_called()
start_rpc.assert_called_once()
should_recover.assert_called_once_with(mock.sentinel.future)
assert bidi_rpc.is_active

def test_done_callbacks_non_recoverable(self):
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])

bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(mock.sentinel.future)

callback.assert_called_once_with(mock.sentinel.future)
should_recover.assert_called_once_with(mock.sentinel.future)
assert not bidi_rpc.is_active

def test_send_terminate(self):
cancellation = ValueError()
call_1 = CallStub([cancellation], active=False)
call_2 = CallStub([])
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)

bidi_rpc.open()

bidi_rpc.send(mock.sentinel.request)

assert bidi_rpc.pending_requests == 1
assert bidi_rpc._request_queue.get() is None

should_recover.assert_not_called()
should_terminate.assert_called_once_with(cancellation)
assert bidi_rpc.call == call_1
assert bidi_rpc.is_active is False
assert call_1.cancelled is True

def test_send_recover(self):
error = ValueError()
Expand Down Expand Up @@ -441,6 +519,26 @@ def test_send_failure(self):
assert bidi_rpc.pending_requests == 1
assert bidi_rpc._request_queue.get() is None

def test_recv_terminate(self):
cancellation = ValueError()
call = CallStub([cancellation])
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable, instance=True, return_value=call
)
should_recover = mock.Mock(spec=["__call__"], return_value=False)
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)

bidi_rpc.open()

bidi_rpc.recv()

should_recover.assert_not_called()
should_terminate.assert_called_once_with(cancellation)
assert bidi_rpc.call == call
assert bidi_rpc.is_active is False
assert call.cancelled is True

def test_recv_recover(self):
error = ValueError()
call_1 = CallStub([1, error])
Expand Down
28 changes: 14 additions & 14 deletions firestore/google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,8 @@
"DO_NOT_USE": -1,
}
_RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated"
_RETRYABLE_STREAM_ERRORS = (
exceptions.DeadlineExceeded,
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.Unknown,
exceptions.GatewayTimeout,
)
_RECOVERABLE_STREAM_EXCEPTIONS = (exceptions.ServiceUnavailable,)
_TERMINATING_STREAM_EXCEPTIONS = (exceptions.Cancelled,)

DocTreeEntry = collections.namedtuple("DocTreeEntry", ["value", "index"])

Expand Down Expand Up @@ -153,6 +148,16 @@ def document_watch_comparator(doc1, doc2):
return 0


def _should_recover(exception):
wrapped = _maybe_wrap_exception(exception)
return isinstance(wrapped, _RECOVERABLE_STREAM_EXCEPTIONS)


def _should_terminate(exception):
wrapped = _maybe_wrap_exception(exception)
return isinstance(wrapped, _TERMINATING_STREAM_EXCEPTIONS)


class Watch(object):

BackgroundConsumer = BackgroundConsumer # FBO unit tests
Expand Down Expand Up @@ -199,12 +204,6 @@ def __init__(
self._closing = threading.Lock()
self._closed = False

def should_recover(exc): # pragma: NO COVER
return (
isinstance(exc, grpc.RpcError)
and exc.code() == grpc.StatusCode.UNAVAILABLE
)

initial_request = firestore_pb2.ListenRequest(
database=self._firestore._database_string, add_target=self._targets
)
Expand All @@ -214,8 +213,9 @@ def should_recover(exc): # pragma: NO COVER

self._rpc = ResumableBidiRpc(
self._api.transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=initial_request,
should_recover=should_recover,
metadata=self._firestore._rpc_metadata,
)

Expand Down
10 changes: 9 additions & 1 deletion firestore/tests/unit/v1/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,18 @@ def convert_precondition(precond):


class DummyRpc(object): # pragma: NO COVER
def __init__(self, listen, initial_request, should_recover, metadata=None):
def __init__(
self,
listen,
should_recover,
should_terminate=None,
initial_request=None,
metadata=None,
):
self.listen = listen
self.initial_request = initial_request
self.should_recover = should_recover
self.should_terminate = should_terminate
self.closed = False
self.callbacks = []
self._metadata = metadata
Expand Down
67 changes: 61 additions & 6 deletions firestore/tests/unit/v1/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,44 @@ def test_diff_doc(self):
self.assertRaises(AssertionError, self._callFUT, 1, 2)


class Test_should_recover(unittest.TestCase):
def _callFUT(self, exception):
from google.cloud.firestore_v1.watch import _should_recover

return _should_recover(exception)

def test_w_unavailable(self):
from google.api_core.exceptions import ServiceUnavailable

exception = ServiceUnavailable("testing")

self.assertTrue(self._callFUT(exception))

def test_w_non_recoverable(self):
exception = ValueError("testing")

self.assertFalse(self._callFUT(exception))


class Test_should_terminate(unittest.TestCase):
def _callFUT(self, exception):
from google.cloud.firestore_v1.watch import _should_terminate

return _should_terminate(exception)

def test_w_unavailable(self):
from google.api_core.exceptions import Cancelled

exception = Cancelled("testing")

self.assertTrue(self._callFUT(exception))

def test_w_non_recoverable(self):
exception = ValueError("testing")

self.assertFalse(self._callFUT(exception))


class TestWatch(unittest.TestCase):
def _makeOne(
self,
Expand Down Expand Up @@ -161,17 +199,26 @@ def _snapshot_callback(self, docs, changes, read_time):
self.snapshotted = (docs, changes, read_time)

def test_ctor(self):
from google.cloud.firestore_v1.proto import firestore_pb2
from google.cloud.firestore_v1.watch import _should_recover
from google.cloud.firestore_v1.watch import _should_terminate

inst = self._makeOne()
self.assertTrue(inst._consumer.started)
self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done])
self.assertIs(inst._rpc.start_rpc, inst._api.transport.listen)
self.assertIs(inst._rpc.should_recover, _should_recover)
self.assertIs(inst._rpc.should_terminate, _should_terminate)
self.assertIsInstance(inst._rpc.initial_request, firestore_pb2.ListenRequest)
self.assertEqual(inst._rpc.metadata, DummyFirestore._rpc_metadata)

def test__on_rpc_done(self):
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME

inst = self._makeOne()
threading = DummyThreading()
with mock.patch("google.cloud.firestore_v1.watch.threading", threading):
inst._on_rpc_done(True)
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME

self.assertTrue(threading.threads[_RPC_ERROR_THREAD_NAME].started)

def test_close(self):
Expand Down Expand Up @@ -835,13 +882,21 @@ def Thread(self, name, target, kwargs):


class DummyRpc(object):
def __init__(self, listen, initial_request, should_recover, metadata=None):
self.listen = listen
self.initial_request = initial_request
def __init__(
self,
start_rpc,
should_recover,
should_terminate=None,
initial_request=None,
metadata=None,
):
self.start_rpc = start_rpc
self.should_recover = should_recover
self.should_terminate = should_terminate
self.initial_request = initial_request
self.metadata = metadata
self.closed = False
self.callbacks = []
self._metadata = metadata

def add_done_callback(self, callback):
self.callbacks.append(callback)
Expand Down