diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py
index 8410d8d7077cd..61e555f4caa78 100644
--- a/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -96,6 +96,8 @@ class KiotaRequestAdapterHook(BaseHook):
:param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
When no timeout is specified or set to None then no HTTP timeout is applied on each request.
:param proxies: A Dict defining the HTTP proxies to be used (default is None).
+ :param host: The host to be used (default is "https://graph.microsoft.com").
+ :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]).
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as "v1.0" or "beta".
@@ -123,27 +125,22 @@ def __init__(
self._api_version = self.resolve_api_version_from_value(api_version)
@property
- def api_version(self) -> APIVersion:
+ def api_version(self) -> str | None:
self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
return self._api_version
@staticmethod
def resolve_api_version_from_value(
- api_version: APIVersion | str, default: APIVersion | None = None
- ) -> APIVersion:
+ api_version: APIVersion | str, default: str | None = None
+ ) -> str | None:
if isinstance(api_version, APIVersion):
- return api_version
- return next(
- filter(lambda version: version.value == api_version, APIVersion),
- default,
- )
+ return api_version.value
+ return api_version or default
- def get_api_version(self, config: dict) -> APIVersion:
- if self._api_version is None:
- return self.resolve_api_version_from_value(
- api_version=config.get("api_version"), default=APIVersion.v1
- )
- return self._api_version
+ def get_api_version(self, config: dict) -> str:
+ return self._api_version or self.resolve_api_version_from_value(
+ config.get("api_version"), APIVersion.v1.value
+ ) # type: ignore
def get_host(self, connection: Connection) -> str:
if connection.schema and connection.host:
@@ -169,15 +166,15 @@ def to_httpx_proxies(cls, proxies: dict) -> dict:
return proxies
def to_msal_proxies(self, authority: str | None, proxies: dict):
- self.log.info("authority: %s", authority)
+ self.log.debug("authority: %s", authority)
if authority:
no_proxies = proxies.get("no")
- self.log.info("no_proxies: %s", no_proxies)
+ self.log.debug("no_proxies: %s", no_proxies)
if no_proxies:
for url in no_proxies.split(","):
self.log.info("url: %s", url)
domain_name = urlparse(url).path.replace("*", "")
- self.log.info("domain_name: %s", domain_name)
+ self.log.debug("domain_name: %s", domain_name)
if authority.endswith(domain_name):
return None
return proxies
@@ -193,10 +190,10 @@ def get_conn(self) -> RequestAdapter:
client_id = connection.login
client_secret = connection.password
config = connection.extra_dejson if connection.extra else {}
- tenant_id = config.get("tenant_id")
+ tenant_id = config.get("tenant_id") or config.get("tenantId")
api_version = self.get_api_version(config)
host = self.get_host(connection)
- base_url = config.get("base_url", urljoin(host, api_version.value))
+ base_url = config.get("base_url", urljoin(host, api_version))
authority = config.get("authority")
proxies = self.proxies or config.get("proxies", {})
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
@@ -209,7 +206,7 @@ def get_conn(self) -> RequestAdapter:
self.log.info(
"Creating Microsoft Graph SDK client %s for conn_id: %s",
- api_version.value,
+ api_version,
self.conn_id,
)
self.log.info("Host: %s", host)
@@ -217,7 +214,7 @@ def get_conn(self) -> RequestAdapter:
self.log.info("Tenant id: %s", tenant_id)
self.log.info("Client id: %s", client_id)
self.log.info("Client secret: %s", client_secret)
- self.log.info("API version: %s", api_version.value)
+ self.log.info("API version: %s", api_version)
self.log.info("Scope: %s", scopes)
self.log.info("Verify: %s", verify)
self.log.info("Timeout: %s", self.timeout)
@@ -238,17 +235,17 @@ def get_conn(self) -> RequestAdapter:
connection_verify=verify,
)
http_client = GraphClientFactory.create_with_default_middleware(
- api_version=api_version,
+ api_version=api_version, # type: ignore
client=httpx.AsyncClient(
proxies=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
),
- host=host,
+ host=host, # type: ignore
)
auth_provider = AzureIdentityAuthenticationProvider(
- credentials=credentials,
+ credentials=credentials, # type: ignore
scopes=scopes,
allowed_hosts=allowed_hosts,
)
@@ -295,7 +292,7 @@ async def run(
error_map=self.error_mapping(),
)
- self.log.info("response: %s", response)
+ self.log.debug("response: %s", response)
return response
diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py
index cd387954737ab..74409f3600a1e 100644
--- a/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -99,7 +99,7 @@ def __init__(
key: str = XCOM_RETURN_KEY,
timeout: float | None = None,
proxies: dict | None = None,
- api_version: APIVersion | None = None,
+ api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
diff --git a/airflow/providers/microsoft/azure/operators/powerbi.py b/airflow/providers/microsoft/azure/operators/powerbi.py
index e54ad250bde74..fc812e852d90b 100644
--- a/airflow/providers/microsoft/azure/operators/powerbi.py
+++ b/airflow/providers/microsoft/azure/operators/powerbi.py
@@ -76,7 +76,7 @@ def __init__(
conn_id: str = PowerBIHook.default_conn_name,
timeout: float = 60 * 60 * 24 * 7,
proxies: dict | None = None,
- api_version: APIVersion | None = None,
+ api_version: APIVersion | str | None = None,
check_interval: int = 60,
**kwargs,
) -> None:
@@ -89,6 +89,14 @@ def __init__(
self.timeout = timeout
self.check_interval = check_interval
+ @property
+ def proxies(self) -> dict | None:
+ return self.hook.proxies
+
+ @property
+ def api_version(self) -> str | None:
+ return self.hook.api_version
+
def execute(self, context: Context):
"""Refresh the Power BI Dataset."""
if self.wait_for_termination:
@@ -98,6 +106,8 @@ def execute(self, context: Context):
group_id=self.group_id,
dataset_id=self.dataset_id,
timeout=self.timeout,
+ proxies=self.proxies,
+ api_version=self.api_version,
check_interval=self.check_interval,
wait_for_termination=self.wait_for_termination,
),
diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py
index 3e1b10cbeb1e6..6736ea59c918d 100644
--- a/airflow/providers/microsoft/azure/sensors/msgraph.py
+++ b/airflow/providers/microsoft/azure/sensors/msgraph.py
@@ -82,7 +82,7 @@ def __init__(
data: dict[str, Any] | str | BytesIO | None = None,
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
proxies: dict | None = None,
- api_version: APIVersion | None = None,
+ api_version: APIVersion | str | None = None,
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py
index 4b9ccb7a71716..0015964be86dd 100644
--- a/airflow/providers/microsoft/azure/triggers/msgraph.py
+++ b/airflow/providers/microsoft/azure/triggers/msgraph.py
@@ -122,7 +122,7 @@ def __init__(
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
timeout: float | None = None,
proxies: dict | None = None,
- api_version: APIVersion | None = None,
+ api_version: APIVersion | str | None = None,
serializer: type[ResponseSerializer] = ResponseSerializer,
):
super().__init__()
@@ -152,14 +152,13 @@ def resolve_type(cls, value: str | type, default) -> type:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the HttpTrigger arguments and classpath."""
- api_version = self.api_version.value if self.api_version else None
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"conn_id": self.conn_id,
"timeout": self.timeout,
"proxies": self.proxies,
- "api_version": api_version,
+ "api_version": self.api_version,
"serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}",
"url": self.url,
"path_parameters": self.path_parameters,
@@ -188,7 +187,7 @@ def proxies(self) -> dict | None:
return self.hook.proxies
@property
- def api_version(self) -> APIVersion:
+ def api_version(self) -> APIVersion | str:
return self.hook.api_version
async def run(self) -> AsyncIterator[TriggerEvent]:
diff --git a/airflow/providers/microsoft/azure/triggers/powerbi.py b/airflow/providers/microsoft/azure/triggers/powerbi.py
index d25802b84fb74..a74898f55f28a 100644
--- a/airflow/providers/microsoft/azure/triggers/powerbi.py
+++ b/airflow/providers/microsoft/azure/triggers/powerbi.py
@@ -58,7 +58,7 @@ def __init__(
group_id: str,
timeout: float = 60 * 60 * 24 * 7,
proxies: dict | None = None,
- api_version: APIVersion | None = None,
+ api_version: APIVersion | str | None = None,
check_interval: int = 60,
wait_for_termination: bool = True,
):
@@ -72,13 +72,12 @@ def __init__(
def serialize(self):
"""Serialize the trigger instance."""
- api_version = self.api_version.value if self.api_version else None
return (
"airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger",
{
"conn_id": self.conn_id,
"proxies": self.proxies,
- "api_version": api_version,
+ "api_version": self.api_version,
"dataset_id": self.dataset_id,
"group_id": self.group_id,
"timeout": self.timeout,
@@ -96,7 +95,7 @@ def proxies(self) -> dict | None:
return self.hook.proxies
@property
- def api_version(self) -> APIVersion:
+ def api_version(self) -> APIVersion | str:
return self.hook.api_version
async def run(self) -> AsyncIterator[TriggerEvent]:
diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
index 342bf542762ac..56a9259f93143 100644
--- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
@@ -72,6 +72,14 @@ Below is an example of using this operator to refresh PowerBI dataset.
:start-after: [START howto_operator_powerbi_refresh_dataset]
:end-before: [END howto_operator_powerbi_refresh_dataset]
+Below is an example of using this operator to create an item schedule in Fabric.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msfabric.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_ms_fabric_create_item_schedule]
+ :end-before: [END howto_operator_ms_fabric_create_item_schedule]
+
Reference
---------
@@ -80,3 +88,4 @@ For further information, look at:
* `Use the Microsoft Graph API `__
* `Using the Power BI REST APIs `__
+* `Using the Fabric REST APIs `__
diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py
index 390be17ba7f35..04e85525616bc 100644
--- a/tests/providers/microsoft/azure/hooks/test_msgraph.py
+++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py
@@ -18,6 +18,7 @@
import asyncio
from json import JSONDecodeError
+from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
@@ -38,8 +39,20 @@
mock_response,
)
+if TYPE_CHECKING:
+ from kiota_abstractions.request_adapter import RequestAdapter
+
class TestKiotaRequestAdapterHook:
+ def setup_method(self):
+ KiotaRequestAdapterHook.cached_request_adapters.clear()
+
+ @staticmethod
+ def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str):
+ assert isinstance(request_adapter, HttpxRequestAdapter)
+ tenant_id = request_adapter._authentication_provider.access_token_provider._credentials._tenant_id
+ assert tenant_id == expected_tenant_id
+
def test_get_conn(self):
with patch(
"airflow.hooks.base.BaseHook.get_connection",
@@ -51,6 +64,23 @@ def test_get_conn(self):
assert isinstance(actual, HttpxRequestAdapter)
assert actual.base_url == "https://graph.microsoft.com/v1.0"
+ def test_get_conn_with_custom_base_url(self):
+ connection = lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ host="api.fabric.microsoft.com",
+ api_version="v1",
+ )
+
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
+
+ assert isinstance(actual, HttpxRequestAdapter)
+ assert actual.base_url == "https://api.fabric.microsoft.com/v1"
+
def test_api_version(self):
with patch(
"airflow.hooks.base.BaseHook.get_connection",
@@ -58,7 +88,7 @@ def test_api_version(self):
):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
- assert hook.api_version == APIVersion.v1
+ assert hook.api_version == APIVersion.v1.value
def test_get_api_version_when_empty_config_dict(self):
with patch(
@@ -68,7 +98,7 @@ def test_get_api_version_when_empty_config_dict(self):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
actual = hook.get_api_version({})
- assert actual == APIVersion.v1
+ assert actual == APIVersion.v1.value
def test_get_api_version_when_api_version_in_config_dict(self):
with patch(
@@ -78,21 +108,64 @@ def test_get_api_version_when_api_version_in_config_dict(self):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
actual = hook.get_api_version({"api_version": "beta"})
- assert actual == APIVersion.beta
+ assert actual == APIVersion.beta.value
+
+ def test_get_api_version_when_custom_api_version_in_config_dict(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version="v1")
+ actual = hook.get_api_version({})
+
+ assert actual == "v1"
def test_get_host_when_connection_has_scheme_and_host(self):
- connection = mock_connection(schema="https", host="graph.microsoft.de")
- hook = KiotaRequestAdapterHook()
- actual = hook.get_host(connection)
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ connection = mock_connection(schema="https", host="graph.microsoft.de")
+ actual = hook.get_host(connection)
- assert actual == NationalClouds.Germany.value
+ assert actual == NationalClouds.Germany.value
def test_get_host_when_connection_has_no_scheme_or_host(self):
- connection = mock_connection()
- hook = KiotaRequestAdapterHook()
- actual = hook.get_host(connection)
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ connection = mock_connection()
+ actual = hook.get_host(connection)
+
+ assert actual == NationalClouds.Global.value
+
+ def test_tenant_id(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
+
+ self.assert_tenant_id(actual, "tenant-id")
+
+ def test_azure_tenant_id(self):
+ airflow_connection = lambda conn_id: get_airflow_connection(
+ conn_id=conn_id,
+ azure_tenant_id="azure-tenant-id",
+ )
+
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
- assert actual == NationalClouds.Global.value
+ self.assert_tenant_id(actual, "azure-tenant-id")
def test_encoded_query_parameters(self):
actual = KiotaRequestAdapterHook.encoded_query_parameters(
diff --git a/tests/providers/microsoft/azure/hooks/test_powerbi.py b/tests/providers/microsoft/azure/hooks/test_powerbi.py
index a3a521b45e820..f22116f6efb03 100644
--- a/tests/providers/microsoft/azure/hooks/test_powerbi.py
+++ b/tests/providers/microsoft/azure/hooks/test_powerbi.py
@@ -54,176 +54,162 @@
GROUP_ID = "group_id"
DATASET_ID = "dataset_id"
-CONFIG = {"conn_id": DEFAULT_CONNECTION_CLIENT_SECRET, "timeout": 3, "api_version": "v1.0"}
-
-@pytest.fixture
-def powerbi_hook():
- return PowerBIHook(**CONFIG)
-
-
-@pytest.mark.asyncio
-async def test_get_refresh_history(powerbi_hook):
- response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]}
-
- with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
- mock_run.return_value = response_data
- result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID)
-
- expected = [{"request_id": "1234", "status": "Completed", "error": ""}]
- assert result == expected
-
-
-@pytest.mark.asyncio
-async def test_get_refresh_history_airflow_exception(powerbi_hook):
- """Test handling of AirflowException in get_refresh_history."""
-
- with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
- mock_run.side_effect = AirflowException("Test exception")
-
- with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"):
- await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID)
-
-
-@pytest.mark.parametrize(
- "input_data, expected_output",
- [
- (
- {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""},
- {
- PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234",
- PowerBIDatasetRefreshFields.STATUS.value: "Completed",
- PowerBIDatasetRefreshFields.ERROR.value: "",
- },
- ),
- (
- {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"},
- {
- PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678",
- PowerBIDatasetRefreshFields.STATUS.value: "In Progress",
- PowerBIDatasetRefreshFields.ERROR.value: "Some error",
- },
- ),
- (
- {"requestId": None, "status": None, "serviceExceptionJson": None},
- {
- PowerBIDatasetRefreshFields.REQUEST_ID.value: "None",
- PowerBIDatasetRefreshFields.STATUS.value: "None",
- PowerBIDatasetRefreshFields.ERROR.value: "None",
- },
- ),
- (
- {}, # Empty input dictionary
- {
- PowerBIDatasetRefreshFields.REQUEST_ID.value: "None",
- PowerBIDatasetRefreshFields.STATUS.value: "None",
- PowerBIDatasetRefreshFields.ERROR.value: "None",
- },
- ),
- ],
-)
-def test_raw_to_refresh_details(input_data, expected_output):
- """Test raw_to_refresh_details method."""
- result = PowerBIHook.raw_to_refresh_details(input_data)
- assert result == expected_output
-
-
-@pytest.mark.asyncio
-async def test_get_refresh_details_by_refresh_id(powerbi_hook):
- # Mock the get_refresh_history method to return a list of refresh histories
- refresh_histories = FORMATTED_RESPONSE
- powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories)
-
- # Call the function with a valid request ID
- refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
- result = await powerbi_hook.get_refresh_details_by_refresh_id(
- dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id
+class TestPowerBIHook:
+ @pytest.mark.asyncio
+ async def test_get_refresh_history(self, powerbi_hook):
+ response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]}
+
+ with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
+ mock_run.return_value = response_data
+ result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID)
+
+ expected = [{"request_id": "1234", "status": "Completed", "error": ""}]
+ assert result == expected
+
+ @pytest.mark.asyncio
+ async def test_get_refresh_history_airflow_exception(self, powerbi_hook):
+ """Test handling of AirflowException in get_refresh_history."""
+
+ with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
+ mock_run.side_effect = AirflowException("Test exception")
+
+ with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"):
+ await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID)
+
+ @pytest.mark.parametrize(
+ "input_data, expected_output",
+ [
+ (
+ {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""},
+ {
+ PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234",
+ PowerBIDatasetRefreshFields.STATUS.value: "Completed",
+ PowerBIDatasetRefreshFields.ERROR.value: "",
+ },
+ ),
+ (
+ {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"},
+ {
+ PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678",
+ PowerBIDatasetRefreshFields.STATUS.value: "In Progress",
+ PowerBIDatasetRefreshFields.ERROR.value: "Some error",
+ },
+ ),
+ (
+ {"requestId": None, "status": None, "serviceExceptionJson": None},
+ {
+ PowerBIDatasetRefreshFields.REQUEST_ID.value: "None",
+ PowerBIDatasetRefreshFields.STATUS.value: "None",
+ PowerBIDatasetRefreshFields.ERROR.value: "None",
+ },
+ ),
+ (
+ {}, # Empty input dictionary
+ {
+ PowerBIDatasetRefreshFields.REQUEST_ID.value: "None",
+ PowerBIDatasetRefreshFields.STATUS.value: "None",
+ PowerBIDatasetRefreshFields.ERROR.value: "None",
+ },
+ ),
+ ],
)
-
- # Assert that the correct refresh details are returned
- assert result == {
- PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c",
- PowerBIDatasetRefreshFields.STATUS.value: "Completed",
- PowerBIDatasetRefreshFields.ERROR.value: "None",
- }
-
- # Call the function with an invalid request ID
- invalid_request_id = "invalid_request_id"
- with pytest.raises(PowerBIDatasetRefreshException):
- await powerbi_hook.get_refresh_details_by_refresh_id(
- dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id
- )
-
-
-@pytest.mark.asyncio
-async def test_get_refresh_details_by_refresh_id_empty_history(powerbi_hook):
- """Test exception when refresh history is empty."""
- # Mock the get_refresh_history method to return an empty list
- powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[])
-
- # Call the function with a request ID
- refresh_id = "any_request_id"
- with pytest.raises(
- PowerBIDatasetRefreshException,
- match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}",
- ):
- await powerbi_hook.get_refresh_details_by_refresh_id(
+ def test_raw_to_refresh_details(self, input_data, expected_output):
+ """Test raw_to_refresh_details method."""
+ result = PowerBIHook.raw_to_refresh_details(input_data)
+ assert result == expected_output
+
+ @pytest.mark.asyncio
+ async def test_get_refresh_details_by_refresh_id(self, powerbi_hook):
+ # Mock the get_refresh_history method to return a list of refresh histories
+ refresh_histories = FORMATTED_RESPONSE
+ powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories)
+
+ # Call the function with a valid request ID
+ refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
+ result = await powerbi_hook.get_refresh_details_by_refresh_id(
dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id
)
-
-@pytest.mark.asyncio
-async def test_get_refresh_details_by_refresh_id_not_found(powerbi_hook):
- """Test exception when the refresh ID is not found in the refresh history."""
- # Mock the get_refresh_history method to return a list of refresh histories without the specified ID
- powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE)
-
- # Call the function with an invalid request ID
- invalid_request_id = "invalid_request_id"
- with pytest.raises(
- PowerBIDatasetRefreshException,
- match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}",
- ):
- await powerbi_hook.get_refresh_details_by_refresh_id(
- dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id
+ # Assert that the correct refresh details are returned
+ assert result == {
+ PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c",
+ PowerBIDatasetRefreshFields.STATUS.value: "Completed",
+ PowerBIDatasetRefreshFields.ERROR.value: "None",
+ }
+
+ # Call the function with an invalid request ID
+ invalid_request_id = "invalid_request_id"
+ with pytest.raises(PowerBIDatasetRefreshException):
+ await powerbi_hook.get_refresh_details_by_refresh_id(
+ dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_refresh_details_by_refresh_id_empty_history(self, powerbi_hook):
+ """Test exception when refresh history is empty."""
+ # Mock the get_refresh_history method to return an empty list
+ powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[])
+
+ # Call the function with a request ID
+ refresh_id = "any_request_id"
+ with pytest.raises(
+ PowerBIDatasetRefreshException,
+ match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}",
+ ):
+ await powerbi_hook.get_refresh_details_by_refresh_id(
+ dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_refresh_details_by_refresh_id_not_found(self, powerbi_hook):
+ """Test exception when the refresh ID is not found in the refresh history."""
+ # Mock the get_refresh_history method to return a list of refresh histories without the specified ID
+ powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE)
+
+ # Call the function with an invalid request ID
+ invalid_request_id = "invalid_request_id"
+ with pytest.raises(
+ PowerBIDatasetRefreshException,
+ match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}",
+ ):
+ await powerbi_hook.get_refresh_details_by_refresh_id(
+ dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id
+ )
+
+ @pytest.mark.asyncio
+ async def test_trigger_dataset_refresh_success(self, powerbi_hook):
+ response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"}
+
+ with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
+ mock_run.return_value = response_data
+ result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID)
+
+ assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
+
+ @pytest.mark.asyncio
+ async def test_trigger_dataset_refresh_failure(self, powerbi_hook):
+ """Test failure to trigger dataset refresh due to AirflowException."""
+ with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
+ mock_run.side_effect = AirflowException("Test exception")
+
+ with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."):
+ await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID)
+
+ @pytest.mark.asyncio
+ async def test_cancel_dataset_refresh(self, powerbi_hook):
+ dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
+
+ with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
+ await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id)
+
+ mock_run.assert_called_once_with(
+ url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}",
+ response_type=None,
+ path_parameters={
+ "group_id": GROUP_ID,
+ "dataset_id": DATASET_ID,
+ "dataset_refresh_id": dataset_refresh_id,
+ },
+ method="DELETE",
)
-
-
-@pytest.mark.asyncio
-async def test_trigger_dataset_refresh_success(powerbi_hook):
- response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"}
-
- with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
- mock_run.return_value = response_data
- result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID)
-
- assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
-
-
-@pytest.mark.asyncio
-async def test_trigger_dataset_refresh_failure(powerbi_hook):
- """Test failure to trigger dataset refresh due to AirflowException."""
- with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
- mock_run.side_effect = AirflowException("Test exception")
-
- with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."):
- await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID)
-
-
-@pytest.mark.asyncio
-async def test_cancel_dataset_refresh(powerbi_hook):
- dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c"
-
- with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run:
- await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id)
-
- mock_run.assert_called_once_with(
- url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}",
- response_type=None,
- path_parameters={
- "group_id": GROUP_ID,
- "dataset_id": DATASET_ID,
- "dataset_refresh_id": dataset_refresh_id,
- },
- method="DELETE",
- )
diff --git a/tests/providers/microsoft/azure/operators/test_powerbi.py b/tests/providers/microsoft/azure/operators/test_powerbi.py
index 2ee5ee723d7a7..35bb76f782ce3 100644
--- a/tests/providers/microsoft/azure/operators/test_powerbi.py
+++ b/tests/providers/microsoft/azure/operators/test_powerbi.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+from unittest import mock
from unittest.mock import MagicMock
import pytest
@@ -25,11 +26,12 @@
from airflow.providers.microsoft.azure.hooks.powerbi import (
PowerBIDatasetRefreshFields,
PowerBIDatasetRefreshStatus,
- PowerBIHook,
)
from airflow.providers.microsoft.azure.operators.powerbi import PowerBIDatasetRefreshOperator
from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger
from airflow.utils import timezone
+from tests.providers.microsoft.azure.base import Base
+from tests.providers.microsoft.conftest import get_airflow_connection, mock_context
DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id"
TASK_ID = "run_powerbi_operator"
@@ -72,86 +74,77 @@
}
-@pytest.fixture
-def mock_powerbi_hook():
- hook = PowerBIHook()
- return hook
-
-
-def test_execute_wait_for_termination_with_Deferrable(mock_powerbi_hook):
- operator = PowerBIDatasetRefreshOperator(
- **CONFIG,
- )
- operator.hook = mock_powerbi_hook
- context = {"ti": MagicMock()}
-
- with pytest.raises(TaskDeferred) as exc:
- operator.execute(context)
-
- assert isinstance(exc.value.trigger, PowerBITrigger)
+class TestPowerBIDatasetRefreshOperator(Base):
+ @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection)
+ def test_execute_wait_for_termination_with_deferrable(self, connection):
+ operator = PowerBIDatasetRefreshOperator(
+ **CONFIG,
+ )
+ context = mock_context(task=operator)
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(context)
-def test_powerbi_operator_async_execute_complete_success():
- """Assert that execute_complete log success message"""
- operator = PowerBIDatasetRefreshOperator(
- **CONFIG,
- )
- context = {"ti": MagicMock()}
- operator.execute_complete(
- context=context,
- event=SUCCESS_TRIGGER_EVENT,
- )
- assert context["ti"].xcom_push.call_count == 2
+ assert isinstance(exc.value.trigger, PowerBITrigger)
+ def test_powerbi_operator_async_execute_complete_success(self):
+ """Assert that execute_complete log success message"""
+ operator = PowerBIDatasetRefreshOperator(
+ **CONFIG,
+ )
+ context = {"ti": MagicMock()}
+ operator.execute_complete(
+ context=context,
+ event=SUCCESS_TRIGGER_EVENT,
+ )
+ assert context["ti"].xcom_push.call_count == 2
-def test_powerbi_operator_async_execute_complete_fail():
- """Assert that execute_complete raise exception on error"""
- operator = PowerBIDatasetRefreshOperator(
- **CONFIG,
- )
- context = {"ti": MagicMock()}
- with pytest.raises(AirflowException):
+ def test_powerbi_operator_async_execute_complete_fail(self):
+ """Assert that execute_complete raise exception on error"""
+ operator = PowerBIDatasetRefreshOperator(
+ **CONFIG,
+ )
+ context = {"ti": MagicMock()}
+ with pytest.raises(AirflowException):
+ operator.execute_complete(
+ context=context,
+ event={"status": "error", "message": "error", "dataset_refresh_id": "1234"},
+ )
+ assert context["ti"].xcom_push.call_count == 0
+
+ def test_execute_complete_no_event(self):
+ """Test execute_complete when event is None or empty."""
+ operator = PowerBIDatasetRefreshOperator(
+ **CONFIG,
+ )
+ context = {"ti": MagicMock()}
operator.execute_complete(
context=context,
- event={"status": "error", "message": "error", "dataset_refresh_id": "1234"},
+ event=None,
+ )
+ assert context["ti"].xcom_push.call_count == 0
+
+ @pytest.mark.db_test
+ def test_powerbi_link(self, create_task_instance_of_operator):
+ """Assert Power BI Extra link matches the expected URL."""
+ ti = create_task_instance_of_operator(
+ PowerBIDatasetRefreshOperator,
+ dag_id="test_powerbi_refresh_op_link",
+ execution_date=DEFAULT_DATE,
+ task_id=TASK_ID,
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ group_id=GROUP_ID,
+ dataset_id=DATASET_ID,
+ check_interval=1,
+ timeout=3,
)
- assert context["ti"].xcom_push.call_count == 0
-
-
-def test_execute_complete_no_event():
- """Test execute_complete when event is None or empty."""
- operator = PowerBIDatasetRefreshOperator(
- **CONFIG,
- )
- context = {"ti": MagicMock()}
- operator.execute_complete(
- context=context,
- event=None,
- )
- assert context["ti"].xcom_push.call_count == 0
-
-
-@pytest.mark.db_test
-def test_powerbilink(create_task_instance_of_operator):
- """Assert Power BI Extra link matches the expected URL."""
- ti = create_task_instance_of_operator(
- PowerBIDatasetRefreshOperator,
- dag_id="test_powerbi_refresh_op_link",
- execution_date=DEFAULT_DATE,
- task_id=TASK_ID,
- conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
- group_id=GROUP_ID,
- dataset_id=DATASET_ID,
- check_interval=1,
- timeout=3,
- )
-
- ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID)
- url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset")
- EXPECTED_ITEM_RUN_OP_EXTRA_LINK = (
- "https://app.powerbi.com" # type: ignore[attr-defined]
- f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined]
- "/details?experience=power-bi"
- )
-
- assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK
+
+ ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID)
+ url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset")
+ EXPECTED_ITEM_RUN_OP_EXTRA_LINK = (
+ "https://app.powerbi.com" # type: ignore[attr-defined]
+ f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined]
+ "/details?experience=power-bi"
+ )
+
+ assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK
diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py
index 5b44a84149501..c3276e258b3da 100644
--- a/tests/providers/microsoft/azure/triggers/test_powerbi.py
+++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py
@@ -19,11 +19,10 @@
import asyncio
from unittest import mock
-from unittest.mock import patch
import pytest
-from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook
+from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus
from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger
from airflow.triggers.base import TriggerEvent
from tests.providers.microsoft.conftest import get_airflow_connection
@@ -54,19 +53,10 @@ def powerbi_trigger():
return trigger
-@pytest.fixture
-def mock_powerbi_hook():
- hook = PowerBIHook()
- return hook
-
-
-def test_powerbi_trigger_serialization():
- """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath."""
-
- with patch(
- "airflow.hooks.base.BaseHook.get_connection",
- side_effect=get_airflow_connection,
- ):
+class TestPowerBITrigger:
+ @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection)
+ def test_powerbi_trigger_serialization(self, connection):
+ """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath."""
powerbi_trigger = PowerBITrigger(
conn_id=POWERBI_CONN_ID,
proxies=None,
@@ -91,167 +81,170 @@ def test_powerbi_trigger_serialization():
"wait_for_termination": True,
}
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_inprogress(
- mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
-):
- """Assert task isn't completed until timeout if dataset refresh is in progress."""
- mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS}
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
- task = asyncio.create_task(powerbi_trigger.run().__anext__())
- await asyncio.sleep(0.5)
-
- # Assert TriggerEvent was not returned
- assert task.done() is False
- asyncio.get_event_loop().stop()
-
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_failed(
- mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
-):
- """Assert event is triggered upon failed dataset refresh."""
- mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED}
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
-
- generator = powerbi_trigger.run()
- actual = await generator.asend(None)
- expected = TriggerEvent(
- {
- "status": "Failed",
- "message": f"The dataset refresh {DATASET_REFRESH_ID} has "
- f"{PowerBIDatasetRefreshStatus.FAILED}.",
- "dataset_refresh_id": DATASET_REFRESH_ID,
- }
- )
- assert expected == actual
-
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_completed(
- mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
-):
- """Assert event is triggered upon successful dataset refresh."""
- mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED}
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
-
- generator = powerbi_trigger.run()
- actual = await generator.asend(None)
- expected = TriggerEvent(
- {
- "status": "Completed",
- "message": f"The dataset refresh {DATASET_REFRESH_ID} has "
- f"{PowerBIDatasetRefreshStatus.COMPLETED}.",
- "dataset_refresh_id": DATASET_REFRESH_ID,
- }
- )
- assert expected == actual
-
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_exception_during_refresh_check_loop(
- mock_trigger_dataset_refresh,
- mock_get_refresh_details_by_refresh_id,
- mock_cancel_dataset_refresh,
- powerbi_trigger,
-):
- """Assert that run catch exception if Power BI API throw exception"""
- mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception")
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
-
- task = [i async for i in powerbi_trigger.run()]
- response = TriggerEvent(
- {
- "status": "error",
- "message": "An error occurred: Test exception",
- "dataset_refresh_id": DATASET_REFRESH_ID,
- }
- )
- assert len(task) == 1
- assert response in task
- mock_cancel_dataset_refresh.assert_called_once()
-
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_exception_during_refresh_cancellation(
- mock_trigger_dataset_refresh,
- mock_get_refresh_details_by_refresh_id,
- mock_cancel_dataset_refresh,
- powerbi_trigger,
-):
- """Assert that run catch exception if Power BI API throw exception"""
- mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception")
- mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh")
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
-
- task = [i async for i in powerbi_trigger.run()]
- response = TriggerEvent(
- {
- "status": "error",
- "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh",
- "dataset_refresh_id": DATASET_REFRESH_ID,
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_inprogress(
+ self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
+ ):
+ """Assert task isn't completed until timeout if dataset refresh is in progress."""
+ mock_get_refresh_details_by_refresh_id.return_value = {
+ "status": PowerBIDatasetRefreshStatus.IN_PROGRESS
}
- )
-
- assert len(task) == 1
- assert response in task
- mock_cancel_dataset_refresh.assert_called_once()
-
-
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_exception_without_refresh_id(
- mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
-):
- """Assert handling of exception when there is no dataset_refresh_id"""
- powerbi_trigger.dataset_refresh_id = None
- mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id")
- mock_trigger_dataset_refresh.return_value = None
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+ task = asyncio.create_task(powerbi_trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # Assert TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_failed(
+ self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
+ ):
+ """Assert event is triggered upon failed dataset refresh."""
+ mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED}
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+
+ generator = powerbi_trigger.run()
+ actual = await generator.asend(None)
+ expected = TriggerEvent(
+ {
+ "status": "Failed",
+ "message": f"The dataset refresh {DATASET_REFRESH_ID} has "
+ f"{PowerBIDatasetRefreshStatus.FAILED}.",
+ "dataset_refresh_id": DATASET_REFRESH_ID,
+ }
+ )
+ assert expected == actual
- task = [i async for i in powerbi_trigger.run()]
- response = TriggerEvent(
- {
- "status": "error",
- "message": "An error occurred: Test exception for no dataset_refresh_id",
- "dataset_refresh_id": None,
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_completed(
+ self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
+ ):
+ """Assert event is triggered upon successful dataset refresh."""
+ mock_get_refresh_details_by_refresh_id.return_value = {
+ "status": PowerBIDatasetRefreshStatus.COMPLETED
}
- )
- assert len(task) == 1
- assert response in task
-
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+
+ generator = powerbi_trigger.run()
+ actual = await generator.asend(None)
+ expected = TriggerEvent(
+ {
+ "status": "Completed",
+ "message": f"The dataset refresh {DATASET_REFRESH_ID} has "
+ f"{PowerBIDatasetRefreshStatus.COMPLETED}.",
+ "dataset_refresh_id": DATASET_REFRESH_ID,
+ }
+ )
+ assert expected == actual
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_exception_during_refresh_check_loop(
+ self,
+ mock_trigger_dataset_refresh,
+ mock_get_refresh_details_by_refresh_id,
+ mock_cancel_dataset_refresh,
+ powerbi_trigger,
+ ):
+ """Assert that run catch exception if Power BI API throw exception"""
+ mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception")
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+
+ task = [i async for i in powerbi_trigger.run()]
+ response = TriggerEvent(
+ {
+ "status": "error",
+ "message": "An error occurred: Test exception",
+ "dataset_refresh_id": DATASET_REFRESH_ID,
+ }
+ )
+ assert len(task) == 1
+ assert response in task
+ mock_cancel_dataset_refresh.assert_called_once()
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_exception_during_refresh_cancellation(
+ self,
+ mock_trigger_dataset_refresh,
+ mock_get_refresh_details_by_refresh_id,
+ mock_cancel_dataset_refresh,
+ powerbi_trigger,
+ ):
+ """Assert that run catch exception if Power BI API throw exception"""
+ mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception")
+ mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh")
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+
+ task = [i async for i in powerbi_trigger.run()]
+ response = TriggerEvent(
+ {
+ "status": "error",
+ "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh",
+ "dataset_refresh_id": DATASET_REFRESH_ID,
+ }
+ )
-@pytest.mark.asyncio
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
-@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
-async def test_powerbi_trigger_run_timeout(
- mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
-):
- """Assert that powerbi run timesout after end_time elapses"""
- mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS}
- mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+ assert len(task) == 1
+ assert response in task
+ mock_cancel_dataset_refresh.assert_called_once()
- generator = powerbi_trigger.run()
- actual = await generator.asend(None)
- expected = TriggerEvent(
- {
- "status": "error",
- "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.",
- "dataset_refresh_id": DATASET_REFRESH_ID,
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_exception_without_refresh_id(
+ self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
+ ):
+ """Assert handling of exception when there is no dataset_refresh_id"""
+ powerbi_trigger.dataset_refresh_id = None
+ mock_get_refresh_details_by_refresh_id.side_effect = Exception(
+ "Test exception for no dataset_refresh_id"
+ )
+ mock_trigger_dataset_refresh.return_value = None
+
+ task = [i async for i in powerbi_trigger.run()]
+ response = TriggerEvent(
+ {
+ "status": "error",
+ "message": "An error occurred: Test exception for no dataset_refresh_id",
+ "dataset_refresh_id": None,
+ }
+ )
+ assert len(task) == 1
+ assert response in task
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
+ @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
+ async def test_powerbi_trigger_run_timeout(
+ self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
+ ):
+ """Assert that powerbi run timesout after end_time elapses"""
+ mock_get_refresh_details_by_refresh_id.return_value = {
+ "status": PowerBIDatasetRefreshStatus.IN_PROGRESS
}
- )
+ mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID
+
+ generator = powerbi_trigger.run()
+ actual = await generator.asend(None)
+ expected = TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.",
+ "dataset_refresh_id": DATASET_REFRESH_ID,
+ }
+ )
- assert expected == actual
+ assert expected == actual
diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py
index 8a258735291f1..de25d24fb05e1 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -32,6 +32,7 @@
from msgraph_core import APIVersion
from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook
from airflow.utils.context import Context
if TYPE_CHECKING:
@@ -183,21 +184,45 @@ def load_file(*args: str, mode="r", encoding="utf-8"):
def get_airflow_connection(
conn_id: str,
+ host: str = "graph.microsoft.com",
login: str = "client_id",
password: str = "client_secret",
tenant_id: str = "tenant-id",
+ azure_tenant_id: str | None = None,
proxies: dict | None = None,
- api_version: APIVersion = APIVersion.v1,
+ scopes: list[str] | None = None,
+ api_version: APIVersion | str | None = APIVersion.v1.value,
+ authority: str | None = None,
+ disable_instance_discovery: bool = False,
):
from airflow.models import Connection
+ extra = {
+ "api_version": api_version,
+ "proxies": proxies or {},
+ "verify": False,
+ "scopes": scopes or [],
+ "authority": authority,
+ "disable_instance_discovery": disable_instance_discovery,
+ }
+
+ if azure_tenant_id:
+ extra["tenantId"] = azure_tenant_id
+ else:
+ extra["tenant_id"] = tenant_id
+
return Connection(
schema="https",
conn_id=conn_id,
conn_type="http",
- host="graph.microsoft.com",
+ host=host,
port=80,
login=login,
password=password,
- extra={"tenant_id": tenant_id, "api_version": api_version.value, "proxies": proxies or {}},
+ extra=extra,
)
+
+
+@pytest.fixture
+def powerbi_hook():
+ return PowerBIHook(**{"conn_id": "powerbi_conn_id", "timeout": 3, "api_version": "v1.0"})
diff --git a/tests/system/providers/microsoft/azure/example_msfabric.py b/tests/system/providers/microsoft/azure/example_msfabric.py
new file mode 100644
index 0000000000000..7d62a49e0bc31
--- /dev/null
+++ b/tests/system/providers/microsoft/azure/example_msfabric.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import models
+from airflow.datasets import Dataset
+from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
+
+DAG_ID = "example_msfabric"
+
+with models.DAG(
+ DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ schedule=None,
+ tags=["example"],
+) as dag:
+ # [START howto_operator_ms_fabric_create_item_schedule]
+ # https://learn.microsoft.com/en-us/rest/api/fabric/core/job-scheduler/create-item-schedule?tabs=HTTP
+ workspaces_task = MSGraphAsyncOperator(
+ task_id="schedule_datapipeline",
+ conn_id="powerbi",
+ method="POST",
+ url="workspaces/{workspaceId}/items/{itemId}/jobs/instances",
+ path_parameters={
+ "workspaceId": "e90b2873-4812-4dfb-9246-593638165644",
+ "itemId": "65448530-e5ec-4aeb-a97e-7cebf5d67c18",
+ },
+ query_parameters={"jobType": "Pipeline"},
+ dag=dag,
+ outlets=[
+ Dataset(
+ "workspaces/e90b2873-4812-4dfb-9246-593638165644/items/65448530-e5ec-4aeb-a97e-7cebf5d67c18/jobs/instances?jobType=Pipeline"
+ )
+ ],
+ )
+ # [END howto_operator_ms_fabric_create_item_schedule]
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)