Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def get_run_output(self, run_id: int) -> dict:
run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output

async def a_get_run_output(self, run_id: int) -> dict:
"""
Async version of `get_run_output()`.

:param run_id: id of the run
:return: output of the run
"""
json = {"run_id": run_id}
run_output = await self._a_do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output

def cancel_run(self, run_id: int) -> None:
"""
Cancel the run.
Expand Down
11 changes: 6 additions & 5 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:

if run_state.result_state == "FAILED":
task_run_id = None
if "tasks" in run_info:
for task in run_info["tasks"]:
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
for task in run_info.get("tasks", []):
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
Expand Down Expand Up @@ -160,13 +159,15 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
validate_trigger_event(event)
run_state = RunState.from_json(event["run_state"])
run_page_url = event["run_page_url"]
errors = event["errors"]
log.info("View run status, Spark UI, and logs at %s", run_page_url)

if run_state.is_successful:
log.info("Job run completed successfully.")
return

error_message = f"Job run failed with terminal state: {run_state}"
error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}"

if event["repair_run"]:
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
Expand Down
45 changes: 30 additions & 15 deletions airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,36 @@ async def run(self):
async with self.hook:
while True:
run_state = await self.hook.a_get_run_state(self.run_id)
if run_state.is_terminal:
yield TriggerEvent(
{
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
}
if not run_state.is_terminal:
self.log.info(
"run-id %s in run state %s. sleeping for %s seconds",
self.run_id,
run_state,
self.polling_period_seconds,
)
return
await asyncio.sleep(self.polling_period_seconds)
continue

self.log.info(
"run-id %s in run state %s. sleeping for %s seconds",
self.run_id,
run_state,
self.polling_period_seconds,
failed_tasks = []
if run_state.result_state == "FAILED":
run_info = await self.hook.a_get_run(self.run_id)
for task in run_info.get("tasks", []):
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
task_key = task["task_key"]
run_output = await self.hook.a_get_run_output(task_run_id)
if "error" in run_output:
error = run_output["error"]
else:
error = run_state.state_message
failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error})
yield TriggerEvent(
{
"run_id": self.run_id,
"run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
"repair_run": self.repair_run,
"errors": failed_tasks,
}
)
await asyncio.sleep(self.polling_period_seconds)
return
2 changes: 1 addition & 1 deletion airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict):

See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`.
"""
keys_to_check = ["run_id", "run_page_url", "run_state"]
keys_to_check = ["run_id", "run_page_url", "run_state", "errors"]
for key in keys_to_check:
if key not in event:
raise AirflowException(f"Could not find `{key}` in the event: {event}")
Expand Down
17 changes: 17 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,23 @@ async def test_get_cluster_state(self, mock_get):
timeout=self.hook.timeout_seconds,
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_get_run_output(self, mock_get):
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_OUTPUT_RESPONSE)
async with self.hook:
run_output = await self.hook.a_get_run_output(RUN_ID)
run_output_error = run_output.get("error")

assert run_output_error == ERROR_MESSAGE
mock_get.assert_called_once_with(
get_run_output_endpoint(HOST),
json={"run_id": RUN_ID},
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)


@pytest.mark.db_test
class TestDatabricksHookAsyncAadToken:
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def test_execute_complete_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand All @@ -1044,6 +1045,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
Expand Down Expand Up @@ -1594,6 +1596,7 @@ def test_execute_complete_success(self):
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand All @@ -1611,6 +1614,7 @@ def test_execute_complete_failure(self, db_mock_class):
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": False,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down Expand Up @@ -1641,6 +1645,7 @@ def test_execute_complete_failure_and_repair_run(
"run_page_url": RUN_PAGE_URL,
"run_state": run_state_failed.to_json(),
"repair_run": True,
"errors": [],
}

op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
Expand Down
99 changes: 97 additions & 2 deletions tests/providers/databricks/triggers/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,22 @@
RETRY_DELAY = 10
RETRY_LIMIT = 3
RUN_ID = 1
TASK_RUN_ID1 = 11
TASK_RUN_ID1_KEY = "first_task"
TASK_RUN_ID2 = 22
TASK_RUN_ID2_KEY = "second_task"
TASK_RUN_ID3 = 33
TASK_RUN_ID3_KEY = "third_task"
JOB_ID = 42
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
ERROR_MESSAGE = "error message from databricks API"
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}

RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"]

LIFE_CYCLE_STATE_PENDING = "PENDING"
LIFE_CYCLE_STATE_TERMINATED = "TERMINATED"
LIFE_CYCLE_STATE_INTERNAL_ERROR = "INTERNAL_ERROR"

STATE_MESSAGE = "Waiting for cluster"

Expand All @@ -66,6 +75,44 @@
"result_state": "SUCCESS",
},
}
GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = {
"job_id": JOB_ID,
"run_page_url": RUN_PAGE_URL,
"state": {
"life_cycle_state": LIFE_CYCLE_STATE_INTERNAL_ERROR,
"state_message": None,
"result_state": "FAILED",
},
"tasks": [
{
"run_id": TASK_RUN_ID1,
"task_key": TASK_RUN_ID1_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
},
{
"run_id": TASK_RUN_ID2,
"task_key": TASK_RUN_ID2_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "SUCCESS",
"state_message": None,
},
},
{
"run_id": TASK_RUN_ID3,
"task_key": TASK_RUN_ID3_KEY,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "Workload failed, see run output for details",
},
},
],
}


class TestDatabricksExecutionTrigger:
Expand Down Expand Up @@ -101,15 +148,21 @@ def test_serialize(self):
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url):
async def test_run_return_success(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="SUCCESS",
)
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -121,13 +174,52 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [],
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_run_return_failure(
self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output
):
mock_get_run_page_url.return_value = RUN_PAGE_URL
mock_get_run_state.return_value = RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED,
state_message="",
result_state="FAILED",
)
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED

trigger_event = self.trigger.run()
async for event in trigger_event:
assert event == TriggerEvent(
{
"run_id": RUN_ID,
"run_state": RunState(
life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED"
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [
{"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE},
{"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE},
],
}
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run")
@mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
async def test_sleep_between_retries(
self, mock_get_run_state, mock_sleep, mock_get_run, mock_get_run_output
):
mock_get_run_state.side_effect = [
RunState(
life_cycle_state=LIFE_CYCLE_STATE_PENDING,
Expand All @@ -140,6 +232,8 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
result_state="SUCCESS",
),
]
mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED
mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE

trigger_event = self.trigger.run()
async for event in trigger_event:
Expand All @@ -151,6 +245,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
).to_json(),
"run_page_url": RUN_PAGE_URL,
"repair_run": False,
"errors": [],
}
)
mock_sleep.assert_called_once()
Expand Down
1 change: 1 addition & 0 deletions tests/providers/databricks/utils/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_validate_trigger_event_success(self):
"run_id": RUN_ID,
"run_page_url": RUN_PAGE_URL,
"run_state": RunState("TERMINATED", "SUCCESS", "").to_json(),
"errors": [],
}
assert validate_trigger_event(event) is None

Expand Down