Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

if TYPE_CHECKING:
from kubernetes.client.models import V1Job, V1Pod
from pendulum import DateTime

from airflow.utils.context import Context

Expand Down Expand Up @@ -773,16 +774,16 @@ def fetch_cluster_info(self) -> tuple[str, str | None]:
self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate
return self._cluster_url, self._ssl_ca_cert

def invoke_defer_method(self):
def invoke_defer_method(self, last_log_time: DateTime | None = None):
"""Redefine triggers which are being used in child classes."""
trigger_start_time = utcnow()
self.defer(
trigger=GKEStartPodTrigger(
pod_name=self.pod.metadata.name,
pod_namespace=self.pod.metadata.namespace,
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
trigger_start_time=trigger_start_time,
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
cluster_url=self._cluster_url, # type: ignore[arg-type]
ssl_ca_cert=self._ssl_ca_cert, # type: ignore[arg-type]
get_logs=self.get_logs,
startup_timeout=self.startup_timeout_seconds,
cluster_context=self.cluster_context,
Expand All @@ -792,6 +793,8 @@ def invoke_defer_method(self):
on_finish_action=self.on_finish_action,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
logging_interval=self.logging_interval,
last_log_time=last_log_time,
),
method_name="execute_complete",
kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert},
Expand All @@ -802,7 +805,7 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
self._cluster_url = kwargs["cluster_url"]
self._ssl_ca_cert = kwargs["ssl_ca_cert"]

return super().execute_complete(context, event, **kwargs)
return super().trigger_reentry(context, event)


class GKEStartJobOperator(KubernetesJobOperator):
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/google/cloud/triggers/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"on_finish_action": self.on_finish_action.value,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"logging_interval": self.logging_interval,
"last_log_time": self.last_log_time,
},
)

Expand Down
54 changes: 54 additions & 0 deletions tests/providers/google/cloud/operators/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def setup_method(self):
namespace=NAMESPACE,
image=IMAGE,
deferrable=True,
on_finish_action="delete_pod",
)
self.gke_op.pod = mock.MagicMock(
name=TASK_NAME,
Expand Down Expand Up @@ -703,6 +704,59 @@ def test_async_create_pod_should_execute_successfully(
fetch_cluster_info_mock.assert_called_once()
assert isinstance(exc.value.trigger, GKEStartPodTrigger)

@pytest.mark.parametrize("status", ["error", "failed", "timeout"])
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod")
@mock.patch(KUB_OP_PATH.format("_clean"))
@mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook")
@mock.patch(KUB_OP_PATH.format("_write_logs"))
def test_execute_complete_failure(self, mock_write_logs, mock_gke_hook, mock_clean, mock_get_pod, status):
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT
with pytest.raises(AirflowException):
self.gke_op.execute_complete(
context=mock.MagicMock(),
event={"name": "test", "status": status, "namespace": "default", "message": ""},
cluster_url=self.gke_op._cluster_url,
ssl_ca_cert=self.gke_op._ssl_ca_cert,
)
mock_write_logs.assert_called_once()

@mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod")
@mock.patch(KUB_OP_PATH.format("_clean"))
@mock.patch(KUB_OP_PATH.format("_write_logs"))
def test_execute_complete_success(self, mock_write_logs, mock_clean, mock_get_pod, mock_gke_hook):
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT
self.gke_op.execute_complete(
context=mock.MagicMock(),
event={"name": "test", "status": "success", "namespace": "default"},
cluster_url=self.gke_op._cluster_url,
ssl_ca_cert=self.gke_op._ssl_ca_cert,
)
mock_write_logs.assert_called_once()

@mock.patch(KUB_OP_PATH.format("pod_manager"))
@mock.patch(
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.invoke_defer_method"
)
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod")
@mock.patch(KUB_OP_PATH.format("_clean"))
@mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook")
def test_execute_complete_running(
self, mock_gke_hook, mock_clean, mock_get_pod, mock_invoke_defer_method, mock_pod_manager
):
self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT
self.gke_op.execute_complete(
context=mock.MagicMock(),
event={"name": "test", "status": "running", "namespace": "default"},
cluster_url=self.gke_op._cluster_url,
ssl_ca_cert=self.gke_op._ssl_ca_cert,
)
mock_pod_manager.fetch_container_logs.assert_called_once()
mock_invoke_defer_method.assert_called_once()


class TestGKEStartJobOperator:
def setup_method(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def test_serialize_should_execute_successfully(self, trigger):
"should_delete_pod": SHOULD_DELETE_POD,
"gcp_conn_id": GCP_CONN_ID,
"impersonation_chain": IMPERSONATION_CHAIN,
"last_log_time": None,
"logging_interval": None,
}

@pytest.mark.asyncio
Expand Down