From 487d17860b718b6af916da90cfa280b28a4af024 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 28 Aug 2017 11:57:16 -0400 Subject: [PATCH 01/14] Cleanups. --- spanner/tests/unit/test_snapshot.py | 2 -- spanner/tests/unit/test_streamed.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 4717a14c2f24..27cfb1380703 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -360,8 +360,6 @@ def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self): class _MockCancellableIterator(object): - cancel_calls = 0 - def __init__(self, *values): self.iter_values = iter(values) diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 0e0bcb7aff6b..19e0415a227c 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -900,8 +900,6 @@ def test___iter___w_existing_rows_read(self): class _MockCancellableIterator(object): - cancel_calls = 0 - def __init__(self, *values): self.iter_values = iter(values) From 2ce70407b9e62428f4a3db74637661f15a96fda5 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 28 Aug 2017 11:17:08 -0400 Subject: [PATCH 02/14] Add required 'retry' arg to 'StreamedResultSet'. Represents the curried function allowing retrying the request on a suitable error during streaming. --- spanner/google/cloud/spanner/streamed.py | 9 ++++++++- spanner/tests/unit/test_streamed.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index c7d950d766d7..7038be3b737b 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -34,11 +34,18 @@ class StreamedResultSet(object): :class:`google.cloud.proto.spanner.v1.result_set_pb2.PartialResultSet` instances. + :type retry: callable + :param retry: + Function (typically curried via :func:`functools.partial`) used to + retry the initial request if a retriable error is raised during + streaming. + :type source: :class:`~google.cloud.spanner.snapshot.Snapshot` :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, source=None): + def __init__(self, response_iterator, retry, source=None): self._response_iterator = response_iterator + self._retry = retry self._rows = [] # Fully-processed rows self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 19e0415a227c..43c3c69c1226 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -25,24 +25,29 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, retry=object(), source=None): + return self._getTargetClass()(response_iterator, retry, source=source) def test_ctor_defaults(self): iterator = _MockCancellableIterator() - streamed = self._make_one(iterator) + retry = object() + streamed = self._make_one(iterator, retry) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._retry, retry) self.assertIsNone(streamed._source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) self.assertIsNone(streamed.resume_token) + self.assertIsNone(streamed.resume_token) def test_ctor_w_source(self): iterator = _MockCancellableIterator() + retry = object() source = object() - streamed = self._make_one(iterator, source=source) + streamed = self._make_one(iterator, retry, source=source) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._retry, retry) self.assertIs(streamed._source, source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -919,8 +924,8 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, retry=object(), source=None): + return self._getTargetClass()(response_iterator, retry, source=source) def _load_json_test(self, test_name): import os From 8817a3240bf914a1d0331438cd061d854f381a79 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 28 Aug 2017 11:17:48 -0400 Subject: [PATCH 03/14] Curry arguments to retry request on a resumable error. --- spanner/google/cloud/spanner/snapshot.py | 19 ++++++++++++--- spanner/tests/unit/test_snapshot.py | 31 +++++++++++++++++++----- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 89bd840000dc..a10d23bd92d1 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -14,6 +14,8 @@ """Model a set of read-only queries to a database as a snapshot.""" +import functools + from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector @@ -96,10 +98,14 @@ def read(self, table, columns, keyset, index='', limit=0, self._read_request_count += 1 + retry = functools.partial( + api.streaming_read, self._session.name, table, columns, keyset, + index=index, limit=limit, resume_token=resume_token) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, retry, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, retry) def execute_sql(self, sql, params=None, param_types=None, query_mode=None, resume_token=b''): @@ -157,10 +163,15 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, self._read_request_count += 1 + retry = functools.partial( + api.execute_streaming_sql, self._session.name, sql, + params=params, param_types=param_types, query_mode=query_mode, + resume_token=resume_token) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, retry, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, retry) class Snapshot(_SnapshotBase): diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 27cfb1380703..267d0b77af61 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -15,6 +15,8 @@ import unittest +import mock + from google.cloud._testing import _GAXBaseAPI @@ -160,9 +162,17 @@ def _read_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.read( - TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + + with partial_patch as patch: + result_set = derived.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT, resume_token=TOKEN) + + self.assertIs(result_set._retry, patch.return_value) + patch.assert_called_once_with( + api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT, resume_token=TOKEN) self.assertEqual(derived._read_request_count, count + 1) @@ -299,9 +309,18 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.execute_sql( - SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + + with partial_patch as patch: + result_set = derived.execute_sql( + SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE, resume_token=TOKEN) + + self.assertIs(result_set._retry, patch.return_value) + patch.assert_called_once_with( + api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, + params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE, + resume_token=TOKEN) self.assertEqual(derived._read_request_count, count + 1) From 8334354b32b21816b9e7141e13bb1e9eb662cd02 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 28 Aug 2017 11:51:03 -0400 Subject: [PATCH 04/14] Remove 'resume_token' from 'Snapshot.{read,execute_sql}'. The token will be managed internally by the 'StreamedResultSet'. --- docs/spanner/snapshot-usage.rst | 40 ------------------------ docs/spanner/transaction-usage.rst | 6 ---- spanner/google/cloud/spanner/session.py | 17 +++------- spanner/google/cloud/spanner/snapshot.py | 21 ++++--------- spanner/tests/unit/test_session.py | 22 +++++-------- spanner/tests/unit/test_snapshot.py | 19 +++++------ 6 files changed, 26 insertions(+), 99 deletions(-) diff --git a/docs/spanner/snapshot-usage.rst b/docs/spanner/snapshot-usage.rst index ba31425a54b4..90c6d0322b4d 100644 --- a/docs/spanner/snapshot-usage.rst +++ b/docs/spanner/snapshot-usage.rst @@ -62,26 +62,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the ``read``, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.read(table, columns, keys) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.read( - table, columns, keys, resume_token=result.resume_token) - continue - else: - break - Execute a SQL Select Statement @@ -112,26 +92,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the query, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.execute_sql(QUERY) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.execute_sql( - QUERY, resume_token=result.resume_token) - continue - else: - break - Next Step --------- diff --git a/docs/spanner/transaction-usage.rst b/docs/spanner/transaction-usage.rst index 0577bc2093b8..5c2e4a9bb5a2 100644 --- a/docs/spanner/transaction-usage.rst +++ b/docs/spanner/transaction-usage.rst @@ -32,12 +32,6 @@ fails if the result set is too large, for row in result.rows: print(row) -.. note:: - - If streaming a chunk fails due to a "resumable" error, - :meth:`Session.read` retries the ``StreamingRead`` API reqeust, - passing the ``resume_token`` from the last partial result streamed. - Execute a SQL Select Statement ------------------------------ diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index d513889053a7..94fd0f092366 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -165,8 +165,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read( - table, columns, keyset, index, limit, resume_token) + return self.snapshot().read(table, columns, keyset, index, limit) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request. :type sql: str @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self.snapshot().execute_sql( - sql, params, param_types, query_mode, resume_token) + sql, params, param_types, query_mode) def batch(self): """Factory to create a batch for this session. diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index a10d23bd92d1..8582033a076d 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -51,8 +51,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """ raise NotImplementedError - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -71,9 +70,6 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -94,21 +90,20 @@ def read(self, table, columns, keyset, index='', limit=0, iterator = api.streaming_read( self._session.name, table, columns, keyset.to_pb(), transaction=transaction, index=index, limit=limit, - resume_token=resume_token, options=options) + options=options) self._read_request_count += 1 retry = functools.partial( api.streaming_read, self._session.name, table, columns, keyset, - index=index, limit=limit, resume_token=resume_token) + index=index, limit=limit) if self._multi_use: return StreamedResultSet(iterator, retry, source=self) else: return StreamedResultSet(iterator, retry) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request for rows in a table. :type sql: str @@ -128,9 +123,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -159,14 +151,13 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, iterator = api.execute_streaming_sql( self._session.name, sql, transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, resume_token=resume_token, options=options) + query_mode=query_mode, options=options) self._read_request_count += 1 retry = functools.partial( api.execute_streaming_sql, self._session.name, sql, - params=params, param_types=param_types, query_mode=query_mode, - resume_token=resume_token) + params=params, param_types=param_types, query_mode=query_mode) if self._multi_use: return StreamedResultSet(iterator, retry, source=self) diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 826369079d29..3c9d9e74af47 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -265,7 +265,6 @@ def test_read(self): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs): self._session = session self._kwargs = kwargs.copy() - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): _read_with.append( - (table, columns, keyset, index, limit, resume_token)) + (table, columns, keyset, index, limit)) return expected with _Monkey(MUT, Snapshot=_Snapshot): found = session.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertIs(found, expected) self.assertEqual(len(_read_with), 1) - (table, columns, key_set, index, limit, resume_token) = _read_with[0] + (table, columns, key_set, index, limit) = _read_with[0] self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) def test_execute_sql_not_created(self): SQL = 'SELECT first_name, age FROM citizens' @@ -315,7 +312,6 @@ def test_execute_sql_defaults(self): from google.cloud._testing import _Monkey SQL = 'SELECT first_name, age FROM citizens' - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -330,25 +326,23 @@ def __init__(self, session, **kwargs): self._kwargs = kwargs.copy() def execute_sql( - self, sql, params=None, param_types=None, query_mode=None, - resume_token=None): + self, sql, params=None, param_types=None, query_mode=None): _executed_sql_with.append( - (sql, params, param_types, query_mode, resume_token)) + (sql, params, param_types, query_mode)) return expected with _Monkey(MUT, Snapshot=_Snapshot): - found = session.execute_sql(SQL, resume_token=TOKEN) + found = session.execute_sql(SQL) self.assertIs(found, expected) self.assertEqual(len(_executed_sql_with), 1) - sql, params, param_types, query_mode, token = _executed_sql_with[0] + sql, params, param_types, query_mode = _executed_sql_with[0] self.assertEqual(sql, SQL) self.assertEqual(params, None) self.assertEqual(param_types, None) self.assertEqual(query_mode, None) - self.assertEqual(token, TOKEN) def test_batch_not_created(self): database = _Database(self.DATABASE_NAME) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 267d0b77af61..931cc7d52127 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -151,7 +151,6 @@ def _read_helper(self, multi_use, first=True, count=0): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database() api = database.spanner_api = _FauxSpannerAPI( _streaming_read_response=_MockCancellableIterator(*result_sets)) @@ -167,12 +166,12 @@ def _read_helper(self, multi_use, first=True, count=0): with partial_patch as patch: result_set = derived.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertIs(result_set._retry, patch.return_value) patch.assert_called_once_with( api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertEqual(derived._read_request_count, count + 1) @@ -203,7 +202,7 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -283,7 +282,6 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): for row in VALUES ] MODE = 2 # PROFILE - TOKEN = b'DEADBEEF' struct_type_pb = StructType(fields=[ StructType.Field(name='first_name', type=Type(code=STRING)), StructType.Field(name='last_name', type=Type(code=STRING)), @@ -314,13 +312,12 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): with partial_patch as patch: result_set = derived.execute_sql( SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, resume_token=TOKEN) + query_mode=MODE) self.assertIs(result_set._retry, patch.return_value) patch.assert_called_once_with( api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, - params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE, - resume_token=TOKEN) + params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE) self.assertEqual(derived._read_request_count, count + 1) @@ -352,7 +349,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertEqual(params, expected_params) self.assertEqual(param_types, PARAM_TYPES) self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -742,7 +739,7 @@ def begin_transaction(self, session, options_, options=None): # pylint: disable=too-many-arguments def streaming_read(self, session, table, columns, key_set, transaction=None, index='', limit=0, - resume_token='', options=None): + resume_token=b'', options=None): from google.gax.errors import GaxError self._streaming_read_with = ( @@ -755,7 +752,7 @@ def streaming_read(self, session, table, columns, key_set, def execute_streaming_sql(self, session, sql, transaction=None, params=None, param_types=None, - resume_token='', query_mode=None, options=None): + resume_token=b'', query_mode=None, options=None): from google.gax.errors import GaxError self._executed_streaming_sql_with = ( From 2abaccfee8ad12a3e2be3fb341ad34a68776211a Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 28 Aug 2017 14:33:43 -0400 Subject: [PATCH 05/14] Rename 'SRS.retry' -> 'restart' for clarity. --- spanner/google/cloud/spanner/snapshot.py | 12 ++++++------ spanner/google/cloud/spanner/streamed.py | 10 +++++----- spanner/tests/unit/test_snapshot.py | 4 ++-- spanner/tests/unit/test_streamed.py | 22 ++++++++++++---------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 8582033a076d..9636bd9762e3 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -94,14 +94,14 @@ def read(self, table, columns, keyset, index='', limit=0): self._read_request_count += 1 - retry = functools.partial( + restart = functools.partial( api.streaming_read, self._session.name, table, columns, keyset, index=index, limit=limit) if self._multi_use: - return StreamedResultSet(iterator, retry, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator, retry) + return StreamedResultSet(iterator, restart) def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request for rows in a table. @@ -155,14 +155,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None): self._read_request_count += 1 - retry = functools.partial( + restart = functools.partial( api.execute_streaming_sql, self._session.name, sql, params=params, param_types=param_types, query_mode=query_mode) if self._multi_use: - return StreamedResultSet(iterator, retry, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator, retry) + return StreamedResultSet(iterator, restart) class Snapshot(_SnapshotBase): diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index 7038be3b737b..d279b73283ee 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -34,18 +34,18 @@ class StreamedResultSet(object): :class:`google.cloud.proto.spanner.v1.result_set_pb2.PartialResultSet` instances. - :type retry: callable - :param retry: + :type restart: callable + :param restart: Function (typically curried via :func:`functools.partial`) used to - retry the initial request if a retriable error is raised during + restart the initial request if a retriable error is raised during streaming. :type source: :class:`~google.cloud.spanner.snapshot.Snapshot` :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, retry, source=None): + def __init__(self, response_iterator, restart, source=None): self._response_iterator = response_iterator - self._retry = retry + self._restart = restart self._rows = [] # Fully-processed rows self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 931cc7d52127..dee8fdecefd7 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -168,7 +168,7 @@ def _read_helper(self, multi_use, first=True, count=0): TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT) - self.assertIs(result_set._retry, patch.return_value) + self.assertIs(result_set._restart, patch.return_value) patch.assert_called_once_with( api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT) @@ -314,7 +314,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE) - self.assertIs(result_set._retry, patch.return_value) + self.assertIs(result_set._restart, patch.return_value) patch.assert_called_once_with( api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE) diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 43c3c69c1226..6e5e855be24c 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -25,15 +25,16 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, response_iterator, retry=object(), source=None): - return self._getTargetClass()(response_iterator, retry, source=source) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def test_ctor_defaults(self): iterator = _MockCancellableIterator() - retry = object() - streamed = self._make_one(iterator, retry) + restart = object() + streamed = self._make_one(iterator, restart) self.assertIs(streamed._response_iterator, iterator) - self.assertIs(streamed._retry, retry) + self.assertIs(streamed._restart, restart) self.assertIsNone(streamed._source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -43,11 +44,11 @@ def test_ctor_defaults(self): def test_ctor_w_source(self): iterator = _MockCancellableIterator() - retry = object() + restart = object() source = object() - streamed = self._make_one(iterator, retry, source=source) + streamed = self._make_one(iterator, restart, source=source) self.assertIs(streamed._response_iterator, iterator) - self.assertIs(streamed._retry, retry) + self.assertIs(streamed._restart, restart) self.assertIs(streamed._source, source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -924,8 +925,9 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, response_iterator, retry=object(), source=None): - return self._getTargetClass()(response_iterator, retry, source=source) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def _load_json_test(self, test_name): import os From 7b7221e45106d221c62f8b53dfdd6c81603e211a Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Sep 2017 12:22:30 -0400 Subject: [PATCH 06/14] Tests match order of methods. --- spanner/tests/unit/test_streamed.py | 84 ++++++++++++++--------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 6e5e855be24c..1b032982ddc0 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -608,48 +608,6 @@ def test_merge_values_partial_and_filled_plus(self): self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) self.assertEqual(streamed._current_row, VALUES[6:]) - def test_one_or_none_no_value(self): - streamed = self._make_one(_MockCancellableIterator()) - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertIsNone(streamed.one_or_none()) - - def test_one_or_none_single_value(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertEqual(streamed.one_or_none(), 'foo') - - def test_one_or_none_multiple_values(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo', 'bar'] - with self.assertRaises(ValueError): - streamed.one_or_none() - - def test_one_or_none_consumed_stream(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._metadata = object() - with self.assertRaises(RuntimeError): - streamed.one_or_none() - - def test_one_single_value(self): - streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - self.assertEqual(streamed.one(), 'foo') - - def test_one_no_value(self): - from google.cloud import exceptions - - iterator = _MockCancellableIterator(['foo']) - streamed = self._make_one(iterator) - with mock.patch.object(streamed, 'consume_next') as consume_next: - consume_next.side_effect = StopIteration - with self.assertRaises(exceptions.NotFound): - streamed.one() - def test_consume_next_empty(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) @@ -903,6 +861,48 @@ def test___iter___w_existing_rows_read(self): self.assertEqual(streamed._current_row, []) self.assertIsNone(streamed._pending_chunk) + def test_one_single_value(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo'] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one(), 'foo') + + def test_one_no_value(self): + from google.cloud import exceptions + + iterator = _MockCancellableIterator(['foo']) + streamed = self._make_one(iterator) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + with self.assertRaises(exceptions.NotFound): + streamed.one() + + def test_one_or_none_no_value(self): + streamed = self._make_one(_MockCancellableIterator()) + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertIsNone(streamed.one_or_none()) + + def test_one_or_none_single_value(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo'] + with mock.patch.object(streamed, 'consume_next') as consume_next: + consume_next.side_effect = StopIteration + self.assertEqual(streamed.one_or_none(), 'foo') + + def test_one_or_none_multiple_values(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._rows = ['foo', 'bar'] + with self.assertRaises(ValueError): + streamed.one_or_none() + + def test_one_or_none_consumed_stream(self): + streamed = self._make_one(_MockCancellableIterator()) + streamed._metadata = object() + with self.assertRaises(RuntimeError): + streamed.one_or_none() + class _MockCancellableIterator(object): From fc25964f4361c3ec2d2b28647129874e9f55bc80 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Sep 2017 12:32:23 -0400 Subject: [PATCH 07/14] Document/test that 'SRS.consume_next' propagates 'ServiceUnavailable'. --- spanner/google/cloud/spanner/streamed.py | 3 +++ spanner/tests/unit/test_streamed.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index d279b73283ee..b050ac90a574 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -136,6 +136,9 @@ def consume_next(self): """Consume the next partial result set from the stream. Parse the result set into new/existing rows in :attr:`_rows` + + :raises :class:`~google.api.core.exceptions.ServiceUnavailable`: + if the iterator must be restarted. """ response = six.next(self._response_iterator) self._counter += 1 diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 1b032982ddc0..792040ce174e 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -608,6 +608,14 @@ def test_merge_values_partial_and_filled_plus(self): self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) self.assertEqual(streamed._current_row, VALUES[6:]) + def test_consume_next_propagates_unavailable(self): + from google.cloud import exceptions + + iterator = _ServiceUnavailableIterator() + streamed = self._make_one(iterator) + with self.assertRaises(exceptions.ServiceUnavailable): + streamed.consume_next() + def test_consume_next_empty(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) @@ -904,6 +912,17 @@ def test_one_or_none_consumed_stream(self): streamed.one_or_none() +class _ServiceUnavailableIterator(object): + + def next(self): + from google.cloud import exceptions + + raise exceptions.ServiceUnavailable('testing') + + def __next__(self): # pragma: NO COVER Py3k + return self.next() + + class _MockCancellableIterator(object): def __init__(self, *values): From 1f688a6c82101cb1360a76fc4f5dbead3dfa616a Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Sep 2017 12:42:02 -0400 Subject: [PATCH 08/14] In 'SRS.consume_next' don't copy empty 'resume_token'. --- spanner/google/cloud/spanner/streamed.py | 3 ++- spanner/tests/unit/test_streamed.py | 23 +++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index b050ac90a574..6030e51afc01 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -142,7 +142,8 @@ def consume_next(self): """ response = six.next(self._response_iterator) self._counter += 1 - self._resume_token = response.resume_token + if response.resume_token: + self._resume_token = response.resume_token if self._metadata is None: # first response metadata = self._metadata = response.metadata diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 792040ce174e..24f4ca89bc68 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -134,7 +134,7 @@ def _make_result_set_stats(query_plan=None, **kw): @staticmethod def _make_partial_result_set( - values, metadata=None, stats=None, chunked_value=False): + values, metadata=None, stats=None, chunked_value=False, **kw): from google.cloud.proto.spanner.v1.result_set_pb2 import ( PartialResultSet) return PartialResultSet( @@ -142,6 +142,7 @@ def _make_partial_result_set( metadata=metadata, stats=stats, chunked_value=chunked_value, + **kw ) def test_properties_set(self): @@ -641,11 +642,13 @@ def test_consume_next_first_set_partial(self): self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_first_set_partial_existing_txn_id(self): TXN_ID = b'DEADBEEF' + PRIOR_TOKEN = b'B4' + NEW_TOKEN = b'AF2' FIELDS = [ self._make_scalar_field('full_name', 'STRING'), self._make_scalar_field('age', 'INT64'), @@ -655,18 +658,21 @@ def test_consume_next_first_set_partial_existing_txn_id(self): FIELDS, transaction_id=b'') BARE = [u'Phred Phlyntstone', 42] VALUES = [self._make_value(bare) for bare in BARE] - result_set = self._make_partial_result_set(VALUES, metadata=metadata) + result_set = self._make_partial_result_set( + VALUES, metadata=metadata, resume_token=NEW_TOKEN) iterator = _MockCancellableIterator(result_set) source = mock.Mock(_transaction_id=TXN_ID, spec=['_transaction_id']) streamed = self._make_one(iterator, source=source) + streamed._resume_token = PRIOR_TOKEN streamed.consume_next() self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, NEW_TOKEN) self.assertEqual(source._transaction_id, TXN_ID) - def test_consume_next_w_partial_result(self): + def test_consume_next_w_partial_result_no_resume_token(self): + PRIOR_TOKEN = b'DEADBEEF' FIELDS = [ self._make_scalar_field('full_name', 'STRING'), self._make_scalar_field('age', 'INT64'), @@ -678,12 +684,13 @@ def test_consume_next_w_partial_result(self): result_set = self._make_partial_result_set(VALUES, chunked_value=True) iterator = _MockCancellableIterator(result_set) streamed = self._make_one(iterator) + streamed._resume_token = PRIOR_TOKEN streamed._metadata = self._make_result_set_metadata(FIELDS) streamed.consume_next() self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._pending_chunk, VALUES[0]) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, PRIOR_TOKEN) def test_consume_next_w_pending_chunk(self): FIELDS = [ @@ -709,7 +716,7 @@ def test_consume_next_w_pending_chunk(self): ]) self.assertEqual(streamed._current_row, [BARE[6]]) self.assertIsNone(streamed._pending_chunk) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) def test_consume_next_last_set(self): FIELDS = [ @@ -733,7 +740,7 @@ def test_consume_next_last_set(self): self.assertEqual(streamed.rows, [BARE]) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._stats, stats) - self.assertEqual(streamed.resume_token, result_set.resume_token) + self.assertEqual(streamed.resume_token, None) def test_consume_all_empty(self): iterator = _MockCancellableIterator() From db9a2abf32d4b9a0450df36df4185346f1f3742b Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 7 Sep 2017 09:16:01 -0400 Subject: [PATCH 09/14] 'SRS.one_or_none': use 'consume_all'. Rationale: the result set is only supposed to return a single row, so the complicated dance to avoid processing extra rows is not worthwhile. Also, we'd like to have the extra buffer handling in only one place. --- spanner/google/cloud/spanner/streamed.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index 6030e51afc01..fd3cf7317766 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -190,6 +190,7 @@ def one(self): in whole or in part. """ answer = self.one_or_none() + if answer is None: raise exceptions.NotFound('No rows matched the given query.') return answer @@ -207,21 +208,13 @@ def one_or_none(self): raise RuntimeError('Can not call `.one` or `.one_or_none` after ' 'stream consumption has already started.') - # Consume the first result of the stream. - # If there is no first result, then return None. - iterator = iter(self) - try: - answer = next(iterator) - except StopIteration: - return None + self.consume_all() - # Attempt to consume more. This should no-op; if we get additional - # rows, then this is an error case. - try: - next(iterator) + if len(self._rows) > 1: raise ValueError('Expected one result; got more.') - except StopIteration: - return answer + + if self._rows: + return self._rows[0] class Unmergeable(ValueError): From 6ce9864db3419039e1c74106d9215b23ecaf51f3 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 7 Sep 2017 10:00:10 -0400 Subject: [PATCH 10/14] Buffer accumulated rows until new token/last set/EOT. --- spanner/google/cloud/spanner/streamed.py | 47 +++++++++----- spanner/tests/unit/test_streamed.py | 79 +++++++++++++++++++----- 2 files changed, 96 insertions(+), 30 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index fd3cf7317766..7ac5eb53acc8 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -46,7 +46,8 @@ class StreamedResultSet(object): def __init__(self, response_iterator, restart, source=None): self._response_iterator = response_iterator self._restart = restart - self._rows = [] # Fully-processed rows + self._pending_rows = [] # Rows pending new token / EOT + self._complete_rows = [] # Fully-processed rows self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS self._stats = None # Until set from last PRS @@ -62,7 +63,7 @@ def rows(self): :rtype: list of row-data lists. :returns: list of completed row data, from proceesd PRS responses. """ - return self._rows + return self._complete_rows @property def fields(self): @@ -129,22 +130,39 @@ def _merge_values(self, values): field = self.fields[index] self._current_row.append(_parse_value_pb(value, field.type)) if len(self._current_row) == width: - self._rows.append(self._current_row) + self._pending_rows.append(self._current_row) self._current_row = [] + def _flush_pending_rows(self): + """Helper for :meth:`consume_next`.""" + flushed = self._pending_rows[:] + self._pending_rows[:] = () + self._complete_rows.extend(flushed) + def consume_next(self): """Consume the next partial result set from the stream. - Parse the result set into new/existing rows in :attr:`_rows` + Parse the result set into new/existing rows in :attr:`_complete_rows` :raises :class:`~google.api.core.exceptions.ServiceUnavailable`: if the iterator must be restarted. + :raises StopIteration: if the iterator is empty. """ - response = six.next(self._response_iterator) + try: + response = six.next(self._response_iterator) + except StopIteration: + self._flush_pending_rows() + raise + self._counter += 1 if response.resume_token: + self._flush_pending_rows() self._resume_token = response.resume_token + if response.HasField('stats'): # last response + self._flush_pending_rows() + self._stats = response.stats + if self._metadata is None: # first response metadata = self._metadata = response.metadata @@ -152,9 +170,6 @@ def consume_next(self): if source is not None and source._transaction_id is None: source._transaction_id = metadata.transaction.id - if response.HasField('stats'): # last response - self._stats = response.stats - values = list(response.values) if self._pending_chunk is not None: values[0] = self._merge_chunk(values[0]) @@ -173,11 +188,15 @@ def consume_all(self): break def __iter__(self): - iter_rows, self._rows[:] = self._rows[:], () + iter_rows, self._complete_rows[:] = self._complete_rows[:], () while True: if not iter_rows: - self.consume_next() # raises StopIteration - iter_rows, self._rows[:] = self._rows[:], () + try: + self.consume_next() # raises StopIteration + except StopIteration: + if not self._complete_rows: + raise + iter_rows, self._complete_rows[:] = self._complete_rows[:], () while iter_rows: yield iter_rows.pop(0) @@ -210,11 +229,11 @@ def one_or_none(self): self.consume_all() - if len(self._rows) > 1: + if len(self._complete_rows) > 1: raise ValueError('Expected one result; got more.') - if self._rows: - return self._rows[0] + if self._complete_rows: + return self._complete_rows[0] class Unmergeable(ValueError): diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 24f4ca89bc68..9e2e79a13354 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -477,7 +477,7 @@ def test_merge_values_empty_and_empty(self): streamed._metadata = self._make_result_set_metadata(FIELDS) streamed._current_row = [] streamed._merge_values([]) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, []) def test_merge_values_empty_and_partial(self): @@ -493,7 +493,7 @@ def test_merge_values_empty_and_partial(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BARE) def test_merge_values_empty_and_filled(self): @@ -509,7 +509,7 @@ def test_merge_values_empty_and_filled(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, [BARE]) + self.assertEqual(streamed._pending_rows, [BARE]) self.assertEqual(streamed._current_row, []) def test_merge_values_empty_and_filled_plus(self): @@ -529,7 +529,7 @@ def test_merge_values_empty_and_filled_plus(self): VALUES = [self._make_value(bare) for bare in BARE] streamed._current_row = [] streamed._merge_values(VALUES) - self.assertEqual(streamed.rows, [BARE[0:3], BARE[3:6]]) + self.assertEqual(streamed._pending_rows, [BARE[0:3], BARE[3:6]]) self.assertEqual(streamed._current_row, BARE[6:]) def test_merge_values_partial_and_empty(self): @@ -546,7 +546,7 @@ def test_merge_values_partial_and_empty(self): ] streamed._current_row[:] = BEFORE streamed._merge_values([]) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BEFORE) def test_merge_values_partial_and_partial(self): @@ -563,7 +563,7 @@ def test_merge_values_partial_and_partial(self): MERGED = [42] TO_MERGE = [self._make_value(item) for item in MERGED] streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed._current_row, BEFORE + MERGED) def test_merge_values_partial_and_filled(self): @@ -582,7 +582,7 @@ def test_merge_values_partial_and_filled(self): MERGED = [42, True] TO_MERGE = [self._make_value(item) for item in MERGED] streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, [BEFORE + MERGED]) + self.assertEqual(streamed._pending_rows, [BEFORE + MERGED]) self.assertEqual(streamed._current_row, []) def test_merge_values_partial_and_filled_plus(self): @@ -606,9 +606,44 @@ def test_merge_values_partial_and_filled_plus(self): TO_MERGE = [self._make_value(item) for item in MERGED] VALUES = BEFORE + MERGED streamed._merge_values(TO_MERGE) - self.assertEqual(streamed.rows, [VALUES[0:3], VALUES[3:6]]) + self.assertEqual(streamed._pending_rows, [VALUES[0:3], VALUES[3:6]]) self.assertEqual(streamed._current_row, VALUES[6:]) + def test__flush_pending_rows_both_empty(self): + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, []) + + def test__flush_pending_rows_complete_empty(self): + row1, row2 = ('foo',), ('bar',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._pending_rows = [row1, row2] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) + + def test__flush_pending_rows_pending_empty(self): + row1, row2 = ('foo',), ('bar',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._complete_rows = [row1, row2] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) + + def test__flush_pending_rows_neither_empty(self): + row1, row2, row3, row4 = ('foo',), ('bar',), ('baz',), ('spam',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator) + streamed._complete_rows = [row1, row2] + streamed._pending_rows = [row3, row4] + streamed._flush_pending_rows() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2, row3, row4]) + def test_consume_next_propagates_unavailable(self): from google.cloud import exceptions @@ -618,10 +653,14 @@ def test_consume_next_propagates_unavailable(self): streamed.consume_next() def test_consume_next_empty(self): + row1, row2 = ('foo',), ('bar',) iterator = _MockCancellableIterator() streamed = self._make_one(iterator) + streamed._pending_rows = [row1, row2] with self.assertRaises(StopIteration): streamed.consume_next() + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) def test_consume_next_first_set_partial(self): TXN_ID = b'DEADBEEF' @@ -654,6 +693,7 @@ def test_consume_next_first_set_partial_existing_txn_id(self): self._make_scalar_field('age', 'INT64'), self._make_scalar_field('married', 'BOOL'), ] + row1, row2 = ('foo',), ('bar',) metadata = self._make_result_set_metadata( FIELDS, transaction_id=b'') BARE = [u'Phred Phlyntstone', 42] @@ -664,8 +704,10 @@ def test_consume_next_first_set_partial_existing_txn_id(self): source = mock.Mock(_transaction_id=TXN_ID, spec=['_transaction_id']) streamed = self._make_one(iterator, source=source) streamed._resume_token = PRIOR_TOKEN + streamed._pending_rows = [row1, row2] streamed.consume_next() - self.assertEqual(streamed.rows, []) + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed.rows, [row1, row2]) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) self.assertEqual(streamed.resume_token, NEW_TOKEN) @@ -710,7 +752,7 @@ def test_consume_next_w_pending_chunk(self): streamed._metadata = self._make_result_set_metadata(FIELDS) streamed._pending_chunk = self._make_value(u'Phred ') streamed.consume_next() - self.assertEqual(streamed.rows, [ + self.assertEqual(streamed._pending_rows, [ [u'Phred Phlyntstone', BARE[1], BARE[2]], [BARE[3], BARE[4], BARE[5]], ]) @@ -724,6 +766,7 @@ def test_consume_next_last_set(self): self._make_scalar_field('age', 'INT64'), self._make_scalar_field('married', 'BOOL'), ] + row1, row2 = ('foo',), ('bar',) metadata = self._make_result_set_metadata(FIELDS) stats = self._make_result_set_stats( rows_returned="1", @@ -736,8 +779,10 @@ def test_consume_next_last_set(self): iterator = _MockCancellableIterator(result_set) streamed = self._make_one(iterator) streamed._metadata = metadata + streamed._pending_rows = [row1, row2] streamed.consume_next() - self.assertEqual(streamed.rows, [BARE]) + self.assertEqual(streamed.rows, [row1, row2]) + self.assertEqual(streamed._pending_rows, [BARE]) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._stats, stats) self.assertEqual(streamed.resume_token, None) @@ -865,7 +910,7 @@ def test___iter___w_existing_rows_read(self): result_set2 = self._make_partial_result_set(VALUES[4:]) iterator = _MockCancellableIterator(result_set1, result_set2) streamed = self._make_one(iterator) - streamed._rows[:] = ALREADY + streamed._complete_rows[:] = ALREADY found = list(streamed) self.assertEqual(found, ALREADY + [ [BARE[0], BARE[1], BARE[2]], @@ -878,7 +923,7 @@ def test___iter___w_existing_rows_read(self): def test_one_single_value(self): streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] + streamed._complete_rows = ['foo'] with mock.patch.object(streamed, 'consume_next') as consume_next: consume_next.side_effect = StopIteration self.assertEqual(streamed.one(), 'foo') @@ -900,15 +945,17 @@ def test_one_or_none_no_value(self): self.assertIsNone(streamed.one_or_none()) def test_one_or_none_single_value(self): + row = ('foo',) streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo'] + streamed._complete_rows = [row] with mock.patch.object(streamed, 'consume_next') as consume_next: consume_next.side_effect = StopIteration - self.assertEqual(streamed.one_or_none(), 'foo') + self.assertEqual(streamed.one_or_none(), row) def test_one_or_none_multiple_values(self): + row1, row2 = ('foo',), ('bar',) streamed = self._make_one(_MockCancellableIterator()) - streamed._rows = ['foo', 'bar'] + streamed._complete_rows = [row1, row2] with self.assertRaises(ValueError): streamed.one_or_none() From c2356b8920312abf7f314af145d178dbbe92161c Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 7 Sep 2017 10:34:59 -0400 Subject: [PATCH 11/14] Add 'SRS._do_restart' helper method. Resets any 'pending' state and then calls 'self._restart', passing most recent token; resets 'self._reponse_iterator' to the result of that call. --- spanner/google/cloud/spanner/streamed.py | 10 ++++++ spanner/tests/unit/test_streamed.py | 44 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index 7ac5eb53acc8..afcacab53b49 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -139,6 +139,16 @@ def _flush_pending_rows(self): self._pending_rows[:] = () self._complete_rows.extend(flushed) + def _do_restart(self): + """Helper for :meth:`consume_next`.""" + if not self._resume_token: + raise ValueError("No resume token") + + self._pending_chunk = None + self._pending_rows[:] = () + self._current_row[:] = () + self._response_iterator = self._restart(self._resume_token) + def consume_next(self): """Consume the next partial result set from the stream. diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 9e2e79a13354..bfb741537c0d 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -644,6 +644,50 @@ def test__flush_pending_rows_neither_empty(self): self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed.rows, [row1, row2, row3, row4]) + def test__do_restart_no_token(self): + restart = mock.Mock() + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) + + with self.assertRaises(ValueError): + streamed._do_restart() + + restart.assert_not_called() + + def test__do_restart_empty(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + + streamed._do_restart() + + self.assertIs(streamed._response_iterator, restart.return_value) + self.assertIsNone(streamed._pending_chunk) + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed._current_row, []) + restart.assert_called_once_with(TOKEN) + + def test__do_restart_non_empty(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() + row1, row2, row3 = ('foo',), ('bar',), ('baz',) + iterator = _MockCancellableIterator() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + streamed._pending_chunk = self._make_value(u'Phred ') + streamed._current_row = [row1] + streamed._pending_rows = [row2, row3] + + streamed._do_restart() + + self.assertIs(streamed._response_iterator, restart.return_value) + self.assertIsNone(streamed._pending_chunk) + self.assertEqual(streamed._pending_rows, []) + self.assertEqual(streamed._current_row, []) + restart.assert_called_once_with(TOKEN) + def test_consume_next_propagates_unavailable(self): from google.cloud import exceptions From a99285672fc929775dd4c967d08bf4027148776e Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 7 Sep 2017 10:44:12 -0400 Subject: [PATCH 12/14] 'SRS.consume_next': invoke 'self._do_restart' on ServiceUnavailable. Fixes #3775. --- spanner/google/cloud/spanner/streamed.py | 5 +++-- spanner/tests/unit/test_streamed.py | 28 +++++++++++++++++++----- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index afcacab53b49..e4dad5a67c55 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -154,12 +154,13 @@ def consume_next(self): Parse the result set into new/existing rows in :attr:`_complete_rows` - :raises :class:`~google.api.core.exceptions.ServiceUnavailable`: - if the iterator must be restarted. :raises StopIteration: if the iterator is empty. """ try: response = six.next(self._response_iterator) + except exceptions.ServiceUnavailable: + self._do_restart() + return except StopIteration: self._flush_pending_rows() raise diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index bfb741537c0d..f80634309ff1 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -688,21 +688,27 @@ def test__do_restart_non_empty(self): self.assertEqual(streamed._current_row, []) restart.assert_called_once_with(TOKEN) - def test_consume_next_propagates_unavailable(self): - from google.cloud import exceptions - + def test_consume_next_w_service_unavailable(self): + TOKEN = b'DEADBEEF' + restart = mock.Mock() iterator = _ServiceUnavailableIterator() - streamed = self._make_one(iterator) - with self.assertRaises(exceptions.ServiceUnavailable): - streamed.consume_next() + streamed = self._make_one(iterator, restart=restart) + streamed._resume_token = TOKEN + + streamed.consume_next() + + self.assertIs(streamed._response_iterator, restart.return_value) + restart.assert_called_once_with(TOKEN) def test_consume_next_empty(self): row1, row2 = ('foo',), ('bar',) iterator = _MockCancellableIterator() streamed = self._make_one(iterator) streamed._pending_rows = [row1, row2] + with self.assertRaises(StopIteration): streamed.consume_next() + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed.rows, [row1, row2]) @@ -721,7 +727,9 @@ def test_consume_next_first_set_partial(self): iterator = _MockCancellableIterator(result_set) source = mock.Mock(_transaction_id=None, spec=['_transaction_id']) streamed = self._make_one(iterator, source=source) + streamed.consume_next() + self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) @@ -749,7 +757,9 @@ def test_consume_next_first_set_partial_existing_txn_id(self): streamed = self._make_one(iterator, source=source) streamed._resume_token = PRIOR_TOKEN streamed._pending_rows = [row1, row2] + streamed.consume_next() + self.assertEqual(streamed._pending_rows, []) self.assertEqual(streamed.rows, [row1, row2]) self.assertEqual(streamed._current_row, BARE) @@ -772,7 +782,9 @@ def test_consume_next_w_partial_result_no_resume_token(self): streamed = self._make_one(iterator) streamed._resume_token = PRIOR_TOKEN streamed._metadata = self._make_result_set_metadata(FIELDS) + streamed.consume_next() + self.assertEqual(streamed.rows, []) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._pending_chunk, VALUES[0]) @@ -795,7 +807,9 @@ def test_consume_next_w_pending_chunk(self): streamed = self._make_one(iterator) streamed._metadata = self._make_result_set_metadata(FIELDS) streamed._pending_chunk = self._make_value(u'Phred ') + streamed.consume_next() + self.assertEqual(streamed._pending_rows, [ [u'Phred Phlyntstone', BARE[1], BARE[2]], [BARE[3], BARE[4], BARE[5]], @@ -824,7 +838,9 @@ def test_consume_next_last_set(self): streamed = self._make_one(iterator) streamed._metadata = metadata streamed._pending_rows = [row1, row2] + streamed.consume_next() + self.assertEqual(streamed.rows, [row1, row2]) self.assertEqual(streamed._pending_rows, [BARE]) self.assertEqual(streamed._current_row, []) From 7d83816b224874171249b2faef597333f2fabb2f Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 11 Sep 2017 10:38:07 -0400 Subject: [PATCH 13/14] On UNAVAILABLE w/o resume token, restart from scratch. Addresses: https://github.com/GoogleCloudPlatform/google-cloud-python/pull/3930#discussion_r137673641. --- spanner/google/cloud/spanner/streamed.py | 9 +++++---- spanner/tests/unit/test_streamed.py | 15 ++++++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index e4dad5a67c55..3120ced43803 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -141,13 +141,14 @@ def _flush_pending_rows(self): def _do_restart(self): """Helper for :meth:`consume_next`.""" - if not self._resume_token: - raise ValueError("No resume token") - self._pending_chunk = None self._pending_rows[:] = () self._current_row[:] = () - self._response_iterator = self._restart(self._resume_token) + + if self._resume_token: + self._response_iterator = self._restart(self._resume_token) + else: + self._response_iterator = self._restart() def consume_next(self): """Consume the next partial result set from the stream. diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index f80634309ff1..cb266e17d4a0 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -649,10 +649,9 @@ def test__do_restart_no_token(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator, restart=restart) - with self.assertRaises(ValueError): - streamed._do_restart() + streamed._do_restart() - restart.assert_not_called() + restart.assert_called_once_with() def test__do_restart_empty(self): TOKEN = b'DEADBEEF' @@ -688,6 +687,16 @@ def test__do_restart_non_empty(self): self.assertEqual(streamed._current_row, []) restart.assert_called_once_with(TOKEN) + def test_consume_next_w_service_no_token_yet(self): + restart = mock.Mock() + iterator = _ServiceUnavailableIterator() + streamed = self._make_one(iterator, restart=restart) + + streamed.consume_next() + + self.assertIs(streamed._response_iterator, restart.return_value) + restart.assert_called_once_with() + def test_consume_next_w_service_unavailable(self): TOKEN = b'DEADBEEF' restart = mock.Mock() From 8f47fe2d49d6e3b790cdbf59310ec9476e81d538 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Mon, 11 Sep 2017 15:26:36 -0400 Subject: [PATCH 14/14] Curry transaction selector / options for restart of 'read'/'execute_sql'. Addresses: https://github.com/GoogleCloudPlatform/google-cloud-python/pull/3930/files/7d83816b224874171249b2faef597333f2fabb2f#r138150848 https://github.com/GoogleCloudPlatform/google-cloud-python/pull/3930/files/7d83816b224874171249b2faef597333f2fabb2f#r138150912 --- spanner/google/cloud/spanner/snapshot.py | 14 ++++--- spanner/tests/unit/test_snapshot.py | 49 +++++++++++++++--------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 9636bd9762e3..0eb705cd3270 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -89,14 +89,15 @@ def read(self, table, columns, keyset, index='', limit=0): iterator = api.streaming_read( self._session.name, table, columns, keyset.to_pb(), - transaction=transaction, index=index, limit=limit, - options=options) + index=index, limit=limit, + transaction=transaction, options=options) self._read_request_count += 1 restart = functools.partial( api.streaming_read, self._session.name, table, columns, keyset, - index=index, limit=limit) + index=index, limit=limit, + transaction=transaction, options=options) if self._multi_use: return StreamedResultSet(iterator, restart, source=self) @@ -150,14 +151,15 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None): api = database.spanner_api iterator = api.execute_streaming_sql( self._session.name, sql, - transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, options=options) + params=params_pb, param_types=param_types, query_mode=query_mode, + transaction=transaction, options=options) self._read_request_count += 1 restart = functools.partial( api.execute_streaming_sql, self._session.name, sql, - params=params, param_types=param_types, query_mode=query_mode) + params=params, param_types=param_types, query_mode=query_mode, + transaction=transaction, options=options) if self._multi_use: return StreamedResultSet(iterator, restart, source=self) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index dee8fdecefd7..b10368676750 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -162,16 +162,23 @@ def _read_helper(self, multi_use, first=True, count=0): derived._transaction_id = TXN_ID partial_patch = mock.patch('functools.partial') + owp_patch = mock.patch( + 'google.cloud.spanner.snapshot._options_with_prefix') - with partial_patch as patch: - result_set = derived.read( - TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT) + with partial_patch as patch_partial: + with owp_patch as patch_owp: + result_set = derived.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT) - self.assertIs(result_set._restart, patch.return_value) - patch.assert_called_once_with( + self.assertIs(result_set._restart, patch_partial.return_value) + patch_partial.assert_called_once_with( api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT) + transaction=derived._make_txn_selector(), + index=INDEX, limit=LIMIT, + options=patch_owp.return_value, + ) + patch_owp.assert_called_once_with(database.name) self.assertEqual(derived._read_request_count, count + 1) @@ -203,8 +210,7 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) self.assertEqual(resume_token, b'') - self.assertEqual(options.kwargs['metadata'], - [('google-cloud-resource-prefix', database.name)]) + self.assertIs(options, patch_owp.return_value) def test_read_wo_multi_use(self): self._read_helper(multi_use=False) @@ -271,6 +277,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): from google.cloud.proto.spanner.v1.type_pb2 import Type, StructType from google.cloud.proto.spanner.v1.type_pb2 import STRING, INT64 from google.cloud.spanner._helpers import _make_value_pb + from google.cloud.spanner._helpers import _options_with_prefix TXN_ID = b'DEADBEEF' VALUES = [ @@ -308,16 +315,23 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): derived._transaction_id = TXN_ID partial_patch = mock.patch('functools.partial') + owp_patch = mock.patch( + 'google.cloud.spanner.snapshot._options_with_prefix') - with partial_patch as patch: - result_set = derived.execute_sql( - SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE) + with partial_patch as patch_partial: + with owp_patch as patch_owp: + result_set = derived.execute_sql( + SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE) - self.assertIs(result_set._restart, patch.return_value) - patch.assert_called_once_with( + self.assertIs(result_set._restart, patch_partial.return_value) + patch_partial.assert_called_once_with( api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, - params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE) + params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE, + transaction=derived._make_txn_selector(), + options=patch_owp.return_value, + ) + patch_owp.assert_called_once_with(database.name) self.assertEqual(derived._read_request_count, count + 1) @@ -350,8 +364,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertEqual(param_types, PARAM_TYPES) self.assertEqual(query_mode, MODE) self.assertEqual(resume_token, b'') - self.assertEqual(options.kwargs['metadata'], - [('google-cloud-resource-prefix', database.name)]) + self.assertIs(options, patch_owp.return_value) def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False)