From 34beb4a395d0386be33d1ef9f2f3717a58658186 Mon Sep 17 00:00:00 2001 From: Josh Fell Date: Thu, 9 Dec 2021 15:14:01 -0500 Subject: [PATCH 1/2] Remove db call from DatabricksHook.__init__ --- airflow/providers/databricks/hooks/databricks.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index ac8d9511e0676..5b120cfff9fc7 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -135,11 +135,7 @@ def __init__( ) -> None: super().__init__() self.databricks_conn_id = databricks_conn_id - self.databricks_conn = self.get_connection(databricks_conn_id) - if 'host' in self.databricks_conn.extra_dejson: - self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) - else: - self.host = self._parse_host(self.databricks_conn.host) + self.databricks_conn = None self.timeout_seconds = timeout_seconds if retry_limit < 1: raise ValueError('Retry limit must be greater than equal to 1') @@ -303,6 +299,14 @@ def _do_api_call(self, endpoint_info, json): :rtype: dict """ method, endpoint = endpoint_info + + self.databricks_conn = self.get_connection(self.databricks_conn_id) + + if 'host' in self.databricks_conn.extra_dejson: + self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) + else: + self.host = self._parse_host(self.databricks_conn.host) + url = f'https://{self.host}/{endpoint}' aad_headers = self._get_aad_headers() From cdb54db55fd51c79113e48f1d21b5582b24ba9f6 Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Thu, 9 Dec 2021 16:17:51 -0500 Subject: [PATCH 2/2] Add check before calling get_connection() Co-authored-by: Ash Berlin-Taylor --- airflow/providers/databricks/hooks/databricks.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 5b120cfff9fc7..e56b15a99b33f 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -300,12 +300,13 @@ def _do_api_call(self, endpoint_info, json): """ method, endpoint = endpoint_info - self.databricks_conn = self.get_connection(self.databricks_conn_id) + if self.databricks_conn is None: + self.databricks_conn = self.get_connection(self.databricks_conn_id) - if 'host' in self.databricks_conn.extra_dejson: - self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) - else: - self.host = self._parse_host(self.databricks_conn.host) + if 'host' in self.databricks_conn.extra_dejson: + self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) + else: + self.host = self._parse_host(self.databricks_conn.host) url = f'https://{self.host}/{endpoint}'