diff --git a/docs/spanner/snapshot-usage.rst b/docs/spanner/snapshot-usage.rst index ba31425a54b4..90c6d0322b4d 100644 --- a/docs/spanner/snapshot-usage.rst +++ b/docs/spanner/snapshot-usage.rst @@ -62,26 +62,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the ``read``, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.read(table, columns, keys) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.read( - table, columns, keys, resume_token=result.resume_token) - continue - else: - break - Execute a SQL Select Statement @@ -112,26 +92,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the query, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.execute_sql(QUERY) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.execute_sql( - QUERY, resume_token=result.resume_token) - continue - else: - break - Next Step --------- diff --git a/docs/spanner/transaction-usage.rst b/docs/spanner/transaction-usage.rst index 0577bc2093b8..5c2e4a9bb5a2 100644 --- a/docs/spanner/transaction-usage.rst +++ b/docs/spanner/transaction-usage.rst @@ -32,12 +32,6 @@ fails if the result set is too large, for row in result.rows: print(row) -.. note:: - - If streaming a chunk fails due to a "resumable" error, - :meth:`Session.read` retries the ``StreamingRead`` API reqeust, - passing the ``resume_token`` from the last partial result streamed. - Execute a SQL Select Statement ------------------------------ diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index d513889053a7..94fd0f092366 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -165,8 +165,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read( - table, columns, keyset, index, limit, resume_token) + return self.snapshot().read(table, columns, keyset, index, limit) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request. :type sql: str @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self.snapshot().execute_sql( - sql, params, param_types, query_mode, resume_token) + sql, params, param_types, query_mode) def batch(self): """Factory to create a batch for this session. diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 89bd840000dc..0eb705cd3270 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -14,6 +14,8 @@ """Model a set of read-only queries to a database as a snapshot.""" +import functools + from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector @@ -49,8 +51,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """ raise NotImplementedError - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -69,9 +70,6 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -91,18 +89,22 @@ def read(self, table, columns, keyset, index='', limit=0, iterator = api.streaming_read( self._session.name, table, columns, keyset.to_pb(), - transaction=transaction, index=index, limit=limit, - resume_token=resume_token, options=options) + index=index, limit=limit, + transaction=transaction, options=options) self._read_request_count += 1 + restart = functools.partial( + api.streaming_read, self._session.name, table, columns, keyset, + index=index, limit=limit, + transaction=transaction, options=options) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, restart) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request for rows in a table. :type sql: str @@ -122,9 +124,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -152,15 +151,20 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, api = database.spanner_api iterator = api.execute_streaming_sql( self._session.name, sql, - transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, resume_token=resume_token, options=options) + params=params_pb, param_types=param_types, query_mode=query_mode, + transaction=transaction, options=options) self._read_request_count += 1 + restart = functools.partial( + api.execute_streaming_sql, self._session.name, sql, + params=params, param_types=param_types, query_mode=query_mode, + transaction=transaction, options=options) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, restart) class Snapshot(_SnapshotBase): diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index c7d950d766d7..3120ced43803 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -34,12 +34,20 @@ class StreamedResultSet(object): :class:`google.cloud.proto.spanner.v1.result_set_pb2.PartialResultSet` instances. + :type restart: callable + :param restart: + Function (typically curried via :func:`functools.partial`) used to + restart the initial request if a retriable error is raised during + streaming. + :type source: :class:`~google.cloud.spanner.snapshot.Snapshot` :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, source=None): + def __init__(self, response_iterator, restart, source=None): self._response_iterator = response_iterator - self._rows = [] # Fully-processed rows + self._restart = restart + self._pending_rows = [] # Rows pending new token / EOT + self._complete_rows = [] # Fully-processed rows self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS self._stats = None # Until set from last PRS @@ -55,7 +63,7 @@ def rows(self): :rtype: list of row-data lists. :returns: list of completed row data, from proceesd PRS responses. """ - return self._rows + return self._complete_rows @property def fields(self): @@ -122,17 +130,50 @@ def _merge_values(self, values): field = self.fields[index] self._current_row.append(_parse_value_pb(value, field.type)) if len(self._current_row) == width: - self._rows.append(self._current_row) + self._pending_rows.append(self._current_row) self._current_row = [] + def _flush_pending_rows(self): + """Helper for :meth:`consume_next`.""" + flushed = self._pending_rows[:] + self._pending_rows[:] = () + self._complete_rows.extend(flushed) + + def _do_restart(self): + """Helper for :meth:`consume_next`.""" + self._pending_chunk = None + self._pending_rows[:] = () + self._current_row[:] = () + + if self._resume_token: + self._response_iterator = self._restart(self._resume_token) + else: + self._response_iterator = self._restart() + def consume_next(self): """Consume the next partial result set from the stream. - Parse the result set into new/existing rows in :attr:`_rows` + Parse the result set into new/existing rows in :attr:`_complete_rows` + + :raises StopIteration: if the iterator is empty. """ - response = six.next(self._response_iterator) + try: + response = six.next(self._response_iterator) + except exceptions.ServiceUnavailable: + self._do_restart() + return + except StopIteration: + self._flush_pending_rows() + raise + self._counter += 1 - self._resume_token = response.resume_token + if response.resume_token: + self._flush_pending_rows() + self._resume_token = response.resume_token + + if response.HasField('stats'): # last response + self._flush_pending_rows() + self._stats = response.stats if self._metadata is None: # first response metadata = self._metadata = response.metadata @@ -141,9 +182,6 @@ def consume_next(self): if source is not None and source._transaction_id is None: source._transaction_id = metadata.transaction.id - if response.HasField('stats'): # last response - self._stats = response.stats - values = list(response.values) if self._pending_chunk is not None: values[0] = self._merge_chunk(values[0]) @@ -162,11 +200,15 @@ def consume_all(self): break def __iter__(self): - iter_rows, self._rows[:] = self._rows[:], () + iter_rows, self._complete_rows[:] = self._complete_rows[:], () while True: if not iter_rows: - self.consume_next() # raises StopIteration - iter_rows, self._rows[:] = self._rows[:], () + try: + self.consume_next() # raises StopIteration + except StopIteration: + if not self._complete_rows: + raise + iter_rows, self._complete_rows[:] = self._complete_rows[:], () while iter_rows: yield iter_rows.pop(0) @@ -179,6 +221,7 @@ def one(self): in whole or in part. """ answer = self.one_or_none() + if answer is None: raise exceptions.NotFound('No rows matched the given query.') return answer @@ -196,21 +239,13 @@ def one_or_none(self): raise RuntimeError('Can not call `.one` or `.one_or_none` after ' 'stream consumption has already started.') - # Consume the first result of the stream. - # If there is no first result, then return None. - iterator = iter(self) - try: - answer = next(iterator) - except StopIteration: - return None + self.consume_all() - # Attempt to consume more. This should no-op; if we get additional - # rows, then this is an error case. - try: - next(iterator) + if len(self._complete_rows) > 1: raise ValueError('Expected one result; got more.') - except StopIteration: - return answer + + if self._complete_rows: + return self._complete_rows[0] class Unmergeable(ValueError): diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 826369079d29..3c9d9e74af47 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -265,7 +265,6 @@ def test_read(self): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs): self._session = session self._kwargs = kwargs.copy() - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): _read_with.append( - (table, columns, keyset, index, limit, resume_token)) + (table, columns, keyset, index, limit)) return expected with _Monkey(MUT, Snapshot=_Snapshot): found = session.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertIs(found, expected) self.assertEqual(len(_read_with), 1) - (table, columns, key_set, index, limit, resume_token) = _read_with[0] + (table, columns, key_set, index, limit) = _read_with[0] self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) def test_execute_sql_not_created(self): SQL = 'SELECT first_name, age FROM citizens' @@ -315,7 +312,6 @@ def test_execute_sql_defaults(self): from google.cloud._testing import _Monkey SQL = 'SELECT first_name, age FROM citizens' - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -330,25 +326,23 @@ def __init__(self, session, **kwargs): self._kwargs = kwargs.copy() def execute_sql( - self, sql, params=None, param_types=None, query_mode=None, - resume_token=None): + self, sql, params=None, param_types=None, query_mode=None): _executed_sql_with.append( - (sql, params, param_types, query_mode, resume_token)) + (sql, params, param_types, query_mode)) return expected with _Monkey(MUT, Snapshot=_Snapshot): - found = session.execute_sql(SQL, resume_token=TOKEN) + found = session.execute_sql(SQL) self.assertIs(found, expected) self.assertEqual(len(_executed_sql_with), 1) - sql, params, param_types, query_mode, token = _executed_sql_with[0] + sql, params, param_types, query_mode = _executed_sql_with[0] self.assertEqual(sql, SQL) self.assertEqual(params, None) self.assertEqual(param_types, None) self.assertEqual(query_mode, None) - self.assertEqual(token, TOKEN) def test_batch_not_created(self): database = _Database(self.DATABASE_NAME) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 4717a14c2f24..b10368676750 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -15,6 +15,8 @@ import unittest +import mock + from google.cloud._testing import _GAXBaseAPI @@ -149,7 +151,6 @@ def _read_helper(self, multi_use, first=True, count=0): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database() api = database.spanner_api = _FauxSpannerAPI( _streaming_read_response=_MockCancellableIterator(*result_sets)) @@ -160,9 +161,24 @@ def _read_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.read( - TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + owp_patch = mock.patch( + 'google.cloud.spanner.snapshot._options_with_prefix') + + with partial_patch as patch_partial: + with owp_patch as patch_owp: + result_set = derived.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT) + + self.assertIs(result_set._restart, patch_partial.return_value) + patch_partial.assert_called_once_with( + api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, + transaction=derived._make_txn_selector(), + index=INDEX, limit=LIMIT, + options=patch_owp.return_value, + ) + patch_owp.assert_called_once_with(database.name) self.assertEqual(derived._read_request_count, count + 1) @@ -193,9 +209,8 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) - self.assertEqual(options.kwargs['metadata'], - [('google-cloud-resource-prefix', database.name)]) + self.assertEqual(resume_token, b'') + self.assertIs(options, patch_owp.return_value) def test_read_wo_multi_use(self): self._read_helper(multi_use=False) @@ -262,6 +277,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 from google.cloud.spanner._helpers import _make_value_pb + from google.cloud.spanner._helpers import _options_with_prefix TXN_ID = b'DEADBEEF' VALUES = [ @@ -273,7 +289,6 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): for row in VALUES ] MODE = 2 # PROFILE - TOKEN = b'DEADBEEF' struct_type_pb = StructType(fields=[ StructType.Field(name='first_name', type=Type(code=STRING)), StructType.Field(name='last_name', type=Type(code=STRING)), @@ -299,9 +314,24 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.execute_sql( - SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + owp_patch = mock.patch( + 'google.cloud.spanner.snapshot._options_with_prefix') + + with partial_patch as patch_partial: + with owp_patch as patch_owp: + result_set = derived.execute_sql( + SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE) + + self.assertIs(result_set._restart, patch_partial.return_value) + patch_partial.assert_called_once_with( + api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, + params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE, + transaction=derived._make_txn_selector(), + options=patch_owp.return_value, + ) + patch_owp.assert_called_once_with(database.name) self.assertEqual(derived._read_request_count, count + 1) @@ -333,9 +363,8 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertEqual(params, expected_params) self.assertEqual(param_types, PARAM_TYPES) self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, TOKEN) - self.assertEqual(options.kwargs['metadata'], - [('google-cloud-resource-prefix', database.name)]) + self.assertEqual(resume_token, b'') + self.assertIs(options, patch_owp.return_value) def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False) @@ -360,8 +389,6 @@ def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self): class _MockCancellableIterator(object): - cancel_calls = 0 - def __init__(self, *values): self.iter_values = iter(values) @@ -725,7 +752,7 @@ def begin_transaction(self, session, options_, options=None): # pylint: disable=too-many-arguments def streaming_read(self, session, table, columns, key_set, transaction=None, index='', limit=0, - resume_token='', options=None): + resume_token=b'', options=None): from google.gax.errors import GaxError self._streaming_read_with = ( @@ -738,7 +765,7 @@ def streaming_read(self, session, table, columns, key_set, def execute_streaming_sql(self, session, sql, transaction=None, params=None, param_types=None, - resume_token='', query_mode=None, options=None): + resume_token=b'', query_mode=None, options=None): from google.gax.errors import GaxError self._executed_streaming_sql_with = ( diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 0e0bcb7aff6b..cb266e17d4a0 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -25,24 +25,30 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def test_ctor_defaults(self): iterator = _MockCancellableIterator() - streamed = self._make_one(iterator) + restart = object() + streamed = self._make_one(iterator, restart) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._restart, restart) self.assertIsNone(streamed._source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) self.assertIsNone(streamed.resume_token) + self.assertIsNone(streamed.resume_token) def test_ctor_w_source(self): iterator = _MockCancellableIterator() + restart = object() source = object() - streamed = self._make_one(iterator, source=source) + streamed = self._make_one(iterator, restart, source=source) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._restart, restart) self.assertIs(streamed._source, source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -128,7 +134,7 @@ def _make_result_set_stats(query_plan=None, **kw): @staticmethod def _make_partial_result_set( - values, metadata=None, stats=None, chunked_value=False): + values, metadata=None, stats=None, chunked_value=False, **kw): from google.cloud.proto.spanner.v1.result_set_pb2 import ( PartialResultSet) return PartialResultSet( @@ -136,6 +142,7 @@ def _make_partial_result_set( metadata=metadata, stats=stats, chunked_value=chunked_value, + **kw ) def test_properties_set(self): @@ -470,7 +477,7 @@ def test_merge_values_empty_and_empty(self): streamed._metadata = self._make_result_set_metadata(FIELDS) streamed._current_row = [] streamed._merge_values([]) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, []) def test_merge_values_empty_and_partial(self): @@ -486,7 +493,7 @@ def test_merge_values_empty_and_partial(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BARE) def test_merge_values_empty_and_filled(self): @@ -502,7 +509,7 @@ def test_merge_values_empty_and_filled(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, [BARE]) + self.assertEqual(streamed._pending_rows, [BARE]) self.assertEqual(streamed._current_row, []) def test_merge_values_empty_and_filled_plus(self): @@ -522,7 +529,7 @@ def test_merge_values_empty_and_filled_plus(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, [BARE[0:3], BARE[3:6]]) + self.assertEqual(streamed._pending_rows, [BARE[0:3], BARE[3:6]]) self.assertEqual(streamed._current_row, BARE[6:]) def test_merge_values_partial_and_empty(self): @@ -539,7 +546,7 @@ def test_merge_values_partial_and_empty(self): ] streamed._current_row[:] = BEFORE streamed._merge_values([]) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BEFORE) def test_merge_values_partial_and_partial(self): @@ -556,7 +563,7 @@ def test_merge_values_partial_and_partial(self): MERGED = [42] TO_MERGE = [self._make_value(item) for item in MERGED] streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BEFORE + MERGED) def test_merge_values_partial_and_filled(self): @@ -575,7 +582,7 @@ def test_merge_values_partial_and_filled(self): MERGED = [42, True] TO_MERGE = [self._make_value(item) for item in MERGED] streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, [BEFORE + MERGED]) + self.assertEqual(streamed._pending_rows, [BEFORE + MERGED]) self.assertEqual(streamed._current_row, []) def test_merge_values_partial_and_filled_plus(self): @@ -599,57 +606,121 @@ def test_merge_values_partial_and_filled_plus(self): TO_MERGE = [self._make_value(item) for item in MERGED] VALUES = BEFORE + MERGED streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) + self.assertEqual(streamed._pending_rows, [VALUES[0:3], VALUES[3:6]]) self.assertEqual(streamed._current_row, VALUES[6:]) - def test_one_or_none_no_value(self): - streamed = self._make_one(_MockCancellableIterator()) - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertIsNone(streamed.one_or_none()) + def test__flush_pending_rows_both_empty(self): + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, []) - def test_one_or_none_single_value(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertEqual(streamed.one_or_none(), 'foo') + def test__flush_pending_rows_complete_empty(self): + row1, row2 = ('foo',), ('bar',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._pending_rows = [row1, row2] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) - def test_one_or_none_multiple_values(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo', 'bar'] - with self.assertRaises(ValueError): - streamed.one_or_none() + def test__flush_pending_rows_pending_empty(self): + row1, row2 = ('foo',), ('bar',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._complete_rows = [row1, row2] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) - def test_one_or_none_consumed_stream(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._metadata = object() - with self.assertRaises(RuntimeError): - streamed.one_or_none() + def test__flush_pending_rows_neither_empty(self): + row1, row2, row3, row4 = ('foo',), ('bar',), ('baz',), ('spam',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._complete_rows = [row1, row2] + streamed._pending_rows = [row3, row4] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2, row3, row4]) + + def test__do_restart_no_token(self): + restart = mock.Mock() + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) - def test_one_single_value(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertEqual(streamed.one(), 'foo') + streamed._do_restart() - def test_one_no_value(self): - from google.cloud import exceptions + restart.assert_called_once_with() - iterator = _MockCancellableIterator(['foo']) - streamed = self._make_one(iterator) - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - with self.assertRaises(exceptions.NotFound): - streamed.one() + def test__do_restart_empty(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + + streamed._do_restart() + + self.assertIs(streamed._response_iterator, restart.return_value) + self.assertIsNone(streamed._pending_chunk) + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed._current_row, []) + restart.assert_called_once_with(TOKEN) + + def test__do_restart_non_empty(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() + row1, row2, row3 = ('foo',), ('bar',), ('baz',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + streamed._pending_chunk = self._make_value(u'Phred ') + streamed._current_row = [row1] + streamed._pending_rows = [row2, row3] + + streamed._do_restart() + + self.assertIs(streamed._response_iterator, restart.return_value) + self.assertIsNone(streamed._pending_chunk) + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed._current_row, []) + restart.assert_called_once_with(TOKEN) + + def test_consume_next_w_service_no_token_yet(self): + restart = mock.Mock() + iterator = _ServiceUnavailableIterator() + streamed = self._make_one(iterator, restart=restart) + + streamed.consume_next() + + self.assertIs(streamed._response_iterator, restart.return_value) + restart.assert_called_once_with() + + def test_consume_next_w_service_unavailable(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() + iterator = _ServiceUnavailableIterator() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + + streamed.consume_next() + + self.assertIs(streamed._response_iterator, restart.return_value) + restart.assert_called_once_with(TOKEN) def test_consume_next_empty(self): + row1, row2 = ('foo',), ('bar',) iterator = _MockCancellableIterator() streamed = self._make_one(iterator) + streamed._pending_rows = [row1, row2] + with self.assertRaises(StopIteration): streamed.consume_next() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) + def test_consume_next_first_set_partial(self): TXN_ID = b'DEADBEEF' FIELDS = [ @@ -665,36 +736,48 @@ def test_consume_next_first_set_partial(self): iterator = _MockCancellableIterator(result_set) source = mock.Mock(_transaction_id=None, spec=['_transaction_id']) streamed = self._make_one(iterator, source=source) + streamed.consume_next() + self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_first_set_partial_existing_txn_id(self): TXN_ID = b'DEADBEEF' + PRIOR_TOKEN = b'B4' + NEW_TOKEN = b'AF2' FIELDS = [ self._make_scalar_field('full_name', 'STRING'), self._make_scalar_field('age', 'INT64'), self._make_scalar_field('married', 'BOOL'), ] + row1, row2 = ('foo',), ('bar',) metadata = self._make_result_set_metadata( FIELDS, transaction_id=b'') BARE = [u'Phred Phlyntstone', 42] VALUES = [self._make_value(bare) for bare in BARE] - result_set = self._make_partial_result_set(VALUES, metadata=metadata) + result_set = self._make_partial_result_set( + VALUES, metadata=metadata, resume_token=NEW_TOKEN) iterator = _MockCancellableIterator(result_set) source = mock.Mock(_transaction_id=TXN_ID, spec=['_transaction_id']) streamed = self._make_one(iterator, source=source) + streamed._resume_token = PRIOR_TOKEN + streamed._pending_rows = [row1, row2] + streamed.consume_next() - self.assertEqual(streamed.rows, []) + + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, NEW_TOKEN) self.assertEqual(source._transaction_id, TXN_ID) - def test_consume_next_w_partial_result(self): + def test_consume_next_w_partial_result_no_resume_token(self): + PRIOR_TOKEN = b'DEADBEEF' FIELDS = [ self._make_scalar_field('full_name', 'STRING'), self._make_scalar_field('age', 'INT64'), @@ -706,12 +789,15 @@ def test_consume_next_w_partial_result(self): result_set = self._make_partial_result_set(VALUES, chunked_value=True) iterator = _MockCancellableIterator(result_set) streamed = self._make_one(iterator) + streamed._resume_token = PRIOR_TOKEN streamed._metadata = self._make_result_set_metadata(FIELDS) + streamed.consume_next() + self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._pending_chunk, VALUES[0]) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, PRIOR_TOKEN) def test_consume_next_w_pending_chunk(self): FIELDS = [ @@ -730,14 +816,16 @@ def test_consume_next_w_pending_chunk(self): streamed = self._make_one(iterator) streamed._metadata = self._make_result_set_metadata(FIELDS) streamed._pending_chunk = self._make_value(u'Phred ') + streamed.consume_next() - self.assertEqual(streamed.rows, [ + + self.assertEqual(streamed._pending_rows, [ [u'Phred Phlyntstone', BARE[1], BARE[2]], [BARE[3], BARE[4], BARE[5]], ]) self.assertEqual(streamed._current_row, [BARE[6]]) self.assertIsNone(streamed._pending_chunk) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) def test_consume_next_last_set(self): FIELDS = [ @@ -745,6 +833,7 @@ def test_consume_next_last_set(self): self._make_scalar_field('age', 'INT64'), self._make_scalar_field('married', 'BOOL'), ] + row1, row2 = ('foo',), ('bar',) metadata = self._make_result_set_metadata(FIELDS) stats = self._make_result_set_stats( rows_returned="1", @@ -757,11 +846,15 @@ def test_consume_next_last_set(self): iterator = _MockCancellableIterator(result_set) streamed = self._make_one(iterator) streamed._metadata = metadata + streamed._pending_rows = [row1, row2] + streamed.consume_next() - self.assertEqual(streamed.rows, [BARE]) + + self.assertEqual(streamed.rows, [row1, row2]) + self.assertEqual(streamed._pending_rows, [BARE]) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._stats, stats) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) def test_consume_all_empty(self): iterator = _MockCancellableIterator() @@ -886,7 +979,7 @@ def test___iter___w_existing_rows_read(self): result_set2 = self._make_partial_result_set(VALUES[4:]) iterator = _MockCancellableIterator(result_set1, result_set2) streamed = self._make_one(iterator) - streamed._rows[:] = ALREADY + streamed._complete_rows[:] = ALREADY found = list(streamed) self.assertEqual(found, ALREADY + [ [BARE[0], BARE[1], BARE[2]], @@ -897,10 +990,63 @@ def test___iter___w_existing_rows_read(self): self.assertEqual(streamed._current_row, []) self.assertIsNone(streamed._pending_chunk) + def test_one_single_value(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._complete_rows = ['foo'] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one(), 'foo') -class _MockCancellableIterator(object): + def test_one_no_value(self): + from google.cloud import exceptions - cancel_calls = 0 + iterator = _MockCancellableIterator(['foo']) + streamed = self._make_one(iterator) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + with self.assertRaises(exceptions.NotFound): + streamed.one() + + def test_one_or_none_no_value(self): + streamed = self._make_one(_MockCancellableIterator()) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertIsNone(streamed.one_or_none()) + + def test_one_or_none_single_value(self): + row = ('foo',) + streamed = self._make_one(_MockCancellableIterator()) + streamed._complete_rows = [row] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one_or_none(), row) + + def test_one_or_none_multiple_values(self): + row1, row2 = ('foo',), ('bar',) + streamed = self._make_one(_MockCancellableIterator()) + streamed._complete_rows = [row1, row2] + with self.assertRaises(ValueError): + streamed.one_or_none() + + def test_one_or_none_consumed_stream(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._metadata = object() + with self.assertRaises(RuntimeError): + streamed.one_or_none() + + +class _ServiceUnavailableIterator(object): + + def next(self): + from google.cloud import exceptions + + raise exceptions.ServiceUnavailable('testing') + + def __next__(self): # pragma: NO COVER Py3k + return self.next() + + +class _MockCancellableIterator(object): def __init__(self, *values): self.iter_values = iter(values) @@ -921,8 +1067,9 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def _load_json_test(self, test_name): import os