diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 8410d8d7077cd..61e555f4caa78 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -96,6 +96,8 @@ class KiotaRequestAdapterHook(BaseHook): :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None). When no timeout is specified or set to None then no HTTP timeout is applied on each request. :param proxies: A Dict defining the HTTP proxies to be used (default is None). + :param host: The host to be used (default is "https://graph.microsoft.com"). + :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as "v1.0" or "beta". @@ -123,27 +125,22 @@ def __init__( self._api_version = self.resolve_api_version_from_value(api_version) @property - def api_version(self) -> APIVersion: + def api_version(self) -> str | None: self.get_conn() # Make sure config has been loaded through get_conn to have correct api version! return self._api_version @staticmethod def resolve_api_version_from_value( - api_version: APIVersion | str, default: APIVersion | None = None - ) -> APIVersion: + api_version: APIVersion | str, default: str | None = None + ) -> str | None: if isinstance(api_version, APIVersion): - return api_version - return next( - filter(lambda version: version.value == api_version, APIVersion), - default, - ) + return api_version.value + return api_version or default - def get_api_version(self, config: dict) -> APIVersion: - if self._api_version is None: - return self.resolve_api_version_from_value( - api_version=config.get("api_version"), default=APIVersion.v1 - ) - return self._api_version + def get_api_version(self, config: dict) -> str: + return self._api_version or self.resolve_api_version_from_value( + config.get("api_version"), APIVersion.v1.value + ) # type: ignore def get_host(self, connection: Connection) -> str: if connection.schema and connection.host: @@ -169,15 +166,15 @@ def to_httpx_proxies(cls, proxies: dict) -> dict: return proxies def to_msal_proxies(self, authority: str | None, proxies: dict): - self.log.info("authority: %s", authority) + self.log.debug("authority: %s", authority) if authority: no_proxies = proxies.get("no") - self.log.info("no_proxies: %s", no_proxies) + self.log.debug("no_proxies: %s", no_proxies) if no_proxies: for url in no_proxies.split(","): self.log.info("url: %s", url) domain_name = urlparse(url).path.replace("*", "") - self.log.info("domain_name: %s", domain_name) + self.log.debug("domain_name: %s", domain_name) if authority.endswith(domain_name): return None return proxies @@ -193,10 +190,10 @@ def get_conn(self) -> RequestAdapter: client_id = connection.login client_secret = connection.password config = connection.extra_dejson if connection.extra else {} - tenant_id = config.get("tenant_id") + tenant_id = config.get("tenant_id") or config.get("tenantId") api_version = self.get_api_version(config) host = self.get_host(connection) - base_url = config.get("base_url", urljoin(host, api_version.value)) + base_url = config.get("base_url", urljoin(host, api_version)) authority = config.get("authority") proxies = self.proxies or config.get("proxies", {}) msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) @@ -209,7 +206,7 @@ def get_conn(self) -> RequestAdapter: self.log.info( "Creating Microsoft Graph SDK client %s for conn_id: %s", - api_version.value, + api_version, self.conn_id, ) self.log.info("Host: %s", host) @@ -217,7 +214,7 @@ def get_conn(self) -> RequestAdapter: self.log.info("Tenant id: %s", tenant_id) self.log.info("Client id: %s", client_id) self.log.info("Client secret: %s", client_secret) - self.log.info("API version: %s", api_version.value) + self.log.info("API version: %s", api_version) self.log.info("Scope: %s", scopes) self.log.info("Verify: %s", verify) self.log.info("Timeout: %s", self.timeout) @@ -238,17 +235,17 @@ def get_conn(self) -> RequestAdapter: connection_verify=verify, ) http_client = GraphClientFactory.create_with_default_middleware( - api_version=api_version, + api_version=api_version, # type: ignore client=httpx.AsyncClient( proxies=httpx_proxies, timeout=Timeout(timeout=self.timeout), verify=verify, trust_env=trust_env, ), - host=host, + host=host, # type: ignore ) auth_provider = AzureIdentityAuthenticationProvider( - credentials=credentials, + credentials=credentials, # type: ignore scopes=scopes, allowed_hosts=allowed_hosts, ) @@ -295,7 +292,7 @@ async def run( error_map=self.error_mapping(), ) - self.log.info("response: %s", response) + self.log.debug("response: %s", response) return response diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index cd387954737ab..74409f3600a1e 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -99,7 +99,7 @@ def __init__( key: str = XCOM_RETURN_KEY, timeout: float | None = None, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None, result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, diff --git a/airflow/providers/microsoft/azure/operators/powerbi.py b/airflow/providers/microsoft/azure/operators/powerbi.py index e54ad250bde74..fc812e852d90b 100644 --- a/airflow/providers/microsoft/azure/operators/powerbi.py +++ b/airflow/providers/microsoft/azure/operators/powerbi.py @@ -76,7 +76,7 @@ def __init__( conn_id: str = PowerBIHook.default_conn_name, timeout: float = 60 * 60 * 24 * 7, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, check_interval: int = 60, **kwargs, ) -> None: @@ -89,6 +89,14 @@ def __init__( self.timeout = timeout self.check_interval = check_interval + @property + def proxies(self) -> dict | None: + return self.hook.proxies + + @property + def api_version(self) -> str | None: + return self.hook.api_version + def execute(self, context: Context): """Refresh the Power BI Dataset.""" if self.wait_for_termination: @@ -98,6 +106,8 @@ def execute(self, context: Context): group_id=self.group_id, dataset_id=self.dataset_id, timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, check_interval=self.check_interval, wait_for_termination=self.wait_for_termination, ), diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py index 3e1b10cbeb1e6..6736ea59c918d 100644 --- a/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -82,7 +82,7 @@ def __init__( data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded", result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index 4b9ccb7a71716..0015964be86dd 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -122,7 +122,7 @@ def __init__( conn_id: str = KiotaRequestAdapterHook.default_conn_name, timeout: float | None = None, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, ): super().__init__() @@ -152,14 +152,13 @@ def resolve_type(cls, value: str | type, default) -> type: def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the HttpTrigger arguments and classpath.""" - api_version = self.api_version.value if self.api_version else None return ( f"{self.__class__.__module__}.{self.__class__.__name__}", { "conn_id": self.conn_id, "timeout": self.timeout, "proxies": self.proxies, - "api_version": api_version, + "api_version": self.api_version, "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", "url": self.url, "path_parameters": self.path_parameters, @@ -188,7 +187,7 @@ def proxies(self) -> dict | None: return self.hook.proxies @property - def api_version(self) -> APIVersion: + def api_version(self) -> APIVersion | str: return self.hook.api_version async def run(self) -> AsyncIterator[TriggerEvent]: diff --git a/airflow/providers/microsoft/azure/triggers/powerbi.py b/airflow/providers/microsoft/azure/triggers/powerbi.py index d25802b84fb74..a74898f55f28a 100644 --- a/airflow/providers/microsoft/azure/triggers/powerbi.py +++ b/airflow/providers/microsoft/azure/triggers/powerbi.py @@ -58,7 +58,7 @@ def __init__( group_id: str, timeout: float = 60 * 60 * 24 * 7, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, check_interval: int = 60, wait_for_termination: bool = True, ): @@ -72,13 +72,12 @@ def __init__( def serialize(self): """Serialize the trigger instance.""" - api_version = self.api_version.value if self.api_version else None return ( "airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger", { "conn_id": self.conn_id, "proxies": self.proxies, - "api_version": api_version, + "api_version": self.api_version, "dataset_id": self.dataset_id, "group_id": self.group_id, "timeout": self.timeout, @@ -96,7 +95,7 @@ def proxies(self) -> dict | None: return self.hook.proxies @property - def api_version(self) -> APIVersion: + def api_version(self) -> APIVersion | str: return self.hook.api_version async def run(self) -> AsyncIterator[TriggerEvent]: diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst index 342bf542762ac..56a9259f93143 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst @@ -72,6 +72,14 @@ Below is an example of using this operator to refresh PowerBI dataset. :start-after: [START howto_operator_powerbi_refresh_dataset] :end-before: [END howto_operator_powerbi_refresh_dataset] +Below is an example of using this operator to create an item schedule in Fabric. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msfabric.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_ms_fabric_create_item_schedule] + :end-before: [END howto_operator_ms_fabric_create_item_schedule] + Reference --------- @@ -80,3 +88,4 @@ For further information, look at: * `Use the Microsoft Graph API `__ * `Using the Power BI REST APIs `__ +* `Using the Fabric REST APIs `__ diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 390be17ba7f35..04e85525616bc 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -18,6 +18,7 @@ import asyncio from json import JSONDecodeError +from typing import TYPE_CHECKING from unittest.mock import patch import pytest @@ -38,8 +39,20 @@ mock_response, ) +if TYPE_CHECKING: + from kiota_abstractions.request_adapter import RequestAdapter + class TestKiotaRequestAdapterHook: + def setup_method(self): + KiotaRequestAdapterHook.cached_request_adapters.clear() + + @staticmethod + def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str): + assert isinstance(request_adapter, HttpxRequestAdapter) + tenant_id = request_adapter._authentication_provider.access_token_provider._credentials._tenant_id + assert tenant_id == expected_tenant_id + def test_get_conn(self): with patch( "airflow.hooks.base.BaseHook.get_connection", @@ -51,6 +64,23 @@ def test_get_conn(self): assert isinstance(actual, HttpxRequestAdapter) assert actual.base_url == "https://graph.microsoft.com/v1.0" + def test_get_conn_with_custom_base_url(self): + connection = lambda conn_id: get_airflow_connection( + conn_id=conn_id, + host="api.fabric.microsoft.com", + api_version="v1", + ) + + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + assert isinstance(actual, HttpxRequestAdapter) + assert actual.base_url == "https://api.fabric.microsoft.com/v1" + def test_api_version(self): with patch( "airflow.hooks.base.BaseHook.get_connection", @@ -58,7 +88,7 @@ def test_api_version(self): ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - assert hook.api_version == APIVersion.v1 + assert hook.api_version == APIVersion.v1.value def test_get_api_version_when_empty_config_dict(self): with patch( @@ -68,7 +98,7 @@ def test_get_api_version_when_empty_config_dict(self): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") actual = hook.get_api_version({}) - assert actual == APIVersion.v1 + assert actual == APIVersion.v1.value def test_get_api_version_when_api_version_in_config_dict(self): with patch( @@ -78,21 +108,64 @@ def test_get_api_version_when_api_version_in_config_dict(self): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") actual = hook.get_api_version({"api_version": "beta"}) - assert actual == APIVersion.beta + assert actual == APIVersion.beta.value + + def test_get_api_version_when_custom_api_version_in_config_dict(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version="v1") + actual = hook.get_api_version({}) + + assert actual == "v1" def test_get_host_when_connection_has_scheme_and_host(self): - connection = mock_connection(schema="https", host="graph.microsoft.de") - hook = KiotaRequestAdapterHook() - actual = hook.get_host(connection) + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + connection = mock_connection(schema="https", host="graph.microsoft.de") + actual = hook.get_host(connection) - assert actual == NationalClouds.Germany.value + assert actual == NationalClouds.Germany.value def test_get_host_when_connection_has_no_scheme_or_host(self): - connection = mock_connection() - hook = KiotaRequestAdapterHook() - actual = hook.get_host(connection) + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + connection = mock_connection() + actual = hook.get_host(connection) + + assert actual == NationalClouds.Global.value + + def test_tenant_id(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + self.assert_tenant_id(actual, "tenant-id") + + def test_azure_tenant_id(self): + airflow_connection = lambda conn_id: get_airflow_connection( + conn_id=conn_id, + azure_tenant_id="azure-tenant-id", + ) + + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() - assert actual == NationalClouds.Global.value + self.assert_tenant_id(actual, "azure-tenant-id") def test_encoded_query_parameters(self): actual = KiotaRequestAdapterHook.encoded_query_parameters( diff --git a/tests/providers/microsoft/azure/hooks/test_powerbi.py b/tests/providers/microsoft/azure/hooks/test_powerbi.py index a3a521b45e820..f22116f6efb03 100644 --- a/tests/providers/microsoft/azure/hooks/test_powerbi.py +++ b/tests/providers/microsoft/azure/hooks/test_powerbi.py @@ -54,176 +54,162 @@ GROUP_ID = "group_id" DATASET_ID = "dataset_id" -CONFIG = {"conn_id": DEFAULT_CONNECTION_CLIENT_SECRET, "timeout": 3, "api_version": "v1.0"} - -@pytest.fixture -def powerbi_hook(): - return PowerBIHook(**CONFIG) - - -@pytest.mark.asyncio -async def test_get_refresh_history(powerbi_hook): - response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]} - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.return_value = response_data - result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) - - expected = [{"request_id": "1234", "status": "Completed", "error": ""}] - assert result == expected - - -@pytest.mark.asyncio -async def test_get_refresh_history_airflow_exception(powerbi_hook): - """Test handling of AirflowException in get_refresh_history.""" - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.side_effect = AirflowException("Test exception") - - with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"): - await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) - - -@pytest.mark.parametrize( - "input_data, expected_output", - [ - ( - {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234", - PowerBIDatasetRefreshFields.STATUS.value: "Completed", - PowerBIDatasetRefreshFields.ERROR.value: "", - }, - ), - ( - {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678", - PowerBIDatasetRefreshFields.STATUS.value: "In Progress", - PowerBIDatasetRefreshFields.ERROR.value: "Some error", - }, - ), - ( - {"requestId": None, "status": None, "serviceExceptionJson": None}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", - PowerBIDatasetRefreshFields.STATUS.value: "None", - PowerBIDatasetRefreshFields.ERROR.value: "None", - }, - ), - ( - {}, # Empty input dictionary - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", - PowerBIDatasetRefreshFields.STATUS.value: "None", - PowerBIDatasetRefreshFields.ERROR.value: "None", - }, - ), - ], -) -def test_raw_to_refresh_details(input_data, expected_output): - """Test raw_to_refresh_details method.""" - result = PowerBIHook.raw_to_refresh_details(input_data) - assert result == expected_output - - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id(powerbi_hook): - # Mock the get_refresh_history method to return a list of refresh histories - refresh_histories = FORMATTED_RESPONSE - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories) - - # Call the function with a valid request ID - refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - result = await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id +class TestPowerBIHook: + @pytest.mark.asyncio + async def test_get_refresh_history(self, powerbi_hook): + response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + expected = [{"request_id": "1234", "status": "Completed", "error": ""}] + assert result == expected + + @pytest.mark.asyncio + async def test_get_refresh_history_airflow_exception(self, powerbi_hook): + """Test handling of AirflowException in get_refresh_history.""" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"): + await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + @pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "", + }, + ), + ( + {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678", + PowerBIDatasetRefreshFields.STATUS.value: "In Progress", + PowerBIDatasetRefreshFields.ERROR.value: "Some error", + }, + ), + ( + {"requestId": None, "status": None, "serviceExceptionJson": None}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ( + {}, # Empty input dictionary + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ], ) - - # Assert that the correct refresh details are returned - assert result == { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", - PowerBIDatasetRefreshFields.STATUS.value: "Completed", - PowerBIDatasetRefreshFields.ERROR.value: "None", - } - - # Call the function with an invalid request ID - invalid_request_id = "invalid_request_id" - with pytest.raises(PowerBIDatasetRefreshException): - await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id - ) - - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id_empty_history(powerbi_hook): - """Test exception when refresh history is empty.""" - # Mock the get_refresh_history method to return an empty list - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[]) - - # Call the function with a request ID - refresh_id = "any_request_id" - with pytest.raises( - PowerBIDatasetRefreshException, - match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}", - ): - await powerbi_hook.get_refresh_details_by_refresh_id( + def test_raw_to_refresh_details(self, input_data, expected_output): + """Test raw_to_refresh_details method.""" + result = PowerBIHook.raw_to_refresh_details(input_data) + assert result == expected_output + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id(self, powerbi_hook): + # Mock the get_refresh_history method to return a list of refresh histories + refresh_histories = FORMATTED_RESPONSE + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories) + + # Call the function with a valid request ID + refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + result = await powerbi_hook.get_refresh_details_by_refresh_id( dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id ) - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id_not_found(powerbi_hook): - """Test exception when the refresh ID is not found in the refresh history.""" - # Mock the get_refresh_history method to return a list of refresh histories without the specified ID - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE) - - # Call the function with an invalid request ID - invalid_request_id = "invalid_request_id" - with pytest.raises( - PowerBIDatasetRefreshException, - match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}", - ): - await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + # Assert that the correct refresh details are returned + assert result == { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "None", + } + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises(PowerBIDatasetRefreshException): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id_empty_history(self, powerbi_hook): + """Test exception when refresh history is empty.""" + # Mock the get_refresh_history method to return an empty list + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[]) + + # Call the function with a request ID + refresh_id = "any_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id + ) + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id_not_found(self, powerbi_hook): + """Test exception when the refresh ID is not found in the refresh history.""" + # Mock the get_refresh_history method to return a list of refresh histories without the specified ID + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE) + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + @pytest.mark.asyncio + async def test_trigger_dataset_refresh_success(self, powerbi_hook): + response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + @pytest.mark.asyncio + async def test_trigger_dataset_refresh_failure(self, powerbi_hook): + """Test failure to trigger dataset refresh due to AirflowException.""" + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."): + await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + @pytest.mark.asyncio + async def test_cancel_dataset_refresh(self, powerbi_hook): + dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id) + + mock_run.assert_called_once_with( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", + response_type=None, + path_parameters={ + "group_id": GROUP_ID, + "dataset_id": DATASET_ID, + "dataset_refresh_id": dataset_refresh_id, + }, + method="DELETE", ) - - -@pytest.mark.asyncio -async def test_trigger_dataset_refresh_success(powerbi_hook): - response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"} - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.return_value = response_data - result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) - - assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - - -@pytest.mark.asyncio -async def test_trigger_dataset_refresh_failure(powerbi_hook): - """Test failure to trigger dataset refresh due to AirflowException.""" - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.side_effect = AirflowException("Test exception") - - with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."): - await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) - - -@pytest.mark.asyncio -async def test_cancel_dataset_refresh(powerbi_hook): - dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id) - - mock_run.assert_called_once_with( - url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", - response_type=None, - path_parameters={ - "group_id": GROUP_ID, - "dataset_id": DATASET_ID, - "dataset_refresh_id": dataset_refresh_id, - }, - method="DELETE", - ) diff --git a/tests/providers/microsoft/azure/operators/test_powerbi.py b/tests/providers/microsoft/azure/operators/test_powerbi.py index 2ee5ee723d7a7..35bb76f782ce3 100644 --- a/tests/providers/microsoft/azure/operators/test_powerbi.py +++ b/tests/providers/microsoft/azure/operators/test_powerbi.py @@ -17,6 +17,7 @@ from __future__ import annotations +from unittest import mock from unittest.mock import MagicMock import pytest @@ -25,11 +26,12 @@ from airflow.providers.microsoft.azure.hooks.powerbi import ( PowerBIDatasetRefreshFields, PowerBIDatasetRefreshStatus, - PowerBIHook, ) from airflow.providers.microsoft.azure.operators.powerbi import PowerBIDatasetRefreshOperator from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.utils import timezone +from tests.providers.microsoft.azure.base import Base +from tests.providers.microsoft.conftest import get_airflow_connection, mock_context DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operator" @@ -72,86 +74,77 @@ } -@pytest.fixture -def mock_powerbi_hook(): - hook = PowerBIHook() - return hook - - -def test_execute_wait_for_termination_with_Deferrable(mock_powerbi_hook): - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - operator.hook = mock_powerbi_hook - context = {"ti": MagicMock()} - - with pytest.raises(TaskDeferred) as exc: - operator.execute(context) - - assert isinstance(exc.value.trigger, PowerBITrigger) +class TestPowerBIDatasetRefreshOperator(Base): + @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + def test_execute_wait_for_termination_with_deferrable(self, connection): + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = mock_context(task=operator) + with pytest.raises(TaskDeferred) as exc: + operator.execute(context) -def test_powerbi_operator_async_execute_complete_success(): - """Assert that execute_complete log success message""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - operator.execute_complete( - context=context, - event=SUCCESS_TRIGGER_EVENT, - ) - assert context["ti"].xcom_push.call_count == 2 + assert isinstance(exc.value.trigger, PowerBITrigger) + def test_powerbi_operator_async_execute_complete_success(self): + """Assert that execute_complete log success message""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + operator.execute_complete( + context=context, + event=SUCCESS_TRIGGER_EVENT, + ) + assert context["ti"].xcom_push.call_count == 2 -def test_powerbi_operator_async_execute_complete_fail(): - """Assert that execute_complete raise exception on error""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - with pytest.raises(AirflowException): + def test_powerbi_operator_async_execute_complete_fail(self): + """Assert that execute_complete raise exception on error""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + with pytest.raises(AirflowException): + operator.execute_complete( + context=context, + event={"status": "error", "message": "error", "dataset_refresh_id": "1234"}, + ) + assert context["ti"].xcom_push.call_count == 0 + + def test_execute_complete_no_event(self): + """Test execute_complete when event is None or empty.""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} operator.execute_complete( context=context, - event={"status": "error", "message": "error", "dataset_refresh_id": "1234"}, + event=None, + ) + assert context["ti"].xcom_push.call_count == 0 + + @pytest.mark.db_test + def test_powerbi_link(self, create_task_instance_of_operator): + """Assert Power BI Extra link matches the expected URL.""" + ti = create_task_instance_of_operator( + PowerBIDatasetRefreshOperator, + dag_id="test_powerbi_refresh_op_link", + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + group_id=GROUP_ID, + dataset_id=DATASET_ID, + check_interval=1, + timeout=3, ) - assert context["ti"].xcom_push.call_count == 0 - - -def test_execute_complete_no_event(): - """Test execute_complete when event is None or empty.""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - operator.execute_complete( - context=context, - event=None, - ) - assert context["ti"].xcom_push.call_count == 0 - - -@pytest.mark.db_test -def test_powerbilink(create_task_instance_of_operator): - """Assert Power BI Extra link matches the expected URL.""" - ti = create_task_instance_of_operator( - PowerBIDatasetRefreshOperator, - dag_id="test_powerbi_refresh_op_link", - execution_date=DEFAULT_DATE, - task_id=TASK_ID, - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - group_id=GROUP_ID, - dataset_id=DATASET_ID, - check_interval=1, - timeout=3, - ) - - ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) - url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") - EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( - "https://app.powerbi.com" # type: ignore[attr-defined] - f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] - "/details?experience=power-bi" - ) - - assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK + + ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) + url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") + EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( + "https://app.powerbi.com" # type: ignore[attr-defined] + f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] + "/details?experience=power-bi" + ) + + assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py index 5b44a84149501..c3276e258b3da 100644 --- a/tests/providers/microsoft/azure/triggers/test_powerbi.py +++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py @@ -19,11 +19,10 @@ import asyncio from unittest import mock -from unittest.mock import patch import pytest -from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.triggers.base import TriggerEvent from tests.providers.microsoft.conftest import get_airflow_connection @@ -54,19 +53,10 @@ def powerbi_trigger(): return trigger -@pytest.fixture -def mock_powerbi_hook(): - hook = PowerBIHook() - return hook - - -def test_powerbi_trigger_serialization(): - """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" - - with patch( - "airflow.hooks.base.BaseHook.get_connection", - side_effect=get_airflow_connection, - ): +class TestPowerBITrigger: + @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + def test_powerbi_trigger_serialization(self, connection): + """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" powerbi_trigger = PowerBITrigger( conn_id=POWERBI_CONN_ID, proxies=None, @@ -91,167 +81,170 @@ def test_powerbi_trigger_serialization(): "wait_for_termination": True, } - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_inprogress( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert task isn't completed until timeout if dataset refresh is in progress.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - task = asyncio.create_task(powerbi_trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # Assert TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_failed( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert event is triggered upon failed dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "Failed", - "message": f"The dataset refresh {DATASET_REFRESH_ID} has " - f"{PowerBIDatasetRefreshStatus.FAILED}.", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert expected == actual - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_completed( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert event is triggered upon successful dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "Completed", - "message": f"The dataset refresh {DATASET_REFRESH_ID} has " - f"{PowerBIDatasetRefreshStatus.COMPLETED}.", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert expected == actual - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_during_refresh_check_loop( - mock_trigger_dataset_refresh, - mock_get_refresh_details_by_refresh_id, - mock_cancel_dataset_refresh, - powerbi_trigger, -): - """Assert that run catch exception if Power BI API throw exception""" - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred: Test exception", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert len(task) == 1 - assert response in task - mock_cancel_dataset_refresh.assert_called_once() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_during_refresh_cancellation( - mock_trigger_dataset_refresh, - mock_get_refresh_details_by_refresh_id, - mock_cancel_dataset_refresh, - powerbi_trigger, -): - """Assert that run catch exception if Power BI API throw exception""" - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") - mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh") - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh", - "dataset_refresh_id": DATASET_REFRESH_ID, + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_inprogress( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert task isn't completed until timeout if dataset refresh is in progress.""" + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.IN_PROGRESS } - ) - - assert len(task) == 1 - assert response in task - mock_cancel_dataset_refresh.assert_called_once() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_without_refresh_id( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert handling of exception when there is no dataset_refresh_id""" - powerbi_trigger.dataset_refresh_id = None - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id") - mock_trigger_dataset_refresh.return_value = None + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + task = asyncio.create_task(powerbi_trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # Assert TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_failed( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert event is triggered upon failed dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Failed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.FAILED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred: Test exception for no dataset_refresh_id", - "dataset_refresh_id": None, + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_completed( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert event is triggered upon successful dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.COMPLETED } - ) - assert len(task) == 1 - assert response in task - + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Completed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.COMPLETED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_during_refresh_check_loop( + self, + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, + ): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_during_refresh_cancellation( + self, + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, + ): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_timeout( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert that powerbi run timesout after end_time elapses""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "error", - "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.", - "dataset_refresh_id": DATASET_REFRESH_ID, + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_without_refresh_id( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert handling of exception when there is no dataset_refresh_id""" + powerbi_trigger.dataset_refresh_id = None + mock_get_refresh_details_by_refresh_id.side_effect = Exception( + "Test exception for no dataset_refresh_id" + ) + mock_trigger_dataset_refresh.return_value = None + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception for no dataset_refresh_id", + "dataset_refresh_id": None, + } + ) + assert len(task) == 1 + assert response in task + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_timeout( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert that powerbi run timesout after end_time elapses""" + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.IN_PROGRESS } - ) + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "error", + "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) - assert expected == actual + assert expected == actual diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index 8a258735291f1..de25d24fb05e1 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -32,6 +32,7 @@ from msgraph_core import APIVersion from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook from airflow.utils.context import Context if TYPE_CHECKING: @@ -183,21 +184,45 @@ def load_file(*args: str, mode="r", encoding="utf-8"): def get_airflow_connection( conn_id: str, + host: str = "graph.microsoft.com", login: str = "client_id", password: str = "client_secret", tenant_id: str = "tenant-id", + azure_tenant_id: str | None = None, proxies: dict | None = None, - api_version: APIVersion = APIVersion.v1, + scopes: list[str] | None = None, + api_version: APIVersion | str | None = APIVersion.v1.value, + authority: str | None = None, + disable_instance_discovery: bool = False, ): from airflow.models import Connection + extra = { + "api_version": api_version, + "proxies": proxies or {}, + "verify": False, + "scopes": scopes or [], + "authority": authority, + "disable_instance_discovery": disable_instance_discovery, + } + + if azure_tenant_id: + extra["tenantId"] = azure_tenant_id + else: + extra["tenant_id"] = tenant_id + return Connection( schema="https", conn_id=conn_id, conn_type="http", - host="graph.microsoft.com", + host=host, port=80, login=login, password=password, - extra={"tenant_id": tenant_id, "api_version": api_version.value, "proxies": proxies or {}}, + extra=extra, ) + + +@pytest.fixture +def powerbi_hook(): + return PowerBIHook(**{"conn_id": "powerbi_conn_id", "timeout": 3, "api_version": "v1.0"}) diff --git a/tests/system/providers/microsoft/azure/example_msfabric.py b/tests/system/providers/microsoft/azure/example_msfabric.py new file mode 100644 index 0000000000000..7d62a49e0bc31 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_msfabric.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import models +from airflow.datasets import Dataset +from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator + +DAG_ID = "example_msfabric" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_ms_fabric_create_item_schedule] + # https://learn.microsoft.com/en-us/rest/api/fabric/core/job-scheduler/create-item-schedule?tabs=HTTP + workspaces_task = MSGraphAsyncOperator( + task_id="schedule_datapipeline", + conn_id="powerbi", + method="POST", + url="workspaces/{workspaceId}/items/{itemId}/jobs/instances", + path_parameters={ + "workspaceId": "e90b2873-4812-4dfb-9246-593638165644", + "itemId": "65448530-e5ec-4aeb-a97e-7cebf5d67c18", + }, + query_parameters={"jobType": "Pipeline"}, + dag=dag, + outlets=[ + Dataset( + "workspaces/e90b2873-4812-4dfb-9246-593638165644/items/65448530-e5ec-4aeb-a97e-7cebf5d67c18/jobs/instances?jobType=Pipeline" + ) + ], + ) + # [END howto_operator_ms_fabric_create_item_schedule] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)