diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 56abfa155da7c..8410d8d7077cd 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -110,12 +110,16 @@ def __init__( conn_id: str = default_conn_name, timeout: float | None = None, proxies: dict | None = None, + host: str = NationalClouds.Global.value, + scopes: list[str] | None = None, api_version: APIVersion | str | None = None, ): super().__init__() self.conn_id = conn_id self.timeout = timeout self.proxies = proxies + self.host = host + self.scopes = scopes or ["https://graph.microsoft.com/.default"] self._api_version = self.resolve_api_version_from_value(api_version) @property @@ -141,11 +145,10 @@ def get_api_version(self, config: dict) -> APIVersion: ) return self._api_version - @staticmethod - def get_host(connection: Connection) -> str: + def get_host(self, connection: Connection) -> str: if connection.schema and connection.host: return f"{connection.schema}://{connection.host}" - return NationalClouds.Global.value + return self.host @staticmethod def format_no_proxy_url(url: str) -> str: @@ -198,7 +201,7 @@ def get_conn(self) -> RequestAdapter: proxies = self.proxies or config.get("proxies", {}) msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) httpx_proxies = self.to_httpx_proxies(proxies=proxies) - scopes = config.get("scopes", ["https://graph.microsoft.com/.default"]) + scopes = config.get("scopes", self.scopes) verify = config.get("verify", True) trust_env = config.get("trust_env", False) disable_instance_discovery = config.get("disable_instance_discovery", False) diff --git a/airflow/providers/microsoft/azure/hooks/powerbi.py b/airflow/providers/microsoft/azure/hooks/powerbi.py new file mode 100644 index 0000000000000..04326f4fecee2 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/powerbi.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Any + +from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook + +if TYPE_CHECKING: + from msgraph_core import APIVersion + + +class PowerBIDatasetRefreshFields(Enum): + """Power BI refresh dataset details.""" + + REQUEST_ID = "request_id" + STATUS = "status" + ERROR = "error" + + +class PowerBIDatasetRefreshStatus: + """Power BI refresh dataset statuses.""" + + IN_PROGRESS = "In Progress" + FAILED = "Failed" + COMPLETED = "Completed" + DISABLED = "Disabled" + + TERMINAL_STATUSES = {FAILED, COMPLETED} + + +class PowerBIDatasetRefreshException(AirflowException): + """An exception that indicates a dataset refresh failed to complete.""" + + +class PowerBIHook(KiotaRequestAdapterHook): + """ + A async hook to interact with Power BI. + + :param conn_id: The Power BI connection id. + """ + + conn_type: str = "powerbi" + conn_name_attr: str = "conn_id" + default_conn_name: str = "powerbi_default" + hook_name: str = "Power BI" + + def __init__( + self, + conn_id: str = default_conn_name, + proxies: dict | None = None, + timeout: float = 60 * 60 * 24 * 7, + api_version: APIVersion | str | None = None, + ): + super().__init__( + conn_id=conn_id, + proxies=proxies, + timeout=timeout, + host="https://api.powerbi.com", + scopes=["https://analysis.windows.net/powerbi/api/.default"], + api_version=api_version, + ) + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Return connection widgets to add to connection form.""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + } + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Return custom field behaviour.""" + return { + "hidden_fields": ["schema", "port", "host", "extra"], + "relabeling": { + "login": "Client ID", + "password": "Client Secret", + }, + } + + async def get_refresh_history( + self, + dataset_id: str, + group_id: str, + ) -> list[dict[str, str]]: + """ + Retrieve the refresh history of the specified dataset from the given group ID. + + :param dataset_id: The dataset ID. + :param group_id: The workspace ID. + + :return: Dictionary containing all the refresh histories of the dataset. + """ + try: + response = await self.run( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes", + path_parameters={ + "group_id": group_id, + "dataset_id": dataset_id, + }, + ) + + refresh_histories = response.get("value") + return [self.raw_to_refresh_details(refresh_history) for refresh_history in refresh_histories] + + except AirflowException: + raise PowerBIDatasetRefreshException("Failed to retrieve refresh history") + + @classmethod + def raw_to_refresh_details(cls, refresh_details: dict) -> dict[str, str]: + """ + Convert raw refresh details into a dictionary containing required fields. + + :param refresh_details: Raw object of refresh details. + """ + return { + PowerBIDatasetRefreshFields.REQUEST_ID.value: str(refresh_details.get("requestId")), + PowerBIDatasetRefreshFields.STATUS.value: ( + "In Progress" + if str(refresh_details.get("status")) == "Unknown" + else str(refresh_details.get("status")) + ), + PowerBIDatasetRefreshFields.ERROR.value: str(refresh_details.get("serviceExceptionJson")), + } + + async def get_refresh_details_by_refresh_id( + self, dataset_id: str, group_id: str, refresh_id: str + ) -> dict[str, str]: + """ + Get the refresh details of the given request Id. + + :param refresh_id: Request Id of the Dataset refresh. + """ + refresh_histories = await self.get_refresh_history(dataset_id=dataset_id, group_id=group_id) + + if len(refresh_histories) == 0: + raise PowerBIDatasetRefreshException( + f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}" + ) + + refresh_ids = [ + refresh_history.get(PowerBIDatasetRefreshFields.REQUEST_ID.value) + for refresh_history in refresh_histories + ] + + if refresh_id not in refresh_ids: + raise PowerBIDatasetRefreshException( + f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}" + ) + + refresh_details = refresh_histories[refresh_ids.index(refresh_id)] + + return refresh_details + + async def trigger_dataset_refresh(self, *, dataset_id: str, group_id: str) -> str: + """ + Triggers a refresh for the specified dataset from the given group id. + + :param dataset_id: The dataset id. + :param group_id: The workspace id. + + :return: Request id of the dataset refresh request. + """ + try: + response = await self.run( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes", + method="POST", + path_parameters={ + "group_id": group_id, + "dataset_id": dataset_id, + }, + ) + + request_id = response.get("requestid") + return request_id + except AirflowException: + raise PowerBIDatasetRefreshException("Failed to trigger dataset refresh.") + + async def cancel_dataset_refresh(self, dataset_id: str, group_id: str, dataset_refresh_id: str) -> None: + """ + Cancel the dataset refresh. + + :param dataset_id: The dataset Id. + :param group_id: The workspace Id. + :param dataset_refresh_id: The dataset refresh Id. + """ + await self.run( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", + response_type=None, + path_parameters={ + "group_id": group_id, + "dataset_id": dataset_id, + "dataset_refresh_id": dataset_refresh_id, + }, + method="DELETE", + ) diff --git a/airflow/providers/microsoft/azure/operators/powerbi.py b/airflow/providers/microsoft/azure/operators/powerbi.py new file mode 100644 index 0000000000000..e54ad250bde74 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/powerbi.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIHook, +) +from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger + +if TYPE_CHECKING: + from msgraph_core import APIVersion + + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.context import Context + + +class PowerBILink(BaseOperatorLink): + """Construct a link to monitor a dataset in Power BI.""" + + name = "Monitor PowerBI Dataset" + + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): + url = ( + "https://app.powerbi.com" # type: ignore[attr-defined] + f"/groups/{operator.group_id}/datasets/{operator.dataset_id}" # type: ignore[attr-defined] + "/details?experience=power-bi" + ) + + return url + + +class PowerBIDatasetRefreshOperator(BaseOperator): + """ + Refreshes a Power BI dataset. + + :param dataset_id: The dataset id. + :param group_id: The workspace id. + :param conn_id: Airflow Connection ID that contains the connection information for the Power BI account used for authentication. + :param timeout: Time in seconds to wait for a dataset to reach a terminal status for asynchronous waits. Used only if ``wait_for_termination`` is True. + :param check_interval: Number of seconds to wait before rechecking the + refresh status. + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "group_id", + ) + template_fields_renderers = {"parameters": "json"} + + operator_extra_links = (PowerBILink(),) + + def __init__( + self, + *, + dataset_id: str, + group_id: str, + conn_id: str = PowerBIHook.default_conn_name, + timeout: float = 60 * 60 * 24 * 7, + proxies: dict | None = None, + api_version: APIVersion | None = None, + check_interval: int = 60, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) + self.dataset_id = dataset_id + self.group_id = group_id + self.wait_for_termination = True + self.conn_id = conn_id + self.timeout = timeout + self.check_interval = check_interval + + def execute(self, context: Context): + """Refresh the Power BI Dataset.""" + if self.wait_for_termination: + self.defer( + trigger=PowerBITrigger( + conn_id=self.conn_id, + group_id=self.group_id, + dataset_id=self.dataset_id, + timeout=self.timeout, + check_interval=self.check_interval, + wait_for_termination=self.wait_for_termination, + ), + method_name=self.execute_complete.__name__, + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> Any: + """ + Return immediately - callback for when the trigger fires. + + Relies on trigger to throw an exception, otherwise it assumes execution was successful. + """ + if event: + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.xcom_push( + context=context, key="powerbi_dataset_refresh_Id", value=event["dataset_refresh_id"] + ) + self.xcom_push(context=context, key="powerbi_dataset_refresh_status", value=event["status"]) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 04e7311b44b08..a2bc784173084 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -176,6 +176,9 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst tags: [azure] + - integration-name: Microsoft Power BI + external-doc-url: https://learn.microsoft.com/en-us/rest/api/power-bi/ + tags: [azure] operators: - integration-name: Microsoft Azure Data Lake Storage @@ -208,6 +211,9 @@ operators: - integration-name: Microsoft Graph API python-modules: - airflow.providers.microsoft.azure.operators.msgraph + - integration-name: Microsoft Power BI + python-modules: + - airflow.providers.microsoft.azure.operators.powerbi sensors: - integration-name: Microsoft Azure Cosmos DB @@ -268,6 +274,9 @@ hooks: - integration-name: Microsoft Graph API python-modules: - airflow.providers.microsoft.azure.hooks.msgraph + - integration-name: Microsoft Power BI + python-modules: + - airflow.providers.microsoft.azure.hooks.powerbi triggers: - integration-name: Microsoft Azure Data Factory @@ -279,6 +288,9 @@ triggers: - integration-name: Microsoft Graph API python-modules: - airflow.providers.microsoft.azure.triggers.msgraph + - integration-name: Microsoft Power BI + python-modules: + - airflow.providers.microsoft.azure.triggers.powerbi transfers: - source-integration-name: Local @@ -334,6 +346,8 @@ connection-types: connection-type: azure_synapse - hook-class-name: airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeStorageV2Hook connection-type: adls + - hook-class-name: airflow.providers.microsoft.azure.hooks.powerbi.PowerBIHook + connection-type: powerbi secrets-backends: - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend @@ -344,6 +358,7 @@ logging: extra-links: - airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryPipelineRunLink - airflow.providers.microsoft.azure.operators.synapse.AzureSynapsePipelineRunLink + - airflow.providers.microsoft.azure.operators.powerbi.PowerBILink config: azure_remote_logging: diff --git a/airflow/providers/microsoft/azure/triggers/powerbi.py b/airflow/providers/microsoft/azure/triggers/powerbi.py new file mode 100644 index 0000000000000..d25802b84fb74 --- /dev/null +++ b/airflow/providers/microsoft/azure/triggers/powerbi.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, AsyncIterator + +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIDatasetRefreshStatus, + PowerBIHook, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from msgraph_core import APIVersion + + +class PowerBITrigger(BaseTrigger): + """ + Triggers when Power BI dataset refresh is completed. + + Wait for termination will always be True. + + :param conn_id: The connection Id to connect to PowerBI. + :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). + When no timeout is specified or set to None then there is no HTTP timeout on each request. + :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as `v1.0` or `beta`. + :param dataset_id: The dataset Id to refresh. + :param group_id: The workspace Id where dataset is located. + :param end_time: Time in seconds when trigger should stop polling. + :param check_interval: Time in seconds to wait between each poll. + :param wait_for_termination: Wait for the dataset refresh to complete or fail. + """ + + def __init__( + self, + conn_id: str, + dataset_id: str, + group_id: str, + timeout: float = 60 * 60 * 24 * 7, + proxies: dict | None = None, + api_version: APIVersion | None = None, + check_interval: int = 60, + wait_for_termination: bool = True, + ): + super().__init__() + self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) + self.dataset_id = dataset_id + self.timeout = timeout + self.group_id = group_id + self.check_interval = check_interval + self.wait_for_termination = wait_for_termination + + def serialize(self): + """Serialize the trigger instance.""" + api_version = self.api_version.value if self.api_version else None + return ( + "airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger", + { + "conn_id": self.conn_id, + "proxies": self.proxies, + "api_version": api_version, + "dataset_id": self.dataset_id, + "group_id": self.group_id, + "timeout": self.timeout, + "check_interval": self.check_interval, + "wait_for_termination": self.wait_for_termination, + }, + ) + + @property + def conn_id(self) -> str: + return self.hook.conn_id + + @property + def proxies(self) -> dict | None: + return self.hook.proxies + + @property + def api_version(self) -> APIVersion: + return self.hook.api_version + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make async connection to the PowerBI and polls for the dataset refresh status.""" + self.dataset_refresh_id = await self.hook.trigger_dataset_refresh( + dataset_id=self.dataset_id, + group_id=self.group_id, + ) + try: + dataset_refresh_status = None + start_time = time.monotonic() + while start_time + self.timeout > time.monotonic(): + refresh_details = await self.hook.get_refresh_details_by_refresh_id( + dataset_id=self.dataset_id, + group_id=self.group_id, + refresh_id=self.dataset_refresh_id, + ) + dataset_refresh_status = refresh_details.get("status") + + if dataset_refresh_status == PowerBIDatasetRefreshStatus.COMPLETED: + yield TriggerEvent( + { + "status": dataset_refresh_status, + "message": f"The dataset refresh {self.dataset_refresh_id} has {dataset_refresh_status}.", + "dataset_refresh_id": self.dataset_refresh_id, + } + ) + return + elif dataset_refresh_status == PowerBIDatasetRefreshStatus.FAILED: + yield TriggerEvent( + { + "status": dataset_refresh_status, + "message": f"The dataset refresh {self.dataset_refresh_id} has {dataset_refresh_status}.", + "dataset_refresh_id": self.dataset_refresh_id, + } + ) + return + + self.log.info( + "Sleeping for %s. The dataset refresh status is %s.", + self.check_interval, + dataset_refresh_status, + ) + await asyncio.sleep(self.check_interval) + + yield TriggerEvent( + { + "status": "error", + "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {self.dataset_refresh_id} has status {dataset_refresh_status}.", + "dataset_refresh_id": self.dataset_refresh_id, + } + ) + return + except Exception as error: + if self.dataset_refresh_id: + try: + self.log.info( + "Unexpected error %s caught. Canceling dataset refresh %s", + error, + self.dataset_refresh_id, + ) + await self.hook.cancel_dataset_refresh( + dataset_id=self.dataset_id, + group_id=self.group_id, + dataset_refresh_id=self.dataset_refresh_id, + ) + except Exception as e: + yield TriggerEvent( + { + "status": "error", + "message": f"An error occurred while canceling dataset: {e}", + "dataset_refresh_id": self.dataset_refresh_id, + } + ) + return + yield TriggerEvent( + { + "status": "error", + "message": f"An error occurred: {error}", + "dataset_refresh_id": self.dataset_refresh_id, + } + ) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 71d280a1971da..390be17ba7f35 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -82,13 +82,15 @@ def test_get_api_version_when_api_version_in_config_dict(self): def test_get_host_when_connection_has_scheme_and_host(self): connection = mock_connection(schema="https", host="graph.microsoft.de") - actual = KiotaRequestAdapterHook.get_host(connection) + hook = KiotaRequestAdapterHook() + actual = hook.get_host(connection) assert actual == NationalClouds.Germany.value def test_get_host_when_connection_has_no_scheme_or_host(self): connection = mock_connection() - actual = KiotaRequestAdapterHook.get_host(connection) + hook = KiotaRequestAdapterHook() + actual = hook.get_host(connection) assert actual == NationalClouds.Global.value diff --git a/tests/providers/microsoft/azure/hooks/test_powerbi.py b/tests/providers/microsoft/azure/hooks/test_powerbi.py new file mode 100644 index 0000000000000..a3a521b45e820 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_powerbi.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIDatasetRefreshException, + PowerBIDatasetRefreshFields, + PowerBIDatasetRefreshStatus, + PowerBIHook, +) + +FORMATTED_RESPONSE = [ + # Completed refresh + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.COMPLETED, + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + # In-progress refresh + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "6b6536c1-cfcb-4148-9c21-402c3f5241e4", + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.IN_PROGRESS, + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + # Failed refresh + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "11bf290a-346b-48b7-8973-c5df149337ff", + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.FAILED, + PowerBIDatasetRefreshFields.ERROR.value: '{"errorCode":"ModelRefreshFailed_CredentialsNotSpecified"}', + }, +] + +DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" +GROUP_ID = "group_id" +DATASET_ID = "dataset_id" + +CONFIG = {"conn_id": DEFAULT_CONNECTION_CLIENT_SECRET, "timeout": 3, "api_version": "v1.0"} + + +@pytest.fixture +def powerbi_hook(): + return PowerBIHook(**CONFIG) + + +@pytest.mark.asyncio +async def test_get_refresh_history(powerbi_hook): + response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + expected = [{"request_id": "1234", "status": "Completed", "error": ""}] + assert result == expected + + +@pytest.mark.asyncio +async def test_get_refresh_history_airflow_exception(powerbi_hook): + """Test handling of AirflowException in get_refresh_history.""" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"): + await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "", + }, + ), + ( + {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678", + PowerBIDatasetRefreshFields.STATUS.value: "In Progress", + PowerBIDatasetRefreshFields.ERROR.value: "Some error", + }, + ), + ( + {"requestId": None, "status": None, "serviceExceptionJson": None}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ( + {}, # Empty input dictionary + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ], +) +def test_raw_to_refresh_details(input_data, expected_output): + """Test raw_to_refresh_details method.""" + result = PowerBIHook.raw_to_refresh_details(input_data) + assert result == expected_output + + +@pytest.mark.asyncio +async def test_get_refresh_details_by_refresh_id(powerbi_hook): + # Mock the get_refresh_history method to return a list of refresh histories + refresh_histories = FORMATTED_RESPONSE + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories) + + # Call the function with a valid request ID + refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + result = await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id + ) + + # Assert that the correct refresh details are returned + assert result == { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "None", + } + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises(PowerBIDatasetRefreshException): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + +@pytest.mark.asyncio +async def test_get_refresh_details_by_refresh_id_empty_history(powerbi_hook): + """Test exception when refresh history is empty.""" + # Mock the get_refresh_history method to return an empty list + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[]) + + # Call the function with a request ID + refresh_id = "any_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id + ) + + +@pytest.mark.asyncio +async def test_get_refresh_details_by_refresh_id_not_found(powerbi_hook): + """Test exception when the refresh ID is not found in the refresh history.""" + # Mock the get_refresh_history method to return a list of refresh histories without the specified ID + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE) + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + +@pytest.mark.asyncio +async def test_trigger_dataset_refresh_success(powerbi_hook): + response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + +@pytest.mark.asyncio +async def test_trigger_dataset_refresh_failure(powerbi_hook): + """Test failure to trigger dataset refresh due to AirflowException.""" + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."): + await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + +@pytest.mark.asyncio +async def test_cancel_dataset_refresh(powerbi_hook): + dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id) + + mock_run.assert_called_once_with( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", + response_type=None, + path_parameters={ + "group_id": GROUP_ID, + "dataset_id": DATASET_ID, + "dataset_refresh_id": dataset_refresh_id, + }, + method="DELETE", + ) diff --git a/tests/providers/microsoft/azure/operators/test_powerbi.py b/tests/providers/microsoft/azure/operators/test_powerbi.py new file mode 100644 index 0000000000000..2ee5ee723d7a7 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_powerbi.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIDatasetRefreshFields, + PowerBIDatasetRefreshStatus, + PowerBIHook, +) +from airflow.providers.microsoft.azure.operators.powerbi import PowerBIDatasetRefreshOperator +from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger +from airflow.utils import timezone + +DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" +TASK_ID = "run_powerbi_operator" +GROUP_ID = "group_id" +DATASET_ID = "dataset_id" +CONFIG = { + "task_id": TASK_ID, + "conn_id": DEFAULT_CONNECTION_CLIENT_SECRET, + "group_id": GROUP_ID, + "dataset_id": DATASET_ID, + "check_interval": 1, + "timeout": 3, +} +NEW_REFRESH_REQUEST_ID = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + +SUCCESS_TRIGGER_EVENT = { + "status": "success", + "message": "success", + "dataset_refresh_id": NEW_REFRESH_REQUEST_ID, +} + +DEFAULT_DATE = timezone.datetime(2021, 1, 1) + + +# Sample responses from PowerBI API +COMPLETED_REFRESH_DETAILS = { + PowerBIDatasetRefreshFields.REQUEST_ID.value: NEW_REFRESH_REQUEST_ID, + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.COMPLETED, +} + +FAILED_REFRESH_DETAILS = { + PowerBIDatasetRefreshFields.REQUEST_ID.value: NEW_REFRESH_REQUEST_ID, + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.FAILED, + PowerBIDatasetRefreshFields.ERROR.value: '{"errorCode":"ModelRefreshFailed_CredentialsNotSpecified"}', +} + +IN_PROGRESS_REFRESH_DETAILS = { + PowerBIDatasetRefreshFields.REQUEST_ID.value: NEW_REFRESH_REQUEST_ID, + PowerBIDatasetRefreshFields.STATUS.value: PowerBIDatasetRefreshStatus.IN_PROGRESS, # endtime is not available. +} + + +@pytest.fixture +def mock_powerbi_hook(): + hook = PowerBIHook() + return hook + + +def test_execute_wait_for_termination_with_Deferrable(mock_powerbi_hook): + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + operator.hook = mock_powerbi_hook + context = {"ti": MagicMock()} + + with pytest.raises(TaskDeferred) as exc: + operator.execute(context) + + assert isinstance(exc.value.trigger, PowerBITrigger) + + +def test_powerbi_operator_async_execute_complete_success(): + """Assert that execute_complete log success message""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + operator.execute_complete( + context=context, + event=SUCCESS_TRIGGER_EVENT, + ) + assert context["ti"].xcom_push.call_count == 2 + + +def test_powerbi_operator_async_execute_complete_fail(): + """Assert that execute_complete raise exception on error""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + with pytest.raises(AirflowException): + operator.execute_complete( + context=context, + event={"status": "error", "message": "error", "dataset_refresh_id": "1234"}, + ) + assert context["ti"].xcom_push.call_count == 0 + + +def test_execute_complete_no_event(): + """Test execute_complete when event is None or empty.""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + operator.execute_complete( + context=context, + event=None, + ) + assert context["ti"].xcom_push.call_count == 0 + + +@pytest.mark.db_test +def test_powerbilink(create_task_instance_of_operator): + """Assert Power BI Extra link matches the expected URL.""" + ti = create_task_instance_of_operator( + PowerBIDatasetRefreshOperator, + dag_id="test_powerbi_refresh_op_link", + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + group_id=GROUP_ID, + dataset_id=DATASET_ID, + check_interval=1, + timeout=3, + ) + + ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) + url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") + EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( + "https://app.powerbi.com" # type: ignore[attr-defined] + f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] + "/details?experience=power-bi" + ) + + assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py new file mode 100644 index 0000000000000..5b44a84149501 --- /dev/null +++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock +from unittest.mock import patch + +import pytest + +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook +from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.microsoft.conftest import get_airflow_connection + +POWERBI_CONN_ID = "powerbi_default" +DATASET_ID = "dataset_id" +GROUP_ID = "group_id" +DATASET_REFRESH_ID = "dataset_refresh_id" +TIMEOUT = 30 +MODULE = "airflow.providers.microsoft.azure" +CHECK_INTERVAL = 10 +API_VERSION = "v1.0" + + +@pytest.fixture +def powerbi_trigger(): + trigger = PowerBITrigger( + conn_id=POWERBI_CONN_ID, + proxies=None, + api_version=API_VERSION, + dataset_id=DATASET_ID, + group_id=GROUP_ID, + check_interval=CHECK_INTERVAL, + wait_for_termination=True, + timeout=TIMEOUT, + ) + + return trigger + + +@pytest.fixture +def mock_powerbi_hook(): + hook = PowerBIHook() + return hook + + +def test_powerbi_trigger_serialization(): + """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" + + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + powerbi_trigger = PowerBITrigger( + conn_id=POWERBI_CONN_ID, + proxies=None, + api_version=API_VERSION, + dataset_id=DATASET_ID, + group_id=GROUP_ID, + check_interval=CHECK_INTERVAL, + wait_for_termination=True, + timeout=TIMEOUT, + ) + + classpath, kwargs = powerbi_trigger.serialize() + assert classpath == f"{MODULE}.triggers.powerbi.PowerBITrigger" + assert kwargs == { + "conn_id": POWERBI_CONN_ID, + "dataset_id": DATASET_ID, + "timeout": TIMEOUT, + "group_id": GROUP_ID, + "proxies": None, + "api_version": API_VERSION, + "check_interval": CHECK_INTERVAL, + "wait_for_termination": True, + } + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_inprogress( + mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger +): + """Assert task isn't completed until timeout if dataset refresh is in progress.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + task = asyncio.create_task(powerbi_trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # Assert TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_failed( + mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger +): + """Assert event is triggered upon failed dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Failed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.FAILED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_completed( + mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger +): + """Assert event is triggered upon successful dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Completed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.COMPLETED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_exception_during_refresh_check_loop( + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, +): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_exception_during_refresh_cancellation( + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, +): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_exception_without_refresh_id( + mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger +): + """Assert handling of exception when there is no dataset_refresh_id""" + powerbi_trigger.dataset_refresh_id = None + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id") + mock_trigger_dataset_refresh.return_value = None + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception for no dataset_refresh_id", + "dataset_refresh_id": None, + } + ) + assert len(task) == 1 + assert response in task + + +@pytest.mark.asyncio +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") +@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") +async def test_powerbi_trigger_run_timeout( + mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger +): + """Assert that powerbi run timesout after end_time elapses""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "error", + "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + + assert expected == actual diff --git a/tests/system/providers/microsoft/azure/example_powerbi_dataset_refresh.py b/tests/system/providers/microsoft/azure/example_powerbi_dataset_refresh.py new file mode 100644 index 0000000000000..52f1f001e9988 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_powerbi_dataset_refresh.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG, settings +from airflow.decorators import task +from airflow.models import Connection +from airflow.models.baseoperator import chain +from airflow.providers.microsoft.azure.operators.powerbi import PowerBIDatasetRefreshOperator + +DAG_ID = "example_refresh_powerbi_dataset" +CONN_ID = "powerbi_default" + +# Before running this system test, you should set following environment variables: +DATASET_ID = os.environ.get("DATASET_ID", "None") +GROUP_ID = os.environ.get("GROUP_ID", "None") +CLIENT_ID = os.environ.get("CLIENT_ID", None) +CLIENT_SECRET = os.environ.get("CLIENT_SECRET", None) +TENANT_ID = os.environ.get("TENANT_ID", None) + + +@task +def create_connection(conn_id_name: str): + conn = Connection( + conn_id=conn_id_name, + conn_type="powerbi", + login=CLIENT_ID, + password=CLIENT_SECRET, + extra={"tenant_id": TENANT_ID}, + ) + session = settings.Session() + session.add(conn) + session.commit() + + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + set_up_connection = create_connection(CONN_ID) + + # [START howto_operator_powerbi_refresh_async] + refresh_powerbi_dataset = PowerBIDatasetRefreshOperator( + conn_id="powerbi_default", + task_id="refresh_powerbi_dataset", + dataset_id=DATASET_ID, + group_id=GROUP_ID, + check_interval=30, + timeout=120, + ) + # [END howto_operator_powerbi_refresh_async] + + chain( + # TEST SETUP + set_up_connection, + # TEST BODY + refresh_powerbi_dataset, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)