From cb548a268bd06753d8971420c3a7cc889d14109c Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 2 Dec 2024 02:20:04 -0800 Subject: [PATCH 01/29] feat(x-goog-spanner-request-id): implement request_id generation and propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generates a request_id that is then injected inside metadata that's sent over to the Cloud Spanner backend. Officially inject the first set of x-goog-spanner-request-id values into header metadata Add request-id interceptor to use in asserting tests Wrap Snapshot methods with x-goog-request-id metadata injector Setup scaffolding for XGoogRequestIdHeader checks Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks Inject header in more Session using spots plus more tests Base for tests with retries on abort More plumbing for Transaction and Database Update unit tests for Transaction Wrap more in Transaction + update tests Update tests Plumb in more tests Update TestDatabase Fixes #1261 --- google/cloud/spanner_v1/client.py | 9 + google/cloud/spanner_v1/database.py | 66 ++- google/cloud/spanner_v1/instance.py | 1 + google/cloud/spanner_v1/pool.py | 6 +- google/cloud/spanner_v1/session.py | 21 +- google/cloud/spanner_v1/snapshot.py | 137 +++++-- .../cloud/spanner_v1/testing/database_test.py | 5 + .../cloud/spanner_v1/testing/interceptors.py | 74 ++++ google/cloud/spanner_v1/transaction.py | 120 ++++-- .../mockserver_tests/mock_server_test_base.py | 4 +- tests/unit/test_atomic_counter.py | 1 + tests/unit/test_database.py | 175 ++++++-- tests/unit/test_request_id_header.py | 387 ++++++++++++++++++ tests/unit/test_snapshot.py | 78 +++- tests/unit/test_spanner.py | 149 +++++++ tests/unit/test_transaction.py | 59 +++ 16 files changed, 1165 insertions(+), 127 deletions(-) create mode 100644 tests/unit/test_request_id_header.py diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index e201f93e9b..ef64d1933e 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -70,6 +70,7 @@ except ImportError: # pragma: NO COVER HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False +from google.cloud.spanner_v1._helpers import AtomicCounter _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -182,6 +183,8 @@ class Client(ClientWithProject): SCOPE = (SPANNER_ADMIN_SCOPE,) """The scopes required for Google Cloud Spanner.""" + NTH_CLIENT = AtomicCounter() + def __init__( self, project=None, @@ -261,6 +264,12 @@ def __init__( "default_transaction_options must be an instance of DefaultTransactionOptions" ) self._default_transaction_options = default_transaction_options + self._nth_client_id = Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() @property def credentials(self): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 03c6e5119f..76a92448d7 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -51,8 +51,10 @@ from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import ( + AtomicCounter, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _metadata_with_request_id, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -151,6 +153,9 @@ class Database(object): _spanner_api: SpannerClient = None + __transport_lock = threading.Lock() + __transports_to_channel_id = dict() + def __init__( self, database_id, @@ -448,6 +453,31 @@ def spanner_api(self): ) return self._spanner_api + @property + def _channel_id(self): + """ + Helper to retrieve the associated channelID for the spanner_api. + This property is paramount to x-goog-spanner-request-id. + """ + with self.__transport_lock: + api = self.spanner_api + channel_id = self.__transports_to_channel_id.get(api._transport, None) + if channel_id is None: + channel_id = len(self.__transports_to_channel_id) + 1 + self.__transports_to_channel_id[api._transport] = channel_id + + return channel_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -703,6 +733,12 @@ def execute_partitioned_dml( _metadata_with_leader_aware_routing(self._route_to_leader_enabled) ) + # Attempt will be incremented inside _restart_on_unavailable. + begin_txn_nth_request = self._next_nth_request + begin_txn_attempt = AtomicCounter(1) + partial_nth_request = self._next_nth_request + partial_attempt = AtomicCounter(0) + def execute_pdml(): with trace_call( "CloudSpanner.Database.execute_partitioned_pdml", @@ -711,7 +747,10 @@ def execute_pdml(): with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") txn = api.begin_transaction( - session=session.name, options=txn_options, metadata=metadata + session=session.name, options=txn_options, + metadata=self.metadata_with_request_id( + begin_txn_nth_request, begin_txn_attempt.value, metadata + ), ) txn_selector = TransactionSelector(id=txn.id) @@ -724,18 +763,25 @@ def execute_pdml(): query_options=query_options, request_options=request_options, ) - method = functools.partial( - api.execute_streaming_sql, - metadata=metadata, - ) + + def wrapped_method(*args, **kwargs): + partial_attempt.increment() + method = functools.partial( + api.execute_streaming_sql, + metadata=self.metadata_with_request_id( + partial_nth_request, partial_attempt.value, metadata + ), + ) + return method(*args, **kwargs) iterator = _restart_on_unavailable( - method=method, + method=wrapped_method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, + attempt=begin_txn_attempt, ) result_set = StreamedResultSet(iterator) @@ -745,6 +791,14 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() + @property + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + def session(self, labels=None, database_role=None): """Factory to create a session for this database. diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index a67e0e630b..06ac1c4593 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -501,6 +501,7 @@ def database( proto_descriptors=proto_descriptors, ) else: + print("enabled interceptors") return TestDatabase( database_id, self, diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0c4dd5a63b..16fd4ec8fb 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -246,6 +246,7 @@ def bind(self, database): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): + attempt = 1 returned_session_count = 0 while not self._sessions.full(): request.session_count = requested_session_count - self._sessions.qsize() @@ -254,9 +255,12 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) + all_metadata = database.metadata_with_request_id( + database._next_nth_request, attempt, metadata + ) resp = api.batch_create_sessions( request=request, - metadata=metadata, + metadata=all_metadata, ) add_span_event( diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index f18ba57582..68bc8e65ab 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -195,7 +195,8 @@ def exists(self): current_span, "Checking if Session exists", {"session.id": self._session_id} ) - api = self._database.spanner_api + database = self._database + api = database.spanner_api metadata = _metadata_with_prefix(self._database.name) if self._database._route_to_leader_enabled: metadata.append( @@ -204,6 +205,10 @@ def exists(self): ) ) + all_metadata = database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ) + observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.GetSession", @@ -212,7 +217,7 @@ def exists(self): metadata=metadata, ) as span, MetricsCapture(): try: - api.get_session(name=self.name, metadata=metadata) + api.get_session(name=self.name, metadata=all_metadata) if span: span.set_attribute("session_found", True) except NotFound: @@ -242,8 +247,11 @@ def delete(self): current_span, "Deleting Session", {"session.id": self._session_id} ) - api = self._database.spanner_api - metadata = _metadata_with_prefix(self._database.name) + database = self._database + api = database.spanner_api + metadata = database.metadata_with_request_id( + database._next_nth_request, 1, _metadata_with_prefix(database.name) + ) observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.DeleteSession", @@ -265,7 +273,10 @@ def ping(self): if self._session_id is None: raise ValueError("Session ID not set by back-end") api = self._database.spanner_api - metadata = _metadata_with_prefix(self._database.name) + database = self._database + metadata = database.metadata_with_request_id( + database._next_nth_request, 1, _metadata_with_prefix(database.name) + ) request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") api.execute_sql(request=request, metadata=metadata) self._last_use_time = datetime.now() diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3b18d2c855..060e5fc91f 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -35,9 +35,11 @@ _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _metadata_with_request_id, _retry, _check_rst_stream_error, _SessionWrapper, + AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -61,6 +63,7 @@ def _restart_on_unavailable( transaction=None, transaction_selector=None, observability_options=None, + attempt=0, ): """Restart iteration after :exc:`.ServiceUnavailable`. @@ -92,6 +95,7 @@ def _restart_on_unavailable( iterator = None while True: + attempt += 1 try: if iterator is None: with trace_call( @@ -329,13 +333,24 @@ def read( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( - api.streaming_read, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter(0) + + def wrapped_restart(*args, **kwargs): + attempt.increment() + all_metadata = database.metadata_with_request_id( + nth_request, attempt.value, metadata + ) + + restart = functools.partial( + api.streaming_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) + return restart(*args, **kwargs) trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) @@ -344,7 +359,7 @@ def read( # lock is added to handle the inline begin for first rpc with self._lock: iterator = _restart_on_unavailable( - restart, + wrapped_restart, request, metadata, f"CloudSpanner.{type(self).__name__}.read", @@ -367,7 +382,7 @@ def read( ) else: iterator = _restart_on_unavailable( - restart, + wrapped_restart, request, metadata, f"CloudSpanner.{type(self).__name__}.read", @@ -562,13 +577,23 @@ def execute_sql( data_boost_enabled=data_boost_enabled, directed_read_options=directed_read_options, ) - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter(0) + + def wrapped_restart(*args, **kwargs): + attempt.increment() + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + retry=retry, + timeout=timeout, + ) + + return restart(*args, **kwargs) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -577,7 +602,7 @@ def execute_sql( # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, @@ -587,7 +612,7 @@ def execute_sql( ) else: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, @@ -718,15 +743,25 @@ def partition_read( observability_options=getattr(database, "observability_options", None), metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.partition_read, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + all_metadata = database.metadata_with_request_id( + nth_request, attempt.value, metadata + ) + method = functools.partial( + api.partition_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) + return method(*args, **kwargs) + response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -822,15 +857,25 @@ def partition_query( observability_options=getattr(database, "observability_options", None), metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.partition_query, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + all_metadata = database.metadata_with_request_id( + nth_request, attempt.value, metadata + ) + method = functools.partial( + api.partition_query, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) + return method(*args, **kwargs) + response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -969,14 +1014,24 @@ def begin(self): observability_options=getattr(database, "observability_options", None), metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_selector.begin, - metadata=metadata, - ) + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + all_metadata = database.metadata_with_request_id( + nth_request, attempt.value, metadata + ) + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_selector.begin, + metadata=all_metadata, + ) + return method(*args, **kwargs) + response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 54afda11e0..80f040d7e0 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -25,6 +25,7 @@ from google.cloud.spanner_v1.testing.interceptors import ( MethodCountInterceptor, MethodAbortInterceptor, + XGoogRequestIDHeaderInterceptor, ) @@ -34,6 +35,8 @@ class TestDatabase(Database): currently, and we don't want to make changes in the Database class for testing purpose as this is a hack to use interceptors in tests.""" + _interceptors = [] + def __init__( self, database_id, @@ -74,6 +77,8 @@ def spanner_api(self): client_options = client._client_options if self._instance.emulator_host is not None: channel = grpc.insecure_channel(self._instance.emulator_host) + self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() + self._interceptors.append(self._x_goog_request_id_interceptor) channel = grpc.intercept_channel(channel, *self._interceptors) transport = SpannerGrpcTransport(channel=channel) self._spanner_api = SpannerClient( diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a8b015a87d..918c7d76b9 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -13,6 +13,8 @@ # limitations under the License. from collections import defaultdict +import threading + from grpc_interceptor import ClientInterceptor from google.api_core.exceptions import Aborted @@ -63,3 +65,75 @@ def reset(self): self._method_to_abort = None self._count = 0 self._connection = None + + +X_GOOG_REQUEST_ID = "x-goog-spanner-request-id" + + +class XGoogRequestIDHeaderInterceptor(ClientInterceptor): + def __init__(self): + self._unary_req_segments = [] + self._stream_req_segments = [] + self.__lock = threading.Lock() + + def intercept(self, method, request_or_iterator, call_details): + metadata = call_details.metadata + x_goog_request_id = None + for key, value in metadata: + if key == X_GOOG_REQUEST_ID: + x_goog_request_id = value + break + + if not x_goog_request_id: + raise Exception( + f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}" + ) + + response_or_iterator = method(request_or_iterator, call_details) + streaming = getattr(response_or_iterator, "__iter__", None) is not None + print( + "intercept got", + x_goog_request_id, + call_details.method, + "streaming", + streaming, + ) + with self.__lock: + if streaming: + self._stream_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) + else: + self._unary_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) + + return response_or_iterator + + @property + def unary_request_ids(self): + return self._unary_req_segments + + @property + def stream_request_ids(self): + return self._stream_req_segments + + def reset(self): + self._stream_req_segments.clear() + self._unary_req_segments.clear() + pass + + +def parse_request_id(request_id_str): + splits = request_id_str.split(".") + version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list( + map(lambda v: int(v), splits) + ) + return ( + version, + rand_process_id, + client_id, + channel_id, + nth_request, + nth_attempt, + ) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 2f52aaa144..8adbf8e108 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -32,6 +32,7 @@ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1 import TransactionOptions +from google.cloud.spanner_v1._helpers import AtomicCounter from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call @@ -181,12 +182,20 @@ def begin(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_options, - metadata=metadata, - ) + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_options, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + ) + return method(*args, **kwargs) def beforeNextRetry(nthRetry, delayInSeconds): add_span_event( @@ -196,7 +205,7 @@ def beforeNextRetry(nthRetry, delayInSeconds): ) response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, beforeNextRetry=beforeNextRetry, ) @@ -217,6 +226,7 @@ def rollback(self): database._route_to_leader_enabled ) ) + observability_options = getattr(database, "observability_options", None) with trace_call( f"CloudSpanner.{type(self).__name__}.rollback", @@ -224,16 +234,26 @@ def rollback(self): observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.rollback, - session=self._session.name, - transaction_id=self._transaction_id, - metadata=metadata, - ) + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.rollback, + session=self._session.name, + transaction_id=self._transaction_id, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + ) + return method(*args, **kwargs) + _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) + self.rolled_back = True del self._session._transaction @@ -306,11 +326,19 @@ def commit( add_span_event(span, "Starting Commit") - method = functools.partial( - api.commit, - request=request, - metadata=metadata, - ) + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.commit, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + ) + return method(*args, **kwargs) def beforeNextRetry(nthRetry, delayInSeconds): add_span_event( @@ -320,7 +348,7 @@ def beforeNextRetry(nthRetry, delayInSeconds): ) response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, beforeNextRetry=beforeNextRetry, ) @@ -469,19 +497,27 @@ def execute_update( last_statement=last_statement, ) - method = functools.partial( - api.execute_sql, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + nth_request = database._next_nth_request + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.execute_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + retry=retry, + timeout=timeout, + ) + return method(*args, **kwargs) if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: response = self._execute_request( - method, + wrapped_method, request, metadata, f"CloudSpanner.{type(self).__name__}.execute_update", @@ -499,7 +535,7 @@ def execute_update( self._transaction_id = response.metadata.transaction.id else: response = self._execute_request( - method, + wrapped_method, request, metadata, f"CloudSpanner.{type(self).__name__}.execute_update", @@ -611,19 +647,27 @@ def batch_update( last_statements=last_statement, ) - method = functools.partial( - api.execute_batch_dml, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + nth_request = database._next_nth_request + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.execute_batch_dml, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + retry=retry, + timeout=timeout, + ) + return method(*args, **kwargs) if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: response = self._execute_request( - method, + wrapped_method, request, metadata, "CloudSpanner.DMLTransaction", @@ -642,7 +686,7 @@ def batch_update( break else: response = self._execute_request( - method, + wrapped_method, request, metadata, "CloudSpanner.DMLTransaction", diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index b332c88d7c..3a826720c1 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -186,6 +186,8 @@ def instance(self) -> Instance: def database(self) -> Database: if self._database is None: self._database = self.instance.database( - "test-database", pool=FixedSizePool(size=10) + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, ) return self._database diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py index 92d10cac79..d1b22cf841 100644 --- a/tests/unit/test_atomic_counter.py +++ b/tests/unit/test_atomic_counter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import random import threading import unittest diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 1afda7f850..a19ff60ccd 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -30,6 +30,11 @@ DirectedReadOptions, DefaultTransactionOptions, ) +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, + _metadata_with_request_id, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID DML_WO_PARAM = """ DELETE FROM citizens @@ -115,7 +120,9 @@ def _make_database_admin_api(): def _make_spanner_api(): from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) + api = mock.create_autospec(SpannerClient, instance=True) + api._transport = "transport" + return api def test_ctor_defaults(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -549,7 +556,9 @@ def test_create_grpc_error(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_create_already_exists(self): @@ -576,7 +585,9 @@ def test_create_already_exists(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_create_instance_not_found(self): @@ -602,7 +613,9 @@ def test_create_instance_not_found(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_create_success(self): @@ -638,7 +651,9 @@ def test_create_success(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_create_success_w_encryption_config_dict(self): @@ -675,7 +690,9 @@ def test_create_success_w_encryption_config_dict(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_create_success_w_proto_descriptors(self): @@ -710,7 +727,9 @@ def test_create_success_w_proto_descriptors(self): api.create_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_exists_grpc_error(self): @@ -728,7 +747,9 @@ def test_exists_grpc_error(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_exists_not_found(self): @@ -745,7 +766,9 @@ def test_exists_not_found(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_exists_success(self): @@ -764,7 +787,9 @@ def test_exists_success(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_reload_grpc_error(self): @@ -782,7 +807,9 @@ def test_reload_grpc_error(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_reload_not_found(self): @@ -800,7 +827,9 @@ def test_reload_not_found(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_reload_success(self): @@ -859,11 +888,15 @@ def test_reload_success(self): api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) api.get_database.assert_called_once_with( name=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_ddl_grpc_error(self): @@ -889,7 +922,9 @@ def test_update_ddl_grpc_error(self): api.update_database_ddl.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_ddl_not_found(self): @@ -915,7 +950,9 @@ def test_update_ddl_not_found(self): api.update_database_ddl.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_ddl(self): @@ -942,7 +979,9 @@ def test_update_ddl(self): api.update_database_ddl.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_ddl_w_operation_id(self): @@ -969,7 +1008,9 @@ def test_update_ddl_w_operation_id(self): api.update_database_ddl.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_success(self): @@ -995,7 +1036,9 @@ def test_update_success(self): api.update_database.assert_called_once_with( database=expected_database, update_mask=field_mask, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_update_ddl_w_proto_descriptors(self): @@ -1023,7 +1066,9 @@ def test_update_ddl_w_proto_descriptors(self): api.update_database_ddl.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_drop_grpc_error(self): @@ -1041,7 +1086,9 @@ def test_drop_grpc_error(self): api.drop_database.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_drop_not_found(self): @@ -1059,7 +1106,9 @@ def test_drop_not_found(self): api.drop_database.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_drop_success(self): @@ -1076,7 +1125,9 @@ def test_drop_success(self): api.drop_database.assert_called_once_with( database=self.DATABASE_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def _execute_partitioned_dml_helper( @@ -1155,6 +1206,10 @@ def _execute_partitioned_dml_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) if retried: @@ -1196,6 +1251,10 @@ def _execute_partitioned_dml_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) if retried: @@ -1490,7 +1549,9 @@ def test_restore_grpc_error(self): api.restore_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_restore_not_found(self): @@ -1516,7 +1577,9 @@ def test_restore_not_found(self): api.restore_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_restore_success(self): @@ -1553,7 +1616,9 @@ def test_restore_success(self): api.restore_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_restore_success_w_encryption_config_dict(self): @@ -1594,7 +1659,9 @@ def test_restore_success_w_encryption_config_dict(self): api.restore_database.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_restore_w_invalid_encryption_config_dict(self): @@ -1741,7 +1808,9 @@ def test_list_database_roles_grpc_error(self): api.list_database_roles.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) def test_list_database_roles_defaults(self): @@ -1762,7 +1831,9 @@ def test_list_database_roles_defaults(self): api.list_database_roles.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ], ) self.assertIsNotNone(resp) @@ -1849,6 +1920,10 @@ def test_context_mgr_success(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -1896,6 +1971,10 @@ def test_context_mgr_w_commit_stats_success(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -1943,6 +2022,10 @@ def test_context_mgr_w_aborted_commit_status(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -3017,6 +3100,10 @@ def test_context_mgr_success(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -3115,6 +3202,8 @@ def _make_database_admin_api(): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__( self, project=TestDatabase.PROJECT_ID, @@ -3135,6 +3224,12 @@ def __init__( self.route_to_leader_enabled = route_to_leader_enabled self.directed_read_options = directed_read_options self.default_transaction_options = default_transaction_options + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object): @@ -3164,6 +3259,28 @@ def __init__(self, name, instance=None): self._directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + @property + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Pool(object): _bound = None diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py new file mode 100644 index 0000000000..35c7de7410 --- /dev/null +++ b/tests/unit/test_request_id_header.py @@ -0,0 +1,387 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import threading + +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, +) +from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor +from google.cloud.spanner_v1 import ( + BatchCreateSessionsRequest, + BeginTransactionRequest, + ExecuteSqlRequest, +) +from google.api_core.exceptions import Aborted +from google.rpc import code_pb2 +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID + + +class TestRequestIDHeader(MockServerTestBase): + def tearDown(self): + self.database._x_goog_request_id_interceptor.reset() + + def test_snapshot_read(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(2, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1), + ) + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_snapshot_read_concurrent(self): + def select1(): + with self.database.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + res_list = [] + for row in rows: + self.assertEqual(1, row[0]) + res_list.append(row) + self.assertEqual(1, len(res_list)) + + n = 10 + threads = [] + for i in range(n): + th = threading.Thread(target=select1, name=f"snapshot-select1-{i}") + th.run() + threads.append(th) + + random.shuffle(threads) + + while True: + n_finished = 0 + for thread in threads: + if thread.is_alive(): + thread.join() + else: + n_finished += 1 + + if n_finished == len(threads): + break + + time.sleep(1) + + requests = self.spanner_service.requests + self.assertEqual(n * 2, len(requests), msg=requests) + + client_id = self.database._nth_client_id + channel_id = self.database._channel_id + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ] + assert got_unary_segments == want_unary_segments + + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ] + assert got_stream_segments == want_stream_segments + + def test_database_run_in_transaction_retries_on_abort(self): + counters = dict(aborted=0) + want_failed_attempts = 2 + + def select_in_txn(txn): + results = txn.execute_sql("select 1") + for row in results: + _ = row + + if counters["aborted"] < want_failed_attempts: + counters["aborted"] += 1 + raise Aborted( + "Thrown from ClientInterceptor for testing", + errors=[FauxCall(code_pb2.ABORTED)], + ) + + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + + self.database.run_in_transaction(select_in_txn) + + def test_database_execute_partitioned_dml_request_id(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + _ = self.database.execute_partitioned_dml("select 1") + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/BeginTransaction", + (1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, 1, 1, 3, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_snapshot_read(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.read("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(2, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + + requests = self.spanner_service.requests + self.assertEqual(n * 2, len(requests), msg=requests) + + client_id = self.database._nth_client_id + channel_id = self.database._channel_id + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ] + assert got_unary_segments == want_unary_segments + + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ] + assert got_stream_segments == want_stream_segments + + def canonicalize_request_id_headers(self): + src = self.database._x_goog_request_id_interceptor + return src._stream_req_segments, src._unary_req_segments + + +class FauxCall: + def __init__(self, code, details="FauxCall"): + self._code = code + self._details = details + + def initial_metadata(self): + return {} + + def trailing_metadata(self): + return {} + + def code(self): + return self._code + + def details(self): + return self._details diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 11fc0135d1..e5254801ce 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -26,6 +26,11 @@ ) from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, + _metadata_with_request_id, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -295,7 +300,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): fail_after=True, error=InternalServerError( "Received unexpected EOS on DATA frame from server" - ) + ), ) after = _MockIterator(*LAST) request = mock.Mock(test="test", spec=["test", "resume_token"]) @@ -467,7 +472,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): fail_after=True, error=InternalServerError( "Received unexpected EOS on DATA frame from server" - ) + ), ) after = _MockIterator(*SECOND) request = mock.Mock(test="test", spec=["test", "resume_token"]) @@ -777,7 +782,13 @@ def _read_helper( ) api.streaming_read.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), + ], retry=retry, timeout=timeout, ) @@ -1026,7 +1037,13 @@ def _execute_sql_helper( ) api.execute_streaming_sql.assert_called_once_with( request=expected_request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), + ], timeout=timeout, retry=retry, ) @@ -1199,6 +1216,10 @@ def _partition_read_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=retry, timeout=timeout, @@ -1378,6 +1399,10 @@ def _partition_query_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=retry, timeout=timeout, @@ -1774,7 +1799,13 @@ def test_begin_ok_exact_staleness(self): api.begin_transaction.assert_called_once_with( session=session.name, options=expected_txn_options, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), + ], ) self.assertSpanAttributes( @@ -1810,7 +1841,13 @@ def test_begin_ok_exact_strong(self): api.begin_transaction.assert_called_once_with( session=session.name, options=expected_txn_options, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), + ], ) self.assertSpanAttributes( @@ -1821,10 +1858,18 @@ def test_begin_ok_exact_strong(self): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object): @@ -1843,6 +1888,27 @@ def __init__(self, directed_read_options=None): def observability_options(self): return dict(db_name=self.name) + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Session(object): def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 8bd95c7228..a67f5ad55c 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -38,9 +38,12 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1._helpers import ( + AtomicCounter, _make_value_pb, _merge_query_options, + _metadata_with_request_id, ) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID import mock @@ -522,6 +525,10 @@ def test_transaction_should_include_begin_with_first_update(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -537,6 +544,10 @@ def test_transaction_should_include_begin_with_first_query(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], timeout=TIMEOUT, retry=RETRY, @@ -554,6 +565,10 @@ def test_transaction_should_include_begin_with_first_read(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -570,6 +585,10 @@ def test_transaction_should_include_begin_with_first_batch_update(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -595,6 +614,10 @@ def test_transaction_should_include_begin_w_exclude_txn_from_change_streams_with metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -639,6 +662,10 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -653,6 +680,10 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -669,6 +700,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_query(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -682,6 +717,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_query(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -698,6 +737,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -711,6 +754,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -732,6 +779,10 @@ def test_transaction_execute_sql_w_directed_read_options(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -755,6 +806,10 @@ def test_transaction_streaming_read_w_directed_read_options(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -771,6 +826,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -782,6 +841,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -798,6 +861,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -810,6 +877,10 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -850,6 +921,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -860,6 +935,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -868,6 +947,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -911,6 +994,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -919,6 +1006,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -929,6 +1020,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -977,6 +1072,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -985,6 +1084,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -995,6 +1098,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -1043,6 +1150,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) req = self._execute_sql_expected_request(database) @@ -1051,6 +1162,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -1061,6 +1176,10 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=RETRY, timeout=TIMEOUT, @@ -1086,12 +1205,20 @@ def test_transaction_should_execute_sql_with_route_to_leader_disabled(self): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object): @@ -1107,6 +1234,28 @@ def __init__(self): self._directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + @property + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Session(object): _transaction = None diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index ddc91ea522..b6d9df5368 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -21,6 +21,11 @@ from google.cloud.spanner_v1 import TypeCode from google.api_core.retry import Retry from google.api_core import gapic_v1 +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, + _metadata_with_request_id, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -197,6 +202,10 @@ def test_begin_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -301,6 +310,10 @@ def test_rollback_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -492,6 +505,10 @@ def _commit_helper( [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) self.assertEqual(actual_request_options, expected_request_options) @@ -666,6 +683,10 @@ def _execute_update_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -859,6 +880,10 @@ def _batch_update_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], retry=retry, timeout=timeout, @@ -974,6 +999,10 @@ def test_context_mgr_success(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", + ), ], ) @@ -1004,11 +1033,19 @@ def test_context_mgr_failure(self): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object): @@ -1024,6 +1061,28 @@ def __init__(self): self._directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + @property + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Session(object): _transaction = None From 105c397ddbf152852e157ed57ba9f1c637534e70 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 20 Dec 2024 00:03:18 -0800 Subject: [PATCH 02/29] More plumbing for Database DDL methods --- google/cloud/spanner_v1/database.py | 53 ++++++++++++++++----- google/cloud/spanner_v1/pool.py | 4 +- google/cloud/spanner_v1/session.py | 26 ++++++++--- tests/unit/test_pool.py | 30 ++++++++++++ tests/unit/test_session.py | 71 +++++++++++++++++++++++++++-- 5 files changed, 161 insertions(+), 23 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 76a92448d7..0e953a8f2d 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -520,7 +520,10 @@ def create(self): database_dialect=self._database_dialect, proto_descriptors=self._proto_descriptors, ) - future = api.create_database(request=request, metadata=metadata) + future = api.create_database( + request=request, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) return future def exists(self): @@ -536,7 +539,12 @@ def exists(self): metadata = _metadata_with_prefix(self.name) try: - api.get_database_ddl(database=self.name, metadata=metadata) + api.get_database_ddl( + database=self.name, + metadata=self.metadata_with_request_id( + self._next_nth_request, 1, metadata + ), + ) except NotFound: return False return True @@ -553,10 +561,16 @@ def reload(self): """ api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - response = api.get_database_ddl(database=self.name, metadata=metadata) + response = api.get_database_ddl( + database=self.name, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) self._ddl_statements = tuple(response.statements) self._proto_descriptors = response.proto_descriptors - response = api.get_database(name=self.name, metadata=metadata) + response = api.get_database( + name=self.name, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) self._state = DatabasePB.State(response.state) self._create_time = response.create_time self._restore_info = response.restore_info @@ -601,7 +615,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None): proto_descriptors=proto_descriptors, ) - future = api.update_database_ddl(request=request, metadata=metadata) + future = api.update_database_ddl( + request=request, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) return future def update(self, fields): @@ -639,7 +656,9 @@ def update(self, fields): metadata = _metadata_with_prefix(self.name) future = api.update_database( - database=database_pb, update_mask=field_mask, metadata=metadata + database=database_pb, + update_mask=field_mask, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) return future @@ -652,7 +671,10 @@ def drop(self): """ api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - api.drop_database(database=self.name, metadata=metadata) + api.drop_database( + database=self.name, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) def execute_partitioned_dml( self, @@ -1019,7 +1041,7 @@ def restore(self, source): ) future = api.restore_database( request=request, - metadata=metadata, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) return future @@ -1088,7 +1110,10 @@ def list_database_roles(self, page_size=None): parent=self.name, page_size=page_size, ) - return api.list_database_roles(request=request, metadata=metadata) + return api.list_database_roles( + request=request, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) def table(self, table_id): """Factory to create a table object within this database. @@ -1172,7 +1197,10 @@ def get_iam_policy(self, policy_version=None): requested_policy_version=policy_version ), ) - response = api.get_iam_policy(request=request, metadata=metadata) + response = api.get_iam_policy( + request=request, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) return response def set_iam_policy(self, policy): @@ -1194,7 +1222,10 @@ def set_iam_policy(self, policy): resource=self.name, policy=policy, ) - response = api.set_iam_policy(request=request, metadata=metadata) + response = api.set_iam_policy( + request=request, + metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), + ) return response @property diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 16fd4ec8fb..516b953276 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -565,7 +565,9 @@ def bind(self, database): while returned_session_count < self.size: resp = api.batch_create_sessions( request=request, - metadata=metadata, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ), ) add_span_event( diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 68bc8e65ab..fd2e1be2b3 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -170,7 +170,9 @@ def create(self): ), MetricsCapture(): session_pb = api.create_session( request=request, - metadata=metadata, + metadata=self._database.metadata_with_request_id( + self._database._next_nth_request, 1, metadata + ), ) self._session_id = session_pb.name.split("/")[-1] @@ -263,7 +265,12 @@ def delete(self): observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - api.delete_session(name=self.name, metadata=metadata) + api.delete_session( + name=self.name, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ), + ) def ping(self): """Ping the session to keep it alive by executing "SELECT 1". @@ -272,13 +279,18 @@ def ping(self): """ if self._session_id is None: raise ValueError("Session ID not set by back-end") - api = self._database.spanner_api database = self._database - metadata = database.metadata_with_request_id( - database._next_nth_request, 1, _metadata_with_prefix(database.name) - ) + api = database.spanner_api + database = self._database request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") - api.execute_sql(request=request, metadata=metadata) + api.execute_sql( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, + 1, + _metadata_with_prefix(database.name), + ), + ) self._last_use_time = datetime.now() def snapshot(self, **kw): diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index a9593b3651..69529da7ab 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -19,6 +19,11 @@ from datetime import datetime, timedelta import mock +from google.cloud.spanner_v1._helpers import ( + _metadata_with_request_id, + AtomicCounter, +) + from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from tests._helpers import ( OpenTelemetryBase, @@ -1193,6 +1198,9 @@ def session_id(self): class _Database(object): + NTH_REQUEST = AtomicCounter() + NTH_CLIENT_ID = AtomicCounter() + def __init__(self, name): self.name = name self._sessions = [] @@ -1247,6 +1255,28 @@ def session(self, **kwargs): def observability_options(self): return dict(db_name=self.name) + @property + def _next_nth_request(self): + return self.NTH_REQUEST.increment() + + @property + def _nth_client_id(self): + return self.NTH_CLIENT_ID.increment() + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + client_id = self._nth_client_id + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Queue(object): _size = 1 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f5f7039b9..1ba58af3e3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -49,6 +49,11 @@ from google.protobuf.struct_pb2 import Struct, Value from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1 import DefaultTransactionOptions +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1._helpers import ( + _metadata_with_request_id, + AtomicCounter, +) def _make_rpc_error(error_cls, trailing_metadata=None): @@ -95,7 +100,7 @@ def _make_database( database.database_role = database_role database._route_to_leader_enabled = True database.default_transaction_options = default_transaction_options - + database.NTH_CLIENT = AtomicCounter() return database @staticmethod @@ -191,6 +196,10 @@ def test_create_w_database_role(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -226,6 +235,10 @@ def test_create_session_span_annotations(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -253,6 +266,10 @@ def test_create_wo_database_role(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -281,6 +298,10 @@ def test_create_ok(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -311,6 +332,10 @@ def test_create_w_labels(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -486,7 +511,13 @@ def test_ping_hit(self): gax_api.execute_sql.assert_called_once_with( request=request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) def test_ping_miss(self): @@ -507,7 +538,13 @@ def test_ping_miss(self): gax_api.execute_sql.assert_called_once_with( request=request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) def test_ping_error(self): @@ -528,7 +565,13 @@ def test_ping_error(self): gax_api.execute_sql.assert_called_once_with( request=request, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) def test_delete_wo_session_id(self): @@ -1579,6 +1622,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -1631,6 +1678,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) request = CommitRequest( @@ -1644,6 +1695,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -1717,6 +1772,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) ] @@ -1736,6 +1795,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) ] From 60f1b71894dac4398cf2d39c6393e8279ec8fa5d Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 20 Dec 2024 05:15:47 -0800 Subject: [PATCH 03/29] Update test_spannner.test_transaction tests --- tests/unit/test_spanner.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index a67f5ad55c..143d88b41e 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -682,7 +682,7 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -719,7 +719,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_query(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -756,7 +756,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -843,7 +843,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], retry=RETRY, @@ -879,7 +879,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], retry=RETRY, @@ -937,7 +937,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -949,7 +949,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.3.1", ), ], retry=RETRY, @@ -996,7 +996,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.3.1", ), ], ) @@ -1022,7 +1022,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], retry=RETRY, @@ -1074,7 +1074,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.3.1", ), ], ) @@ -1100,7 +1100,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], retry=RETRY, @@ -1152,7 +1152,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.3.1", ), ], ) @@ -1178,7 +1178,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", ), ], retry=RETRY, @@ -1198,7 +1198,13 @@ def test_transaction_should_execute_sql_with_route_to_leader_disabled(self): api.execute_streaming_sql.assert_called_once_with( request=self._execute_sql_expected_request(database=database), - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], timeout=TIMEOUT, retry=RETRY, ) From 53653d14c79dfee889dc0774882809de71fb4d2c Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 20 Dec 2024 06:23:50 -0800 Subject: [PATCH 04/29] Update test_session tests --- google/cloud/spanner_v1/database.py | 1 - google/cloud/spanner_v1/transaction.py | 1 + tests/unit/test_session.py | 94 +++++++++++++++++++++++--- tests/unit/test_spanner.py | 2 +- 4 files changed, 86 insertions(+), 12 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 0e953a8f2d..d0354d0240 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -469,7 +469,6 @@ def _channel_id(self): return channel_id def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 8adbf8e108..c7e286d2c4 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -23,6 +23,7 @@ _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _metadata_with_request_id, _retry, _check_rst_stream_error, _merge_Transaction_Options, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1ba58af3e3..b07fccf1c4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -100,7 +100,20 @@ def _make_database( database.database_role = database_role database._route_to_leader_enabled = True database.default_transaction_options = default_transaction_options - database.NTH_CLIENT = AtomicCounter() + nth_client_id = AtomicCounter(1) + database.NTH_CLIENT = nth_client_id + next_nth_request = AtomicCounter(1) + + def metadata_with_request_id(nth_request, nth_attempt, prior_metadata=[]): + return _metadata_with_request_id( + nth_client_id.value, + 1, + next_nth_request.increment(), + nth_attempt, + prior_metadata, + ) + + database.metadata_with_request_id = metadata_with_request_id return database @staticmethod @@ -1043,6 +1056,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) @@ -1093,10 +1110,25 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], - ) - ] - * 2, + ), + mock.call( + session=self.SESSION_NAME, + options=expected_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.2", + ), + ], + ), + ], ) request = CommitRequest( session=self.SESSION_NAME, @@ -1112,10 +1144,24 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], - ) - ] - * 2, + ), + mock.call( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.2", + ), + ], + ), + ], ) def test_run_in_transaction_w_abort_w_retry_metadata(self): @@ -1410,6 +1456,10 @@ def _time(_results=[1, 2, 4, 8]): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) ] @@ -1429,6 +1479,10 @@ def _time(_results=[1, 2, 4, 8]): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], ) ] @@ -1481,6 +1535,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) request = CommitRequest( @@ -1495,6 +1553,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], ) database.logger.info.assert_called_once_with( @@ -1543,6 +1605,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) request = CommitRequest( @@ -1557,6 +1623,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], ) database.logger.info.assert_not_called() @@ -1609,6 +1679,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) request = CommitRequest( @@ -1624,7 +1698,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", ), ], ) @@ -1680,7 +1754,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1697,7 +1771,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", ), ], ) diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 143d88b41e..b978791847 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1203,7 +1203,7 @@ def test_transaction_should_execute_sql_with_route_to_leader_disabled(self): ( "x-goog-spanner-request-id", f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", - ) + ), ], timeout=TIMEOUT, retry=RETRY, From 4afa2cecf3a0910b9a480b718a529092020e232f Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 24 Dec 2024 17:55:56 -0800 Subject: [PATCH 05/29] Update tests --- google/cloud/spanner_v1/database.py | 17 +++++++++++++++-- google/cloud/spanner_v1/instance.py | 1 - google/cloud/spanner_v1/pool.py | 2 +- google/cloud/spanner_v1/snapshot.py | 3 --- google/cloud/spanner_v1/testing/interceptors.py | 7 ------- google/cloud/spanner_v1/testing/mock_spanner.py | 6 +----- google/cloud/spanner_v1/transaction.py | 1 - tests/unit/test_atomic_counter.py | 1 - tests/unit/test_database.py | 1 - tests/unit/test_pool.py | 1 - tests/unit/test_request_id_header.py | 9 +++------ tests/unit/test_session.py | 16 ++++++++++++++++ tests/unit/test_snapshot.py | 1 - tests/unit/test_spanner.py | 1 - tests/unit/test_transaction.py | 1 - 15 files changed, 36 insertions(+), 32 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index d0354d0240..2b914f4d58 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -754,10 +754,10 @@ def execute_partitioned_dml( _metadata_with_leader_aware_routing(self._route_to_leader_enabled) ) - # Attempt will be incremented inside _restart_on_unavailable. begin_txn_nth_request = self._next_nth_request - begin_txn_attempt = AtomicCounter(1) + begin_txn_attempt = AtomicCounter(0) partial_nth_request = self._next_nth_request + # partial_attempt will be incremented inside _restart_on_unavailable. partial_attempt = AtomicCounter(0) def execute_pdml(): @@ -767,6 +767,7 @@ def execute_pdml(): ) as span, MetricsCapture(): with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") + begin_txn_attempt.increment() txn = api.begin_transaction( session=session.name, options=txn_options, metadata=self.metadata_with_request_id( @@ -804,6 +805,18 @@ def wrapped_method(*args, **kwargs): observability_options=self.observability_options, attempt=begin_txn_attempt, ) +<<<<<<< HEAD +======= + return method(*args, **kwargs) + + iterator = _restart_on_unavailable( + method=wrapped_method, + trace_name="CloudSpanner.ExecuteStreamingSql", + request=request, + transaction_selector=txn_selector, + observability_options=self.observability_options, + ) +>>>>>>> 54df502... Update tests result_set = StreamedResultSet(iterator) list(result_set) # consume all partials diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 06ac1c4593..a67e0e630b 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -501,7 +501,6 @@ def database( proto_descriptors=proto_descriptors, ) else: - print("enabled interceptors") return TestDatabase( database_id, self, diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 516b953276..4c4451d6fd 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -256,7 +256,7 @@ def bind(self, database): span_event_attributes, ) all_metadata = database.metadata_with_request_id( - database._next_nth_request, attempt, metadata + database._next_nth_request, 1, metadata ) resp = api.batch_create_sessions( request=request, diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 060e5fc91f..ca809f1509 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -35,7 +35,6 @@ _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, - _metadata_with_request_id, _retry, _check_rst_stream_error, _SessionWrapper, @@ -63,7 +62,6 @@ def _restart_on_unavailable( transaction=None, transaction_selector=None, observability_options=None, - attempt=0, ): """Restart iteration after :exc:`.ServiceUnavailable`. @@ -95,7 +93,6 @@ def _restart_on_unavailable( iterator = None while True: - attempt += 1 try: if iterator is None: with trace_call( diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index 918c7d76b9..4cd3abd306 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -91,13 +91,6 @@ def intercept(self, method, request_or_iterator, call_details): response_or_iterator = method(request_or_iterator, call_details) streaming = getattr(response_or_iterator, "__iter__", None) is not None - print( - "intercept got", - x_goog_request_id, - call_details.method, - "streaming", - streaming, - ) with self.__lock: if streaming: self._stream_req_segments.append( diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index f60dbbe72a..13ead85b13 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -22,8 +22,6 @@ from google.cloud.spanner_v1 import ( TransactionOptions, ResultSetMetadata, - ExecuteSqlRequest, - ExecuteBatchDmlRequest, ) from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc @@ -186,9 +184,7 @@ def BeginTransaction(self, request, context): self._requests.append(request) return self.__create_transaction(request.session, request.options) - def __maybe_create_transaction( - self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest - ): + def __maybe_create_transaction(self, request): started_transaction = None if not request.transaction.begin == TransactionOptions(): started_transaction = self.__create_transaction( diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index c7e286d2c4..8adbf8e108 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -23,7 +23,6 @@ _merge_query_options, _metadata_with_prefix, _metadata_with_leader_aware_routing, - _metadata_with_request_id, _retry, _check_rst_stream_error, _merge_Transaction_Options, diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py index d1b22cf841..92d10cac79 100644 --- a/tests/unit/test_atomic_counter.py +++ b/tests/unit/test_atomic_counter.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import random import threading import unittest diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index a19ff60ccd..8bd4ec0d92 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3268,7 +3268,6 @@ def _nth_client_id(self): return self._instance._client._nth_client_id def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 69529da7ab..98bfef8f6e 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -1264,7 +1264,6 @@ def _nth_client_id(self): return self.NTH_CLIENT_ID.increment() def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index 35c7de7410..e399593d9c 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -19,7 +19,6 @@ MockServerTestBase, add_select1_result, ) -from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, BeginTransactionRequest, @@ -99,8 +98,6 @@ def select1(): if n_finished == len(threads): break - time.sleep(1) - requests = self.spanner_service.requests self.assertEqual(n * 2, len(requests), msg=requests) @@ -252,7 +249,7 @@ def test_database_execute_partitioned_dml_request_id(self): assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments - def test_snapshot_read(self): + def test_snapshot_read_with_request_ids(self): add_select1_result() if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors @@ -269,8 +266,8 @@ def test_snapshot_read(self): self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) - requests = self.spanner_service.requests - self.assertEqual(n * 2, len(requests), msg=requests) + # requests = self.spanner_service.requests + # self.assertEqual(n * 2, len(requests), msg=requests) client_id = self.database._nth_client_id channel_id = self.database._channel_id diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b07fccf1c4..1360491dc1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -992,6 +992,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) request = CommitRequest( @@ -1005,6 +1009,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], ) @@ -1308,6 +1316,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) request = CommitRequest( @@ -1321,6 +1333,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index e5254801ce..c30d831782 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1896,7 +1896,6 @@ def _nth_client_id(self): return self._instance._client._nth_client_id def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index b978791847..6834d8022d 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1249,7 +1249,6 @@ def _nth_client_id(self): return self._instance._client._nth_client_id def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index b6d9df5368..64fafcae46 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -1070,7 +1070,6 @@ def _nth_client_id(self): return self._instance._client._nth_client_id def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - client_id = self._nth_client_id return _metadata_with_request_id( self._nth_client_id, self._channel_id, From 3402751a4e7da60d6f554252d257199479e382cc Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 24 Dec 2024 19:33:56 -0800 Subject: [PATCH 06/29] Update propagation of changes --- google/cloud/spanner_v1/client.py | 4 +- tests/unit/test_session.py | 225 ++++++++++++++++++++++++------ 2 files changed, 186 insertions(+), 43 deletions(-) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index ef64d1933e..f536484c12 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -25,6 +25,7 @@ """ import grpc import os +import sys import warnings from google.api_core.gapic_v1 import client_info @@ -265,10 +266,11 @@ def __init__( ) self._default_transaction_options = default_transaction_options self._nth_client_id = Client.NTH_CLIENT.increment() - self._nth_request = AtomicCounter() + self._nth_request = AtomicCounter(0) @property def _next_nth_request(self): + print("next_nth_request called by", sys._getframe().f_back.f_code.co_name) return self._nth_request.increment() @property diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1360491dc1..8f4d379032 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -102,7 +102,7 @@ def _make_database( database.default_transaction_options = default_transaction_options nth_client_id = AtomicCounter(1) database.NTH_CLIENT = nth_client_id - next_nth_request = AtomicCounter(1) + next_nth_request = AtomicCounter(0) def metadata_with_request_id(nth_request, nth_attempt, prior_metadata=[]): return _metadata_with_request_id( @@ -396,6 +396,10 @@ def test_exists_hit(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -424,6 +428,10 @@ def test_exists_hit_wo_span(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -444,6 +452,10 @@ def test_exists_miss(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -471,6 +483,10 @@ def test_exists_miss_wo_span(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -492,6 +508,10 @@ def test_exists_error(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -608,7 +628,13 @@ def test_delete_hit(self): gax_api.delete_session.assert_called_once_with( name=self.SESSION_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) attrs = {"session.id": session._session_id, "session.name": session.name} @@ -631,7 +657,13 @@ def test_delete_miss(self): gax_api.delete_session.assert_called_once_with( name=self.SESSION_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) attrs = {"session.id": session._session_id, "session.name": session.name} @@ -656,7 +688,13 @@ def test_delete_error(self): gax_api.delete_session.assert_called_once_with( name=self.SESSION_NAME, - metadata=[("google-cloud-resource-prefix", database.name)], + metadata=[ + ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], ) attrs = {"session.id": session._session_id, "session.name": session.name} @@ -994,7 +1032,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1011,7 +1049,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1066,7 +1104,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1120,7 +1158,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ), @@ -1132,7 +1170,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.2", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", ), ], ), @@ -1154,7 +1192,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ), @@ -1165,7 +1203,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.2", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.4.1", ), ], ), @@ -1232,10 +1270,25 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], + ), + mock.call( + session=self.SESSION_NAME, + options=expected_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), ], - ) - ] - * 2, + ), + ], ) request = CommitRequest( session=self.SESSION_NAME, @@ -1251,10 +1304,24 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), + ], + ), + mock.call( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.4.1", + ), ], - ) - ] - * 2, + ), + ], ) def test_run_in_transaction_w_callback_raises_abort_wo_metadata(self): @@ -1318,7 +1385,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1335,7 +1402,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1400,6 +1467,10 @@ def _time(_results=[1, 1.5]): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) request = CommitRequest( @@ -1413,6 +1484,10 @@ def _time(_results=[1, 1.5]): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + ), ], ) @@ -1463,6 +1538,7 @@ def _time(_results=[1, 2, 4, 8]): self.assertEqual(kw, {}) expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) + print("gax_api", gax_api.begin_transaction.call_args_list[2]) self.assertEqual( gax_api.begin_transaction.call_args_list, [ @@ -1474,12 +1550,35 @@ def _time(_results=[1, 2, 4, 8]): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), + ], + ), + mock.call( + session=self.SESSION_NAME, + options=expected_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", ), ], - ) - ] - * 3, + ), + mock.call( + session=self.SESSION_NAME, + options=expected_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.5.1", + ), + ], + ), + ], ) request = CommitRequest( session=self.SESSION_NAME, @@ -1497,12 +1596,33 @@ def _time(_results=[1, 2, 4, 8]): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], - ) - ] - * 3, + ), + mock.call( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.4.1", + ), + ], + ), + mock.call( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.6.1", + ), + ], + ), + ], ) def test_run_in_transaction_w_commit_stats_success(self): @@ -1553,7 +1673,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1571,7 +1691,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1623,7 +1743,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1641,7 +1761,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1697,7 +1817,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1714,7 +1834,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1770,7 +1890,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], ) @@ -1787,7 +1907,7 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], ) @@ -1867,9 +1987,20 @@ def unit_of_work(txn, *args, **kw): f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", ), ], - ) - ] - * 2, + ), + mock.call( + session=self.SESSION_NAME, + options=expected_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.3.1", + ), + ], + ), + ], ) request = CommitRequest( session=self.SESSION_NAME, @@ -1887,12 +2018,22 @@ def unit_of_work(txn, *args, **kw): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.2.1", ), ], - ) - ] - * 2, + ), + mock.call( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.4.1", + ), + ], + ), + ], ) def test_run_in_transaction_w_isolation_level_at_request(self): From 2c4d32d51fe61b6d9e6b231fac84ff3a72534ba0 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 25 Dec 2024 04:15:03 -0800 Subject: [PATCH 07/29] Plumb test_context_manager --- google/cloud/spanner_v1/_helpers.py | 4 ++ google/cloud/spanner_v1/batch.py | 43 ++++++++++++++----- google/cloud/spanner_v1/client.py | 2 - google/cloud/spanner_v1/database.py | 4 +- .../mockserver_tests/mock_server_test_base.py | 2 + tests/unit/test_batch.py | 35 +++++++++++++++ tests/unit/test_database.py | 8 +++- 7 files changed, 83 insertions(+), 15 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index d1f64db2d8..0e72bdaad4 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -688,6 +688,10 @@ def __radd__(self, n): """ return self.__add__(n) + def reset(self): + with self.__lock: + self.__value = 0 + def _metadata_with_request_id(*args, **kwargs): return with_request_id(*args, **kwargs) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 39e29d4d41..9ef0b1f45c 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -26,6 +26,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _merge_Transaction_Options, + AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1 import RequestOptions @@ -249,11 +250,22 @@ def commit( observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.commit, - request=request, - metadata=metadata, - ) + attempt = AtomicCounter(0) + next_nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + all_metadata = database.metadata_with_request_id( + next_nth_request, + attempt.increment(), + metadata, + ) + method = functools.partial( + api.commit, + request=request, + metadata=all_metadata, + ) + return method(*args, **kwargs) + deadline = time.time() + kwargs.get( "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS ) @@ -372,11 +384,22 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - method = functools.partial( - api.batch_write, - request=request, - metadata=metadata, - ) + attempt = AtomicCounter(0) + next_nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + all_metadata = database.metadata_with_request_id( + next_nth_request, + attempt.increment(), + metadata, + ) + method = functools.partial( + api.batch_write, + request=request, + metadata=all_metadata, + ) + return method(*args, **kwargs) + response = _retry( method, allowed_exceptions={ diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index f536484c12..ac176ba6f2 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -25,7 +25,6 @@ """ import grpc import os -import sys import warnings from google.api_core.gapic_v1 import client_info @@ -270,7 +269,6 @@ def __init__( @property def _next_nth_request(self): - print("next_nth_request called by", sys._getframe().f_back.f_code.co_name) return self._nth_request.increment() @property diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 2b914f4d58..047b4c403a 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -827,7 +827,9 @@ def wrapped_method(*args, **kwargs): @property def _next_nth_request(self): - return self._instance._client._next_nth_request + if self._instance and self._instance._client: + return self._instance._client._next_nth_request + return 1 @property def _nth_client_id(self): diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 3a826720c1..474d078052 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -20,6 +20,7 @@ start_mock_server, SpannerServicer, ) +from google.cloud.spanner_v1.client import Client import google.cloud.spanner_v1.types.type as spanner_type import google.cloud.spanner_v1.types.result_set as result_set from google.api_core.client_options import ClientOptions @@ -154,6 +155,7 @@ def teardown_class(cls): if MockServerTestBase.server is not None: MockServerTestBase.server.stop(grace=None) MockServerTestBase.server = None + Client.NTH_CLIENT.reset() def setup_method(self, *args, **kwargs): self._client = None diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2cea740ab6..5dfc0e9ac7 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -37,6 +37,10 @@ from google.cloud.spanner_v1.keyset import KeySet from google.rpc.status_pb2 import Status +from google.cloud.spanner_v1._helpers import ( + _metadata_with_request_id, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -457,6 +461,10 @@ def test_context_mgr_success(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Database.NTH_CLIENT}.1.1.1", + ), ], ) self.assertEqual(request_options, RequestOptions()) @@ -639,12 +647,39 @@ def session_id(self): class _Database(object): + name = "testing" + _route_to_leader_enabled = True + NTH_CLIENT = 1 + def __init__(self, enable_end_to_end_tracing=False): self.name = "testing" self._route_to_leader_enabled = True if enable_end_to_end_tracing: self.observability_options = dict(enable_end_to_end_tracing=True) self.default_transaction_options = DefaultTransactionOptions() + self._nth_request = 0 + + @property + def _next_nth_request(self): + self._nth_request += 1 + return self._nth_request + + @property + def _nth_client_id(self): + return 1 + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 class _FauxSpannerAPI: diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 8bd4ec0d92..6750989d4b 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3261,11 +3261,15 @@ def __init__(self, name, instance=None): @property def _next_nth_request(self): - return self._instance._client._next_nth_request + if self._instance and self._instance._client: + return self._instance._client._next_nth_request + return 1 @property def _nth_client_id(self): - return self._instance._client._nth_client_id + if self._instance and self._instance._client: + return self._instance._client._nth_client_id + return 1 def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): return _metadata_with_request_id( From 7ea1b79cfec9366066f2893813973cace998cd21 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 25 Dec 2024 05:31:08 -0800 Subject: [PATCH 08/29] Fix more tests --- tests/mockserver_tests/mock_server_test_base.py | 3 +++ tests/unit/test_database.py | 8 ++++---- tests/unit/test_request_id_header.py | 14 +++++++++----- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 474d078052..c07907f476 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -155,6 +155,9 @@ def teardown_class(cls): if MockServerTestBase.server is not None: MockServerTestBase.server.stop(grace=None) MockServerTestBase.server = None + self.reset() + + def reset(self): Client.NTH_CLIENT.reset() def setup_method(self, *args, **kwargs): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 6750989d4b..699ccdd0a2 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1922,7 +1922,7 @@ def test_context_mgr_success(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -1973,7 +1973,7 @@ def test_context_mgr_w_commit_stats_success(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -2024,7 +2024,7 @@ def test_context_mgr_w_aborted_commit_status(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -3102,7 +3102,7 @@ def test_context_mgr_success(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index e399593d9c..280b8b24cf 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -50,18 +50,20 @@ def test_snapshot_read(self): self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id # Now ensure monotonicity of the received request-id segments. got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() want_unary_segments = [ ( "/google.spanner.v1.Spanner/BatchCreateSessions", - (1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1), + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), ) ] want_stream_segments = [ ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1), + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), ) ] @@ -229,20 +231,22 @@ def test_database_execute_partitioned_dml_request_id(self): # Now ensure monotonicity of the received request-id segments. got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id want_unary_segments = [ ( "/google.spanner.v1.Spanner/BatchCreateSessions", - (1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1), + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), ), ( "/google.spanner.v1.Spanner/BeginTransaction", - (1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1), + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), ), ] want_stream_segments = [ ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, 1, 1, 3, 1), + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 3, 1), ) ] From 0f1ecd3204e1d0b47c7d3abfe6fba6270d3c0423 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 25 Dec 2024 05:58:47 -0800 Subject: [PATCH 09/29] More test plumbing --- tests/unit/test_atomic_counter.py | 1 + tests/unit/test_batch.py | 9 ++ tests/unit/test_database.py | 3 +- tests/unit/test_request_id_header.py | 127 ++------------------------- 4 files changed, 20 insertions(+), 120 deletions(-) diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py index 92d10cac79..e8d8b6b7ce 100644 --- a/tests/unit/test_atomic_counter.py +++ b/tests/unit/test_atomic_counter.py @@ -15,6 +15,7 @@ import random import threading import unittest + from google.cloud.spanner_v1._helpers import AtomicCounter diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 5dfc0e9ac7..56aae12854 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -253,6 +253,10 @@ def test_commit_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ), ], ) self.assertEqual(request_options, RequestOptions()) @@ -582,6 +586,10 @@ def _test_batch_write_with_request_options( expected_metadata = [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ), ] if enable_end_to_end_tracing and ot_helpers.HAS_OPENTELEMETRY_INSTALLED: @@ -595,6 +603,7 @@ def _test_batch_write_with_request_options( filtered_metadata = [item for item in metadata if item[0] != "traceparent"] self.assertEqual(filtered_metadata, expected_metadata) + if request_options is None: expected_request_options = RequestOptions() elif type(request_options) is dict: diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 699ccdd0a2..66cd763cab 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1208,7 +1208,7 @@ def _execute_partitioned_dml_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ), ], ) @@ -3226,6 +3226,7 @@ def __init__( self.default_transaction_options = default_transaction_options self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() + self.credentials = {} @property def _next_nth_request(self): diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index 280b8b24cf..a49d0521ce 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -15,25 +15,26 @@ import random import threading -from tests.mockserver_tests.mock_server_test_base import ( - MockServerTestBase, - add_select1_result, -) +from google.api_core.exceptions import Aborted +from google.rpc import code_pb2 + from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, BeginTransactionRequest, ExecuteSqlRequest, ) -from google.api_core.exceptions import Aborted -from google.rpc import code_pb2 from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, +) class TestRequestIDHeader(MockServerTestBase): def tearDown(self): self.database._x_goog_request_id_interceptor.reset() - def test_snapshot_read(self): + def test_snapshot_execute_sql(self): add_select1_result() if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors @@ -253,118 +254,6 @@ def test_database_execute_partitioned_dml_request_id(self): assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments - def test_snapshot_read_with_request_ids(self): - add_select1_result() - if not getattr(self.database, "_interceptors", None): - self.database._interceptors = MockServerTestBase._interceptors - with self.database.snapshot() as snapshot: - results = snapshot.read("select 1") - result_list = [] - for row in results: - result_list.append(row) - self.assertEqual(1, row[0]) - self.assertEqual(1, len(result_list)) - - requests = self.spanner_service.requests - self.assertEqual(2, len(requests), msg=requests) - self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) - self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) - - # requests = self.spanner_service.requests - # self.assertEqual(n * 2, len(requests), msg=requests) - - client_id = self.database._nth_client_id - channel_id = self.database._channel_id - got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() - - want_unary_segments = [ - ( - "/google.spanner.v1.Spanner/BatchCreateSessions", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), - ), - ] - assert got_unary_segments == want_unary_segments - - want_stream_segments = [ - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), - ), - ] - assert got_stream_segments == want_stream_segments - def canonicalize_request_id_headers(self): src = self.database._x_goog_request_id_interceptor return src._stream_req_segments, src._unary_req_segments From 3532936722ddbc30cfab1466ba71158bdccfe118 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 27 Dec 2024 06:56:35 -0800 Subject: [PATCH 10/29] Infer database._channel_id only once along with spanner_api --- google/cloud/spanner_v1/database.py | 25 ++- .../mockserver_tests/mock_server_test_base.py | 5 +- tests/unit/test_database.py | 154 ++++++++++++++++-- 3 files changed, 152 insertions(+), 32 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 047b4c403a..e6f655d7b1 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -193,6 +193,7 @@ def __init__( self._instance._client.default_transaction_options ) self._proto_descriptors = proto_descriptors + self._channel_id = 0 # It'll be created when _spanner_api is created. if pool is None: pool = BurstyPool(database_role=database_role) @@ -451,22 +452,16 @@ def spanner_api(self): client_info=client_info, client_options=client_options, ) - return self._spanner_api - @property - def _channel_id(self): - """ - Helper to retrieve the associated channelID for the spanner_api. - This property is paramount to x-goog-spanner-request-id. - """ - with self.__transport_lock: - api = self.spanner_api - channel_id = self.__transports_to_channel_id.get(api._transport, None) - if channel_id is None: - channel_id = len(self.__transports_to_channel_id) + 1 - self.__transports_to_channel_id[api._transport] = channel_id - - return channel_id + with self.__transport_lock: + transport = self._spanner_api._transport + channel_id = self.__transports_to_channel_id.get(transport, None) + if channel_id is None: + channel_id = len(self.__transports_to_channel_id) + 1 + self.__transports_to_channel_id[transport] = channel_id + self._channel_id = channel_id + + return self._spanner_api def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): return _metadata_with_request_id( diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index c07907f476..fec056b523 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -154,11 +154,8 @@ def setup_class(cls): def teardown_class(cls): if MockServerTestBase.server is not None: MockServerTestBase.server.stop(grace=None) + Client.NTH_CLIENT.reset() MockServerTestBase.server = None - self.reset() - - def reset(self): - Client.NTH_CLIENT.reset() def setup_method(self, *args, **kwargs): self._client = None diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 66cd763cab..c38996044e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -558,6 +558,10 @@ def test_create_grpc_error(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -587,6 +591,10 @@ def test_create_already_exists(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -615,6 +623,10 @@ def test_create_instance_not_found(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -653,6 +665,10 @@ def test_create_success(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -692,6 +708,10 @@ def test_create_success_w_encryption_config_dict(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -729,6 +749,10 @@ def test_create_success_w_proto_descriptors(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -749,6 +773,10 @@ def test_exists_grpc_error(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -768,6 +796,10 @@ def test_exists_not_found(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -789,6 +821,10 @@ def test_exists_success(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -809,6 +845,10 @@ def test_reload_grpc_error(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -829,6 +869,10 @@ def test_reload_not_found(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -890,12 +934,20 @@ def test_reload_success(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) api.get_database.assert_called_once_with( name=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + ), ], ) @@ -924,6 +976,10 @@ def test_update_ddl_grpc_error(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -952,6 +1008,10 @@ def test_update_ddl_not_found(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -981,6 +1041,10 @@ def test_update_ddl(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1010,6 +1074,10 @@ def test_update_ddl_w_operation_id(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1038,6 +1106,10 @@ def test_update_success(self): update_mask=field_mask, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1068,6 +1140,10 @@ def test_update_ddl_w_proto_descriptors(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1088,6 +1164,10 @@ def test_drop_grpc_error(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1108,6 +1188,10 @@ def test_drop_not_found(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1127,6 +1211,10 @@ def test_drop_success(self): database=self.DATABASE_NAME, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1200,22 +1288,34 @@ def _execute_partitioned_dml_helper( exclude_txn_from_change_streams=exclude_txn_from_change_streams, ) - api.begin_transaction.assert_called_with( - session=session.name, - options=txn_options, - metadata=[ - ("google-cloud-resource-prefix", database.name), - ("x-goog-spanner-route-to-leader", "true"), - ( - "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", - ), - ], - ) if retried: self.assertEqual(api.begin_transaction.call_count, 2) + api.begin_transaction.assert_called_with( + session=session.name, + options=txn_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.2", + ), + ], + ) else: self.assertEqual(api.begin_transaction.call_count, 1) + api.begin_transaction.assert_called_with( + session=session.name, + options=txn_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), + ], + ) if params: expected_params = Struct( @@ -1253,7 +1353,7 @@ def _execute_partitioned_dml_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.{database._channel_id}.2.1", ), ], ) @@ -1275,6 +1375,10 @@ def _execute_partitioned_dml_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.2", + ), ], ) self.assertEqual(api.execute_streaming_sql.call_count, 2) @@ -1551,6 +1655,10 @@ def test_restore_grpc_error(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1579,6 +1687,10 @@ def test_restore_not_found(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1618,6 +1730,10 @@ def test_restore_success(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1661,6 +1777,10 @@ def test_restore_success_w_encryption_config_dict(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1810,6 +1930,10 @@ def test_list_database_roles_grpc_error(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) @@ -1833,6 +1957,10 @@ def test_list_database_roles_defaults(self): request=expected_request, metadata=[ ("google-cloud-resource-prefix", database.name), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) self.assertIsNotNone(resp) From 24f1bd47d89dee1b37e2938c148b29851fbb7378 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 27 Dec 2024 07:11:46 -0800 Subject: [PATCH 11/29] Update batch tests --- tests/unit/test_batch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 56aae12854..5888f3a2ae 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -355,6 +355,10 @@ def _test_commit_with_options( [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), ], ) self.assertEqual(actual_request_options, expected_request_options) From ba39d46b08fe35b50995a5862a957b099525e9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 27 Dec 2024 20:25:48 +0100 Subject: [PATCH 12/29] test: add tests for retries and move mock server test to correct directory This commit contains a couple of changes: 1. Moves the request-ID tests using the mock server to the mockserver tests directory. 2. Adds tests for retries of RPCs. These currently fail, as the request ID is not updated correctly when a retry is executed for a unary RPC. Also, the retry test for streaming RPCs fails, but this is due to an existing bug in the client lib. --- .../cloud/spanner_v1/testing/mock_spanner.py | 1 + noxfile.py | 2 +- .../mockserver_tests/mock_server_test_base.py | 20 ++ .../test_request_id_header.py | 171 +++++++++++------- 4 files changed, 124 insertions(+), 70 deletions(-) rename tests/{unit => mockserver_tests}/test_request_id_header.py (66%) diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index 13ead85b13..f8971a6098 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -105,6 +105,7 @@ def CreateSession(self, request, context): def BatchCreateSessions(self, request, context): self._requests.append(request) + self.mock_spanner.pop_error(context) sessions = [] for i in range(request.session_count): sessions.append( diff --git a/noxfile.py b/noxfile.py index cb683afd7e..0d6aad5bbd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,7 @@ ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.8" +DEFAULT_PYTHON_VERSION = "3.12" DEFAULT_MOCK_SERVER_TESTS_PYTHON_VERSION = "3.12" UNIT_TEST_PYTHON_VERSIONS: List[str] = [ diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index fec056b523..24bbac0861 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -57,6 +57,26 @@ def aborted_status() -> _Status: ) return status +# Creates an UNAVAILABLE status with the smallest possible retry delay. +def unavailable_status() -> _Status: + error = status_pb2.Status( + code=code_pb2.UNAVAILABLE, + message="Service unavailable.", + ) + retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) + status = _Status( + code=code_to_grpc_status_code(error.code), + details=error.message, + trailing_metadata=( + ("grpc-status-details-bin", error.SerializeToString()), + ( + "google.rpc.retryinfo-bin", + retry_info.SerializeToString(), + ), + ), + ) + return status + # Creates an UNAVAILABLE status with the smallest possible retry delay. def unavailable_status() -> _Status: diff --git a/tests/unit/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py similarity index 66% rename from tests/unit/test_request_id_header.py rename to tests/mockserver_tests/test_request_id_header.py index a49d0521ce..494a4f3879 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -15,18 +15,16 @@ import random import threading -from google.api_core.exceptions import Aborted -from google.rpc import code_pb2 - from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, BeginTransactionRequest, ExecuteSqlRequest, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, - add_select1_result, + add_select1_result, aborted_status, add_error, unavailable_status, ) @@ -102,7 +100,7 @@ def select1(): break requests = self.spanner_service.requests - self.assertEqual(n * 2, len(requests), msg=requests) + self.assertEqual(n + 1, len(requests), msg=requests) client_id = self.database._nth_client_id channel_id = self.database._channel_id @@ -113,42 +111,6 @@ def select1(): "/google.spanner.v1.Spanner/BatchCreateSessions", (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), - ), ] assert got_unary_segments == want_unary_segments @@ -159,39 +121,39 @@ def select1(): ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), ), ] assert got_stream_segments == want_stream_segments @@ -207,10 +169,7 @@ def select_in_txn(txn): if counters["aborted"] < want_failed_attempts: counters["aborted"] += 1 - raise Aborted( - "Thrown from ClientInterceptor for testing", - errors=[FauxCall(code_pb2.ABORTED)], - ) + add_error(SpannerServicer.Commit.__name__, aborted_status()) add_select1_result() if not getattr(self.database, "_interceptors", None): @@ -254,24 +213,98 @@ def test_database_execute_partitioned_dml_request_id(self): assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments - def canonicalize_request_id_headers(self): - src = self.database._x_goog_request_id_interceptor - return src._stream_req_segments, src._unary_req_segments + def test_unary_retryable_error(self): + add_select1_result() + add_error(SpannerServicer.BatchCreateSessions.__name__, unavailable_status()) + + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 2), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ) + ] + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments -class FauxCall: - def __init__(self, code, details="FauxCall"): - self._code = code - self._details = details + def test_streaming_retryable_error(self): + add_select1_result() + # TODO: UNAVAILABLE errors are not correctly handled by the client lib. + # This is probably the reason behind + # https://github.com/googleapis/python-spanner/issues/1150. + # The fix + add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) - def initial_metadata(self): - return {} + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) - def trailing_metadata(self): - return {} + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) - def code(self): - return self._code + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 2), + ), + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments - def details(self): - return self._details + def canonicalize_request_id_headers(self): + src = self.database._x_goog_request_id_interceptor + return src._stream_req_segments, src._unary_req_segments From 04ec941da11f890dcdb1350906e2a5f4c4bcc474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 27 Dec 2024 20:28:25 +0100 Subject: [PATCH 13/29] fix: revert default Python version --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 0d6aad5bbd..cb683afd7e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,7 @@ ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.12" +DEFAULT_PYTHON_VERSION = "3.8" DEFAULT_MOCK_SERVER_TESTS_PYTHON_VERSION = "3.12" UNIT_TEST_PYTHON_VERSIONS: List[str] = [ From d4bf747e0dfb4a1528495ecc8c5f92ac4e8761d7 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 3 Jan 2025 22:04:58 -0800 Subject: [PATCH 14/29] Fix discrepancy with api.batch_create_sessions automatically retrying assuming idempotency with headers too --- google/cloud/spanner_v1/pool.py | 79 +++++++++++++++---- google/cloud/spanner_v1/snapshot.py | 5 +- .../cloud/spanner_v1/testing/database_test.py | 5 ++ .../cloud/spanner_v1/testing/interceptors.py | 3 +- .../mockserver_tests/mock_server_test_base.py | 1 + .../test_request_id_header.py | 21 +++-- 6 files changed, 88 insertions(+), 26 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 4c4451d6fd..c35442e28f 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -15,10 +15,12 @@ """Pools managing shared Session objects.""" import datetime +import random import queue import time from google.cloud.exceptions import NotFound +from google.api_core.exceptions import ServiceUnavailable from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( @@ -255,13 +257,23 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - all_metadata = database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ) - resp = api.batch_create_sessions( - request=request, - metadata=all_metadata, - ) + nth_req = database._next_nth_request + + def create_sessions(attempt): + all_metadata = database.metadata_with_request_id( + nth_req, attempt, metadata + ) + return api.batch_create_sessions( + request=request, + metadata=all_metadata, + # Manually passing retry=None because otherwise any + # UNAVAILABLE retry will be retried without replenishing + # the metadata, hence this allows us to manually update + # the metadata using retry_on_unavailable. + retry=None, + ) + + resp = retry_on_unavailable(create_sessions) add_span_event( span, @@ -561,14 +573,27 @@ def bind(self, database): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - returned_session_count = 0 - while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), - ) + created_session_count = 0 + while created_session_count < self.size: + nth_req = database._next_nth_request + + def create_sessions(attempt): + all_metadata = database.metadata_with_request_id( + nth_req, attempt, metadata + ) + return api.batch_create_sessions( + request=request, + metadata=all_metadata, + # Manually passing retry=None because otherwise any + # UNAVAILABLE retry will be retried without replenishing + # the metadata, hence this allows us to manually update + # the metadata using retry_on_unavailable. + # TODO: Figure out how to intercept and monkey patch the internals + # of the gRPC transport. + retry=None, + ) + + resp = retry_on_unavailable(create_sessions) add_span_event( span, @@ -577,13 +602,14 @@ def bind(self, database): for session_pb in resp.session: session = self._new_session() - returned_session_count += 1 session._session_id = session_pb.name.split("/")[-1] self.put(session) + created_session_count += len(resp.session) + add_span_event( span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", + f"Requested for {requested_session_count} sessions, returned {created_session_count}", span_event_attributes, ) @@ -806,3 +832,22 @@ def __enter__(self): def __exit__(self, *ignored): self._pool.put(self._session) + + +def retry_on_unavailable(fn, max=6): + """ + Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions, + each time passing in the iteration's ordinal number to signal + the nth attempt. It retries with exponential backoff with jitter. + """ + last_exc = None + for i in range(max): + try: + return fn(i + 1) + except ServiceUnavailable as exc: + last_exc = exc + time.sleep(i**2 + random.random()) + except: + raise + + raise last_exc diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ca809f1509..90d1ed306c 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -364,6 +364,7 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, + attempt=attempt, ) self._read_request_count += 1 if self._multi_use: @@ -387,6 +388,7 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, + attempt=attempt, ) self._read_request_count += 1 @@ -579,12 +581,11 @@ def execute_sql( attempt = AtomicCounter(0) def wrapped_restart(*args, **kwargs): - attempt.increment() restart = functools.partial( api.execute_streaming_sql, request=request, metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata + nth_request, attempt.increment(), metadata ), retry=retry, timeout=timeout, diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 80f040d7e0..4a6e94c88b 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -79,6 +79,7 @@ def spanner_api(self): channel = grpc.insecure_channel(self._instance.emulator_host) self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() self._interceptors.append(self._x_goog_request_id_interceptor) + # print("self._interceptors", self._interceptors) channel = grpc.intercept_channel(channel, *self._interceptors) transport = SpannerGrpcTransport(channel=channel) self._spanner_api = SpannerClient( @@ -115,3 +116,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials): client_options=client_options, transport=transport, ) + + def reset(self): + if self._x_goog_request_id_interceptor: + self._x_goog_request_id_interceptor.reset() diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index 4cd3abd306..a1ab53af40 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -90,7 +90,9 @@ def intercept(self, method, request_or_iterator, call_details): ) response_or_iterator = method(request_or_iterator, call_details) + print("call_details", call_details, "\n", response_or_iterator) streaming = getattr(response_or_iterator, "__iter__", None) is not None + print("x_append", call_details.method, x_goog_request_id) with self.__lock: if streaming: self._stream_req_segments.append( @@ -114,7 +116,6 @@ def stream_request_ids(self): def reset(self): self._stream_req_segments.clear() self._unary_req_segments.clear() - pass def parse_request_id(request_id_str): diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 24bbac0861..2f89415b55 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -57,6 +57,7 @@ def aborted_status() -> _Status: ) return status + # Creates an UNAVAILABLE status with the smallest possible retry delay. def unavailable_status() -> _Status: error = status_pb2.Status( diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 494a4f3879..306d2cbf93 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -24,7 +24,10 @@ from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, - add_select1_result, aborted_status, add_error, unavailable_status, + add_select1_result, + aborted_status, + add_error, + unavailable_status, ) @@ -70,6 +73,13 @@ def test_snapshot_execute_sql(self): assert got_stream_segments == want_stream_segments def test_snapshot_read_concurrent(self): + # Trigger BatchCreateSessions firstly. + with self.database.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + for row in rows: + _ = row + + # The other requests can then proceed. def select1(): with self.database.snapshot() as snapshot: rows = snapshot.execute_sql("select 1") @@ -100,7 +110,7 @@ def select1(): break requests = self.spanner_service.requests - self.assertEqual(n + 1, len(requests), msg=requests) + self.assertEqual(2 + n * 2, len(requests), msg=requests) client_id = self.database._nth_client_id channel_id = self.database._channel_id @@ -112,6 +122,7 @@ def select1(): (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), ), ] + print("got_unary", got_unary_segments) assert got_unary_segments == want_unary_segments want_stream_segments = [ @@ -254,15 +265,13 @@ def test_unary_retryable_error(self): ) ] + print("got_unary_segments", got_unary_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments def test_streaming_retryable_error(self): add_select1_result() - # TODO: UNAVAILABLE errors are not correctly handled by the client lib. - # This is probably the reason behind - # https://github.com/googleapis/python-spanner/issues/1150. - # The fix + add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) if not getattr(self.database, "_interceptors", None): From 9ce98c3e8d7c940cf8faf1667a6040bf7644aba9 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 3 Jan 2025 22:58:56 -0800 Subject: [PATCH 15/29] Take into account current behavior of /GetSession /BatchCreateSession in tests --- .../cloud/spanner_v1/testing/database_test.py | 1 - .../cloud/spanner_v1/testing/interceptors.py | 2 - .../test_request_id_header.py | 73 +++++++++++++++---- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 4a6e94c88b..5af89fea42 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -79,7 +79,6 @@ def spanner_api(self): channel = grpc.insecure_channel(self._instance.emulator_host) self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() self._interceptors.append(self._x_goog_request_id_interceptor) - # print("self._interceptors", self._interceptors) channel = grpc.intercept_channel(channel, *self._interceptors) transport = SpannerGrpcTransport(channel=channel) self._spanner_api = SpannerClient( diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a1ab53af40..4fe4ed147d 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -90,9 +90,7 @@ def intercept(self, method, request_or_iterator, call_details): ) response_or_iterator = method(request_or_iterator, call_details) - print("call_details", call_details, "\n", response_or_iterator) streaming = getattr(response_or_iterator, "__iter__", None) is not None - print("x_append", call_details.method, x_goog_request_id) with self.__lock: if streaming: self._stream_req_segments.append( diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 306d2cbf93..d0b4ae1ec3 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -73,15 +73,16 @@ def test_snapshot_execute_sql(self): assert got_stream_segments == want_stream_segments def test_snapshot_read_concurrent(self): + db = self.database # Trigger BatchCreateSessions firstly. - with self.database.snapshot() as snapshot: + with db.snapshot() as snapshot: rows = snapshot.execute_sql("select 1") for row in rows: _ = row # The other requests can then proceed. def select1(): - with self.database.snapshot() as snapshot: + with db.snapshot() as snapshot: rows = snapshot.execute_sql("select 1") res_list = [] for row in rows: @@ -112,8 +113,8 @@ def select1(): requests = self.spanner_service.requests self.assertEqual(2 + n * 2, len(requests), msg=requests) - client_id = self.database._nth_client_id - channel_id = self.database._channel_id + client_id = db._nth_client_id + channel_id = db._channel_id got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() want_unary_segments = [ @@ -121,8 +122,47 @@ def select1(): "/google.spanner.v1.Spanner/BatchCreateSessions", (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 21, 1), + ), ] - print("got_unary", got_unary_segments) assert got_unary_segments == want_unary_segments want_stream_segments = [ @@ -132,39 +172,43 @@ def select1(): ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), ), ( "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 22, 1), ), ] assert got_stream_segments == want_stream_segments @@ -265,7 +309,6 @@ def test_unary_retryable_error(self): ) ] - print("got_unary_segments", got_unary_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments From 049508216379e543fd1d51a94a068a8008761d6f Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 14 Jan 2025 23:32:00 -0800 Subject: [PATCH 16/29] Implement interceptor to wrap and increase x-goog-spanner-request-id attempts per retry This monkey patches SpannerClient methods to have an interceptor that increases the attempts per retry. The prelude though is that any callers to it must pass in the attempt value 0 so that each pass through will correctly increase the attempt field's value. --- google/cloud/spanner_v1/_helpers.py | 33 ++++++- google/cloud/spanner_v1/batch.py | 46 ++++----- google/cloud/spanner_v1/database.py | 74 +++++++++------ google/cloud/spanner_v1/pool.py | 12 --- google/cloud/spanner_v1/snapshot.py | 140 ++++++++++++---------------- tests/unit/test_snapshot.py | 1 + 6 files changed, 159 insertions(+), 147 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 0e72bdaad4..e76c9c19e3 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -33,7 +33,7 @@ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1 import TransactionOptions -from google.cloud.spanner_v1.request_id_header import with_request_id +from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id from google.rpc.error_details_pb2 import RetryInfo try: @@ -45,6 +45,7 @@ HAS_OPENTELEMETRY_INSTALLED = False from typing import List, Tuple import random +from typing import Callable # Validation error messages NUMERIC_MAX_SCALE_ERR_MSG = ( @@ -730,3 +731,33 @@ def _merge_Transaction_Options( # Convert protobuf object back into a TransactionOptions instance return TransactionOptions(merged_pb) + + +class InterceptingHeaderInjector: + def __init__(self, original_callable: Callable): + self._original_callable = original_callable + + def __call__(self, *args, **kwargs): + metadata = kwargs.get("metadata", []) + # Find all the headers that match the x-goog-spanner-request-id + # header an on each retry increment the value. + all_metadata = [] + for key, value in metadata: + if key is REQ_ID_HEADER_KEY: + # Otherwise now increment the count for the attempt number. + splits = value.split(".") + attempt_plus_one = int(splits[-1]) + 1 + splits[-1] = str(attempt_plus_one) + value_before = value + value = ".".join(splits) + print("incrementing value on retry from", value_before, "to", value) + + all_metadata.append( + ( + key, + value, + ) + ) + + kwargs["metadata"] = all_metadata + return self._original_callable(*args, **kwargs) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 9ef0b1f45c..811055a256 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -253,18 +253,16 @@ def commit( attempt = AtomicCounter(0) next_nth_request = database._next_nth_request - def wrapped_method(*args, **kwargs): - all_metadata = database.metadata_with_request_id( - next_nth_request, - attempt.increment(), - metadata, - ) - method = functools.partial( - api.commit, - request=request, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + next_nth_request, + attempt.increment(), + metadata, + ) + method = functools.partial( + api.commit, + request=request, + metadata=all_metadata, + ) deadline = time.time() + kwargs.get( "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS @@ -384,21 +382,17 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - attempt = AtomicCounter(0) next_nth_request = database._next_nth_request - - def wrapped_method(*args, **kwargs): - all_metadata = database.metadata_with_request_id( - next_nth_request, - attempt.increment(), - metadata, - ) - method = functools.partial( - api.batch_write, - request=request, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + next_nth_request, + 0, + metadata, + ) + method = functools.partial( + api.batch_write, + request=request, + metadata=all_metadata, + ) response = _retry( method, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index e6f655d7b1..71eaa547e3 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,6 +55,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, + InterceptingHeaderInjector, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -432,6 +433,43 @@ def logger(self): @property def spanner_api(self): + """Helper for session-related API calls.""" + api = self._generate_spanner_api() + if not api: + return api + + # Now wrap each method's __call__ method with our wrapped one. + # This is how to deal with the fact that there are no proper gRPC + # interceptors for Python hence the remedy is to replace callables + # with our custom wrapper. + attrs = dir(api) + for attr_name in attrs: + mangled = attr_name.startswith("__") + if mangled: + continue + + non_public = attr_name.startswith("_") + if non_public: + continue + + attr = getattr(api, attr_name) + callable_attr = callable(attr) + if callable_attr is None: + continue + + # We should only be looking at bound methods to SpannerClient + # as those are the RPC invoking methods that need to be wrapped + + is_method = type(attr).__name__ == "method" + if not is_method: + continue + + print("attr_name", attr_name, "callable_attr", attr) + setattr(api, attr_name, InterceptingHeaderInjector(attr)) + + return api + + def _generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client_info = self._instance._client._client_info @@ -762,11 +800,11 @@ def execute_pdml(): ) as span, MetricsCapture(): with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") - begin_txn_attempt.increment() txn = api.begin_transaction( - session=session.name, options=txn_options, + session=session.name, + options=txn_options, metadata=self.metadata_with_request_id( - begin_txn_nth_request, begin_txn_attempt.value, metadata + begin_txn_nth_request, begin_txn_attempt.increment(), metadata ), ) @@ -781,37 +819,21 @@ def execute_pdml(): request_options=request_options, ) - def wrapped_method(*args, **kwargs): - partial_attempt.increment() - method = functools.partial( - api.execute_streaming_sql, - metadata=self.metadata_with_request_id( - partial_nth_request, partial_attempt.value, metadata - ), - ) - return method(*args, **kwargs) + method = functools.partial( + api.execute_streaming_sql, + metadata=self.metadata_with_request_id( + partial_nth_request, partial_attempt.increment(), metadata + ), + ) iterator = _restart_on_unavailable( - method=wrapped_method, + method=method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, - attempt=begin_txn_attempt, ) -<<<<<<< HEAD -======= - return method(*args, **kwargs) - - iterator = _restart_on_unavailable( - method=wrapped_method, - trace_name="CloudSpanner.ExecuteStreamingSql", - request=request, - transaction_selector=txn_selector, - observability_options=self.observability_options, - ) ->>>>>>> 54df502... Update tests result_set = StreamedResultSet(iterator) list(result_set) # consume all partials diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index c35442e28f..ae95bef646 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -266,11 +266,6 @@ def create_sessions(attempt): return api.batch_create_sessions( request=request, metadata=all_metadata, - # Manually passing retry=None because otherwise any - # UNAVAILABLE retry will be retried without replenishing - # the metadata, hence this allows us to manually update - # the metadata using retry_on_unavailable. - retry=None, ) resp = retry_on_unavailable(create_sessions) @@ -584,13 +579,6 @@ def create_sessions(attempt): return api.batch_create_sessions( request=request, metadata=all_metadata, - # Manually passing retry=None because otherwise any - # UNAVAILABLE retry will be retried without replenishing - # the metadata, hence this allows us to manually update - # the metadata using retry_on_unavailable. - # TODO: Figure out how to intercept and monkey patch the internals - # of the gRPC transport. - retry=None, ) resp = retry_on_unavailable(create_sessions) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 90d1ed306c..ab516dfa03 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -332,22 +332,17 @@ def read( ) nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_restart(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) - restart = functools.partial( - api.streaming_read, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return restart(*args, **kwargs) + restart = functools.partial( + api.streaming_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) @@ -356,7 +351,7 @@ def wrapped_restart(*args, **kwargs): # lock is added to handle the inline begin for first rpc with self._lock: iterator = _restart_on_unavailable( - wrapped_restart, + restart, request, metadata, f"CloudSpanner.{type(self).__name__}.read", @@ -364,7 +359,6 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, - attempt=attempt, ) self._read_request_count += 1 if self._multi_use: @@ -380,7 +374,7 @@ def wrapped_restart(*args, **kwargs): ) else: iterator = _restart_on_unavailable( - wrapped_restart, + restart, request, metadata, f"CloudSpanner.{type(self).__name__}.read", @@ -388,7 +382,6 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, - attempt=attempt, ) self._read_request_count += 1 @@ -578,20 +571,18 @@ def execute_sql( ) nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_restart(*args, **kwargs): - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=database.metadata_with_request_id( - nth_request, attempt.increment(), metadata - ), - retry=retry, - timeout=timeout, - ) - - return restart(*args, **kwargs) + if not isinstance(nth_request, int): + raise Exception(f"failed to get an integer back: {nth_request}") + + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, 1, metadata + ), + retry=retry, + timeout=timeout, + ) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -600,7 +591,7 @@ def wrapped_restart(*args, **kwargs): # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - wrapped_restart, + restart, request, metadata, trace_attributes, @@ -610,7 +601,7 @@ def wrapped_restart(*args, **kwargs): ) else: return self._get_streamed_result_set( - wrapped_restart, + restart, request, metadata, trace_attributes, @@ -742,24 +733,19 @@ def partition_read( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.partition_read, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.partition_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -856,24 +842,19 @@ def partition_query( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.partition_query, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.partition_query, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -1013,23 +994,18 @@ def begin(self): metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_selector.begin, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_selector.begin, + metadata=all_metadata, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index c30d831782..5c55b631e2 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1888,6 +1888,7 @@ def __init__(self, directed_read_options=None): def observability_options(self): return dict(db_name=self.name) + @property def _next_nth_request(self): return self._instance._client._next_nth_request From df8f81faa8e0a2030ac4119ed4ef132cd2b0a45e Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 17 Jan 2025 05:05:40 -0800 Subject: [PATCH 17/29] Correctly handle wrapping by class for api objects --- google/cloud/spanner_v1/_helpers.py | 92 ++++++++++++++----- google/cloud/spanner_v1/database.py | 40 ++------ google/cloud/spanner_v1/pool.py | 13 ++- .../services/spanner/transports/grpc.py | 4 +- google/cloud/spanner_v1/snapshot.py | 42 ++++----- .../cloud/spanner_v1/testing/database_test.py | 9 ++ .../cloud/spanner_v1/testing/mock_spanner.py | 1 + .../test_request_id_header.py | 1 + 8 files changed, 118 insertions(+), 84 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index e76c9c19e3..4262e9c6c6 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -737,27 +737,73 @@ class InterceptingHeaderInjector: def __init__(self, original_callable: Callable): self._original_callable = original_callable - def __call__(self, *args, **kwargs): - metadata = kwargs.get("metadata", []) - # Find all the headers that match the x-goog-spanner-request-id - # header an on each retry increment the value. - all_metadata = [] - for key, value in metadata: - if key is REQ_ID_HEADER_KEY: - # Otherwise now increment the count for the attempt number. - splits = value.split(".") - attempt_plus_one = int(splits[-1]) + 1 - splits[-1] = str(attempt_plus_one) - value_before = value - value = ".".join(splits) - print("incrementing value on retry from", value_before, "to", value) - - all_metadata.append( - ( - key, - value, - ) - ) - kwargs["metadata"] = all_metadata - return self._original_callable(*args, **kwargs) +patched = {} + + +def inject_retry_header_control(api): + # For each method, add an _attempt value that'll then be + # retrieved for each retry. + # 1. Patch the __getattribute__ method to match items in our manifest. + target = type(api) + hex_id = hex(id(target)) + if patched.get(hex_id, None) is not None: + return + + orig_getattribute = getattr(target, "__getattribute__") + + def patched_getattribute(*args, **kwargs): + attr = orig_getattribute(*args, **kwargs) + + # 0. If we already patched it, we can return immediately. + if getattr(attr, "_patched", None) is not None: + return attr + + # 1. Skip over non-methods. + if not callable(attr): + return attr + + # 2. Skip modifying private and mangled methods. + mangled_or_private = attr.__name__.startswith("_") + if mangled_or_private: + return attr + + print("\033[35mattr", attr, "hex_id", hex(id(attr)), "\033[00m") + + # 3. Wrap the callable attribute and then capture its metadata keyed argument. + def wrapped_attr(*args, **kwargs): + metadata = kwargs.get("metadata", []) + if not metadata: + # Increment the reinvocation count. + print("not metatadata", attr.__name__) + wrapped_attr._attempt += 1 + return attr(*args, **kwargs) + + # 4. Find all the headers that match the target header key. + all_metadata = [] + for key, value in metadata: + if key is REQ_ID_HEADER_KEY: + print("key", key, "value", value, "attempt", wrapped_attr._attempt) + # 5. Increment the original_attempt with that of our re-invocation count. + splits = value.split(".") + hdr_attempt_plus_reinvocation = ( + int(splits[-1]) + wrapped_attr._attempt + ) + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((key, value)) + + # Increment the reinvocation count. + wrapped_attr._attempt += 1 + + kwargs["metadata"] = all_metadata + print("\033[34mwrap_callable", hex(id(attr)), attr.__name__, "\033[00m") + return attr(*args, **kwargs) + + wrapped_attr._attempt = 0 + wrapped_attr._patched = True + return wrapped_attr + + setattr(target, "__getattribute__", patched_getattribute) + patched[hex_id] = True diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 71eaa547e3..c03daddb3d 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,7 +55,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, - InterceptingHeaderInjector, + inject_retry_header_control, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -434,42 +434,14 @@ def logger(self): @property def spanner_api(self): """Helper for session-related API calls.""" - api = self._generate_spanner_api() + api = self.__generate_spanner_api() if not api: return api - # Now wrap each method's __call__ method with our wrapped one. - # This is how to deal with the fact that there are no proper gRPC - # interceptors for Python hence the remedy is to replace callables - # with our custom wrapper. - attrs = dir(api) - for attr_name in attrs: - mangled = attr_name.startswith("__") - if mangled: - continue - - non_public = attr_name.startswith("_") - if non_public: - continue - - attr = getattr(api, attr_name) - callable_attr = callable(attr) - if callable_attr is None: - continue - - # We should only be looking at bound methods to SpannerClient - # as those are the RPC invoking methods that need to be wrapped - - is_method = type(attr).__name__ == "method" - if not is_method: - continue - - print("attr_name", attr_name, "callable_attr", attr) - setattr(api, attr_name, InterceptingHeaderInjector(attr)) - + inject_retry_header_control(api) return api - def _generate_spanner_api(self): + def __generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client_info = self._instance._client._client_info @@ -804,7 +776,9 @@ def execute_pdml(): session=session.name, options=txn_options, metadata=self.metadata_with_request_id( - begin_txn_nth_request, begin_txn_attempt.increment(), metadata + begin_txn_nth_request, + begin_txn_attempt.increment(), + metadata, ), ) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index ae95bef646..28524758b6 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -268,7 +268,8 @@ def create_sessions(attempt): metadata=all_metadata, ) - resp = retry_on_unavailable(create_sessions) + resp = retry_on_unavailable(create_sessions, "fixedpool") + # print("resp.FixedPool", resp) add_span_event( span, @@ -581,7 +582,8 @@ def create_sessions(attempt): metadata=all_metadata, ) - resp = retry_on_unavailable(create_sessions) + resp = retry_on_unavailable(create_sessions, "pingpool") + print("resp.PingingPool", resp) add_span_event( span, @@ -822,7 +824,7 @@ def __exit__(self, *ignored): self._pool.put(self._session) -def retry_on_unavailable(fn, max=6): +def retry_on_unavailable(fn, kind, max=6): """ Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions, each time passing in the iteration's ordinal number to signal @@ -830,12 +832,15 @@ def retry_on_unavailable(fn, max=6): """ last_exc = None for i in range(max): + print("retry_on_unavailable", kind, i) try: return fn(i + 1) except ServiceUnavailable as exc: + print("exc", exc) last_exc = exc time.sleep(i**2 + random.random()) - except: + except Exception as e: + print("got exception", e) raise raise last_exc diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc.py b/google/cloud/spanner_v1/services/spanner/transports/grpc.py index d325442dc9..14868c5f04 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc.py @@ -413,7 +413,9 @@ def batch_create_sessions( request_serializer=spanner.BatchCreateSessionsRequest.serialize, response_deserializer=spanner.BatchCreateSessionsResponse.deserialize, ) - return self._stubs["batch_create_sessions"] + fn = self._stubs["batch_create_sessions"] + print("\033[32minvoking batch_create_sessionhex_id", hex(id(fn)), "\033[00m") + return fn @property def get_session(self) -> Callable[[spanner.GetSessionRequest], spanner.Session]: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ab516dfa03..522a623941 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -332,9 +332,7 @@ def read( ) nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id( - nth_request, 1, metadata - ) + all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) restart = functools.partial( api.streaming_read, @@ -574,15 +572,19 @@ def execute_sql( if not isinstance(nth_request, int): raise Exception(f"failed to get an integer back: {nth_request}") - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=database.metadata_with_request_id( - nth_request, 1, metadata - ), - retry=retry, - timeout=timeout, - ) + attempt = AtomicCounter(0) + + def wrapped_restart(*args, **kwargs): + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.increment(), metadata + ), + retry=retry, + timeout=timeout, + ) + return restart(*args, **kwargs) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -591,7 +593,7 @@ def execute_sql( # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, @@ -601,7 +603,7 @@ def execute_sql( ) else: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, @@ -733,9 +735,7 @@ def partition_read( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id( - nth_request, 1, metadata - ) + all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) method = functools.partial( api.partition_read, request=request, @@ -842,9 +842,7 @@ def partition_query( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id( - nth_request, 1, metadata - ) + all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) method = functools.partial( api.partition_query, request=request, @@ -994,9 +992,7 @@ def begin(self): metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id( - nth_request, 1, metadata - ) + all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) method = functools.partial( api.begin_transaction, session=self._session.name, diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 5af89fea42..2c8c651bfc 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -27,6 +27,7 @@ MethodAbortInterceptor, XGoogRequestIDHeaderInterceptor, ) +from google.cloud.spanner_v1._helpers import inject_retry_header_control class TestDatabase(Database): @@ -70,6 +71,14 @@ def __init__( @property def spanner_api(self): + api = self.__generate_spanner_api() + if not api: + return api + + inject_retry_header_control(api) + return api + + def __generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client = self._instance._client diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index f8971a6098..e2ac14e976 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -53,6 +53,7 @@ def pop_error(self, context): name = inspect.currentframe().f_back.f_code.co_name error: _Status | None = self.errors.pop(name, None) if error: + print("context.abort_with_status", error) context.abort_with_status(error) def get_result_as_partial_result_sets( diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index d0b4ae1ec3..29d2df29b2 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -309,6 +309,7 @@ def test_unary_retryable_error(self): ) ] + print("got_unaries", got_unary_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments From f435747a2354ab3278ce8e63084967b8cfe8ec1d Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 17 Jan 2025 22:31:06 -0800 Subject: [PATCH 18/29] Revert poool creation interception attempts --- google/cloud/spanner_v1/pool.py | 68 ++++++++------------------------- 1 file changed, 16 insertions(+), 52 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 28524758b6..6ab944fc44 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -257,19 +257,13 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - nth_req = database._next_nth_request - def create_sessions(attempt): - all_metadata = database.metadata_with_request_id( - nth_req, attempt, metadata - ) - return api.batch_create_sessions( - request=request, - metadata=all_metadata, - ) - - resp = retry_on_unavailable(create_sessions, "fixedpool") - # print("resp.FixedPool", resp) + resp = api.batch_create_sessions( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ), + ) add_span_event( span, @@ -569,21 +563,14 @@ def bind(self, database): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - created_session_count = 0 - while created_session_count < self.size: - nth_req = database._next_nth_request - - def create_sessions(attempt): - all_metadata = database.metadata_with_request_id( - nth_req, attempt, metadata - ) - return api.batch_create_sessions( - request=request, - metadata=all_metadata, - ) - - resp = retry_on_unavailable(create_sessions, "pingpool") - print("resp.PingingPool", resp) + returned_session_count = 0 + while returned_session_count < self.size: + resp = api.batch_create_sessions( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ), + ) add_span_event( span, @@ -592,14 +579,13 @@ def create_sessions(attempt): for session_pb in resp.session: session = self._new_session() + returned_session_count += 1 session._session_id = session_pb.name.split("/")[-1] self.put(session) - created_session_count += len(resp.session) - add_span_event( span, - f"Requested for {requested_session_count} sessions, returned {created_session_count}", + f"Requested for {requested_session_count} sessions, returned {returned_session_count}", span_event_attributes, ) @@ -822,25 +808,3 @@ def __enter__(self): def __exit__(self, *ignored): self._pool.put(self._session) - - -def retry_on_unavailable(fn, kind, max=6): - """ - Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions, - each time passing in the iteration's ordinal number to signal - the nth attempt. It retries with exponential backoff with jitter. - """ - last_exc = None - for i in range(max): - print("retry_on_unavailable", kind, i) - try: - return fn(i + 1) - except ServiceUnavailable as exc: - print("exc", exc) - last_exc = exc - time.sleep(i**2 + random.random()) - except Exception as e: - print("got exception", e) - raise - - raise last_exc From ea0823f2d9dbe9fb1d06bf6030c14d8af3192822 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 17 Jan 2025 22:39:23 -0800 Subject: [PATCH 19/29] Wire up and revert some prints --- google/cloud/spanner_v1/_helpers.py | 8 ++++-- google/cloud/spanner_v1/batch.py | 28 +++++++------------ google/cloud/spanner_v1/database.py | 1 - google/cloud/spanner_v1/pool.py | 3 -- .../services/spanner/transports/grpc.py | 4 +-- google/cloud/spanner_v1/session.py | 11 ++++---- 6 files changed, 23 insertions(+), 32 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 4262e9c6c6..5be76056bc 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -752,8 +752,12 @@ def inject_retry_header_control(api): orig_getattribute = getattr(target, "__getattribute__") - def patched_getattribute(*args, **kwargs): - attr = orig_getattribute(*args, **kwargs) + def patched_getattribute(obj, key, *args, **kwargs): + if key.startswith("_"): + return orig_getattribute(obj, key, *args, **kwargs) + + attr = orig_getattribute(obj, key, *args, **kwargs) + print("args", args, "attr.dir", dir(attr)) # 0. If we already patched it, we can return immediately. if getattr(attr, "_patched", None) is not None: diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 811055a256..9cad7cafdb 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -250,20 +250,15 @@ def commit( observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - attempt = AtomicCounter(0) - next_nth_request = database._next_nth_request - - all_metadata = database.metadata_with_request_id( - next_nth_request, - attempt.increment(), - metadata, - ) method = functools.partial( api.commit, request=request, - metadata=all_metadata, + metadata=database.metadata_with_request_id( + database._next_nth_request, + 1, + metadata, + ), ) - deadline = time.time() + kwargs.get( "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS ) @@ -382,18 +377,15 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - next_nth_request = database._next_nth_request - all_metadata = database.metadata_with_request_id( - next_nth_request, - 0, - metadata, - ) method = functools.partial( api.batch_write, request=request, - metadata=all_metadata, + metadata=database.metadata_with_request_id( + database._next_nth_request, + 1, + metadata, + ), ) - response = _retry( method, allowed_exceptions={ diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index c03daddb3d..e02fdac078 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -792,7 +792,6 @@ def execute_pdml(): query_options=query_options, request_options=request_options, ) - method = functools.partial( api.execute_streaming_sql, metadata=self.metadata_with_request_id( diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 6ab944fc44..4e8d2a42dc 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -15,12 +15,10 @@ """Pools managing shared Session objects.""" import datetime -import random import queue import time from google.cloud.exceptions import NotFound -from google.api_core.exceptions import ServiceUnavailable from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( @@ -257,7 +255,6 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - resp = api.batch_create_sessions( request=request, metadata=database.metadata_with_request_id( diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc.py b/google/cloud/spanner_v1/services/spanner/transports/grpc.py index 14868c5f04..d325442dc9 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc.py @@ -413,9 +413,7 @@ def batch_create_sessions( request_serializer=spanner.BatchCreateSessionsRequest.serialize, response_deserializer=spanner.BatchCreateSessionsResponse.deserialize, ) - fn = self._stubs["batch_create_sessions"] - print("\033[32minvoking batch_create_sessionhex_id", hex(id(fn)), "\033[00m") - return fn + return self._stubs["batch_create_sessions"] @property def get_session(self) -> Callable[[spanner.GetSessionRequest], spanner.Session]: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index fd2e1be2b3..994c13b3cf 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -207,10 +207,6 @@ def exists(self): ) ) - all_metadata = database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ) - observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.GetSession", @@ -219,7 +215,12 @@ def exists(self): metadata=metadata, ) as span, MetricsCapture(): try: - api.get_session(name=self.name, metadata=all_metadata) + api.get_session( + name=self.name, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ), + ) if span: span.set_attribute("session_found", True) except NotFound: From f8ad94f2d39601c589efd8308142ead6b1f8cbf6 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 28 Mar 2025 14:21:32 +0300 Subject: [PATCH 20/29] Initial ExecuteStreamingSql request in snapshot should have the header --- google/cloud/spanner_v1/_helpers.py | 6 ------ google/cloud/spanner_v1/database.py | 20 ++++++++++++------- google/cloud/spanner_v1/snapshot.py | 8 ++++++-- .../test_request_id_header.py | 1 - 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 5be76056bc..48b3bbe374 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -757,7 +757,6 @@ def patched_getattribute(obj, key, *args, **kwargs): return orig_getattribute(obj, key, *args, **kwargs) attr = orig_getattribute(obj, key, *args, **kwargs) - print("args", args, "attr.dir", dir(attr)) # 0. If we already patched it, we can return immediately. if getattr(attr, "_patched", None) is not None: @@ -772,14 +771,11 @@ def patched_getattribute(obj, key, *args, **kwargs): if mangled_or_private: return attr - print("\033[35mattr", attr, "hex_id", hex(id(attr)), "\033[00m") - # 3. Wrap the callable attribute and then capture its metadata keyed argument. def wrapped_attr(*args, **kwargs): metadata = kwargs.get("metadata", []) if not metadata: # Increment the reinvocation count. - print("not metatadata", attr.__name__) wrapped_attr._attempt += 1 return attr(*args, **kwargs) @@ -787,7 +783,6 @@ def wrapped_attr(*args, **kwargs): all_metadata = [] for key, value in metadata: if key is REQ_ID_HEADER_KEY: - print("key", key, "value", value, "attempt", wrapped_attr._attempt) # 5. Increment the original_attempt with that of our re-invocation count. splits = value.split(".") hdr_attempt_plus_reinvocation = ( @@ -802,7 +797,6 @@ def wrapped_attr(*args, **kwargs): wrapped_attr._attempt += 1 kwargs["metadata"] = all_metadata - print("\033[34mwrap_callable", hex(id(attr)), attr.__name__, "\033[00m") return attr(*args, **kwargs) wrapped_attr._attempt = 0 diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index e02fdac078..662c553495 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -792,15 +792,21 @@ def execute_pdml(): query_options=query_options, request_options=request_options, ) - method = functools.partial( - api.execute_streaming_sql, - metadata=self.metadata_with_request_id( - partial_nth_request, partial_attempt.increment(), metadata - ), - ) + + def wrapped_method(*args, **kwargs): + print("\033[34mwrapped_method\033[00m") + method = functools.partial( + api.execute_streaming_sql, + metadata=self.metadata_with_request_id( + partial_nth_request, + partial_attempt.increment(), + metadata, + ), + ) + return method(*args, **kwargs) iterator = _restart_on_unavailable( - method=method, + method=wrapped_method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, metadata=metadata, diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 522a623941..ff3433dc21 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -586,6 +586,10 @@ def wrapped_restart(*args, **kwargs): ) return restart(*args, **kwargs) + # The initial request should contain the request-id. + augmented_metadata = database.metadata_with_request_id( + nth_request, attempt.increment(), metadata + ) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -595,7 +599,7 @@ def wrapped_restart(*args, **kwargs): return self._get_streamed_result_set( wrapped_restart, request, - metadata, + augmented_metadata, trace_attributes, column_info, observability_options, @@ -605,7 +609,7 @@ def wrapped_restart(*args, **kwargs): return self._get_streamed_result_set( wrapped_restart, request, - metadata, + augmented_metadata, trace_attributes, column_info, observability_options, diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 29d2df29b2..9c131dc3a8 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -316,7 +316,6 @@ def test_unary_retryable_error(self): def test_streaming_retryable_error(self): add_select1_result() add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) - add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors From 15984cb3728e06fde99a44cfc4ab472f45be6de5 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 28 Mar 2025 16:05:37 +0300 Subject: [PATCH 21/29] Consolidate retries for _restart_on_unavailable --- google/cloud/spanner_v1/database.py | 19 +++----- google/cloud/spanner_v1/snapshot.py | 67 +++++++++++++++++------------ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 662c553495..8dcd2a66a1 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -793,25 +793,19 @@ def execute_pdml(): request_options=request_options, ) - def wrapped_method(*args, **kwargs): - print("\033[34mwrapped_method\033[00m") - method = functools.partial( - api.execute_streaming_sql, - metadata=self.metadata_with_request_id( - partial_nth_request, - partial_attempt.increment(), - metadata, - ), - ) - return method(*args, **kwargs) + method = functools.partial( + api.execute_streaming_sql, + metadata=metadata, + ) iterator = _restart_on_unavailable( - method=wrapped_method, + method=method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, + request_id_manager=self, ) result_set = StreamedResultSet(iterator) @@ -1051,6 +1045,7 @@ def restore(self, source): ) future = api.restore_database( request=request, + # TODO: Infer the channel_id being used. metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata), ) return future diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ff3433dc21..15a3b6bc17 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -62,6 +62,7 @@ def _restart_on_unavailable( transaction=None, transaction_selector=None, observability_options=None, + request_id_manager=None, ): """Restart iteration after :exc:`.ServiceUnavailable`. @@ -81,6 +82,8 @@ def _restart_on_unavailable( resume_token = b"" item_buffer = [] + next_nth_request = lambda: getattr(request_id_manager, "_next_nth_request", 0) + nth_request = next_nth_request() if transaction is not None: transaction_selector = transaction._make_txn_selector() @@ -102,7 +105,13 @@ def _restart_on_unavailable( observability_options=observability_options, metadata=metadata, ), MetricsCapture(): - iterator = method(request=request, metadata=metadata) + attempt += 1 + iterator = method( + request=request, + metadata=request_id_manager.metadata_with_request_id( + nth_request, attempt, metadata + ), + ) for item in iterator: item_buffer.append(item) # Setting the transaction id because the transaction begin was inlined for first rpc. @@ -130,7 +139,14 @@ def _restart_on_unavailable( if transaction is not None: transaction_selector = transaction._make_txn_selector() request.transaction = transaction_selector - iterator = method(request=request) + nth_request = next_nth_request() + attempt = 1 + iterator = method( + request=request, + metadata=request_id_manager.metadata_with_request_id( + nth_request, attempt, metadata + ), + ) continue except InternalServerError as exc: resumable_error = any( @@ -150,8 +166,15 @@ def _restart_on_unavailable( request.resume_token = resume_token if transaction is not None: transaction_selector = transaction._make_txn_selector() + nth_request = next_nth_request() + attempt = 1 request.transaction = transaction_selector - iterator = method(request=request) + iterator = method( + request=request, + metadata=request_id_manager.metadata_with_request_id( + nth_request, attempt, metadata + ), + ) continue if len(item_buffer) == 0: @@ -357,6 +380,7 @@ def read( trace_attributes, transaction=self, observability_options=observability_options, + request_id_manager=self._session._database, ) self._read_request_count += 1 if self._multi_use: @@ -380,6 +404,7 @@ def read( trace_attributes, transaction=self, observability_options=observability_options, + request_id_manager=self._session._database, ) self._read_request_count += 1 @@ -568,27 +593,12 @@ def execute_sql( directed_read_options=directed_read_options, ) - nth_request = getattr(database, "_next_nth_request", 0) - if not isinstance(nth_request, int): - raise Exception(f"failed to get an integer back: {nth_request}") - - attempt = AtomicCounter(0) - - def wrapped_restart(*args, **kwargs): - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=database.metadata_with_request_id( - nth_request, attempt.increment(), metadata - ), - retry=retry, - timeout=timeout, - ) - return restart(*args, **kwargs) - - # The initial request should contain the request-id. - augmented_metadata = database.metadata_with_request_id( - nth_request, attempt.increment(), metadata + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=metadata, + retry=retry, + timeout=timeout, ) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -597,9 +607,9 @@ def wrapped_restart(*args, **kwargs): # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - wrapped_restart, + restart, request, - augmented_metadata, + metadata, trace_attributes, column_info, observability_options, @@ -607,9 +617,9 @@ def wrapped_restart(*args, **kwargs): ) else: return self._get_streamed_result_set( - wrapped_restart, + restart, request, - augmented_metadata, + metadata, trace_attributes, column_info, observability_options, @@ -635,6 +645,7 @@ def _get_streamed_result_set( trace_attributes, transaction=self, observability_options=observability_options, + request_id_manager=self._session._database, ) self._read_request_count += 1 self._execute_sql_count += 1 From a5fdebd9c56ac93190aaa6ffd713e0a2d4e54102 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sun, 30 Mar 2025 21:38:03 +0300 Subject: [PATCH 22/29] Fix missing variable declaration --- google/cloud/spanner_v1/snapshot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 15a3b6bc17..8e64ceee9e 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -94,6 +94,7 @@ def _restart_on_unavailable( request.transaction = transaction_selector iterator = None + attempt = 0 while True: try: From 80faca08427408f64f19d234139cc6bea0490643 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 31 Mar 2025 02:51:13 +0300 Subject: [PATCH 23/29] Adjust with updates --- google/cloud/spanner_v1/database.py | 8 +- google/cloud/spanner_v1/request_id_header.py | 2 +- google/cloud/spanner_v1/session.py | 4 +- google/cloud/spanner_v1/snapshot.py | 78 ++++++++++------- tests/unit/test_batch.py | 10 ++- tests/unit/test_session.py | 12 +++ tests/unit/test_snapshot.py | 92 +++++++++++++++----- tests/unit/test_spanner.py | 14 ++- 8 files changed, 154 insertions(+), 66 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 8dcd2a66a1..8c07a01bb4 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -759,12 +759,6 @@ def execute_partitioned_dml( _metadata_with_leader_aware_routing(self._route_to_leader_enabled) ) - begin_txn_nth_request = self._next_nth_request - begin_txn_attempt = AtomicCounter(0) - partial_nth_request = self._next_nth_request - # partial_attempt will be incremented inside _restart_on_unavailable. - partial_attempt = AtomicCounter(0) - def execute_pdml(): with trace_call( "CloudSpanner.Database.execute_partitioned_pdml", @@ -772,6 +766,8 @@ def execute_pdml(): ) as span, MetricsCapture(): with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") + begin_txn_nth_request = self._next_nth_request + begin_txn_attempt = AtomicCounter(0) txn = api.begin_transaction( session=session.name, options=txn_options, diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index 8376778273..74a5bb1253 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -37,6 +37,6 @@ def generate_rand_uint64(): def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]): req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" - all_metadata = other_metadata.copy() + all_metadata = (other_metadata or []).copy() all_metadata.append((REQ_ID_HEADER_KEY, req_id)) return all_metadata diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 994c13b3cf..7e34206de9 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -268,9 +268,7 @@ def delete(self): ), MetricsCapture(): api.delete_session( name=self.name, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + metadata=metadata, ) def ping(self): diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 8e64ceee9e..8921704c42 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -355,13 +355,10 @@ def read( directed_read_options=directed_read_options, ) - nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) - restart = functools.partial( api.streaming_read, request=request, - metadata=all_metadata, + metadata=metadata, retry=retry, timeout=timeout, ) @@ -751,17 +748,24 @@ def partition_read( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) - method = functools.partial( - api.partition_read, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) + counters = dict(attempt=0) + + def attempt_tracking_method(): + counters["attempt"] += 1 + all_metadata = database.metadata_with_request_id( + nth_request, counters["attempt"], metadata + ) + method = functools.partial( + api.partition_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) + return method() response = _retry( - method, + attempt_tracking_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -858,17 +862,24 @@ def partition_query( metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) - method = functools.partial( - api.partition_query, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) + counters = dict(attempt=0) + + def attempt_tracking_method(): + counters["attempt"] += 1 + all_metadata = database.metadata_with_request_id( + nth_request, counters["attempt"], metadata + ) + method = functools.partial( + api.partition_query, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) + return method() response = _retry( - method, + attempt_tracking_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -1008,16 +1019,23 @@ def begin(self): metadata=metadata, ), MetricsCapture(): nth_request = getattr(database, "_next_nth_request", 0) - all_metadata = database.metadata_with_request_id(nth_request, 1, metadata) - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_selector.begin, - metadata=all_metadata, - ) + counters = dict(attempt=0) + + def attempt_tracking_method(): + counters["attempt"] += 1 + all_metadata = database.metadata_with_request_id( + nth_request, counters["attempt"], metadata + ) + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_selector.begin, + metadata=all_metadata, + ) + return method() response = _retry( - method, + attempt_tracking_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 5888f3a2ae..93ed62454a 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -590,10 +590,6 @@ def _test_batch_write_with_request_options( expected_metadata = [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), - ( - "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", - ), ] if enable_end_to_end_tracing and ot_helpers.HAS_OPENTELEMETRY_INSTALLED: @@ -603,6 +599,12 @@ def _test_batch_write_with_request_options( "traceparent is missing in metadata", ) + expected_metadata.append( + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ) + ) # Remove traceparent from actual metadata for comparison filtered_metadata = [item for item in metadata if item[0] != "traceparent"] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f4d379032..23c8c02bd3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2065,6 +2065,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -2099,6 +2103,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -2137,6 +2145,10 @@ def unit_of_work(txn, *args, **kw): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database.NTH_CLIENT.value}.1.1.1", + ), ], ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 5c55b631e2..c1fe8ec1b5 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -140,6 +140,7 @@ def _call_fut( session, attributes, transaction=derived, + request_id_manager=session._database, ) def _make_item(self, value, resume_token=b"", metadata=None): @@ -158,9 +159,18 @@ def test_iteration_w_empty_raw(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), []) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) + self.assertNoSpans() def test_iteration_w_non_empty_raw(self): @@ -172,9 +182,17 @@ def test_iteration_w_non_empty_raw(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) self.assertNoSpans() def test_iteration_w_raw_w_resume_tken(self): @@ -191,9 +209,17 @@ def test_iteration_w_raw_w_resume_tken(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_no_token(self): @@ -212,7 +238,7 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, b"") @@ -239,7 +265,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, b"") @@ -261,10 +287,18 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable(self): @@ -283,7 +317,7 @@ def test_iteration_w_raw_raising_unavailable(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + LAST)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, RESUME_TOKEN) @@ -309,7 +343,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + LAST)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, RESUME_TOKEN) @@ -331,10 +365,18 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) self.assertNoSpans() def test_iteration_w_raw_raising_unavailable_after_token(self): @@ -352,7 +394,7 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, RESUME_TOKEN) @@ -375,7 +417,7 @@ def test_iteration_w_raw_w_multiuse(self): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST)) self.assertEqual(len(restart.mock_calls), 1) begin_count = sum( @@ -406,7 +448,7 @@ def test_iteration_w_raw_raising_unavailable_w_multiuse(self): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(SECOND)) self.assertEqual(len(restart.mock_calls), 2) begin_count = sum( @@ -445,7 +487,7 @@ def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): derived = self._makeDerived(session) derived._multi_use = True - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) @@ -481,7 +523,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) self.assertEqual(request.resume_token, RESUME_TOKEN) @@ -502,10 +544,18 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): database.spanner_api = self._make_spanner_api() session = _Session(database) derived = self._makeDerived(session) - resumable = self._call_fut(derived, restart, request) + resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) - restart.assert_called_once_with(request=request, metadata=None) + restart.assert_called_once_with( + request=request, + metadata=[ + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ) + ], + ) self.assertNoSpans() def test_iteration_w_span_creation(self): diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 6834d8022d..3a31d9b7d8 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -13,6 +13,8 @@ # limitations under the License. +import sys +import traceback import threading from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -644,6 +646,10 @@ def test_transaction_should_include_begin_w_isolation_level_with_first_update( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + ), ], ) @@ -1226,6 +1232,12 @@ def __init__(self): def _next_nth_request(self): return self._nth_request.increment() + def get_next_request(self): + # This method exists because somehow Python isn't able to + # call the property method "_next_nth_request" and that's + # needlessly stalled progress. + return self._nth_request.increment() + class _Instance(object): def __init__(self): @@ -1242,7 +1254,7 @@ def __init__(self): @property def _next_nth_request(self): - return self._instance._client._next_nth_request + return self._instance._client.get_next_request() @property def _nth_client_id(self): From 06c12a25be51dd3b695b05e54bed00ed038a481e Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 1 Apr 2025 20:25:09 +0300 Subject: [PATCH 24/29] Update _execute_partitioned_dml_helper --- google/cloud/spanner_v1/database.py | 6 ++---- tests/unit/test_database.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 8c07a01bb4..dc4cae0252 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -766,14 +766,12 @@ def execute_pdml(): ) as span, MetricsCapture(): with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") - begin_txn_nth_request = self._next_nth_request - begin_txn_attempt = AtomicCounter(0) txn = api.begin_transaction( session=session.name, options=txn_options, metadata=self.metadata_with_request_id( - begin_txn_nth_request, - begin_txn_attempt.increment(), + self._next_nth_request, + 1, metadata, ), ) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index c38996044e..95128647bf 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1298,7 +1298,8 @@ def _execute_partitioned_dml_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.2", + # Please note that this try was by an abort and not from service unavailable. + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.3.1", ), ], ) @@ -1370,6 +1371,22 @@ def _execute_partitioned_dml_helper( query_options=expected_query_options, request_options=expected_request_options, ) + + api.begin_transaction.assert_called_with( + session=self.SESSION_NAME, + options=txn_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + # Retrying on an aborted response involves creating the transaction afresh + # and also re-invoking execute_streaming_sql, hence the fresh request 4.1. + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.3.1", + ), + ], + ) + api.execute_streaming_sql.assert_called_with( request=expected_request, metadata=[ @@ -1377,7 +1394,9 @@ def _execute_partitioned_dml_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.2", + # Retrying on an aborted response involves creating the transaction afresh + # and also re-invoking execute_streaming_sql, hence the fresh request 4.1. + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.4.1", ), ], ) From 9d4d9420b169b5965f90e5439b53392ad3db6576 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 8 Apr 2025 22:02:42 +0300 Subject: [PATCH 25/29] Monkey patch updates --- google/cloud/spanner_v1/__init__.py | 3 + google/cloud/spanner_v1/_helpers.py | 157 +++++++++++++++--- google/cloud/spanner_v1/database.py | 5 +- google/cloud/spanner_v1/pool.py | 2 + google/cloud/spanner_v1/snapshot.py | 21 ++- .../test_request_id_header.py | 1 + 6 files changed, 151 insertions(+), 38 deletions(-) diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index beeed1dacf..240eb838da 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -75,6 +75,7 @@ from google.cloud.spanner_v1.pool import FixedSizePool from google.cloud.spanner_v1.pool import PingingPool from google.cloud.spanner_v1.pool import TransactionPingingPool +from google.cloud.spanner_v1._helpers import monkey_patch COMMIT_TIMESTAMP = "spanner.commit_timestamp()" @@ -83,6 +84,8 @@ ``(allow_commit_timestamp=true)`` in the schema. """ +monkey_patch(Transaction) + __all__ = ( # google.cloud.spanner_v1 diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 48b3bbe374..90c12a3070 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -19,6 +19,7 @@ import math import time import base64 +import inspect import threading from google.protobuf.struct_pb2 import ListValue @@ -739,9 +740,100 @@ def __init__(self, original_callable: Callable): patched = {} +patched_mu = threading.Lock() def inject_retry_header_control(api): + return + monkey_patch(type(api)) + +memoize_map = dict() + +def monkey_patch(obj): + return + + """ + klass = obj + attrs = dir(klass) + for attr_key in attrs: + if attr_key.startswith('_'): + continue + + attr_value = getattr(obj, attr_key) + if not callable(attr_value): + continue + + signature = inspect.signature(attr_value) + print(attr_key, signature.parameters) + + call = attr_value + # Our goal is to replace the runtime pass through. + def wrapped(*args, **kwargs): + print(attr_key, 'called') + return call(*args, **kwargs) + + setattr(klass, attr_key, wrapped) + + return + """ + + orig_get_attr = getattr(obj, "__getattribute__") + def patched_getattribute(obj, key, *args, **kwargs): + if key.startswith('_'): + return orig_get_attr(obj, key, *args, **kwargs) + + orig_value = orig_get_attr(obj, key, *args, **kwargs) + if not callable(orig_value): + return orig_value + + map_key = hex(id(key)) + hex(id(obj)) + memoized = memoize_map.get(map_key, None) + if memoized: + print("memoized_hit", key, '\033[35m', inspect.getsource(orig_value), '\033[00m') + return memoized + + signature = inspect.signature(orig_value) + if signature.parameters.get('metadata', None) is None: + return orig_value + + print(key, '\033[34m', map_key, '\033[00m', signature, signature.parameters.get('metadata', None)) + counters = dict(attempt=0) + def patched_method(*aargs, **kkwargs): + counters['attempt'] += 1 + metadata = kkwargs.get('metadata', None) + if not metadata: + return orig_value(*aargs, **kkwargs) + + # 4. Find all the headers that match the target header key. + all_metadata = [] + for mkey, value in metadata: + if mkey is REQ_ID_HEADER_KEY: + attempt = counters['attempt'] + if attempt > 1: + # 5. Increment the original_attempt with that of our re-invocation count. + splits = value.split(".") + print('\033[34mkey', mkey, '\033[00m', splits) + hdr_attempt_plus_reinvocation = ( + int(splits[-1]) + attempt + ) + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((mkey, value)) + + kwargs["metadata"] = all_metadata + return orig_value(*aargs, **kkwargs) + + memoize_map[map_key] = patched_method + return patched_method + + setattr(obj, '__getattribute__', patched_getattribute) + + +def foo(api): + global patched + global patched_mu + # For each method, add an _attempt value that'll then be # retrieved for each retry. # 1. Patch the __getattribute__ method to match items in our manifest. @@ -753,55 +845,66 @@ def inject_retry_header_control(api): orig_getattribute = getattr(target, "__getattribute__") def patched_getattribute(obj, key, *args, **kwargs): + # 1. Skip modifying private and mangled methods. if key.startswith("_"): return orig_getattribute(obj, key, *args, **kwargs) attr = orig_getattribute(obj, key, *args, **kwargs) - # 0. If we already patched it, we can return immediately. - if getattr(attr, "_patched", None) is not None: - return attr - - # 1. Skip over non-methods. + # 2. Skip over non-methods. if not callable(attr): + patched_mu.release() return attr - # 2. Skip modifying private and mangled methods. - mangled_or_private = attr.__name__.startswith("_") - if mangled_or_private: - return attr - + patched_key = hex(id(key)) + hex(id(obj)) + patched_mu.acquire() + already_patched = patched.get(patched_key, None) + + other_attempts = dict(attempts=0) # 3. Wrap the callable attribute and then capture its metadata keyed argument. def wrapped_attr(*args, **kwargs): + print("\033[31m", key, "attempt", other_attempts['attempts'], "\033[00m") + other_attempts['attempts'] += 1 + metadata = kwargs.get("metadata", []) if not metadata: # Increment the reinvocation count. wrapped_attr._attempt += 1 return attr(*args, **kwargs) + print("\033[35mwrapped_attr", key, args, kwargs, 'attempt', wrapped_attr._attempt, "\033[00m") + # 4. Find all the headers that match the target header key. all_metadata = [] - for key, value in metadata: - if key is REQ_ID_HEADER_KEY: - # 5. Increment the original_attempt with that of our re-invocation count. - splits = value.split(".") - hdr_attempt_plus_reinvocation = ( - int(splits[-1]) + wrapped_attr._attempt - ) - splits[-1] = str(hdr_attempt_plus_reinvocation) - value = ".".join(splits) - - all_metadata.append((key, value)) - - # Increment the reinvocation count. - wrapped_attr._attempt += 1 + for mkey, value in metadata: + if mkey is REQ_ID_HEADER_KEY: + if wrapped_attr._attempt > 0: + # 5. Increment the original_attempt with that of our re-invocation count. + splits = value.split(".") + print('\033[34mkey', mkey, '\033[00m', splits) + hdr_attempt_plus_reinvocation = ( + int(splits[-1]) + wrapped_attr._attempt + ) + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((mkey, value)) kwargs["metadata"] = all_metadata + wrapped_attr._attempt += 1 + print(key, "\033[36mreplaced_all_metadata", all_metadata, "\033[00m") return attr(*args, **kwargs) - wrapped_attr._attempt = 0 - wrapped_attr._patched = True + if already_patched: + print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched) + setattr(attr, 'patched', True) + # Increment the reinvocation count. + patched_mu.release() + return already_patched + + patched[patched_key] = wrapped_attr + setattr(wrapped_attr, '_attempt', 0) + patched_mu.release() return wrapped_attr setattr(target, "__getattribute__", patched_getattribute) - patched[hex_id] = True diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index dc4cae0252..4eb30d646c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,7 +55,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, - inject_retry_header_control, + monkey_patch, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -438,7 +438,7 @@ def spanner_api(self): if not api: return api - inject_retry_header_control(api) + monkey_patch(api) return api def __generate_spanner_api(self): @@ -813,6 +813,7 @@ def execute_pdml(): def _next_nth_request(self): if self._instance and self._instance._client: return self._instance._client._next_nth_request + raise Exception("returning 1 for next_nth_request") return 1 @property diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 4e8d2a42dc..da19f6a6d2 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -249,6 +249,7 @@ def bind(self, database): attempt = 1 returned_session_count = 0 while not self._sessions.full(): + print("fixedPool.batchCreateSessions") request.session_count = requested_session_count - self._sessions.qsize() add_span_event( span, @@ -562,6 +563,7 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: + print("pingingPool.batchCreateSessions") resp = api.batch_create_sessions( request=request, metadata=database.metadata_with_request_id( diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 8921704c42..e0c7cca516 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -591,13 +591,16 @@ def execute_sql( directed_read_options=directed_read_options, ) - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + def wrapped_restart(*args, **kwargs): + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=kwargs.get('metadata', metadata), + retry=retry, + timeout=timeout, + ) + return restart(*args, **kwargs) + trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -605,7 +608,7 @@ def execute_sql( # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, @@ -615,7 +618,7 @@ def execute_sql( ) else: return self._get_streamed_result_set( - restart, + wrapped_restart, request, metadata, trace_attributes, diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 9c131dc3a8..24af837728 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -310,6 +310,7 @@ def test_unary_retryable_error(self): ] print("got_unaries", got_unary_segments) + print("got_stream", got_stream_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments From c31c03fcf7437ff1fa6b3d9b293597539c0a2a6a Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 17 May 2025 13:32:00 -0700 Subject: [PATCH 26/29] Reduce unnecessary changes --- google/cloud/spanner_v1/__init__.py | 3 --- google/cloud/spanner_v1/pool.py | 3 --- google/cloud/spanner_v1/snapshot.py | 2 -- .../cloud/spanner_v1/testing/mock_spanner.py | 1 - tests/unit/test_atomic_counter.py | 1 - tests/unit/test_database.py | 25 ------------------- tests/unit/test_session.py | 1 - tests/unit/test_spanner.py | 6 ----- 8 files changed, 42 deletions(-) diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 11776eec45..48b11d9342 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -75,7 +75,6 @@ from google.cloud.spanner_v1.pool import FixedSizePool from google.cloud.spanner_v1.pool import PingingPool from google.cloud.spanner_v1.pool import TransactionPingingPool -from google.cloud.spanner_v1._helpers import monkey_patch COMMIT_TIMESTAMP = "spanner.commit_timestamp()" @@ -84,8 +83,6 @@ ``(allow_commit_timestamp=true)`` in the schema. """ -monkey_patch(Transaction) - __all__ = ( # google.cloud.spanner_v1 diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index da19f6a6d2..0bc0135ba0 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -246,10 +246,8 @@ def bind(self, database): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - attempt = 1 returned_session_count = 0 while not self._sessions.full(): - print("fixedPool.batchCreateSessions") request.session_count = requested_session_count - self._sessions.qsize() add_span_event( span, @@ -563,7 +561,6 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: - print("pingingPool.batchCreateSessions") resp = api.batch_create_sessions( request=request, metadata=database.metadata_with_request_id( diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 7a6fd5d150..badc23026e 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -82,8 +82,6 @@ def _restart_on_unavailable( resume_token = b"" item_buffer = [] - next_nth_request = lambda: getattr(request_id_manager, "_next_nth_request", 0) - nth_request = next_nth_request() if transaction is not None: transaction_selector = transaction._make_txn_selector() diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index e2ac14e976..f8971a6098 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -53,7 +53,6 @@ def pop_error(self, context): name = inspect.currentframe().f_back.f_code.co_name error: _Status | None = self.errors.pop(name, None) if error: - print("context.abort_with_status", error) context.abort_with_status(error) def get_result_as_partial_result_sets( diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py index e8d8b6b7ce..92d10cac79 100644 --- a/tests/unit/test_atomic_counter.py +++ b/tests/unit/test_atomic_counter.py @@ -15,7 +15,6 @@ import random import threading import unittest - from google.cloud.spanner_v1._helpers import AtomicCounter diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 3838b879b9..bef2373f01 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3449,31 +3449,6 @@ def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): def _channel_id(self): return 1 - @property - def _next_nth_request(self): - if self._instance and self._instance._client: - return self._instance._client._next_nth_request - return 1 - - @property - def _nth_client_id(self): - if self._instance and self._instance._client: - return self._instance._client._nth_client_id - return 1 - - def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): - return _metadata_with_request_id( - self._nth_client_id, - self._channel_id, - nth_request, - nth_attempt, - prior_metadata, - ) - - @property - def _channel_id(self): - return 1 - class _Pool(object): _bound = None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4f4f53c78c..e7b7c3dee6 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1560,7 +1560,6 @@ def _time(_results=[1, 2, 4, 8]): self.assertEqual(kw, {}) expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) - print("gax_api", gax_api.begin_transaction.call_args_list[2]) self.assertEqual( gax_api.begin_transaction.call_args_list, [ diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 261703d434..e30448cd01 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -13,8 +13,6 @@ # limitations under the License. -import sys -import traceback import threading from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -45,10 +43,6 @@ _merge_query_options, _metadata_with_request_id, ) -from google.cloud.spanner_v1._helpers import ( - AtomicCounter, - _metadata_with_request_id, -) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID import mock From e67cc9e79ae8f39231fb88e03ac752de23439eb9 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 17 May 2025 23:48:55 -0700 Subject: [PATCH 27/29] Experiment with wrapping for gapic retries --- google/cloud/spanner_v1/_helpers.py | 175 +++++++++++++----- google/cloud/spanner_v1/batch.py | 1 + google/cloud/spanner_v1/database.py | 9 - google/cloud/spanner_v1/pool.py | 95 +++++++++- .../mockserver_tests/mock_server_test_base.py | 1 + 5 files changed, 216 insertions(+), 65 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 884e9b7c66..01f6d181ca 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -576,6 +576,7 @@ def _retry( def _check_rst_stream_error(exc): + print("\033[31mrst_", exc, "\033[00m") resumable_error = ( any( resumable_message in exc.message @@ -589,6 +590,11 @@ def _check_rst_stream_error(exc): raise +def _check_unavailable(exc): + print("\033[31mcheck_unavailable", exc, "\033[00m") + raise + + def _metadata_with_leader_aware_routing(value, **kw): """Create RPC metadata containing a leader aware routing header @@ -763,63 +769,125 @@ def __init__(self, original_callable: Callable): def inject_retry_header_control(api): - return - monkey_patch(type(api)) + # monkey_patch(type(api)) + # monkey_patch(api) + pass -memoize_map = dict() -def monkey_patch(obj): - return +def monkey_patch(typ): + keys = dir(typ) + attempts = dict() + for key in keys: + if key.startswith("_"): + continue - """ - klass = obj - attrs = dir(klass) - for attr_key in attrs: - if attr_key.startswith('_'): + if key != "batch_create_sessions": continue - attr_value = getattr(obj, attr_key) - if not callable(attr_value): + fn = getattr(typ, key) + + signature = inspect.signature(fn) + if signature.parameters.get("metadata", None) is None: continue - signature = inspect.signature(attr_value) - print(attr_key, signature.parameters) + print("fn.__call__", inspect.getsource(fn)) - call = attr_value - # Our goal is to replace the runtime pass through. - def wrapped(*args, **kwargs): - print(attr_key, 'called') - return call(*args, **kwargs) + def as_proxy(db, *args, **kwargs): + print("db_key", hex(id(db))) + print("as_proxy", args, kwargs) + metadata = kwargs.get("metadata", None) + if not metadata: + return fn(db, *args, **kwargs) - setattr(klass, attr_key, wrapped) + hash_key = hex(id(db)) + "." + hex(id(key)) + attempts.setdefault(hash_key, 0) + attempts[hash_key] += 1 + # 4. Find all the headers that match the target header key. + all_metadata = [] + for mkey, value in metadata: + if mkey is not REQ_ID_HEADER_KEY: + continue - return - """ + splits = value.split(".") + # 5. Increment the original_attempt with that of our re-invocation count. + print("\033[34mkey", mkey, "\033[00m", splits) + hdr_attempt_plus_reinvocation = int(splits[-1]) + attempts[hash_key] + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + all_metadata.append((mkey, value)) + + kwargs["metadata"] = all_metadata + return fn(db, *args, **kwargs) + + setattr(typ, key, as_proxy) + + +def alt_foo(): + memoize_map = dict() orig_get_attr = getattr(obj, "__getattribute__") + hex_orig = hex(id(orig_get_attr)) + hex_patched = None + def patched_getattribute(obj, key, *args, **kwargs): - if key.startswith('_'): + if key.startswith("_"): return orig_get_attr(obj, key, *args, **kwargs) - orig_value = orig_get_attr(obj, key, *args, **kwargs) - if not callable(orig_value): - return orig_value + if key != "batch_create_sessions": + return orig_get_attr(obj, key, *args, **kwargs) map_key = hex(id(key)) + hex(id(obj)) memoized = memoize_map.get(map_key, None) if memoized: - print("memoized_hit", key, '\033[35m', inspect.getsource(orig_value), '\033[00m') + if False: + print( + "memoized_hit", + key, + "\033[35m", + inspect.getsource(orig_value), + "\033[00m", + ) + print("memoized_hit", key, "\033[35m", map_key, "\033[00m") return memoized + orig_value = orig_get_attr(obj, key, *args, **kwargs) + if not callable(orig_value): + return orig_value + signature = inspect.signature(orig_value) - if signature.parameters.get('metadata', None) is None: + if signature.parameters.get("metadata", None) is None: return orig_value - print(key, '\033[34m', map_key, '\033[00m', signature, signature.parameters.get('metadata', None)) + if False: + print( + key, + "\033[34m", + map_key, + "\033[00m", + signature, + signature.parameters.get("metadata", None), + ) + + if False: + stack = inspect.stack() + ends = stack[-50:-20] + for i, st in enumerate(ends): + print(i, st.filename, st.lineno) + + print( + "\033[33mmonkey patching now\033[00m", + key, + "hex_orig", + hex_orig, + "hex_patched", + hex_patched, + ) counters = dict(attempt=0) + def patched_method(*aargs, **kkwargs): - counters['attempt'] += 1 - metadata = kkwargs.get('metadata', None) + counters["attempt"] += 1 + print("counters", counters) + metadata = kkwargs.get("metadata", None) if not metadata: return orig_value(*aargs, **kkwargs) @@ -827,32 +895,38 @@ def patched_method(*aargs, **kkwargs): all_metadata = [] for mkey, value in metadata: if mkey is REQ_ID_HEADER_KEY: - attempt = counters['attempt'] + attempt = counters["attempt"] if attempt > 1: # 5. Increment the original_attempt with that of our re-invocation count. splits = value.split(".") - print('\033[34mkey', mkey, '\033[00m', splits) - hdr_attempt_plus_reinvocation = ( - int(splits[-1]) + attempt - ) + print("\033[34mkey", mkey, "\033[00m", splits) + hdr_attempt_plus_reinvocation = int(splits[-1]) + attempt splits[-1] = str(hdr_attempt_plus_reinvocation) value = ".".join(splits) all_metadata.append((mkey, value)) kwargs["metadata"] = all_metadata - return orig_value(*aargs, **kkwargs) + + try: + return orig_value(*aargs, **kkwargs) + + except (InternalServerError, ServiceUnavailable) as exc: + print("caught this exception, incrementing", exc) + counters["attempt"] += 1 + raise exc memoize_map[map_key] = patched_method return patched_method - setattr(obj, '__getattribute__', patched_getattribute) + hex_patched = hex(id(patched_getattribute)) + setattr(obj, "__getattribute__", patched_getattribute) def foo(api): global patched global patched_mu - + # For each method, add an _attempt value that'll then be # retrieved for each retry. # 1. Patch the __getattribute__ method to match items in our manifest. @@ -878,12 +952,13 @@ def patched_getattribute(obj, key, *args, **kwargs): patched_key = hex(id(key)) + hex(id(obj)) patched_mu.acquire() already_patched = patched.get(patched_key, None) - + other_attempts = dict(attempts=0) + # 3. Wrap the callable attribute and then capture its metadata keyed argument. def wrapped_attr(*args, **kwargs): - print("\033[31m", key, "attempt", other_attempts['attempts'], "\033[00m") - other_attempts['attempts'] += 1 + print("\033[31m", key, "attempt", other_attempts["attempts"], "\033[00m") + other_attempts["attempts"] += 1 metadata = kwargs.get("metadata", []) if not metadata: @@ -891,7 +966,15 @@ def wrapped_attr(*args, **kwargs): wrapped_attr._attempt += 1 return attr(*args, **kwargs) - print("\033[35mwrapped_attr", key, args, kwargs, 'attempt', wrapped_attr._attempt, "\033[00m") + print( + "\033[35mwrapped_attr", + key, + args, + kwargs, + "attempt", + wrapped_attr._attempt, + "\033[00m", + ) # 4. Find all the headers that match the target header key. all_metadata = [] @@ -900,7 +983,7 @@ def wrapped_attr(*args, **kwargs): if wrapped_attr._attempt > 0: # 5. Increment the original_attempt with that of our re-invocation count. splits = value.split(".") - print('\033[34mkey', mkey, '\033[00m', splits) + print("\033[34mkey", mkey, "\033[00m", splits) hdr_attempt_plus_reinvocation = ( int(splits[-1]) + wrapped_attr._attempt ) @@ -916,13 +999,13 @@ def wrapped_attr(*args, **kwargs): if already_patched: print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched) - setattr(attr, 'patched', True) + setattr(attr, "patched", True) # Increment the reinvocation count. patched_mu.release() return already_patched patched[patched_key] = wrapped_attr - setattr(wrapped_attr, '_attempt', 0) + setattr(wrapped_attr, "_attempt", 0) patched_mu.release() return wrapped_attr diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 8bef3bb00a..61088dd7a5 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -250,6 +250,7 @@ def commit( observability_options=observability_options, metadata=metadata, ), MetricsCapture(): + def wrapped_method(*args, **kwargs): method = functools.partial( api.commit, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index d591eb45e0..c635244154 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -432,15 +432,6 @@ def logger(self): @property def spanner_api(self): - """Helper for session-related API calls.""" - api = self.__generate_spanner_api() - if not api: - return api - - monkey_patch(api) - return api - - def __generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client_info = self._instance._client._client_info diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0bc0135ba0..cab536ee34 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -18,12 +18,18 @@ import queue import time +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import ServiceUnavailable from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( + _check_rst_stream_error, + _check_unavailable, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _retry, + AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, @@ -254,11 +260,25 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + print("attempt", attempt.value) + + def wrapped_method(*args, **kwargs): + print("\033[33mwrapped_method", *args, "kwargs", kwargs, "\033[00m") + return api.batch_create_sessions( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, attempt.increment(), metadata + ), + ) + + resp = _retry( + wrapped_method, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + ServiceUnavailable: _check_unavailable, + }, ) add_span_event( @@ -462,6 +482,48 @@ def clear(self): session.delete() +class CallerInterceptor: + def __init__(self, api): + pass + + def intercept(self): + # 1. For each callable method, intercept it and on each call, if we detect + # retryable UNAVAILABLE and INTERNAL grpc status codes, automatically retry. + elems = dir(self.__api) + for name in elems: + if name.startswith("_"): + continue + + if name != "batch_create_sessions": + continue + + attr = getattr(self.__api, name) + if not callable(attr): + continue + + setattr(self.__api, name, rewrite_and_handle_retries(attr)) + + +""" +from google.api_core.exceptions import ( + InternalServerError, + ServiceUnavailable, +) + +def rewrite_and_handle_retries(fn): + attempts = AtomicCounter(0) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except InternalServerError, ServiceUnavailable Aborted as exc: + attempts.increment() + finally: + pass + + return wrapper +""" + + class PingingPool(AbstractSessionPool): """Concrete session pool implementation: @@ -561,11 +623,24 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + attempt = AtomicCounter(0) + print("attempt", attempt.value) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + return api.batch_create_sessions( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, attempt.increment(), metadata + ), + ) + + resp = _retry( + wrapped_method, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + ServiceUnavailable: _check_unavailable, + }, ) add_span_event( diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 2f89415b55..d359ded559 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -213,4 +213,5 @@ def database(self) -> Database: pool=FixedSizePool(size=10), enable_interceptors_in_tests=True, ) + print("self._database", self._database) return self._database From dab92ddcad4aba24a782f609235f21d0c78f9b20 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sun, 18 May 2025 18:33:02 -0700 Subject: [PATCH 28/29] Remove duplicate code --- .../mockserver_tests/mock_server_test_base.py | 23 --------- tests/unit/test_spanner.py | 48 ++++++++++--------- 2 files changed, 25 insertions(+), 46 deletions(-) diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index d359ded559..7b4538d601 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -20,7 +20,6 @@ start_mock_server, SpannerServicer, ) -from google.cloud.spanner_v1.client import Client import google.cloud.spanner_v1.types.type as spanner_type import google.cloud.spanner_v1.types.result_set as result_set from google.api_core.client_options import ClientOptions @@ -79,27 +78,6 @@ def unavailable_status() -> _Status: return status -# Creates an UNAVAILABLE status with the smallest possible retry delay. -def unavailable_status() -> _Status: - error = status_pb2.Status( - code=code_pb2.UNAVAILABLE, - message="Service unavailable.", - ) - retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) - status = _Status( - code=code_to_grpc_status_code(error.code), - details=error.message, - trailing_metadata=( - ("grpc-status-details-bin", error.SerializeToString()), - ( - "google.rpc.retryinfo-bin", - retry_info.SerializeToString(), - ), - ), - ) - return status - - def add_error(method: str, error: status_pb2.Status): MockServerTestBase.spanner_service.mock_spanner.add_error(method, error) @@ -213,5 +191,4 @@ def database(self) -> Database: pool=FixedSizePool(size=10), enable_interceptors_in_tests=True, ) - print("self._database", self._database) return self._database diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index e30448cd01..b3b24ad6c8 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -38,9 +38,11 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1._helpers import ( - AtomicCounter, _make_value_pb, _merge_query_options, +) +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, _metadata_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -545,7 +547,7 @@ def test_transaction_should_include_begin_with_first_query(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], timeout=TIMEOUT, @@ -566,7 +568,7 @@ def test_transaction_should_include_begin_with_first_read(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -586,7 +588,7 @@ def test_transaction_should_include_begin_with_first_batch_update(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -615,7 +617,7 @@ def test_transaction_should_include_begin_w_exclude_txn_from_change_streams_with ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -645,7 +647,7 @@ def test_transaction_should_include_begin_w_isolation_level_with_first_update( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -667,7 +669,7 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -685,7 +687,7 @@ def test_transaction_should_use_transaction_id_if_error_with_first_batch_update( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], ) @@ -705,7 +707,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_query(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -722,7 +724,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_query(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], ) @@ -742,7 +744,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -759,7 +761,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_update(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], ) @@ -784,7 +786,7 @@ def test_transaction_execute_sql_w_directed_read_options(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=gapic_v1.method.DEFAULT, @@ -811,7 +813,7 @@ def test_transaction_streaming_read_w_directed_read_options(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -831,7 +833,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -846,7 +848,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_read(self): ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], retry=RETRY, @@ -866,7 +868,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], retry=RETRY, @@ -882,7 +884,7 @@ def test_transaction_should_use_transaction_id_returned_by_first_batch_update(se ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], retry=RETRY, @@ -926,7 +928,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], ) @@ -940,7 +942,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.2.1", ), ], ) @@ -952,7 +954,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.3.1", ), ], retry=RETRY, @@ -1000,7 +1002,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_ ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.3.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.3.1", ), ], ) @@ -1216,7 +1218,7 @@ def test_transaction_should_execute_sql_with_route_to_leader_disabled(self): ("google-cloud-resource-prefix", database.name), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", ), ], timeout=TIMEOUT, From 4334dd4a186820ecfac9732d8b90ec3f8d781f1a Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sun, 18 May 2025 19:18:26 -0700 Subject: [PATCH 29/29] Complete mockserver tests --- google/cloud/spanner_v1/_helpers.py | 13 ++++++++++-- google/cloud/spanner_v1/pool.py | 8 ++++--- .../mockserver_tests/mock_server_test_base.py | 21 ++++++++++++++++++- .../test_request_id_header.py | 5 ++--- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 01f6d181ca..5107188324 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -591,8 +591,17 @@ def _check_rst_stream_error(exc): def _check_unavailable(exc): - print("\033[31mcheck_unavailable", exc, "\033[00m") - raise + resumable_error = ( + any( + resumable_message in exc.message + for resumable_message in ( + "INTERNAL", + "Service unavailable", + ) + ), + ) + if not resumable_error: + raise def _metadata_with_leader_aware_routing(value, **kw): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index cab536ee34..af4c7a1f7b 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -15,6 +15,7 @@ """Pools managing shared Session objects.""" import datetime +import functools import queue import time @@ -266,12 +267,14 @@ def bind(self, database): def wrapped_method(*args, **kwargs): print("\033[33mwrapped_method", *args, "kwargs", kwargs, "\033[00m") - return api.batch_create_sessions( + method = functools.partial( + api.batch_create_sessions, request=request, metadata=database.metadata_with_request_id( - database._next_nth_request, attempt.increment(), metadata + nth_request, attempt.increment(), metadata ), ) + return method(*args, **kwargs) resp = _retry( wrapped_method, @@ -624,7 +627,6 @@ def bind(self, database): returned_session_count = 0 while returned_session_count < self.size: attempt = AtomicCounter(0) - print("attempt", attempt.value) nth_request = database._next_nth_request def wrapped_method(*args, **kwargs): diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 7b4538d601..0c82075910 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -61,7 +61,7 @@ def aborted_status() -> _Status: def unavailable_status() -> _Status: error = status_pb2.Status( code=code_pb2.UNAVAILABLE, - message="Service unavailable.", + message="Received unexpected EOS on DATA frame from server", ) retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) status = _Status( @@ -77,6 +77,25 @@ def unavailable_status() -> _Status: ) return status +# Creates an INTERNAL status with the smallest possible retry delay. +def internal_status() -> _Status: + error = status_pb2.Status( + code=code_pb2.INTERNAL, + message="Service unavailable.", + ) + retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) + status = _Status( + code=code_to_grpc_status_code(error.code), + details=error.message, + trailing_metadata=( + ("grpc-status-details-bin", error.SerializeToString()), + ( + "google.rpc.retryinfo-bin", + retry_info.SerializeToString(), + ), + ), + ) + return status def add_error(method: str, error: status_pb2.Status): MockServerTestBase.spanner_service.mock_spanner.add_error(method, error) diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 24af837728..805f422d0e 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -27,6 +27,7 @@ add_select1_result, aborted_status, add_error, + internal_status, unavailable_status, ) @@ -270,7 +271,7 @@ def test_database_execute_partitioned_dml_request_id(self): def test_unary_retryable_error(self): add_select1_result() - add_error(SpannerServicer.BatchCreateSessions.__name__, unavailable_status()) + add_error(SpannerServicer.BatchCreateSessions.__name__, internal_status()) if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors @@ -309,8 +310,6 @@ def test_unary_retryable_error(self): ) ] - print("got_unaries", got_unary_segments) - print("got_stream", got_stream_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments