diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index d9e4e893b1d0e..b160a2e2b2cff 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -642,8 +642,10 @@ def __init__( *, project: str | None = None, session: ClientSession | None = None, + scopes: Sequence[str] | None = None, ) -> None: - super().__init__(session=cast(Session, session)) + _scopes: list[str] | None = list(scopes) if scopes else None + super().__init__(session=cast(Session, session), scopes=_scopes) self.credentials = credentials self.project = project @@ -659,6 +661,7 @@ async def from_hook( credentials=credentials, project=project, session=session, + scopes=hook.scopes, ) async def get_project(self) -> str | None: diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index f4b71d7449ffe..fd53930c5c1ff 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -50,6 +50,7 @@ MODULE_NAME = "airflow.providers.google.common.hooks.base_google" PROJECT_ID = "PROJECT_ID" ENV_VALUE = "/tmp/a" +SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] class NoForbiddenAfterCount: @@ -881,14 +882,14 @@ class TestCredentialsToken: @pytest.mark.asyncio async def test_get_project(self): mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) - token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID, scopes=SCOPES) assert await token.get_project() == PROJECT_ID @pytest.mark.asyncio async def test_get(self): mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) mock_credentials.token = "ACCESS_TOKEN" - token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID, scopes=SCOPES) assert await token.get() == "ACCESS_TOKEN" mock_credentials.refresh.assert_called_once()