From 8a6e70a41127d5e10460b036fe1844935387f21d Mon Sep 17 00:00:00 2001 From: vincbeck Date: Mon, 3 Nov 2025 12:36:08 -0500 Subject: [PATCH] Fix mypy static errors in databricks provider --- .../databricks/hooks/databricks_base.py | 25 ++++++++++++------- .../databricks/hooks/databricks_sql.py | 4 +-- .../databricks/plugins/databricks_workflow.py | 2 +- .../plugins/test_databricks_workflow.py | 2 +- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index 919e21c3287ca..6415740d90e75 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -166,12 +166,12 @@ def user_agent_value(self) -> str: return ua_string @cached_property - def host(self) -> str: + def host(self) -> str | None: + host = None if "host" in self.databricks_conn.extra_dejson: host = self._parse_host(self.databricks_conn.extra_dejson["host"]) - else: + elif self.databricks_conn.host: host = self._parse_host(self.databricks_conn.host) - return host async def __aenter__(self): @@ -207,6 +207,11 @@ def _parse_host(host: str) -> str: # In this case, host = xx.cloud.databricks.com return host + def _get_connection_attr(self, attr_name: str) -> str: + if not (attr := getattr(self.databricks_conn, attr_name)): + raise ValueError(f"`{attr_name}` must be present in Connection") + return attr + def _get_retry_object(self) -> Retrying: """ Instantiate a retry object. @@ -235,7 +240,7 @@ def _get_sp_token(self, resource: str) -> str: with attempt: resp = requests.post( resource, - auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password), + auth=HTTPBasicAuth(self._get_connection_attr("login"), self.databricks_conn.password), data="grant_type=client_credentials&scope=all-apis", headers={ **self.user_agent_header, @@ -271,7 +276,9 @@ async def _a_get_sp_token(self, resource: str) -> str: with attempt: async with self._session.post( resource, - auth=aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password), + auth=aiohttp.BasicAuth( + self._get_connection_attr("login"), self.databricks_conn.password + ), data="grant_type=client_credentials&scope=all-apis", headers={ **self.user_agent_header, @@ -316,7 +323,7 @@ def _get_aad_token(self, resource: str) -> str: token = ManagedIdentityCredential().get_token(f"{resource}/.default") else: credential = ClientSecretCredential( - client_id=self.databricks_conn.login, + client_id=self._get_connection_attr("login"), client_secret=self.databricks_conn.password, tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], ) @@ -364,7 +371,7 @@ async def _a_get_aad_token(self, resource: str) -> str: token = await credential.get_token(f"{resource}/.default") else: async with AsyncClientSecretCredential( - client_id=self.databricks_conn.login, + client_id=self._get_connection_attr("login"), client_secret=self.databricks_conn.password, tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], ) as credential: @@ -678,7 +685,7 @@ def _do_api_call( auth = _TokenAuth(token) else: self.log.info("Using basic auth.") - auth = HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password) + auth = HTTPBasicAuth(self._get_connection_attr("login"), self.databricks_conn.password) request_func: Any if method == "GET": @@ -745,7 +752,7 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A auth = BearerAuth(token) else: self.log.info("Using basic auth.") - auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password) + auth = aiohttp.BasicAuth(self._get_connection_attr("login"), self.databricks_conn.password) request_func: Any if method == "GET": diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index d86d0453fa592..f7619bfbb2ef8 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -184,13 +184,13 @@ def sqlalchemy_url(self) -> URL: "catalog": self.catalog, "schema": self.schema, } - url_query = {k: v for k, v in url_query.items() if v is not None} + url_query_formatted: dict[str, str] = {k: v for k, v in url_query.items() if v is not None} return URL.create( drivername="databricks", username="token", password=self._get_token(raise_error=True), host=self.host, - query=url_query, + query=url_query_formatted, ) def get_uri(self) -> str: diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 8513ac27f1cfb..c53739634eeb7 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -143,7 +143,7 @@ def _get_dagrun(dag, run_id: str, session: Session) -> DagRun: if not session: raise AirflowException("Session not provided.") - return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() + return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).one() @provide_session def _clear_task_instances( diff --git a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py index c41f02b9239ab..1bb7974df8f67 100644 --- a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py @@ -82,7 +82,7 @@ def test_get_dagrun_airflow2(): session = MagicMock() dag = MagicMock(dag_id=DAG_ID) - session.query.return_value.filter.return_value.first.return_value = DagRun() + session.query.return_value.filter.return_value.one.return_value = DagRun() result = _get_dagrun(dag, RUN_ID, session=session)