diff --git a/gcloud/bigtable/row_data.py b/gcloud/bigtable/row_data.py index 66044244442d..e64a242f8507 100644 --- a/gcloud/bigtable/row_data.py +++ b/gcloud/bigtable/row_data.py @@ -279,3 +279,45 @@ def rows(self): # NOTE: To avoid duplicating large objects, this is just the # mutable private data. return self._rows + + def cancel(self): + """Cancels the iterator, closing the stream.""" + self._response_iterator.cancel() + + def consume_next(self): + """Consumes the next ``ReadRowsResponse`` from the stream. + + Parses the response and stores it as a :class:`PartialRowData` + in a dictionary owned by this object. + + :raises: :class:`StopIteration ` if the + response iterator has no more responses to stream. + """ + read_rows_response = self._response_iterator.next() + row_key = read_rows_response.row_key + partial_row = self._rows.get(row_key) + if partial_row is None: + partial_row = self._rows[row_key] = PartialRowData(row_key) + # NOTE: This is not atomic in the case of failures. + partial_row.update_from_read_rows(read_rows_response) + + def consume_all(self, max_loops=None): + """Consume the streamed responses until there are no more. + + This simply calls :meth:`consume_next` until there are no + more to consume. + + :type max_loops: int + :param max_loops: (Optional) Maximum number of times to try to consume + an additional ``ReadRowsResponse``. You can use this + to avoid long wait times. + """ + curr_loop = 0 + if max_loops is None: + max_loops = float('inf') + while curr_loop < max_loops: + curr_loop += 1 + try: + self.consume_next() + except StopIteration: + break diff --git a/gcloud/bigtable/test_row_data.py b/gcloud/bigtable/test_row_data.py index 62698d19d5d0..56b1c15f0655 100644 --- a/gcloud/bigtable/test_row_data.py +++ b/gcloud/bigtable/test_row_data.py @@ -386,6 +386,22 @@ def _getTargetClass(self): from gcloud.bigtable.row_data import PartialRowsData return PartialRowsData + def _getDoNothingClass(self): + klass = self._getTargetClass() + + class FakePartialRowsData(klass): + + def __init__(self, *args, **kwargs): + super(FakePartialRowsData, self).__init__(*args, **kwargs) + self._consumed = [] + + def consume_next(self): + value = self._response_iterator.next() + self._consumed.append(value) + return value + + return FakePartialRowsData + def _makeOne(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) @@ -425,3 +441,84 @@ def test_rows_getter(self): partial_rows_data = self._makeOne(None) partial_rows_data._rows = value = object() self.assertTrue(partial_rows_data.rows is value) + + def test_cancel(self): + response_iterator = _MockCancellableIterator() + partial_rows_data = self._makeOne(response_iterator) + self.assertEqual(response_iterator.cancel_calls, 0) + partial_rows_data.cancel() + self.assertEqual(response_iterator.cancel_calls, 1) + + def test_consume_next(self): + from gcloud.bigtable._generated import ( + bigtable_service_messages_pb2 as messages_pb2) + from gcloud.bigtable.row_data import PartialRowData + + row_key = b'row-key' + value_pb = messages_pb2.ReadRowsResponse(row_key=row_key) + response_iterator = _MockCancellableIterator(value_pb) + partial_rows_data = self._makeOne(response_iterator) + self.assertEqual(partial_rows_data.rows, {}) + partial_rows_data.consume_next() + expected_rows = {row_key: PartialRowData(row_key)} + self.assertEqual(partial_rows_data.rows, expected_rows) + + def test_consume_next_row_exists(self): + from gcloud.bigtable._generated import ( + bigtable_service_messages_pb2 as messages_pb2) + from gcloud.bigtable.row_data import PartialRowData + + row_key = b'row-key' + chunk = messages_pb2.ReadRowsResponse.Chunk(commit_row=True) + value_pb = messages_pb2.ReadRowsResponse(row_key=row_key, + chunks=[chunk]) + response_iterator = _MockCancellableIterator(value_pb) + partial_rows_data = self._makeOne(response_iterator) + existing_values = PartialRowData(row_key) + partial_rows_data._rows[row_key] = existing_values + self.assertFalse(existing_values.committed) + partial_rows_data.consume_next() + self.assertTrue(existing_values.committed) + self.assertEqual(existing_values.cells, {}) + + def test_consume_next_empty_iter(self): + response_iterator = _MockCancellableIterator() + partial_rows_data = self._makeOne(response_iterator) + with self.assertRaises(StopIteration): + partial_rows_data.consume_next() + + def test_consume_all(self): + klass = self._getDoNothingClass() + + value1, value2, value3 = object(), object(), object() + response_iterator = _MockCancellableIterator(value1, value2, value3) + partial_rows_data = klass(response_iterator) + self.assertEqual(partial_rows_data._consumed, []) + partial_rows_data.consume_all() + self.assertEqual(partial_rows_data._consumed, [value1, value2, value3]) + + def test_consume_all_with_max_loops(self): + klass = self._getDoNothingClass() + + value1, value2, value3 = object(), object(), object() + response_iterator = _MockCancellableIterator(value1, value2, value3) + partial_rows_data = klass(response_iterator) + self.assertEqual(partial_rows_data._consumed, []) + partial_rows_data.consume_all(max_loops=1) + self.assertEqual(partial_rows_data._consumed, [value1]) + # Make sure the iterator still has the remaining values. + self.assertEqual(list(response_iterator.iter_values), [value2, value3]) + + +class _MockCancellableIterator(object): + + cancel_calls = 0 + + def __init__(self, *values): + self.iter_values = iter(values) + + def cancel(self): + self.cancel_calls += 1 + + def next(self): + return next(self.iter_values)