diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py index a9ce336022fff..b8ad957a9c6d8 100644 --- a/airflow/providers/airbyte/hooks/airbyte.py +++ b/airflow/providers/airbyte/hooks/airbyte.py @@ -17,12 +17,23 @@ # under the License. from __future__ import annotations +import base64 +import json import time -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar + +import aiohttp +from aiohttp import ClientResponseError +from asgiref.sync import sync_to_async from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook +if TYPE_CHECKING: + from airflow.models import Connection + +T = TypeVar("T", bound=Any) + class AirbyteHook(HttpHook): """ @@ -50,6 +61,51 @@ def __init__(self, airbyte_conn_id: str = "airbyte_default", api_version: str = super().__init__(http_conn_id=airbyte_conn_id) self.api_version: str = api_version + async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]: + """Get Headers, tenants from the connection details.""" + connection: Connection = await sync_to_async(self.get_connection)(self.http_conn_id) + base_url = connection.host + + credentials = f"{connection.login}:{connection.password}" + credentials_base64 = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") + + authorized_headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Basic {credentials_base64}", + } + + return authorized_headers, base_url + + async def get_job_details(self, job_id: int) -> Any: + """ + Uses Http async call to retrieve metadata for a specific job of an Airbyte Sync. + + :param job_id: The ID of an Airbyte Sync Job. + """ + headers, base_url = await self.get_headers_tenants_from_connection() + url = f"{base_url}/api/{self.api_version}/jobs/get" + self.log.info("URL for api request: %s", url) + async with aiohttp.ClientSession(headers=headers) as session: + async with session.post(url=url, data=json.dumps({"id": job_id})) as response: + try: + response.raise_for_status() + return await response.json() + except ClientResponseError as e: + msg = f"{e.status}: {e.message} - {e.request_info}" + raise AirflowException(msg) + + async def get_job_status(self, job_id: int) -> str: + """ + Retrieves the status for a specific job of an Airbyte Sync. + + :param job_id: The ID of an Airbyte Sync Job. + """ + self.log.info("Getting the status of job run %s.", job_id) + response = await self.get_job_details(job_id=job_id) + job_run_status: str = response["job"]["status"] + return job_run_status + def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, timeout: float | None = 3600) -> None: """ Poll a job to check if it finishes. diff --git a/airflow/providers/airbyte/operators/airbyte.py b/airflow/providers/airbyte/operators/airbyte.py index 6d101662db589..84a12dadfa2e5 100644 --- a/airflow/providers/airbyte/operators/airbyte.py +++ b/airflow/providers/airbyte/operators/airbyte.py @@ -17,10 +17,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import time +from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -40,6 +44,7 @@ class AirbyteTriggerSyncOperator(BaseOperator): :param asynchronous: Optional. Flag to get job_id after submitting the job to the Airbyte API. This is useful for submitting long running jobs and waiting on them asynchronously using the AirbyteJobSensor. Defaults to False. + :param deferrable: Run operator in the deferrable mode. :param api_version: Optional. Airbyte API version. Defaults to "v1". :param wait_seconds: Optional. Number of seconds between checks. Only used when ``asynchronous`` is False. Defaults to 3 seconds. @@ -48,12 +53,14 @@ class AirbyteTriggerSyncOperator(BaseOperator): """ template_fields: Sequence[str] = ("connection_id",) + ui_color = "#6C51FD" def __init__( self, connection_id: str, airbyte_conn_id: str = "airbyte_default", asynchronous: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), api_version: str = "v1", wait_seconds: float = 3, timeout: float = 3600, @@ -66,23 +73,62 @@ def __init__( self.api_version = api_version self.wait_seconds = wait_seconds self.asynchronous = asynchronous + self.deferrable = deferrable def execute(self, context: Context) -> None: """Create Airbyte Job and wait to finish.""" - self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) - job_object = self.hook.submit_sync_connection(connection_id=self.connection_id) + hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version) + job_object = hook.submit_sync_connection(connection_id=self.connection_id) self.job_id = job_object.json()["job"]["id"] + state = job_object.json()["job"]["status"] + end_time = time.time() + self.timeout self.log.info("Job %s was submitted to Airbyte Server", self.job_id) if not self.asynchronous: self.log.info("Waiting for job %s to complete", self.job_id) - self.hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) + if self.deferrable: + if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE): + self.defer( + timeout=self.execution_timeout, + trigger=AirbyteSyncTrigger( + conn_id=self.airbyte_conn_id, + job_id=self.job_id, + end_time=end_time, + poll_interval=60, + ), + method_name="execute_complete", + ) + elif state == hook.SUCCEEDED: + self.log.info("Job %s completed successfully", self.job_id) + return + elif state == hook.ERROR: + raise AirflowException(f"Job failed:\n{self.job_id}") + elif state == hook.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{self.job_id}") + else: + raise Exception(f"Encountered unexpected state `{state}` for job_id `{self.job_id}") + else: + hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout) self.log.info("Job %s completed successfully", self.job_id) return self.job_id + def execute_complete(self, context: Context, event: Any = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.log.info("%s completed successfully.", self.task_id) + return None + def on_kill(self): """Cancel the job if task is cancelled.""" + hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id) if self.job_id: self.log.info("on_kill: cancel the airbyte Job %s", self.job_id) - self.hook.cancel_job(self.job_id) + hook.cancel_job(self.job_id) diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml index c973844dd7287..4f163eedfcf00 100644 --- a/airflow/providers/airbyte/provider.yaml +++ b/airflow/providers/airbyte/provider.yaml @@ -69,6 +69,11 @@ sensors: python-modules: - airflow.providers.airbyte.sensors.airbyte +triggers: + - integration-name: Airbyte + python-modules: + - airflow.providers.airbyte.triggers.airbyte + connection-types: - hook-class-name: airflow.providers.airbyte.hooks.airbyte.AirbyteHook connection-type: airbyte diff --git a/airflow/providers/airbyte/sensors/airbyte.py b/airflow/providers/airbyte/sensors/airbyte.py index f38206246e55a..4556d554304e1 100644 --- a/airflow/providers/airbyte/sensors/airbyte.py +++ b/airflow/providers/airbyte/sensors/airbyte.py @@ -18,10 +18,14 @@ """This module contains a Airbyte Job sensor.""" from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import time +import warnings +from typing import TYPE_CHECKING, Any, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -34,6 +38,7 @@ class AirbyteJobSensor(BaseSensorOperator): :param airbyte_job_id: Required. Id of the Airbyte job :param airbyte_conn_id: Optional. The name of the Airflow connection to get + :param deferrable: Run sensor in the deferrable mode. connection information for Airbyte. Defaults to "airbyte_default". :param api_version: Optional. Airbyte API version. Defaults to "v1". """ @@ -45,11 +50,30 @@ def __init__( self, *, airbyte_job_id: int, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), airbyte_conn_id: str = "airbyte_default", api_version: str = "v1", **kwargs, ) -> None: + if deferrable: + if "poke_interval" not in kwargs: + # TODO: Remove once deprecated + if "polling_interval" in kwargs: + kwargs["poke_interval"] = kwargs["polling_interval"] + warnings.warn( + "Argument `poll_interval` is deprecated and will be removed " + "in a future release. Please use `poke_interval` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + else: + kwargs["poke_interval"] = 5 + + if "timeout" not in kwargs: + kwargs["timeout"] = 60 * 60 * 24 * 7 + super().__init__(**kwargs) + self.deferrable = deferrable self.airbyte_conn_id = airbyte_conn_id self.airbyte_job_id = airbyte_job_id self.api_version = api_version @@ -79,3 +103,49 @@ def poke(self, context: Context) -> bool: self.log.info("Waiting for job %s to complete.", self.airbyte_job_id) return False + + def execute(self, context: Context) -> Any: + """Submits a job which generates a run_id and gets deferred.""" + if not self.deferrable: + super().execute(context) + else: + hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id) + job = hook.get_job(job_id=(int(self.airbyte_job_id))) + state = job.json()["job"]["status"] + end_time = time.time() + self.timeout + + self.log.info("Airbyte Job Id: Job %s", self.airbyte_job_id) + + if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE): + self.defer( + timeout=self.execution_timeout, + trigger=AirbyteSyncTrigger( + conn_id=self.airbyte_conn_id, + job_id=self.airbyte_job_id, + end_time=end_time, + poll_interval=60, + ), + method_name="execute_complete", + ) + elif state == hook.SUCCEEDED: + self.log.info("%s completed successfully.", self.task_id) + return + elif state == hook.ERROR: + raise AirflowException(f"Job failed:\n{job}") + elif state == hook.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job}") + else: + raise Exception(f"Encountered unexpected state `{state}` for job_id `{self.airbyte_job_id}") + + def execute_complete(self, context: Context, event: Any = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.log.info("%s completed successfully.", self.task_id) + return None diff --git a/airflow/providers/airbyte/triggers/__init__.py b/airflow/providers/airbyte/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/airbyte/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/airbyte/triggers/airbyte.py b/airflow/providers/airbyte/triggers/airbyte.py new file mode 100644 index 0000000000000..06c926d6818b3 --- /dev/null +++ b/airflow/providers/airbyte/triggers/airbyte.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import time +from typing import Any, AsyncIterator + +from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AirbyteSyncTrigger(BaseTrigger): + """ + Triggers Airbyte Sync, makes an asynchronous HTTP call to get the status via a job ID. + + This trigger is designed to initiate and monitor the status of Airbyte Sync jobs. It + makes use of asynchronous communication to check the progress of a job run over time. + + :param conn_id: The connection identifier for connecting to Airbyte. + :param job_id: The ID of an Airbyte Sync job. + :param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. + :param poll_interval: polling period in seconds to check for the status. + """ + + def __init__( + self, + job_id: int, + conn_id: str, + end_time: float, + poll_interval: float, + ): + super().__init__() + self.job_id = job_id + self.conn_id = conn_id + self.end_time = end_time + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes AirbyteSyncTrigger arguments and classpath.""" + return ( + "airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger", + { + "job_id": self.job_id, + "conn_id": self.conn_id, + "end_time": self.end_time, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make async connection to Airbyte, polls for the pipeline run status.""" + hook = AirbyteHook(airbyte_conn_id=self.conn_id) + try: + while await self.is_still_running(hook): + if self.end_time < time.time(): + yield TriggerEvent( + { + "status": "error", + "message": f"Job run {self.job_id} has not reached a terminal status after " + f"{self.end_time} seconds.", + "job_id": self.job_id, + } + ) + await asyncio.sleep(self.poll_interval) + job_run_status = await hook.get_job_status(self.job_id) + if job_run_status == hook.SUCCEEDED: + yield TriggerEvent( + { + "status": "success", + "message": f"Job run {self.job_id} has completed successfully.", + "job_id": self.job_id, + } + ) + elif job_run_status == hook.CANCELLED: + yield TriggerEvent( + { + "status": "cancelled", + "message": f"Job run {self.job_id} has been cancelled.", + "job_id": self.job_id, + } + ) + else: + yield TriggerEvent( + { + "status": "error", + "message": f"Job run {self.job_id} has failed.", + "job_id": self.job_id, + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e), "job_id": self.job_id}) + + async def is_still_running(self, hook: AirbyteHook) -> bool: + """ + Async function to check whether the job is submitted via async API. + + If job is in running state returns True if it is still running else return False + """ + job_run_status = await hook.get_job_status(self.job_id) + if job_run_status in (AirbyteHook.RUNNING, AirbyteHook.PENDING, AirbyteHook.INCOMPLETE): + return True + return False diff --git a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst index 68fd8c44cb987..60f47955dd292 100644 --- a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst +++ b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst @@ -38,10 +38,9 @@ create in Airbyte between a source and destination synchronization job. Use the ``airbyte_conn_id`` parameter to specify the Airbyte connection to use to connect to your account. -You can trigger a synchronization job in Airflow in two ways with the Operator. The first one -is a synchronous process. This will trigger the Airbyte job and the Operator manage the status -of the job. Another way is use the flag ``async = True`` so the Operator only trigger the job and -return the ``job_id`` that should be pass to the AirbyteSensor. +You can trigger a synchronization job in Airflow in two ways with the Operator. The first one is a synchronous process. +This Operator will initiate the Airbyte job, and the Operator manages the job status. Another way is to use the flag +``async = True`` so the Operator only triggers the job and returns the ``job_id``, passed to the AirbyteSensor. An example using the synchronous way: diff --git a/tests/providers/airbyte/operators/test_airbyte.py b/tests/providers/airbyte/operators/test_airbyte.py index f8ecd15615c8d..2c0085f53db38 100644 --- a/tests/providers/airbyte/operators/test_airbyte.py +++ b/tests/providers/airbyte/operators/test_airbyte.py @@ -37,7 +37,7 @@ class TestAirbyteTriggerSyncOp: @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.wait_for_job", return_value=None) def test_execute(self, mock_wait_for_job, mock_submit_sync_connection): mock_submit_sync_connection.return_value = mock.Mock( - **{"json.return_value": {"job": {"id": self.job_id}}} + **{"json.return_value": {"job": {"id": self.job_id, "status": "running"}}} ) op = AirbyteTriggerSyncOperator( diff --git a/tests/providers/airbyte/triggers/__init__.py b/tests/providers/airbyte/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/airbyte/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/airbyte/triggers/test_airbyte.py b/tests/providers/airbyte/triggers/test_airbyte.py new file mode 100644 index 0000000000000..103df7cf00a66 --- /dev/null +++ b/tests/providers/airbyte/triggers/test_airbyte.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import time +from unittest import mock + +import pytest + +from airflow.providers.airbyte.hooks.airbyte import AirbyteHook +from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger +from airflow.triggers.base import TriggerEvent + + +class TestAirbyteSyncTrigger: + DAG_ID = "airbyte_sync_run" + TASK_ID = "airbyte_sync_run_task_op" + JOB_ID = 1234 + CONN_ID = "airbyte_default" + END_TIME = time.time() + 60 * 60 * 24 * 7 + POLL_INTERVAL = 3.0 + + def test_serialization(self): + """Assert TestAirbyteSyncTrigger correctly serializes its arguments and classpath.""" + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, end_time=self.END_TIME, job_id=self.JOB_ID + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger" + assert kwargs == { + "job_id": self.JOB_ID, + "conn_id": self.CONN_ID, + "end_time": self.END_TIME, + "poll_interval": self.POLL_INTERVAL, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + async def test_airbyte_run_sync_trigger(self, mocked_is_still_running): + """Test AirbyteSyncTrigger is triggered with mocked details and run successfully.""" + mocked_is_still_running.return_value = True + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + job_id=self.JOB_ID, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_value, mock_status, mock_message", + [ + (AirbyteHook.SUCCEEDED, "success", "Job run 1234 has completed successfully."), + ], + ) + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_for_terminal_status_success( + self, mock_get_job_status, mocked_is_still_running, mock_value, mock_status, mock_message + ): + """Assert that run trigger success message in case of job success""" + mocked_is_still_running.return_value = False + mock_get_job_status.return_value = mock_value + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + job_id=self.JOB_ID, + ) + expected_result = { + "status": mock_status, + "message": mock_message, + "job_id": self.JOB_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_value, mock_status, mock_message", + [ + (AirbyteHook.CANCELLED, "cancelled", "Job run 1234 has been cancelled."), + ], + ) + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_for_terminal_status_cancelled( + self, mock_get_job_status, mocked_is_still_running, mock_value, mock_status, mock_message + ): + """Assert that run trigger success message in case of job success""" + mocked_is_still_running.return_value = False + mock_get_job_status.return_value = mock_value + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, end_time=self.END_TIME, job_id=self.JOB_ID + ) + expected_result = { + "status": mock_status, + "message": mock_message, + "job_id": self.JOB_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_value, mock_status, mock_message", + [ + (AirbyteHook.ERROR, "error", "Job run 1234 has failed."), + ], + ) + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_for_terminal_status_error( + self, mock_get_job_status, mocked_is_still_running, mock_value, mock_status, mock_message + ): + """Assert that run trigger success message in case of job success""" + mocked_is_still_running.return_value = False + mock_get_job_status.return_value = mock_value + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + job_id=self.JOB_ID, + ) + expected_result = { + "status": mock_status, + "message": mock_message, + "job_id": self.JOB_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_exception(self, mock_get_job_status, mocked_is_still_running): + """Assert that run catch exception if Airbyte Sync job API throw exception""" + mocked_is_still_running.return_value = False + mock_get_job_status.side_effect = Exception("Test exception") + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + job_id=self.JOB_ID, + ) + task = [i async for i in trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "Test exception", + "job_id": self.JOB_ID, + } + ) + assert len(task) == 1 + assert response in task + + @pytest.mark.asyncio + @mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running") + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_timeout(self, mock_get_job_status, mocked_is_still_running): + """Assert that run timeout after end_time elapsed""" + mocked_is_still_running.return_value = True + mock_get_job_status.side_effect = Exception("Test exception") + end_time = time.time() + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=end_time, + job_id=self.JOB_ID, + ) + generator = trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "error", + "message": f"Job run {self.JOB_ID} has not reached a terminal status " + f"after {end_time} seconds.", + "job_id": self.JOB_ID, + } + ) + assert expected == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_response, expected_status", + [ + (AirbyteHook.SUCCEEDED, False), + ], + ) + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_job_is_still_running_success( + self, mock_get_job_status, mock_response, expected_status + ): + """Test is_still_running with mocked response job status and assert + the return response with expected value""" + hook = mock.AsyncMock(AirbyteHook) + hook.get_job_status.return_value = mock_response + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + job_id=self.JOB_ID, + ) + response = await trigger.is_still_running(hook) + assert response == expected_status + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_response, expected_status", + [ + (AirbyteHook.RUNNING, True), + ], + ) + @mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status") + async def test_airbyte_sync_run_is_still_running( + self, mock_get_job_status, mock_response, expected_status + ): + """Test is_still_running with mocked response job status and assert + the return response with expected value""" + airbyte_hook = mock.AsyncMock(AirbyteHook) + airbyte_hook.get_job_status.return_value = mock_response + trigger = AirbyteSyncTrigger( + conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, end_time=self.END_TIME, job_id=self.JOB_ID + ) + response = await trigger.is_still_running(airbyte_hook) + assert response == expected_status