Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 71 additions & 89 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated
Expand Down Expand Up @@ -638,7 +639,7 @@ def __init__(
request_id: str | None = None,
delete_on_error: bool = True,
use_if_exists: bool = True,
retry: AsyncRetry | _MethodDefault = DEFAULT,
retry: AsyncRetry | _MethodDefault | Retry = DEFAULT,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
Expand Down Expand Up @@ -1184,7 +1185,7 @@ def __init__(
project_id: str = PROVIDE_PROJECT_ID,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: AsyncRetry | _MethodDefault = DEFAULT,
retry: AsyncRetry | _MethodDefault | Retry = DEFAULT,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
Expand Down Expand Up @@ -2712,7 +2713,7 @@ def __init__(
region: str,
request_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
retry: AsyncRetry | _MethodDefault = DEFAULT,
retry: AsyncRetry | _MethodDefault | Retry = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
Expand Down Expand Up @@ -2985,10 +2986,10 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
def __init__(
self,
*,
region: str | None = None,
region: str,
project_id: str = PROVIDE_PROJECT_ID,
batch: dict | Batch,
batch_id: str,
batch_id: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
Expand Down Expand Up @@ -3021,20 +3022,20 @@ def __init__(
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# batch_id might not be set and will be generated
if self.batch_id:
link = DATAPROC_BATCH_LINK.format(
region=self.region, project_id=self.project_id, batch_id=self.batch_id
if self.asynchronous and self.deferrable:
raise AirflowException(
"Both asynchronous and deferrable parameters were passed. Please, provide only one."
)
self.log.info("Creating batch %s", self.batch_id)
self.log.info("Once started, the batch job will be available at %s", link)

batch_id: str = ""
if self.batch_id:
batch_id = self.batch_id
self.log.info("Starting batch %s", batch_id)
else:
self.log.info("Starting batch job. The batch ID will be generated since it was not provided.")
if self.region is None:
raise AirflowException("Region should be set here")
self.log.info("Starting batch. The batch ID will be generated since it was not provided.")

try:
self.operation = hook.create_batch(
self.operation = self.hook.create_batch(
region=self.region,
project_id=self.project_id,
batch=self.batch,
Expand All @@ -3044,85 +3045,62 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")

if not self.deferrable:
if not self.asynchronous:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)

else:
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=self.project_id,
region=self.region,
batch_id=self.batch_id,
)
return self.operation.operation.name

else:
# processing ends in execute_complete
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

except AlreadyExists:
self.log.info("Batch with given id already exists")
# This is only likely to happen if batch_id was provided
# Could be running if Airflow was restarted after task started
# poll until a final state is reached

self.log.info("Attaching to the job %s if it is still running.", self.batch_id)
self.log.info("Batch with given id already exists.")
self.log.info("Attaching to the job %s if it is still running.", batch_id)
else:
batch_id = self.operation.metadata.batch.split("/")[-1]
self.log.info("The batch %s was created.", batch_id)

# deferrable handling of a batch_id that already exists - processing ends in execute_complete
if self.deferrable:
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=self.project_id,
region=self.region,
batch_id=batch_id,
)

# non-deferrable handling of a batch_id that already exists
result = hook.wait_for_batch(
batch_id=self.batch_id,
if self.asynchronous:
batch = self.hook.get_batch(
batch_id=batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
wait_check_interval=self.polling_interval_seconds,
)
batch_id = self.batch_id or result.name.split("/")[-1]
self.log.info("The batch %s was created asynchronously. Exiting.", batch_id)
return Batch.to_dict(batch)

self.handle_batch_status(context, result.state, batch_id)
project_id = self.project_id or hook.project_id
if project_id:
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=project_id,
region=self.region,
batch_id=batch_id,
if self.deferrable:
self.defer(
trigger=DataprocBatchTrigger(
batch_id=batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
return Batch.to_dict(result)

self.log.info("Waiting for the completion of batch job %s", batch_id)
batch = self.hook.wait_for_batch(
batch_id=batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

self.handle_batch_status(context, batch.state, batch_id, batch.state_message)
return Batch.to_dict(batch)

@cached_property
def hook(self) -> DataprocHook:
return DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

def execute_complete(self, context, event=None) -> None:
"""
Expand All @@ -3135,23 +3113,27 @@ def execute_complete(self, context, event=None) -> None:
raise AirflowException("Batch failed.")
state = event["batch_state"]
batch_id = event["batch_id"]
self.handle_batch_status(context, state, batch_id)
self.handle_batch_status(context, state, batch_id, state_message=event["batch_state_message"])

def on_kill(self):
if self.operation:
self.operation.cancel()

def handle_batch_status(self, context: Context, state: Batch.State, batch_id: str) -> None:
def handle_batch_status(
self, context: Context, state: Batch.State, batch_id: str, state_message: str | None = None
) -> None:
# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
if state == Batch.State.FAILED:
raise AirflowException("Batch job %s failed. Driver Logs: %s", batch_id, link)
raise AirflowException(
f"Batch job {batch_id} failed with error: {state_message}\nDriver Logs: {link}"
)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
raise AirflowException("Batch job %s was cancelled. Driver logs: %s", batch_id, link)
raise AirflowException(f"Batch job {batch_id} was cancelled. Driver logs: {link}")
if state == Batch.State.STATE_UNSPECIFIED:
raise AirflowException("Batch job %s unspecified. Driver logs: %s", batch_id, link)
raise AirflowException(f"Batch job {batch_id} unspecified. Driver logs: {link}")
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)


Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ async def run(self):
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})

yield TriggerEvent(
{"batch_id": self.batch_id, "batch_state": state, "batch_state_message": batch.state_message}
)


class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
Expand Down
3 changes: 1 addition & 2 deletions scripts/ci/pre_commit/check_system_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
errors: list[str] = []

WATCHER_APPEND_INSTRUCTION = "list(dag.tasks) >> watcher()"
WATCHER_APPEND_INSTRUCTION_SHORT = " >> watcher()"

PYTEST_FUNCTION = """
from tests.system.utils import get_test_run # noqa: E402
Expand All @@ -53,7 +52,7 @@
def _check_file(file: Path):
content = file.read_text()
if "from tests.system.utils.watcher import watcher" in content:
index = content.find(WATCHER_APPEND_INSTRUCTION_SHORT)
index = content.find(WATCHER_APPEND_INSTRUCTION)
if index == -1:
errors.append(
f"[red]The example {file} imports tests.system.utils.watcher "
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,7 +2708,7 @@ def test_execute_batch_failed(self, mock_hook, to_dict_mock):
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.FAILED)
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED)
with pytest.raises(AirflowException):
op.execute(context=MagicMock())

Expand All @@ -2729,12 +2729,12 @@ def test_execute_batch_already_exists_succeeds(self, mock_hook):
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED)
mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}"
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=5,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand All @@ -2757,13 +2757,13 @@ def test_execute_batch_already_exists_fails(self, mock_hook):
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED)
mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}"
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=5,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand All @@ -2786,13 +2786,13 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook):
)
mock_hook.return_value.wait_for_operation.side_effect = AlreadyExists("")
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.CANCELLED)
mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}"
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
batch_id=BATCH_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
wait_check_interval=5,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand Down
30 changes: 25 additions & 5 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TEST_PROJECT_ID = "project-id"
TEST_REGION = "region"
TEST_BATCH_ID = "batch-id"
TEST_BATCH_STATE_MESSAGE = "Test batch state message"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
Expand Down Expand Up @@ -391,12 +392,15 @@ async def test_async_create_batch_trigger_triggers_on_success_should_execute_suc
Tests the DataprocBatchTrigger only fires once the batch execution reaches a successful state.
"""

mock_hook.return_value = async_get_batch(state=Batch.State.SUCCEEDED, batch_id=TEST_BATCH_ID)
mock_hook.return_value = async_get_batch(
state=Batch.State.SUCCEEDED, batch_id=TEST_BATCH_ID, state_message=TEST_BATCH_STATE_MESSAGE
)

expected_event = TriggerEvent(
{
"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.SUCCEEDED,
"batch_state_message": TEST_BATCH_STATE_MESSAGE,
}
)

Expand All @@ -409,9 +413,17 @@ async def test_async_create_batch_trigger_triggers_on_success_should_execute_suc
async def test_async_create_batch_trigger_run_returns_failed_event(
self, mock_hook, batch_trigger, async_get_batch
):
mock_hook.return_value = async_get_batch(state=Batch.State.FAILED, batch_id=TEST_BATCH_ID)
mock_hook.return_value = async_get_batch(
state=Batch.State.FAILED, batch_id=TEST_BATCH_ID, state_message=TEST_BATCH_STATE_MESSAGE
)

expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, "batch_state": Batch.State.FAILED})
expected_event = TriggerEvent(
{
"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.FAILED,
"batch_state_message": TEST_BATCH_STATE_MESSAGE,
}
)

actual_event = await batch_trigger.run().asend(None)
await asyncio.sleep(0.5)
Expand All @@ -420,9 +432,17 @@ async def test_async_create_batch_trigger_run_returns_failed_event(
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
async def test_create_batch_run_returns_cancelled_event(self, mock_hook, batch_trigger, async_get_batch):
mock_hook.return_value = async_get_batch(state=Batch.State.CANCELLED, batch_id=TEST_BATCH_ID)
mock_hook.return_value = async_get_batch(
state=Batch.State.CANCELLED, batch_id=TEST_BATCH_ID, state_message=TEST_BATCH_STATE_MESSAGE
)

expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, "batch_state": Batch.State.CANCELLED})
expected_event = TriggerEvent(
{
"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.CANCELLED,
"batch_state_message": TEST_BATCH_STATE_MESSAGE,
}
)

actual_event = await batch_trigger.run().asend(None)
await asyncio.sleep(0.5)
Expand Down
Loading