Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 47 additions & 43 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]

# key, pod state, pod_id, namespace, resource_version
# key, pod state, pod_name, namespace, resource_version
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]

# pod_id, namespace, pod state, annotations, resource_version
# pod_name, namespace, pod state, annotations, resource_version
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]

ALL_NAMESPACES = "ALL_NAMESPACES"
Expand Down Expand Up @@ -180,7 +180,7 @@ def _run(
task_instance_related_annotations["map_index"] = map_index

self.process_status(
pod_id=task.metadata.name,
pod_name=task.metadata.name,
namespace=task.metadata.namespace,
status=task.status.phase,
annotations=task_instance_related_annotations,
Expand Down Expand Up @@ -208,7 +208,7 @@ def process_error(self, event: Any) -> str:

def process_status(
self,
pod_id: str,
pod_name: str,
namespace: str,
status: str,
annotations: dict[str, str],
Expand All @@ -218,28 +218,28 @@ def process_status(
"""Process status response."""
if status == "Pending":
if event["type"] == "DELETED":
self.log.info("Event: Failed to start pod %s", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
self.log.info("Event: Failed to start pod %s", pod_name)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
else:
self.log.debug("Event: %s Pending", pod_id)
self.log.debug("Event: %s Pending", pod_name)
elif status == "Failed":
self.log.error("Event: %s Failed", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
self.log.error("Event: %s Failed", pod_name)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
elif status == "Succeeded":
self.log.info("Event: %s Succeeded", pod_id)
self.watcher_queue.put((pod_id, namespace, State.SUCCESS, annotations, resource_version))
self.log.info("Event: %s Succeeded", pod_name)
self.watcher_queue.put((pod_name, namespace, State.SUCCESS, annotations, resource_version))
elif status == "Running":
if event["type"] == "DELETED":
self.log.info("Event: Pod %s deleted before it could complete", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
self.log.info("Event: Pod %s deleted before it could complete", pod_name)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
else:
self.log.info("Event: %s is Running", pod_id)
self.log.info("Event: %s is Running", pod_name)
else:
self.log.warning(
"Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with "
"resource_version: %s",
status,
pod_id,
pod_name,
namespace,
annotations,
resource_version,
Expand Down Expand Up @@ -368,12 +368,12 @@ def run_next(self, next_job: KubernetesJobType) -> None:
self.run_pod_async(pod, **self.kube_config.kube_client_request_args)
self.log.debug("Kubernetes Job created!")

def delete_pod(self, pod_id: str, namespace: str) -> None:
"""Deletes POD."""
def delete_pod(self, pod_name: str, namespace: str) -> None:
"""Deletes Pod from a namespace. Does not raise if it does not exist."""
try:
self.log.debug("Deleting pod %s in namespace %s", pod_id, namespace)
self.log.debug("Deleting pod %s in namespace %s", pod_name, namespace)
self.kube_client.delete_namespaced_pod(
pod_id,
pod_name,
namespace,
body=client.V1DeleteOptions(**self.kube_config.delete_option_kwargs),
**self.kube_config.kube_client_request_args,
Expand Down Expand Up @@ -419,14 +419,14 @@ def sync(self) -> None:

def process_watcher_task(self, task: KubernetesWatchType) -> None:
"""Process the task by watcher."""
pod_id, namespace, state, annotations, resource_version = task
pod_name, namespace, state, annotations, resource_version = task
self.log.debug(
"Attempting to finish pod; pod_id: %s; state: %s; annotations: %s", pod_id, state, annotations
"Attempting to finish pod; pod_name: %s; state: %s; annotations: %s", pod_name, state, annotations
)
key = annotations_to_key(annotations=annotations)
if key:
self.log.debug("finishing job %s - %s (%s)", key, state, pod_id)
self.result_queue.put((key, state, pod_id, namespace, resource_version))
self.log.debug("finishing job %s - %s (%s)", key, state, pod_name)
self.result_queue.put((key, state, pod_name, namespace, resource_version))

def _flush_watcher_queue(self) -> None:
self.log.debug("Executor shutting down, watcher_queue approx. size=%d", self.watcher_queue.qsize())
Expand Down Expand Up @@ -658,11 +658,11 @@ def sync(self) -> None:
try:
results = self.result_queue.get_nowait()
try:
key, state, pod_id, namespace, resource_version = results
key, state, pod_name, namespace, resource_version = results
last_resource_version[namespace] = resource_version
self.log.info("Changing state of %s to %s", results, state)
try:
self._change_state(key, state, pod_id, namespace)
self._change_state(key, state, pod_name, namespace)
except Exception as e:
self.log.exception(
"Exception: %s when attempting to change state of %s to %s, re-queueing.",
Expand Down Expand Up @@ -725,7 +725,7 @@ def sync(self) -> None:
next_event = self.event_scheduler.run(blocking=False)
self.log.debug("Next timed event is in %f", next_event)

def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, namespace: str) -> None:
def _change_state(self, key: TaskInstanceKey, state: str | None, pod_name: str, namespace: str) -> None:
if TYPE_CHECKING:
assert self.kube_scheduler

Expand All @@ -735,10 +735,10 @@ def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, na

if self.kube_config.delete_worker_pods:
if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_id, namespace)
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
else:
self.kube_scheduler.patch_pod_executor_done(pod_name=pod_id, namespace=namespace)
self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace)
self.log.info("Patched pod %s in namespace %s to mark it as done", str(key), str(namespace))

try:
Expand Down Expand Up @@ -801,9 +801,10 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li
return messages, ["\n".join(log)]

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
# Always flush TIs without queued_by_job_id
tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
pod_ids = {ti.key: ti for ti in tis if ti.queued_by_job_id}
tis_to_flush_by_key = {ti.key: ti for ti in tis if ti.queued_by_job_id}
kube_client: client.CoreV1Api = self.kube_client
for scheduler_job_id in scheduler_job_ids:
scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id))
Expand All @@ -821,9 +822,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
}
pod_list = self._list_pods(query_kwargs)
for pod in pod_list:
self.adopt_launched_task(kube_client, pod, pod_ids)
self.adopt_launched_task(kube_client, pod, tis_to_flush_by_key)
self._adopt_completed_pods(kube_client)
tis_to_flush.extend(pod_ids.values())
tis_to_flush.extend(tis_to_flush_by_key.values())
return tis_to_flush

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
Expand Down Expand Up @@ -861,26 +862,29 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
self.log.warning("Found multiple pods for ti %s: %s", ti, pod_list)
continue
readable_tis.append(repr(ti))
self.kube_scheduler.delete_pod(pod_id=pod_list[0].metadata.name, namespace=namespace)
self.kube_scheduler.delete_pod(pod_name=pod_list[0].metadata.name, namespace=namespace)
return readable_tis

def adopt_launched_task(
self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: dict[TaskInstanceKey, k8s.V1Pod]
self,
kube_client: client.CoreV1Api,
pod: k8s.V1Pod,
tis_to_flush_by_key: dict[TaskInstanceKey, k8s.V1Pod],
) -> None:
"""
Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors.

:param kube_client: kubernetes client for speaking to kube API
:param pod: V1Pod spec that we will patch with new label
:param pod_ids: pod_ids we expect to patch.
:param tis_to_flush_by_key: TIs that will be flushed if they aren't adopted
"""
if TYPE_CHECKING:
assert self.scheduler_job_id

self.log.info("attempting to adopt pod %s", pod.metadata.name)
pod_id = annotations_to_key(pod.metadata.annotations)
if pod_id not in pod_ids:
self.log.error("attempting to adopt taskinstance which was not specified by database: %s", pod_id)
ti_key = annotations_to_key(pod.metadata.annotations)
if ti_key not in tis_to_flush_by_key:
self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key)
return

new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id)
Expand All @@ -894,8 +898,8 @@ def adopt_launched_task(
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
return

del pod_ids[pod_id]
self.running.add(pod_id)
del tis_to_flush_by_key[ti_key]
self.running.add(ti_key)

def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
"""
Expand Down Expand Up @@ -925,8 +929,8 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
pod_id = annotations_to_key(pod.metadata.annotations)
self.running.add(pod_id)
ti_id = annotations_to_key(pod.metadata.annotations)
self.running.add(ti_id)

def _flush_task_queue(self) -> None:
if TYPE_CHECKING:
Expand All @@ -952,12 +956,12 @@ def _flush_result_queue(self) -> None:
results = self.result_queue.get_nowait()
self.log.warning("Executor shutting down, flushing results=%s", results)
try:
key, state, pod_id, namespace, resource_version = results
key, state, pod_name, namespace, resource_version = results
self.log.info(
"Changing state of %s to %s : resource_version=%d", results, state, resource_version
)
try:
self._change_state(key, state, pod_id, namespace)
self._change_state(key, state, pod_name, namespace)
except Exception as e:
self.log.exception(
"Ignoring exception: %s when attempting to change state of %s to %s.",
Expand Down
Loading