Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ def user_agent_value(self) -> str:
return ua_string

@cached_property
def host(self) -> str:
def host(self) -> str | None:
host = None
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
elif self.databricks_conn.host:
host = self._parse_host(self.databricks_conn.host)

return host

async def __aenter__(self):
Expand Down Expand Up @@ -207,6 +207,11 @@ def _parse_host(host: str) -> str:
# In this case, host = xx.cloud.databricks.com
return host

def _get_connection_attr(self, attr_name: str) -> str:
if not (attr := getattr(self.databricks_conn, attr_name)):
raise ValueError(f"`{attr_name}` must be present in Connection")
return attr

def _get_retry_object(self) -> Retrying:
"""
Instantiate a retry object.
Expand Down Expand Up @@ -235,7 +240,7 @@ def _get_sp_token(self, resource: str) -> str:
with attempt:
resp = requests.post(
resource,
auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password),
auth=HTTPBasicAuth(self._get_connection_attr("login"), self.databricks_conn.password),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
Expand Down Expand Up @@ -271,7 +276,9 @@ async def _a_get_sp_token(self, resource: str) -> str:
with attempt:
async with self._session.post(
resource,
auth=aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password),
auth=aiohttp.BasicAuth(
self._get_connection_attr("login"), self.databricks_conn.password
),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
Expand Down Expand Up @@ -316,7 +323,7 @@ def _get_aad_token(self, resource: str) -> str:
token = ManagedIdentityCredential().get_token(f"{resource}/.default")
else:
credential = ClientSecretCredential(
client_id=self.databricks_conn.login,
client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
Expand Down Expand Up @@ -364,7 +371,7 @@ async def _a_get_aad_token(self, resource: str) -> str:
token = await credential.get_token(f"{resource}/.default")
else:
async with AsyncClientSecretCredential(
client_id=self.databricks_conn.login,
client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
) as credential:
Expand Down Expand Up @@ -678,7 +685,7 @@ def _do_api_call(
auth = _TokenAuth(token)
else:
self.log.info("Using basic auth.")
auth = HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password)
auth = HTTPBasicAuth(self._get_connection_attr("login"), self.databricks_conn.password)

request_func: Any
if method == "GET":
Expand Down Expand Up @@ -745,7 +752,7 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A
auth = BearerAuth(token)
else:
self.log.info("Using basic auth.")
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)
auth = aiohttp.BasicAuth(self._get_connection_attr("login"), self.databricks_conn.password)

request_func: Any
if method == "GET":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def sqlalchemy_url(self) -> URL:
"catalog": self.catalog,
"schema": self.schema,
}
url_query = {k: v for k, v in url_query.items() if v is not None}
url_query_formatted: dict[str, str] = {k: v for k, v in url_query.items() if v is not None}
return URL.create(
drivername="databricks",
username="token",
password=self._get_token(raise_error=True),
host=self.host,
query=url_query,
query=url_query_formatted,
)

def get_uri(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _get_dagrun(dag, run_id: str, session: Session) -> DagRun:
if not session:
raise AirflowException("Session not provided.")

return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first()
return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).one()

@provide_session
def _clear_task_instances(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_get_dagrun_airflow2():

session = MagicMock()
dag = MagicMock(dag_id=DAG_ID)
session.query.return_value.filter.return_value.first.return_value = DagRun()
session.query.return_value.filter.return_value.one.return_value = DagRun()

result = _get_dagrun(dag, RUN_ID, session=session)

Expand Down