From 6a0e6abcb919df3ec59c58dcb1ebf784e6177eef Mon Sep 17 00:00:00 2001 From: Cedrik Neumann Date: Wed, 17 Jan 2024 16:11:35 +0100 Subject: [PATCH 1/5] feat: full support for google credentials in gcloud-aio clients The class CredentialsToken implements the ability to generate access tokens to be used in gcloud-aio clients from Google credentials objects provided by instances of Google Cloud hooks. With this change we provide all credentials based capabilities of Google Cloud hooks (for exmaple impersonation) to gcloud-aio clients. --- .../providers/google/cloud/hooks/bigquery.py | 25 ++++--- airflow/providers/google/cloud/hooks/gcs.py | 7 +- .../google/common/hooks/base_google.py | 55 +++++++++++++++- .../google/cloud/hooks/test_bigquery.py | 12 +++- .../google/common/hooks/test_base_google.py | 65 +++++++++++++++++++ 5 files changed, 148 insertions(+), 16 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 3ad29c66f05cf..bb9e75e26e13a 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -3268,8 +3268,13 @@ async def get_job_instance( self, project_id: str | None, job_id: str | None, session: ClientSession ) -> Job: """Get the specified job resource by job ID and project ID.""" - with await self.service_file_as_context() as f: - return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session)) + token = await self.get_token(session=session) + return Job( + job_id=job_id, + project=project_id, + token=token, + session=cast(Session, session), + ) async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]: async with ClientSession() as s: @@ -3513,11 +3518,11 @@ async def get_table_client( access to the specified project. :param session: aiohttp ClientSession """ - with await self.service_file_as_context() as file: - return Table_async( - dataset_name=dataset, - table_name=table_id, - project=project_id, - service_file=file, - session=cast(Session, session), - ) + token = await self.get_token(session=session) + return Table_async( + dataset_name=dataset, + table_name=table_id, + project=project_id, + token=token, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 02055583ce15a..42c876fbb7968 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -1391,5 +1391,8 @@ class GCSAsyncHook(GoogleBaseAsyncHook): async def get_storage_client(self, session: ClientSession) -> Storage: """Returns a Google Cloud Storage service object.""" - with await self.service_file_as_context() as file: - return Storage(service_file=file, session=cast(Session, session)) + token = await self.get_token(session=session) + return Storage( + token=token, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 99120820e598c..71eaa9938100a 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -18,6 +18,7 @@ """This module contains a Google Cloud API base hook.""" from __future__ import annotations +import datetime import functools import json import logging @@ -35,6 +36,7 @@ import requests import tenacity from asgiref.sync import sync_to_async +from gcloud.aio.auth.token import Token from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests from google.auth import _cloud_sdk, compute_engine # type: ignore[attr-defined] from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS @@ -43,6 +45,7 @@ from googleapiclient import discovery from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload, build_http, set_user_agent +from requests import Session from airflow import version from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -56,7 +59,9 @@ from airflow.utils.process_utils import patch_environ if TYPE_CHECKING: + from aiohttp import ClientSession from google.api_core.gapic_v1.client_info import ClientInfo + from google.auth.credentials import Credentials log = logging.getLogger(__name__) @@ -623,6 +628,51 @@ def test_connection(self): return status, message +class CredentialsToken(Token): + """A token implementation which makes Google credentials objects accessible to [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients. + + This class allows us to create token instances from credentials objects and thus supports a variety of use cases for Google + credentials in Airflow (i.e. impersonation chain). By relying on a existing credentials object we leverage functionality provided by the GoogleBaseHook + for generating credentials objects. + """ + + def __init__( + self, + credentials: Credentials, + *, + project: str | None = None, + session: ClientSession | None = None, + ) -> None: + super().__init__(session=cast(Session, session)) + self.credentials = credentials + self.project = project + + @classmethod + async def from_hook( + cls, + hook: GoogleBaseHook, + *, + session: ClientSession | None = None, + ) -> CredentialsToken: + credentials, project = hook.get_credentials_and_project_id() + return cls( + credentials=credentials, + project=project, + session=session, + ) + + async def get_project(self) -> str | None: + return self.project + + async def acquire_access_token(self, timeout: int = 10) -> None: + await sync_to_async(self.credentials.refresh)(google.auth.transport.requests.Request()) + + self.access_token = cast(str, self.credentials.token) + self.access_token_duration = 3600 + self.access_token_acquired_at = datetime.datetime.utcnow() + self.acquiring = None + + class GoogleBaseAsyncHook(BaseHook): """GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker.""" @@ -639,6 +689,7 @@ async def get_sync_hook(self) -> Any: self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs) return self._sync_hook - async def service_file_as_context(self) -> Any: + async def get_token(self, *, session: ClientSession | None = None) -> CredentialsToken: + """Returns a Token instance for use in [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients.""" sync_hook = await self.get_sync_hook() - return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)() + return await CredentialsToken.from_hook(sync_hook, session=session) diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 47fc20464749e..730fa6734ad9e 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -22,6 +22,7 @@ from unittest import mock from unittest.mock import AsyncMock +import google.auth import pytest from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core import page_iterator @@ -2143,8 +2144,12 @@ def get_credentials_and_project_id(self): class TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass): @pytest.mark.db_test @pytest.mark.asyncio + @mock.patch("google.auth.default") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") - async def test_get_job_instance(self, mock_session): + async def test_get_job_instance(self, mock_session, mock_auth_default): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) hook = BigQueryAsyncHook() result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session) assert isinstance(result, Job) @@ -2315,10 +2320,13 @@ def test_convert_to_float_if_possible(self, test_input, expected): @pytest.mark.db_test @pytest.mark.asyncio + @mock.patch("google.auth.default") @mock.patch("aiohttp.client.ClientSession") - async def test_get_table_client(self, mock_session): + async def test_get_table_client(self, mock_session, mock_auth_default): """Test get_table_client async function and check whether the return value is a Table instance object""" + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) hook = BigQueryTableAsyncHook() result = await hook.get_table_client( dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index bd4342ec66d78..c50062fc40ed9 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -26,6 +26,7 @@ from unittest.mock import patch import google.auth +import google.auth.compute_engine import pytest import tenacity from google.auth.environment_vars import CREDENTIALS @@ -874,3 +875,67 @@ def test_should_fallback_when_empty_string_in_env_var(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") assert isinstance(instance.num_retries, int) assert 5 == instance.num_retries + + +class TestGoogleBaseAsyncHook: + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token(self, mock_auth_default, monkeypatch) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID", + ) + + instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default") + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "ACCESS_TOKEN" + mock_credentials.refresh.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token_impersonation(self, mock_auth_default, monkeypatch, requests_mock) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID", + ) + requests_mock.post( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken", + text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}', + ) + + instance = hook.GoogleBaseAsyncHook( + gcp_conn_id="google_cloud_default", + impersonation_chain="SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com", + ) + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "IMPERSONATED_ACCESS_TOKEN" + + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token_impersonation_conn(self, mock_auth_default, monkeypatch, requests_mock) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID&impersonation_chain=SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com", + ) + requests_mock.post( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken", + text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}', + ) + + instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default") + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "IMPERSONATED_ACCESS_TOKEN" From 4043ba2e9c68093bc591bdcb4608a31919e63d2d Mon Sep 17 00:00:00 2001 From: Cedrik Neumann Date: Thu, 18 Jan 2024 09:16:51 +0100 Subject: [PATCH 2/5] test: add tests for CredentialsToken --- .../google/common/hooks/test_base_google.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index c50062fc40ed9..5315aac3cb710 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -877,6 +877,34 @@ def test_should_fallback_when_empty_string_in_env_var(self): assert 5 == instance.num_retries +class TestCredentialsToken: + @pytest.mark.asyncio + async def test_get_project(self): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + token = hook.CredentialsToken(mock_credentials, project=PROJECT_ID) + assert await token.get_project() == PROJECT_ID + + @pytest.mark.asyncio + async def test_get(self): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + token = hook.CredentialsToken(mock_credentials, project=PROJECT_ID) + assert await token.get() == "ACCESS_TOKEN" + mock_credentials.refresh.assert_called_once() + + @pytest.mark.asyncio + @mock.patch(MODULE_NAME + ".get_credentials_and_project_id", return_value=("CREDENTIALS", "PROJECT_ID")) + async def test_from_hook(self, get_creds_and_project, monkeypatch): + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://", + ) + instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") + token = await hook.CredentialsToken.from_hook(instance) + assert token.credentials == "CREDENTIALS" + assert token.project == "PROJECT_ID" + + class TestGoogleBaseAsyncHook: @pytest.mark.asyncio @mock.patch("google.auth.default") From a8d3a26eef513839179749b82941cc64d3dd39d2 Mon Sep 17 00:00:00 2001 From: Cedrik Neumann <7921017+m1racoli@users.noreply.github.com> Date: Fri, 19 Jan 2024 09:05:37 +0100 Subject: [PATCH 3/5] Update tests/providers/google/common/hooks/test_base_google.py Co-authored-by: Wei Lee --- tests/providers/google/common/hooks/test_base_google.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index 5315aac3cb710..16ff061282f34 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -893,7 +893,7 @@ async def test_get(self): mock_credentials.refresh.assert_called_once() @pytest.mark.asyncio - @mock.patch(MODULE_NAME + ".get_credentials_and_project_id", return_value=("CREDENTIALS", "PROJECT_ID")) + @mock.patch(f"{MODULE_NAME}.get_credentials_and_project_id", return_value=("CREDENTIALS", "PROJECT_ID")) async def test_from_hook(self, get_creds_and_project, monkeypatch): monkeypatch.setenv( "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", From 9a797319809153e8e2937f6f1ff8469859091555 Mon Sep 17 00:00:00 2001 From: Cedrik Neumann Date: Tue, 23 Jan 2024 11:42:31 +0100 Subject: [PATCH 4/5] refactor: make CredentialsToken private This class is only intended to be used within the Google provider and might need to change in the future. Making it private in order to avoid a potential breaking change in the future. --- airflow/providers/google/common/hooks/base_google.py | 8 ++++---- tests/providers/google/common/hooks/test_base_google.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 71eaa9938100a..439684368cf13 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -628,7 +628,7 @@ def test_connection(self): return status, message -class CredentialsToken(Token): +class _CredentialsToken(Token): """A token implementation which makes Google credentials objects accessible to [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients. This class allows us to create token instances from credentials objects and thus supports a variety of use cases for Google @@ -653,7 +653,7 @@ async def from_hook( hook: GoogleBaseHook, *, session: ClientSession | None = None, - ) -> CredentialsToken: + ) -> _CredentialsToken: credentials, project = hook.get_credentials_and_project_id() return cls( credentials=credentials, @@ -689,7 +689,7 @@ async def get_sync_hook(self) -> Any: self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs) return self._sync_hook - async def get_token(self, *, session: ClientSession | None = None) -> CredentialsToken: + async def get_token(self, *, session: ClientSession | None = None) -> _CredentialsToken: """Returns a Token instance for use in [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients.""" sync_hook = await self.get_sync_hook() - return await CredentialsToken.from_hook(sync_hook, session=session) + return await _CredentialsToken.from_hook(sync_hook, session=session) diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index 16ff061282f34..f4b71d7449ffe 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -881,14 +881,14 @@ class TestCredentialsToken: @pytest.mark.asyncio async def test_get_project(self): mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) - token = hook.CredentialsToken(mock_credentials, project=PROJECT_ID) + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) assert await token.get_project() == PROJECT_ID @pytest.mark.asyncio async def test_get(self): mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) mock_credentials.token = "ACCESS_TOKEN" - token = hook.CredentialsToken(mock_credentials, project=PROJECT_ID) + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) assert await token.get() == "ACCESS_TOKEN" mock_credentials.refresh.assert_called_once() @@ -900,7 +900,7 @@ async def test_from_hook(self, get_creds_and_project, monkeypatch): "google-cloud-platform://", ) instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") - token = await hook.CredentialsToken.from_hook(instance) + token = await hook._CredentialsToken.from_hook(instance) assert token.credentials == "CREDENTIALS" assert token.project == "PROJECT_ID" From 524540077694c861c64e8928fa7a9103b614b9bd Mon Sep 17 00:00:00 2001 From: Cedrik Neumann Date: Tue, 23 Jan 2024 11:49:41 +0100 Subject: [PATCH 5/5] Revert removal of service_file_as_context The method `service_file_as_context` not being used anymore in the airflow, but it is public and removing would imply a breaking changes for users for the Google provider. Therefore we keep it. --- airflow/providers/google/common/hooks/base_google.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 439684368cf13..d9e4e893b1d0e 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -693,3 +693,8 @@ async def get_token(self, *, session: ClientSession | None = None) -> _Credentia """Returns a Token instance for use in [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients.""" sync_hook = await self.get_sync_hook() return await _CredentialsToken.from_hook(sync_hook, session=session) + + async def service_file_as_context(self) -> Any: + """This is the async equivalent of the non-async GoogleBaseHook's `provide_gcp_credential_file_as_context` method.""" + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)()