diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 index 9688ca7501..1d14fbd642 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 @@ -369,6 +369,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): method=rpc, request=request, response=response, + metadata=metadata, ) {%- endif %} {%- if not method.void %} diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/pagers.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/pagers.py.j2 index 0e7ef018a7..da4dc87581 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, Sequence, Tuple {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -35,10 +35,10 @@ class {{ method.name }}Pager: the most recent response is retained, and thus used for attribute lookup. """ def __init__(self, - method: Callable[[{{ method.input.ident }}], - {{ method.output.ident }}], + method: Callable[..., {{ method.output.ident }}], request: {{ method.input.ident }}, - response: {{ method.output.ident }}): + response: {{ method.output.ident }}, + metadata: Sequence[Tuple[str, str]] = ())): """Instantiate the pager. Args: @@ -48,10 +48,13 @@ class {{ method.name }}Pager: The initial request object. response (:class:`{{ method.output.ident.sphinx }}`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = {{ method.input.ident }}(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -61,7 +64,7 @@ class {{ method.name }}Pager: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}: diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 7a83fd1122..e4583db132 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -24,6 +24,9 @@ from google.api_core import future from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% endif -%} +{% if service.has_pagers -%} +from google.api_core import gapic_v1 +{% endif -%} {% for method in service.methods.values() -%} {% for ref_type in method.ref_types if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation') @@ -442,9 +445,24 @@ def test_{{ method.name|snake_case }}_pager(): ), RuntimeError, ) - results = [i for i in client.{{ method.name|snake_case }}( - request={}, - )] + + metadata = () + {% if method.field_headers -%} + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', ''), + {%- endif %} + {%- endfor %} + )), + ) + {% endif -%} + pager = client.{{ method.name|snake_case }}(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] assert len(results) == 6 assert all(isinstance(i, {{ method.paged_result_field.message.ident }}) for i in results)