diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 7b86a5653f..5107188324 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -19,6 +19,7 @@ import math import time import base64 +import inspect import threading from google.protobuf.struct_pb2 import ListValue @@ -33,7 +34,7 @@ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import JsonObject, Interval from google.cloud.spanner_v1 import TransactionOptions -from google.cloud.spanner_v1.request_id_header import with_request_id +from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id from google.rpc.error_details_pb2 import RetryInfo try: @@ -45,6 +46,7 @@ HAS_OPENTELEMETRY_INSTALLED = False from typing import List, Tuple import random +from typing import Callable # Validation error messages NUMERIC_MAX_SCALE_ERR_MSG = ( @@ -574,6 +576,7 @@ def _retry( def _check_rst_stream_error(exc): + print("\033[31mrst_", exc, "\033[00m") resumable_error = ( any( resumable_message in exc.message @@ -587,6 +590,20 @@ def _check_rst_stream_error(exc): raise +def _check_unavailable(exc): + resumable_error = ( + any( + resumable_message in exc.message + for resumable_message in ( + "INTERNAL", + "Service unavailable", + ) + ), + ) + if not resumable_error: + raise + + def _metadata_with_leader_aware_routing(value, **kw): """Create RPC metadata containing a leader aware routing header @@ -749,3 +766,256 @@ def _merge_Transaction_Options( # Convert protobuf object back into a TransactionOptions instance return TransactionOptions(merged_pb) + + +class InterceptingHeaderInjector: + def __init__(self, original_callable: Callable): + self._original_callable = original_callable + + +patched = {} +patched_mu = threading.Lock() + + +def inject_retry_header_control(api): + # monkey_patch(type(api)) + # monkey_patch(api) + pass + + +def monkey_patch(typ): + keys = dir(typ) + attempts = dict() + for key in keys: + if key.startswith("_"): + continue + + if key != "batch_create_sessions": + continue + + fn = getattr(typ, key) + + signature = inspect.signature(fn) + if signature.parameters.get("metadata", None) is None: + continue + + print("fn.__call__", inspect.getsource(fn)) + + def as_proxy(db, *args, **kwargs): + print("db_key", hex(id(db))) + print("as_proxy", args, kwargs) + metadata = kwargs.get("metadata", None) + if not metadata: + return fn(db, *args, **kwargs) + + hash_key = hex(id(db)) + "." + hex(id(key)) + attempts.setdefault(hash_key, 0) + attempts[hash_key] += 1 + # 4. Find all the headers that match the target header key. + all_metadata = [] + for mkey, value in metadata: + if mkey is not REQ_ID_HEADER_KEY: + continue + + splits = value.split(".") + # 5. Increment the original_attempt with that of our re-invocation count. + print("\033[34mkey", mkey, "\033[00m", splits) + hdr_attempt_plus_reinvocation = int(splits[-1]) + attempts[hash_key] + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((mkey, value)) + + kwargs["metadata"] = all_metadata + return fn(db, *args, **kwargs) + + setattr(typ, key, as_proxy) + + +def alt_foo(): + memoize_map = dict() + orig_get_attr = getattr(obj, "__getattribute__") + hex_orig = hex(id(orig_get_attr)) + hex_patched = None + + def patched_getattribute(obj, key, *args, **kwargs): + if key.startswith("_"): + return orig_get_attr(obj, key, *args, **kwargs) + + if key != "batch_create_sessions": + return orig_get_attr(obj, key, *args, **kwargs) + + map_key = hex(id(key)) + hex(id(obj)) + memoized = memoize_map.get(map_key, None) + if memoized: + if False: + print( + "memoized_hit", + key, + "\033[35m", + inspect.getsource(orig_value), + "\033[00m", + ) + print("memoized_hit", key, "\033[35m", map_key, "\033[00m") + return memoized + + orig_value = orig_get_attr(obj, key, *args, **kwargs) + if not callable(orig_value): + return orig_value + + signature = inspect.signature(orig_value) + if signature.parameters.get("metadata", None) is None: + return orig_value + + if False: + print( + key, + "\033[34m", + map_key, + "\033[00m", + signature, + signature.parameters.get("metadata", None), + ) + + if False: + stack = inspect.stack() + ends = stack[-50:-20] + for i, st in enumerate(ends): + print(i, st.filename, st.lineno) + + print( + "\033[33mmonkey patching now\033[00m", + key, + "hex_orig", + hex_orig, + "hex_patched", + hex_patched, + ) + counters = dict(attempt=0) + + def patched_method(*aargs, **kkwargs): + counters["attempt"] += 1 + print("counters", counters) + metadata = kkwargs.get("metadata", None) + if not metadata: + return orig_value(*aargs, **kkwargs) + + # 4. Find all the headers that match the target header key. + all_metadata = [] + for mkey, value in metadata: + if mkey is REQ_ID_HEADER_KEY: + attempt = counters["attempt"] + if attempt > 1: + # 5. Increment the original_attempt with that of our re-invocation count. + splits = value.split(".") + print("\033[34mkey", mkey, "\033[00m", splits) + hdr_attempt_plus_reinvocation = int(splits[-1]) + attempt + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((mkey, value)) + + kwargs["metadata"] = all_metadata + + try: + return orig_value(*aargs, **kkwargs) + + except (InternalServerError, ServiceUnavailable) as exc: + print("caught this exception, incrementing", exc) + counters["attempt"] += 1 + raise exc + + memoize_map[map_key] = patched_method + return patched_method + + hex_patched = hex(id(patched_getattribute)) + setattr(obj, "__getattribute__", patched_getattribute) + + +def foo(api): + global patched + global patched_mu + + # For each method, add an _attempt value that'll then be + # retrieved for each retry. + # 1. Patch the __getattribute__ method to match items in our manifest. + target = type(api) + hex_id = hex(id(target)) + if patched.get(hex_id, None) is not None: + return + + orig_getattribute = getattr(target, "__getattribute__") + + def patched_getattribute(obj, key, *args, **kwargs): + # 1. Skip modifying private and mangled methods. + if key.startswith("_"): + return orig_getattribute(obj, key, *args, **kwargs) + + attr = orig_getattribute(obj, key, *args, **kwargs) + + # 2. Skip over non-methods. + if not callable(attr): + patched_mu.release() + return attr + + patched_key = hex(id(key)) + hex(id(obj)) + patched_mu.acquire() + already_patched = patched.get(patched_key, None) + + other_attempts = dict(attempts=0) + + # 3. Wrap the callable attribute and then capture its metadata keyed argument. + def wrapped_attr(*args, **kwargs): + print("\033[31m", key, "attempt", other_attempts["attempts"], "\033[00m") + other_attempts["attempts"] += 1 + + metadata = kwargs.get("metadata", []) + if not metadata: + # Increment the reinvocation count. + wrapped_attr._attempt += 1 + return attr(*args, **kwargs) + + print( + "\033[35mwrapped_attr", + key, + args, + kwargs, + "attempt", + wrapped_attr._attempt, + "\033[00m", + ) + + # 4. Find all the headers that match the target header key. + all_metadata = [] + for mkey, value in metadata: + if mkey is REQ_ID_HEADER_KEY: + if wrapped_attr._attempt > 0: + # 5. Increment the original_attempt with that of our re-invocation count. + splits = value.split(".") + print("\033[34mkey", mkey, "\033[00m", splits) + hdr_attempt_plus_reinvocation = ( + int(splits[-1]) + wrapped_attr._attempt + ) + splits[-1] = str(hdr_attempt_plus_reinvocation) + value = ".".join(splits) + + all_metadata.append((mkey, value)) + + kwargs["metadata"] = all_metadata + wrapped_attr._attempt += 1 + print(key, "\033[36mreplaced_all_metadata", all_metadata, "\033[00m") + return attr(*args, **kwargs) + + if already_patched: + print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched) + setattr(attr, "patched", True) + # Increment the reinvocation count. + patched_mu.release() + return already_patched + + patched[patched_key] = wrapped_attr + setattr(wrapped_attr, "_attempt", 0) + patched_mu.release() + return wrapped_attr + + setattr(target, "__getattribute__", patched_getattribute) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0cbf044672..61088dd7a5 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -26,6 +26,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _merge_Transaction_Options, + AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1 import RequestOptions @@ -388,7 +389,11 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals method = functools.partial( api.batch_write, request=request, - metadata=metadata, + metadata=database.metadata_with_request_id( + database._next_nth_request, + 1, + metadata, + ), ) response = _retry( method, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f2d570feb9..c635244154 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -51,6 +51,7 @@ from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import ( + AtomicCounter, _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0bc0135ba0..af4c7a1f7b 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -15,15 +15,22 @@ """Pools managing shared Session objects.""" import datetime +import functools import queue import time +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import ServiceUnavailable from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( + _check_rst_stream_error, + _check_unavailable, _metadata_with_prefix, _metadata_with_leader_aware_routing, + _retry, + AtomicCounter, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, @@ -254,11 +261,27 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + print("attempt", attempt.value) + + def wrapped_method(*args, **kwargs): + print("\033[33mwrapped_method", *args, "kwargs", kwargs, "\033[00m") + method = functools.partial( + api.batch_create_sessions, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.increment(), metadata + ), + ) + return method(*args, **kwargs) + + resp = _retry( + wrapped_method, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + ServiceUnavailable: _check_unavailable, + }, ) add_span_event( @@ -462,6 +485,48 @@ def clear(self): session.delete() +class CallerInterceptor: + def __init__(self, api): + pass + + def intercept(self): + # 1. For each callable method, intercept it and on each call, if we detect + # retryable UNAVAILABLE and INTERNAL grpc status codes, automatically retry. + elems = dir(self.__api) + for name in elems: + if name.startswith("_"): + continue + + if name != "batch_create_sessions": + continue + + attr = getattr(self.__api, name) + if not callable(attr): + continue + + setattr(self.__api, name, rewrite_and_handle_retries(attr)) + + +""" +from google.api_core.exceptions import ( + InternalServerError, + ServiceUnavailable, +) + +def rewrite_and_handle_retries(fn): + attempts = AtomicCounter(0) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except InternalServerError, ServiceUnavailable Aborted as exc: + attempts.increment() + finally: + pass + + return wrapper +""" + + class PingingPool(AbstractSessionPool): """Concrete session pool implementation: @@ -561,11 +626,23 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + return api.batch_create_sessions( + request=request, + metadata=database.metadata_with_request_id( + database._next_nth_request, attempt.increment(), metadata + ), + ) + + resp = _retry( + wrapped_method, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + ServiceUnavailable: _check_unavailable, + }, ) add_span_event( diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 5af89fea42..2c8c651bfc 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -27,6 +27,7 @@ MethodAbortInterceptor, XGoogRequestIDHeaderInterceptor, ) +from google.cloud.spanner_v1._helpers import inject_retry_header_control class TestDatabase(Database): @@ -70,6 +71,14 @@ def __init__( @property def spanner_api(self): + api = self.__generate_spanner_api() + if not api: + return api + + inject_retry_header_control(api) + return api + + def __generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client = self._instance._client diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index 71b77e4d16..4fe4ed147d 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -71,9 +71,6 @@ def reset(self): class XGoogRequestIDHeaderInterceptor(ClientInterceptor): - # TODO:(@odeke-em): delete this guard when PR #1367 is merged. - X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED = True - def __init__(self): self._unary_req_segments = [] self._stream_req_segments = [] @@ -87,24 +84,22 @@ def intercept(self, method, request_or_iterator, call_details): x_goog_request_id = value break - if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED and not x_goog_request_id: + if not x_goog_request_id: raise Exception( f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}" ) response_or_iterator = method(request_or_iterator, call_details) streaming = getattr(response_or_iterator, "__iter__", None) is not None - - if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED: - with self.__lock: - if streaming: - self._stream_req_segments.append( - (call_details.method, parse_request_id(x_goog_request_id)) - ) - else: - self._unary_req_segments.append( - (call_details.method, parse_request_id(x_goog_request_id)) - ) + with self.__lock: + if streaming: + self._stream_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) + else: + self._unary_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) return response_or_iterator diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 7b4538d601..0c82075910 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -61,7 +61,7 @@ def aborted_status() -> _Status: def unavailable_status() -> _Status: error = status_pb2.Status( code=code_pb2.UNAVAILABLE, - message="Service unavailable.", + message="Received unexpected EOS on DATA frame from server", ) retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) status = _Status( @@ -77,6 +77,25 @@ def unavailable_status() -> _Status: ) return status +# Creates an INTERNAL status with the smallest possible retry delay. +def internal_status() -> _Status: + error = status_pb2.Status( + code=code_pb2.INTERNAL, + message="Service unavailable.", + ) + retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) + status = _Status( + code=code_to_grpc_status_code(error.code), + details=error.message, + trailing_metadata=( + ("grpc-status-details-bin", error.SerializeToString()), + ( + "google.rpc.retryinfo-bin", + retry_info.SerializeToString(), + ), + ), + ) + return status def add_error(method: str, error: status_pb2.Status): MockServerTestBase.spanner_service.mock_spanner.add_error(method, error) diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py new file mode 100644 index 0000000000..805f422d0e --- /dev/null +++ b/tests/mockserver_tests/test_request_id_header.py @@ -0,0 +1,362 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import threading + +from google.cloud.spanner_v1 import ( + BatchCreateSessionsRequest, + BeginTransactionRequest, + ExecuteSqlRequest, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, + aborted_status, + add_error, + internal_status, + unavailable_status, +) + + +class TestRequestIDHeader(MockServerTestBase): + def tearDown(self): + self.database._x_goog_request_id_interceptor.reset() + + def test_snapshot_execute_sql(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(2, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ) + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_snapshot_read_concurrent(self): + db = self.database + # Trigger BatchCreateSessions firstly. + with db.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + for row in rows: + _ = row + + # The other requests can then proceed. + def select1(): + with db.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + res_list = [] + for row in rows: + self.assertEqual(1, row[0]) + res_list.append(row) + self.assertEqual(1, len(res_list)) + + n = 10 + threads = [] + for i in range(n): + th = threading.Thread(target=select1, name=f"snapshot-select1-{i}") + th.run() + threads.append(th) + + random.shuffle(threads) + + while True: + n_finished = 0 + for thread in threads: + if thread.is_alive(): + thread.join() + else: + n_finished += 1 + + if n_finished == len(threads): + break + + requests = self.spanner_service.requests + self.assertEqual(2 + n * 2, len(requests), msg=requests) + + client_id = db._nth_client_id + channel_id = db._channel_id + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 21, 1), + ), + ] + assert got_unary_segments == want_unary_segments + + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 22, 1), + ), + ] + assert got_stream_segments == want_stream_segments + + def test_database_run_in_transaction_retries_on_abort(self): + counters = dict(aborted=0) + want_failed_attempts = 2 + + def select_in_txn(txn): + results = txn.execute_sql("select 1") + for row in results: + _ = row + + if counters["aborted"] < want_failed_attempts: + counters["aborted"] += 1 + add_error(SpannerServicer.Commit.__name__, aborted_status()) + + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + + self.database.run_in_transaction(select_in_txn) + + def test_database_execute_partitioned_dml_request_id(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + _ = self.database.execute_partitioned_dml("select 1") + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/BeginTransaction", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 3, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_unary_retryable_error(self): + add_select1_result() + add_error(SpannerServicer.BatchCreateSessions.__name__, internal_status()) + + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 2), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_streaming_retryable_error(self): + add_select1_result() + add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) + + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + NTH_CLIENT = self.database._nth_client_id + CHANNEL_ID = self.database._channel_id + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1, 1), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 2, 2), + ), + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def canonicalize_request_id_headers(self): + src = self.database._x_goog_request_id_interceptor + return src._stream_req_segments, src._unary_req_segments diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2014b60eb9..9a885a5a6f 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -596,6 +596,12 @@ def _test_batch_write_with_request_options( "traceparent is missing in metadata", ) + expected_metadata.append( + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ) + ) # Remove traceparent from actual metadata for comparison filtered_metadata = [item for item in metadata if item[0] != "traceparent"] diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 56ac22eab0..bef2373f01 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -120,7 +120,9 @@ def _make_database_admin_api(): def _make_spanner_api(): from google.cloud.spanner_v1 import SpannerClient - return mock.create_autospec(SpannerClient, instance=True) + api = mock.create_autospec(SpannerClient, instance=True) + api._transport = "transport" + return api def test_ctor_defaults(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -1300,6 +1302,19 @@ def _execute_partitioned_dml_helper( ], ) self.assertEqual(api.begin_transaction.call_count, 2) + api.begin_transaction.assert_called_with( + session=session.name, + options=txn_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + # Please note that this try was by an abort and not from service unavailable. + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.3.1", + ), + ], + ) else: api.begin_transaction.assert_called_with( session=session.name, @@ -1314,6 +1329,18 @@ def _execute_partitioned_dml_helper( ], ) self.assertEqual(api.begin_transaction.call_count, 1) + api.begin_transaction.assert_called_with( + session=session.name, + options=txn_options, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", + ), + ], + ) if params: expected_params = Struct( @@ -3241,6 +3268,10 @@ def test_context_mgr_success(self): metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ), ], ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b80d6bd18a..e7b7c3dee6 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -133,6 +133,7 @@ def _make_database( database.default_transaction_options = default_transaction_options inject_into_mock_database(database) + database.metadata_with_request_id = metadata_with_request_id return database @staticmethod diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 7b3ad679a9..ec044b140b 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -30,6 +30,11 @@ from google.cloud.spanner_v1.param_types import INT64 from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from google.api_core.retry import Retry +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, + _metadata_with_request_id, +) +from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -550,7 +555,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): metadata=[ ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ) ], ) @@ -1089,7 +1094,7 @@ def _execute_sql_helper( ("google-cloud-resource-prefix", database.name), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ), ], timeout=timeout, @@ -1266,7 +1271,7 @@ def _partition_read_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ), ], retry=retry, @@ -1449,7 +1454,7 @@ def _partition_query_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ), ], retry=retry, @@ -1906,10 +1911,18 @@ def test_begin_ok_exact_strong(self): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object):