diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index a3ff23bbdea07..8628c592584be 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -487,7 +487,7 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection: return conn except AirflowRuntimeError as e: if e.error.error == ErrorType.CONNECTION_NOT_FOUND: - raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") + raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") from None raise # check cache first diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 9b9afd69cf09a..571e4ce42a886 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -354,7 +354,7 @@ def dag_in_a_fn(): assert result is not None assert result.import_errors != {} if result.import_errors: - assert "CONNECTION_NOT_FOUND" in next(iter(result.import_errors.values())) + assert "The conn_id `my_conn` isn't defined" in next(iter(result.import_errors.values())) def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_client): tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'") diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 89344cf967d3b..e7918b2f07055 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -24,7 +24,8 @@ import attrs -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType log = logging.getLogger(__name__) @@ -149,7 +150,12 @@ def get_hook(self, *, hook_params=None): def get(cls, conn_id: str) -> Any: from airflow.sdk.execution_time.context import _get_connection - return _get_connection(conn_id) + try: + return _get_connection(conn_id) + except AirflowRuntimeError as e: + if e.error.error == ErrorType.CONNECTION_NOT_FOUND: + raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") from None + raise @property def extra_dejson(self) -> dict: diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index c76994995ebab..251d6861d63ca 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -272,7 +272,9 @@ class ConnectionAccessor: """Wrapper to access Connection entries in template.""" def __getattr__(self, conn_id: str) -> Any: - return _get_connection(conn_id) + from airflow.sdk.definitions.connection import Connection + + return Connection.get(conn_id) def __repr__(self) -> str: return "" diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py b/task-sdk/tests/task_sdk/bases/test_hook.py index f17ce8a12fcbb..63fa85395150a 100644 --- a/task-sdk/tests/task_sdk/bases/test_hook.py +++ b/task-sdk/tests/task_sdk/bases/test_hook.py @@ -19,8 +19,9 @@ import pytest +from airflow.exceptions import AirflowNotFoundException from airflow.sdk import BaseHook -from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, GetConnection from tests_common.test_utils.config import conf_vars @@ -64,7 +65,7 @@ def test_get_connection_not_found(self, mock_supervisor_comms): hook = BaseHook() mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) - with pytest.raises(AirflowRuntimeError, match="CONNECTION_NOT_FOUND"): + with pytest.raises(AirflowNotFoundException, match="The conn_id `test_conn` isn't defined"): hook.get_connection(conn_id=conn_id) def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, tmp_path): diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py b/task-sdk/tests/task_sdk/definitions/test_connections.py index 3bbb63a769788..6e4d977c6591b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connections.py +++ b/task-sdk/tests/task_sdk/definitions/test_connections.py @@ -23,9 +23,10 @@ import pytest from airflow.configuration import initialize_secrets_backends -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.sdk import Connection -from airflow.sdk.execution_time.comms import ConnectionResult +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -121,6 +122,13 @@ def test_conn_get(self, mock_supervisor_comms): extra=None, ) + def test_conn_get_not_found(self, mock_supervisor_comms): + error_response = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + mock_supervisor_comms.send.return_value = error_response + + with pytest.raises(AirflowNotFoundException, match="The conn_id `mysql_conn` isn't defined"): + _ = Connection.get(conn_id="mysql_conn") + class TestConnectionsFromSecrets: def test_get_connection_secrets_backend(self, mock_supervisor_comms, tmp_path):