From 13b878921e146b20c525d24c5d7bf6d36729e59a Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 30 Jul 2024 14:48:45 +0200 Subject: [PATCH 01/20] refactor: Allow custom api versions to be passed --- .../microsoft/azure/hooks/msgraph.py | 87 +++++++++++++------ .../microsoft/azure/operators/msgraph.py | 2 +- .../microsoft/azure/sensors/msgraph.py | 2 +- .../microsoft/azure/triggers/msgraph.py | 5 +- .../microsoft/azure/hooks/test_msgraph.py | 56 ++++++++++-- tests/providers/microsoft/conftest.py | 2 +- 6 files changed, 111 insertions(+), 43 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 56abfa155da7c..95ccc72b1cb3a 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -86,6 +86,39 @@ async def handle_response_async( return value +class DefaultResponseHandler(ResponseHandler): + """DefaultResponseHandler returns JSON payload or content in bytes or response headers.""" + + @staticmethod + def get_value(response: NativeResponseType) -> Any: + with suppress(JSONDecodeError): + return response.json() + content = response.content + if not content: + return {key: value for key, value in response.headers.items()} + return content + + async def handle_response_async( + self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None + ) -> Any: + """ + Invoke this callback method when a response is received. + + param response: The type of the native response object. + param error_map: The error dict to use in case of a failed request. + """ + value = self.get_value(response) + if response.status_code not in {200, 201, 202, 204, 302}: + message = value or response.reason_phrase + status_code = HTTPStatus(response.status_code) + if status_code == HTTPStatus.BAD_REQUEST: + raise AirflowBadRequest(message) + elif status_code == HTTPStatus.NOT_FOUND: + raise AirflowNotFoundException(message) + raise AirflowException(message) + return value + + class KiotaRequestAdapterHook(BaseHook): """ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter. @@ -96,6 +129,8 @@ class KiotaRequestAdapterHook(BaseHook): :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None). When no timeout is specified or set to None then no HTTP timeout is applied on each request. :param proxies: A Dict defining the HTTP proxies to be used (default is None). + :param host: The host to be used (default is "https://graph.microsoft.com"). + :param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as "v1.0" or "beta". @@ -110,12 +145,16 @@ def __init__( conn_id: str = default_conn_name, timeout: float | None = None, proxies: dict | None = None, + host: str = NationalClouds.Global.value, + scopes: list[str] | None = None, api_version: APIVersion | str | None = None, ): super().__init__() self.conn_id = conn_id self.timeout = timeout self.proxies = proxies + self.host = host + self.scopes = scopes or ["https://graph.microsoft.com/.default"] self._api_version = self.resolve_api_version_from_value(api_version) @property @@ -124,28 +163,20 @@ def api_version(self) -> APIVersion: return self._api_version @staticmethod - def resolve_api_version_from_value( - api_version: APIVersion | str, default: APIVersion | None = None - ) -> APIVersion: + def resolve_api_version_from_value(api_version: APIVersion | str, default: str | None = None) -> str: if isinstance(api_version, APIVersion): - return api_version - return next( - filter(lambda version: version.value == api_version, APIVersion), - default, - ) + return api_version.value + return api_version or default - def get_api_version(self, config: dict) -> APIVersion: - if self._api_version is None: - return self.resolve_api_version_from_value( - api_version=config.get("api_version"), default=APIVersion.v1 - ) - return self._api_version + def get_api_version(self, config: dict) -> str: + return self._api_version or self.resolve_api_version_from_value( + config.get("api_version"), APIVersion.v1.value + ) - @staticmethod - def get_host(connection: Connection) -> str: + def get_host(self, connection: Connection) -> str: if connection.schema and connection.host: return f"{connection.schema}://{connection.host}" - return NationalClouds.Global.value + return self.host @staticmethod def format_no_proxy_url(url: str) -> str: @@ -166,15 +197,15 @@ def to_httpx_proxies(cls, proxies: dict) -> dict: return proxies def to_msal_proxies(self, authority: str | None, proxies: dict): - self.log.info("authority: %s", authority) + self.log.debug("authority: %s", authority) if authority: no_proxies = proxies.get("no") - self.log.info("no_proxies: %s", no_proxies) + self.log.debug("no_proxies: %s", no_proxies) if no_proxies: for url in no_proxies.split(","): self.log.info("url: %s", url) domain_name = urlparse(url).path.replace("*", "") - self.log.info("domain_name: %s", domain_name) + self.log.debug("domain_name: %s", domain_name) if authority.endswith(domain_name): return None return proxies @@ -193,12 +224,12 @@ def get_conn(self) -> RequestAdapter: tenant_id = config.get("tenant_id") api_version = self.get_api_version(config) host = self.get_host(connection) - base_url = config.get("base_url", urljoin(host, api_version.value)) + base_url = config.get("base_url", urljoin(host, api_version)) authority = config.get("authority") proxies = self.proxies or config.get("proxies", {}) msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies) httpx_proxies = self.to_httpx_proxies(proxies=proxies) - scopes = config.get("scopes", ["https://graph.microsoft.com/.default"]) + scopes = config.get("scopes", self.scopes) verify = config.get("verify", True) trust_env = config.get("trust_env", False) disable_instance_discovery = config.get("disable_instance_discovery", False) @@ -206,7 +237,7 @@ def get_conn(self) -> RequestAdapter: self.log.info( "Creating Microsoft Graph SDK client %s for conn_id: %s", - api_version.value, + api_version, self.conn_id, ) self.log.info("Host: %s", host) @@ -214,7 +245,7 @@ def get_conn(self) -> RequestAdapter: self.log.info("Tenant id: %s", tenant_id) self.log.info("Client id: %s", client_id) self.log.info("Client secret: %s", client_secret) - self.log.info("API version: %s", api_version.value) + self.log.info("API version: %s", api_version) self.log.info("Scope: %s", scopes) self.log.info("Verify: %s", verify) self.log.info("Timeout: %s", self.timeout) @@ -235,17 +266,17 @@ def get_conn(self) -> RequestAdapter: connection_verify=verify, ) http_client = GraphClientFactory.create_with_default_middleware( - api_version=api_version, + api_version=api_version, # type: ignore client=httpx.AsyncClient( proxies=httpx_proxies, timeout=Timeout(timeout=self.timeout), verify=verify, trust_env=trust_env, ), - host=host, + host=host, # type: ignore ) auth_provider = AzureIdentityAuthenticationProvider( - credentials=credentials, + credentials=credentials, # type: ignore scopes=scopes, allowed_hosts=allowed_hosts, ) @@ -292,7 +323,7 @@ async def run( error_map=self.error_mapping(), ) - self.log.info("response: %s", response) + self.log.debug("response: %s", response) return response diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index cd387954737ab..74409f3600a1e 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -99,7 +99,7 @@ def __init__( key: str = XCOM_RETURN_KEY, timeout: float | None = None, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None, result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py index 3e1b10cbeb1e6..6736ea59c918d 100644 --- a/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -82,7 +82,7 @@ def __init__( data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded", result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index 4b9ccb7a71716..32b4e2f142301 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -122,7 +122,7 @@ def __init__( conn_id: str = KiotaRequestAdapterHook.default_conn_name, timeout: float | None = None, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, ): super().__init__() @@ -152,14 +152,13 @@ def resolve_type(cls, value: str | type, default) -> type: def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the HttpTrigger arguments and classpath.""" - api_version = self.api_version.value if self.api_version else None return ( f"{self.__class__.__module__}.{self.__class__.__name__}", { "conn_id": self.conn_id, "timeout": self.timeout, "proxies": self.proxies, - "api_version": api_version, + "api_version": self.api_version.value, "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", "url": self.url, "path_parameters": self.path_parameters, diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 71d280a1971da..1bf47f05a8816 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -40,6 +40,9 @@ class TestKiotaRequestAdapterHook: + def setup_method(self): + KiotaRequestAdapterHook.cached_request_adapters.clear() + def test_get_conn(self): with patch( "airflow.hooks.base.BaseHook.get_connection", @@ -51,6 +54,21 @@ def test_get_conn(self): assert isinstance(actual, HttpxRequestAdapter) assert actual.base_url == "https://graph.microsoft.com/v1.0" + def test_get_conn_with_custom_base_url(self): + connection = lambda conn_id: get_airflow_connection( + conn_id=conn_id, host="api.fabric.microsoft.com", api_version="v1", + ) + + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + assert isinstance(actual, HttpxRequestAdapter) + assert actual.base_url == "https://api.fabric.microsoft.com/v1" + def test_api_version(self): with patch( "airflow.hooks.base.BaseHook.get_connection", @@ -58,7 +76,7 @@ def test_api_version(self): ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - assert hook.api_version == APIVersion.v1 + assert hook.api_version == APIVersion.v1.value def test_get_api_version_when_empty_config_dict(self): with patch( @@ -68,7 +86,7 @@ def test_get_api_version_when_empty_config_dict(self): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") actual = hook.get_api_version({}) - assert actual == APIVersion.v1 + assert actual == APIVersion.v1.value def test_get_api_version_when_api_version_in_config_dict(self): with patch( @@ -78,19 +96,39 @@ def test_get_api_version_when_api_version_in_config_dict(self): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") actual = hook.get_api_version({"api_version": "beta"}) - assert actual == APIVersion.beta + assert actual == APIVersion.beta.value + + def test_get_api_version_when_custom_api_version_in_config_dict(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version="v1") + actual = hook.get_api_version({}) + + assert actual == "v1" def test_get_host_when_connection_has_scheme_and_host(self): - connection = mock_connection(schema="https", host="graph.microsoft.de") - actual = KiotaRequestAdapterHook.get_host(connection) + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + connection = mock_connection(schema="https", host="graph.microsoft.de") + actual = hook.get_host(connection) - assert actual == NationalClouds.Germany.value + assert actual == NationalClouds.Germany.value def test_get_host_when_connection_has_no_scheme_or_host(self): - connection = mock_connection() - actual = KiotaRequestAdapterHook.get_host(connection) + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + connection = mock_connection() + actual = hook.get_host(connection) - assert actual == NationalClouds.Global.value + assert actual == NationalClouds.Global.value def test_encoded_query_parameters(self): actual = KiotaRequestAdapterHook.encoded_query_parameters( diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index ecd19d8865c68..f7198437589a6 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -178,7 +178,7 @@ def get_airflow_connection( password: str = "client_secret", tenant_id: str = "tenant-id", proxies: dict | None = None, - api_version: APIVersion = APIVersion.v1, + api_version: str = APIVersion.v1.value, ): from airflow.models import Connection From e3becfb429108f63f3f2a41ecbb2fde59ca10d4c Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 13:39:24 +0200 Subject: [PATCH 02/20] refactored: Reformatted TestKiotaRequestAdapterHook --- tests/providers/microsoft/azure/hooks/test_msgraph.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 1bf47f05a8816..1611bb226faca 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -56,7 +56,9 @@ def test_get_conn(self): def test_get_conn_with_custom_base_url(self): connection = lambda conn_id: get_airflow_connection( - conn_id=conn_id, host="api.fabric.microsoft.com", api_version="v1", + conn_id=conn_id, + host="api.fabric.microsoft.com", + api_version="v1", ) with patch( @@ -110,7 +112,7 @@ def test_get_api_version_when_custom_api_version_in_config_dict(self): def test_get_host_when_connection_has_scheme_and_host(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") From 8f46b85392383bdcbfadd3a587b333f3aa0c224c Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 13:41:11 +0200 Subject: [PATCH 03/20] refactored: Changed type for api_version in get_airflow_connection --- tests/providers/microsoft/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index f7198437589a6..a51b9bda05aac 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -178,7 +178,7 @@ def get_airflow_connection( password: str = "client_secret", tenant_id: str = "tenant-id", proxies: dict | None = None, - api_version: str = APIVersion.v1.value, + api_version: APIVersion | str | None = APIVersion.v1.value, ): from airflow.models import Connection From 925dff71b9c6e370ca0a45e78a3e23d7f54796b0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 14:06:50 +0200 Subject: [PATCH 04/20] refactored: Reformatted TestKiotaRequestAdapterHook --- tests/providers/microsoft/azure/hooks/test_msgraph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 1611bb226faca..45b33f528e438 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -112,7 +112,7 @@ def test_get_api_version_when_custom_api_version_in_config_dict(self): def test_get_host_when_connection_has_scheme_and_host(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", + "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -123,8 +123,8 @@ def test_get_host_when_connection_has_scheme_and_host(self): def test_get_host_when_connection_has_no_scheme_or_host(self): with patch( - "airflow.hooks.base.BaseHook.get_connection", - side_effect=get_airflow_connection, + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") connection = mock_connection() From 28847f5859daade4a6329f4b7271fae08a5bda4c Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 14:09:50 +0200 Subject: [PATCH 05/20] refactored: No need to call value on api_version --- tests/providers/microsoft/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index a51b9bda05aac..72cb7a3e86494 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -190,5 +190,5 @@ def get_airflow_connection( port=80, login=login, password=password, - extra={"tenant_id": tenant_id, "api_version": api_version.value, "proxies": proxies or {}}, + extra={"tenant_id": tenant_id, "api_version": api_version, "proxies": proxies or {}}, ) From 62d058183a199f95c1001eafa2cf2f38b2682202 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 14:11:19 +0200 Subject: [PATCH 06/20] refactored: Removed duplicate DefaultResponseHandler --- .../microsoft/azure/hooks/msgraph.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 95ccc72b1cb3a..c911927c132df 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -86,39 +86,6 @@ async def handle_response_async( return value -class DefaultResponseHandler(ResponseHandler): - """DefaultResponseHandler returns JSON payload or content in bytes or response headers.""" - - @staticmethod - def get_value(response: NativeResponseType) -> Any: - with suppress(JSONDecodeError): - return response.json() - content = response.content - if not content: - return {key: value for key, value in response.headers.items()} - return content - - async def handle_response_async( - self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None - ) -> Any: - """ - Invoke this callback method when a response is received. - - param response: The type of the native response object. - param error_map: The error dict to use in case of a failed request. - """ - value = self.get_value(response) - if response.status_code not in {200, 201, 202, 204, 302}: - message = value or response.reason_phrase - status_code = HTTPStatus(response.status_code) - if status_code == HTTPStatus.BAD_REQUEST: - raise AirflowBadRequest(message) - elif status_code == HTTPStatus.NOT_FOUND: - raise AirflowNotFoundException(message) - raise AirflowException(message) - return value - - class KiotaRequestAdapterHook(BaseHook): """ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter. From 1c18572ce9f5cb2baa9b8d9350da4678dd63c1fd Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 14:40:25 +0200 Subject: [PATCH 07/20] refactored: Changed return type of resolve_api_version_from_value --- airflow/providers/microsoft/azure/hooks/msgraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index c911927c132df..f5469a4e7905b 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -130,7 +130,7 @@ def api_version(self) -> APIVersion: return self._api_version @staticmethod - def resolve_api_version_from_value(api_version: APIVersion | str, default: str | None = None) -> str: + def resolve_api_version_from_value(api_version: APIVersion | str, default: str | None = None) -> str | None: if isinstance(api_version, APIVersion): return api_version.value return api_version or default From fa6c5af38ac8e8bfa235aadaa0bca1b8071bfa17 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 15:08:04 +0200 Subject: [PATCH 08/20] refactored: Reformatted resolve_api_version_from_value method --- airflow/providers/microsoft/azure/hooks/msgraph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index f5469a4e7905b..87407673f04c2 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -130,7 +130,9 @@ def api_version(self) -> APIVersion: return self._api_version @staticmethod - def resolve_api_version_from_value(api_version: APIVersion | str, default: str | None = None) -> str | None: + def resolve_api_version_from_value( + api_version: APIVersion | str, default: str | None = None + ) -> str | None: if isinstance(api_version, APIVersion): return api_version.value return api_version or default From d16a100015ac69b069fbd78ce8838eb8315273e9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 15:12:19 +0200 Subject: [PATCH 09/20] refactored: Try ignore type error in get_api_version method --- airflow/providers/microsoft/azure/hooks/msgraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 87407673f04c2..95ad8f6c9ae12 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -140,7 +140,7 @@ def resolve_api_version_from_value( def get_api_version(self, config: dict) -> str: return self._api_version or self.resolve_api_version_from_value( config.get("api_version"), APIVersion.v1.value - ) + ) # type: ignore def get_host(self, connection: Connection) -> str: if connection.schema and connection.host: From eaeab3a2e8cc95d87dd06f70f1edeebe3d4c061e Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 15:24:10 +0200 Subject: [PATCH 10/20] docs: Added example on how to use the MSGraphAsyncOperator to trigger pipelines in MS Fabric --- .../operators/msgraph.rst | 9 +++ .../microsoft/azure/example_msfabric.py | 65 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 tests/system/providers/microsoft/azure/example_msfabric.py diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst index 342bf542762ac..56a9259f93143 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst @@ -72,6 +72,14 @@ Below is an example of using this operator to refresh PowerBI dataset. :start-after: [START howto_operator_powerbi_refresh_dataset] :end-before: [END howto_operator_powerbi_refresh_dataset] +Below is an example of using this operator to create an item schedule in Fabric. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msfabric.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_ms_fabric_create_item_schedule] + :end-before: [END howto_operator_ms_fabric_create_item_schedule] + Reference --------- @@ -80,3 +88,4 @@ For further information, look at: * `Use the Microsoft Graph API `__ * `Using the Power BI REST APIs `__ +* `Using the Fabric REST APIs `__ diff --git a/tests/system/providers/microsoft/azure/example_msfabric.py b/tests/system/providers/microsoft/azure/example_msfabric.py new file mode 100644 index 0000000000000..ec05d0b9559ed --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_msfabric.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.datasets import Dataset + +from airflow import models +from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator +from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor + +DAG_ID = "example_msfabric" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_ms_fabric_create_item_schedule] + # https://learn.microsoft.com/en-us/rest/api/fabric/core/job-scheduler/create-item-schedule?tabs=HTTP + workspaces_task = MSGraphAsyncOperator( + task_id="schedule_datapipeline", + conn_id="powerbi", + method="POST", + url="workspaces/{workspaceId}/items/{itemId}/jobs/instances", + path_parameters={ + "workspaceId": "e90b2873-4812-4dfb-9246-593638165644", + "itemId": "65448530-e5ec-4aeb-a97e-7cebf5d67c18", + }, + query_parameters={"jobType": "Pipeline"}, + dag=dag, + outlets=[ + Dataset( + "workspaces/e90b2873-4812-4dfb-9246-593638165644/items/65448530-e5ec-4aeb-a97e-7cebf5d67c18/jobs/instances?jobType=Pipeline" + ) + ], + ) + # [END howto_operator_ms_fabric_create_item_schedule] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From 33534675af773ee0fdd2e91005365a4e48785e88 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Aug 2024 20:19:20 +0200 Subject: [PATCH 11/20] refactor: Removed unused imports --- tests/system/providers/microsoft/azure/example_msfabric.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/system/providers/microsoft/azure/example_msfabric.py b/tests/system/providers/microsoft/azure/example_msfabric.py index ec05d0b9559ed..7d62a49e0bc31 100644 --- a/tests/system/providers/microsoft/azure/example_msfabric.py +++ b/tests/system/providers/microsoft/azure/example_msfabric.py @@ -18,11 +18,9 @@ from datetime import datetime -from airflow.datasets import Dataset - from airflow import models +from airflow.datasets import Dataset from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator -from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor DAG_ID = "example_msfabric" From 28b6b0c20acc3f67ac283ea749f5352f85489b21 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 13 Aug 2024 14:43:16 +0200 Subject: [PATCH 12/20] fix: Fixed serialization of api_version in MSGraphTrigger --- airflow/providers/microsoft/azure/triggers/msgraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index 32b4e2f142301..b7ad5d646bb08 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -158,7 +158,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "conn_id": self.conn_id, "timeout": self.timeout, "proxies": self.proxies, - "api_version": self.api_version.value, + "api_version": self.api_version, "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", "url": self.url, "path_parameters": self.path_parameters, From eeb9ce8117d8179d05a76bb4603cbb7ee2c5958f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 13 Aug 2024 14:43:39 +0200 Subject: [PATCH 13/20] refactor: Refactored get_airflow_connection in conftest --- tests/providers/microsoft/conftest.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index cf2ccfc5bb2c0..0f2c7a997ec0a 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -194,11 +194,15 @@ def load_file(*args: str, mode="r", encoding="utf-8"): def get_airflow_connection( conn_id: str, + host: str = "graph.microsoft.com", login: str = "client_id", password: str = "client_secret", tenant_id: str = "tenant-id", proxies: dict | None = None, + scopes: list[str] | None = None, api_version: APIVersion | str | None = APIVersion.v1.value, + authority: str | None = None, + disable_instance_discovery: bool = False, ): from airflow.models import Connection @@ -206,9 +210,17 @@ def get_airflow_connection( schema="https", conn_id=conn_id, conn_type="http", - host="graph.microsoft.com", + host=host, port=80, login=login, password=password, - extra={"tenant_id": tenant_id, "api_version": api_version, "proxies": proxies or {}}, + extra={ + "tenant_id": tenant_id, + "api_version": api_version, + "proxies": proxies or {}, + "verify": False, + "scopes": scopes or [], + "authority": authority, + "disable_instance_discovery": disable_instance_discovery, + }, ) From 45a5b7335e29031f09d95effdf8886dda32501c9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 14 Aug 2024 15:53:46 +0200 Subject: [PATCH 14/20] refactor: Updated api version types as well in PowerBI --- airflow/providers/microsoft/azure/hooks/msgraph.py | 2 +- .../providers/microsoft/azure/operators/powerbi.py | 12 +++++++++++- .../providers/microsoft/azure/triggers/msgraph.py | 2 +- .../providers/microsoft/azure/triggers/powerbi.py | 7 +++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 95ad8f6c9ae12..5792c8528c0b8 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -125,7 +125,7 @@ def __init__( self._api_version = self.resolve_api_version_from_value(api_version) @property - def api_version(self) -> APIVersion: + def api_version(self) -> str: self.get_conn() # Make sure config has been loaded through get_conn to have correct api version! return self._api_version diff --git a/airflow/providers/microsoft/azure/operators/powerbi.py b/airflow/providers/microsoft/azure/operators/powerbi.py index e54ad250bde74..a1d8dd0a40b8e 100644 --- a/airflow/providers/microsoft/azure/operators/powerbi.py +++ b/airflow/providers/microsoft/azure/operators/powerbi.py @@ -76,7 +76,7 @@ def __init__( conn_id: str = PowerBIHook.default_conn_name, timeout: float = 60 * 60 * 24 * 7, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, check_interval: int = 60, **kwargs, ) -> None: @@ -89,6 +89,14 @@ def __init__( self.timeout = timeout self.check_interval = check_interval + @property + def proxies(self) -> dict | None: + return self.hook.proxies + + @property + def api_version(self) -> str: + return self.hook.api_version + def execute(self, context: Context): """Refresh the Power BI Dataset.""" if self.wait_for_termination: @@ -98,6 +106,8 @@ def execute(self, context: Context): group_id=self.group_id, dataset_id=self.dataset_id, timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, check_interval=self.check_interval, wait_for_termination=self.wait_for_termination, ), diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index b7ad5d646bb08..0015964be86dd 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -187,7 +187,7 @@ def proxies(self) -> dict | None: return self.hook.proxies @property - def api_version(self) -> APIVersion: + def api_version(self) -> APIVersion | str: return self.hook.api_version async def run(self) -> AsyncIterator[TriggerEvent]: diff --git a/airflow/providers/microsoft/azure/triggers/powerbi.py b/airflow/providers/microsoft/azure/triggers/powerbi.py index d25802b84fb74..a74898f55f28a 100644 --- a/airflow/providers/microsoft/azure/triggers/powerbi.py +++ b/airflow/providers/microsoft/azure/triggers/powerbi.py @@ -58,7 +58,7 @@ def __init__( group_id: str, timeout: float = 60 * 60 * 24 * 7, proxies: dict | None = None, - api_version: APIVersion | None = None, + api_version: APIVersion | str | None = None, check_interval: int = 60, wait_for_termination: bool = True, ): @@ -72,13 +72,12 @@ def __init__( def serialize(self): """Serialize the trigger instance.""" - api_version = self.api_version.value if self.api_version else None return ( "airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger", { "conn_id": self.conn_id, "proxies": self.proxies, - "api_version": api_version, + "api_version": self.api_version, "dataset_id": self.dataset_id, "group_id": self.group_id, "timeout": self.timeout, @@ -96,7 +95,7 @@ def proxies(self) -> dict | None: return self.hook.proxies @property - def api_version(self) -> APIVersion: + def api_version(self) -> APIVersion | str: return self.hook.api_version async def run(self) -> AsyncIterator[TriggerEvent]: From a562908586317e7ee7d0f0805f72292518e18736 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 14 Aug 2024 16:51:27 +0200 Subject: [PATCH 15/20] refactor: Api version property could return None --- airflow/providers/microsoft/azure/hooks/msgraph.py | 2 +- airflow/providers/microsoft/azure/operators/powerbi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 5792c8528c0b8..6abc2ae004a9a 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -125,7 +125,7 @@ def __init__( self._api_version = self.resolve_api_version_from_value(api_version) @property - def api_version(self) -> str: + def api_version(self) -> str | None: self.get_conn() # Make sure config has been loaded through get_conn to have correct api version! return self._api_version diff --git a/airflow/providers/microsoft/azure/operators/powerbi.py b/airflow/providers/microsoft/azure/operators/powerbi.py index a1d8dd0a40b8e..fc812e852d90b 100644 --- a/airflow/providers/microsoft/azure/operators/powerbi.py +++ b/airflow/providers/microsoft/azure/operators/powerbi.py @@ -94,7 +94,7 @@ def proxies(self) -> dict | None: return self.hook.proxies @property - def api_version(self) -> str: + def api_version(self) -> str | None: return self.hook.api_version def execute(self, context: Context): From 6577f1a8354c71a3a92b46784b24aab3f168f0d6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Sep 2024 08:34:35 +0200 Subject: [PATCH 16/20] refactor: Refactored and fixed PowerBI related tests --- .../microsoft/azure/hooks/test_powerbi.py | 320 ++++++++-------- .../microsoft/azure/operators/test_powerbi.py | 147 ++++---- .../microsoft/azure/triggers/test_powerbi.py | 344 +++++++++--------- tests/providers/microsoft/conftest.py | 6 + 4 files changed, 397 insertions(+), 420 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_powerbi.py b/tests/providers/microsoft/azure/hooks/test_powerbi.py index a3a521b45e820..f22116f6efb03 100644 --- a/tests/providers/microsoft/azure/hooks/test_powerbi.py +++ b/tests/providers/microsoft/azure/hooks/test_powerbi.py @@ -54,176 +54,162 @@ GROUP_ID = "group_id" DATASET_ID = "dataset_id" -CONFIG = {"conn_id": DEFAULT_CONNECTION_CLIENT_SECRET, "timeout": 3, "api_version": "v1.0"} - -@pytest.fixture -def powerbi_hook(): - return PowerBIHook(**CONFIG) - - -@pytest.mark.asyncio -async def test_get_refresh_history(powerbi_hook): - response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]} - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.return_value = response_data - result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) - - expected = [{"request_id": "1234", "status": "Completed", "error": ""}] - assert result == expected - - -@pytest.mark.asyncio -async def test_get_refresh_history_airflow_exception(powerbi_hook): - """Test handling of AirflowException in get_refresh_history.""" - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.side_effect = AirflowException("Test exception") - - with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"): - await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) - - -@pytest.mark.parametrize( - "input_data, expected_output", - [ - ( - {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234", - PowerBIDatasetRefreshFields.STATUS.value: "Completed", - PowerBIDatasetRefreshFields.ERROR.value: "", - }, - ), - ( - {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678", - PowerBIDatasetRefreshFields.STATUS.value: "In Progress", - PowerBIDatasetRefreshFields.ERROR.value: "Some error", - }, - ), - ( - {"requestId": None, "status": None, "serviceExceptionJson": None}, - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", - PowerBIDatasetRefreshFields.STATUS.value: "None", - PowerBIDatasetRefreshFields.ERROR.value: "None", - }, - ), - ( - {}, # Empty input dictionary - { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", - PowerBIDatasetRefreshFields.STATUS.value: "None", - PowerBIDatasetRefreshFields.ERROR.value: "None", - }, - ), - ], -) -def test_raw_to_refresh_details(input_data, expected_output): - """Test raw_to_refresh_details method.""" - result = PowerBIHook.raw_to_refresh_details(input_data) - assert result == expected_output - - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id(powerbi_hook): - # Mock the get_refresh_history method to return a list of refresh histories - refresh_histories = FORMATTED_RESPONSE - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories) - - # Call the function with a valid request ID - refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - result = await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id +class TestPowerBIHook: + @pytest.mark.asyncio + async def test_get_refresh_history(self, powerbi_hook): + response_data = {"value": [{"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}]} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + expected = [{"request_id": "1234", "status": "Completed", "error": ""}] + assert result == expected + + @pytest.mark.asyncio + async def test_get_refresh_history_airflow_exception(self, powerbi_hook): + """Test handling of AirflowException in get_refresh_history.""" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to retrieve refresh history"): + await powerbi_hook.get_refresh_history(DATASET_ID, GROUP_ID) + + @pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + {"requestId": "1234", "status": "Completed", "serviceExceptionJson": ""}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "1234", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "", + }, + ), + ( + {"requestId": "5678", "status": "Unknown", "serviceExceptionJson": "Some error"}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5678", + PowerBIDatasetRefreshFields.STATUS.value: "In Progress", + PowerBIDatasetRefreshFields.ERROR.value: "Some error", + }, + ), + ( + {"requestId": None, "status": None, "serviceExceptionJson": None}, + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ( + {}, # Empty input dictionary + { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "None", + PowerBIDatasetRefreshFields.STATUS.value: "None", + PowerBIDatasetRefreshFields.ERROR.value: "None", + }, + ), + ], ) - - # Assert that the correct refresh details are returned - assert result == { - PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", - PowerBIDatasetRefreshFields.STATUS.value: "Completed", - PowerBIDatasetRefreshFields.ERROR.value: "None", - } - - # Call the function with an invalid request ID - invalid_request_id = "invalid_request_id" - with pytest.raises(PowerBIDatasetRefreshException): - await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id - ) - - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id_empty_history(powerbi_hook): - """Test exception when refresh history is empty.""" - # Mock the get_refresh_history method to return an empty list - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[]) - - # Call the function with a request ID - refresh_id = "any_request_id" - with pytest.raises( - PowerBIDatasetRefreshException, - match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}", - ): - await powerbi_hook.get_refresh_details_by_refresh_id( + def test_raw_to_refresh_details(self, input_data, expected_output): + """Test raw_to_refresh_details method.""" + result = PowerBIHook.raw_to_refresh_details(input_data) + assert result == expected_output + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id(self, powerbi_hook): + # Mock the get_refresh_history method to return a list of refresh histories + refresh_histories = FORMATTED_RESPONSE + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=refresh_histories) + + # Call the function with a valid request ID + refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + result = await powerbi_hook.get_refresh_details_by_refresh_id( dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id ) - -@pytest.mark.asyncio -async def test_get_refresh_details_by_refresh_id_not_found(powerbi_hook): - """Test exception when the refresh ID is not found in the refresh history.""" - # Mock the get_refresh_history method to return a list of refresh histories without the specified ID - powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE) - - # Call the function with an invalid request ID - invalid_request_id = "invalid_request_id" - with pytest.raises( - PowerBIDatasetRefreshException, - match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}", - ): - await powerbi_hook.get_refresh_details_by_refresh_id( - dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + # Assert that the correct refresh details are returned + assert result == { + PowerBIDatasetRefreshFields.REQUEST_ID.value: "5e2d9921-e91b-491f-b7e1-e7d8db49194c", + PowerBIDatasetRefreshFields.STATUS.value: "Completed", + PowerBIDatasetRefreshFields.ERROR.value: "None", + } + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises(PowerBIDatasetRefreshException): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id_empty_history(self, powerbi_hook): + """Test exception when refresh history is empty.""" + # Mock the get_refresh_history method to return an empty list + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=[]) + + # Call the function with a request ID + refresh_id = "any_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {refresh_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=refresh_id + ) + + @pytest.mark.asyncio + async def test_get_refresh_details_by_refresh_id_not_found(self, powerbi_hook): + """Test exception when the refresh ID is not found in the refresh history.""" + # Mock the get_refresh_history method to return a list of refresh histories without the specified ID + powerbi_hook.get_refresh_history = mock.AsyncMock(return_value=FORMATTED_RESPONSE) + + # Call the function with an invalid request ID + invalid_request_id = "invalid_request_id" + with pytest.raises( + PowerBIDatasetRefreshException, + match=f"Unable to fetch the details of dataset refresh with Request Id: {invalid_request_id}", + ): + await powerbi_hook.get_refresh_details_by_refresh_id( + dataset_id=DATASET_ID, group_id=GROUP_ID, refresh_id=invalid_request_id + ) + + @pytest.mark.asyncio + async def test_trigger_dataset_refresh_success(self, powerbi_hook): + response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"} + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.return_value = response_data + result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + @pytest.mark.asyncio + async def test_trigger_dataset_refresh_failure(self, powerbi_hook): + """Test failure to trigger dataset refresh due to AirflowException.""" + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + mock_run.side_effect = AirflowException("Test exception") + + with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."): + await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) + + @pytest.mark.asyncio + async def test_cancel_dataset_refresh(self, powerbi_hook): + dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" + + with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: + await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id) + + mock_run.assert_called_once_with( + url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", + response_type=None, + path_parameters={ + "group_id": GROUP_ID, + "dataset_id": DATASET_ID, + "dataset_refresh_id": dataset_refresh_id, + }, + method="DELETE", ) - - -@pytest.mark.asyncio -async def test_trigger_dataset_refresh_success(powerbi_hook): - response_data = {"requestid": "5e2d9921-e91b-491f-b7e1-e7d8db49194c"} - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.return_value = response_data - result = await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) - - assert result == "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - - -@pytest.mark.asyncio -async def test_trigger_dataset_refresh_failure(powerbi_hook): - """Test failure to trigger dataset refresh due to AirflowException.""" - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - mock_run.side_effect = AirflowException("Test exception") - - with pytest.raises(PowerBIDatasetRefreshException, match="Failed to trigger dataset refresh."): - await powerbi_hook.trigger_dataset_refresh(dataset_id=DATASET_ID, group_id=GROUP_ID) - - -@pytest.mark.asyncio -async def test_cancel_dataset_refresh(powerbi_hook): - dataset_refresh_id = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" - - with mock.patch.object(KiotaRequestAdapterHook, "run", new_callable=mock.AsyncMock) as mock_run: - await powerbi_hook.cancel_dataset_refresh(DATASET_ID, GROUP_ID, dataset_refresh_id) - - mock_run.assert_called_once_with( - url="myorg/groups/{group_id}/datasets/{dataset_id}/refreshes/{dataset_refresh_id}", - response_type=None, - path_parameters={ - "group_id": GROUP_ID, - "dataset_id": DATASET_ID, - "dataset_refresh_id": dataset_refresh_id, - }, - method="DELETE", - ) diff --git a/tests/providers/microsoft/azure/operators/test_powerbi.py b/tests/providers/microsoft/azure/operators/test_powerbi.py index 2ee5ee723d7a7..ad4f3694f7731 100644 --- a/tests/providers/microsoft/azure/operators/test_powerbi.py +++ b/tests/providers/microsoft/azure/operators/test_powerbi.py @@ -17,6 +17,7 @@ from __future__ import annotations +from unittest import mock from unittest.mock import MagicMock import pytest @@ -25,11 +26,12 @@ from airflow.providers.microsoft.azure.hooks.powerbi import ( PowerBIDatasetRefreshFields, PowerBIDatasetRefreshStatus, - PowerBIHook, ) from airflow.providers.microsoft.azure.operators.powerbi import PowerBIDatasetRefreshOperator from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.utils import timezone +from tests.providers.microsoft.azure.base import Base +from tests.providers.microsoft.conftest import mock_context, get_airflow_connection DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operator" @@ -72,86 +74,77 @@ } -@pytest.fixture -def mock_powerbi_hook(): - hook = PowerBIHook() - return hook - - -def test_execute_wait_for_termination_with_Deferrable(mock_powerbi_hook): - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - operator.hook = mock_powerbi_hook - context = {"ti": MagicMock()} - - with pytest.raises(TaskDeferred) as exc: - operator.execute(context) - - assert isinstance(exc.value.trigger, PowerBITrigger) +class TestPowerBIDatasetRefreshOperator(Base): + @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + def test_execute_wait_for_termination_with_deferrable(self, connection): + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = mock_context(task=operator) + with pytest.raises(TaskDeferred) as exc: + operator.execute(context) -def test_powerbi_operator_async_execute_complete_success(): - """Assert that execute_complete log success message""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - operator.execute_complete( - context=context, - event=SUCCESS_TRIGGER_EVENT, - ) - assert context["ti"].xcom_push.call_count == 2 + assert isinstance(exc.value.trigger, PowerBITrigger) + def test_powerbi_operator_async_execute_complete_success(self): + """Assert that execute_complete log success message""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + operator.execute_complete( + context=context, + event=SUCCESS_TRIGGER_EVENT, + ) + assert context["ti"].xcom_push.call_count == 2 -def test_powerbi_operator_async_execute_complete_fail(): - """Assert that execute_complete raise exception on error""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - with pytest.raises(AirflowException): + def test_powerbi_operator_async_execute_complete_fail(self): + """Assert that execute_complete raise exception on error""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + with pytest.raises(AirflowException): + operator.execute_complete( + context=context, + event={"status": "error", "message": "error", "dataset_refresh_id": "1234"}, + ) + assert context["ti"].xcom_push.call_count == 0 + + def test_execute_complete_no_event(self): + """Test execute_complete when event is None or empty.""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} operator.execute_complete( context=context, - event={"status": "error", "message": "error", "dataset_refresh_id": "1234"}, + event=None, + ) + assert context["ti"].xcom_push.call_count == 0 + + @pytest.mark.db_test + def test_powerbi_link(self, create_task_instance_of_operator): + """Assert Power BI Extra link matches the expected URL.""" + ti = create_task_instance_of_operator( + PowerBIDatasetRefreshOperator, + dag_id="test_powerbi_refresh_op_link", + execution_date=DEFAULT_DATE, + task_id=TASK_ID, + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + group_id=GROUP_ID, + dataset_id=DATASET_ID, + check_interval=1, + timeout=3, ) - assert context["ti"].xcom_push.call_count == 0 - - -def test_execute_complete_no_event(): - """Test execute_complete when event is None or empty.""" - operator = PowerBIDatasetRefreshOperator( - **CONFIG, - ) - context = {"ti": MagicMock()} - operator.execute_complete( - context=context, - event=None, - ) - assert context["ti"].xcom_push.call_count == 0 - - -@pytest.mark.db_test -def test_powerbilink(create_task_instance_of_operator): - """Assert Power BI Extra link matches the expected URL.""" - ti = create_task_instance_of_operator( - PowerBIDatasetRefreshOperator, - dag_id="test_powerbi_refresh_op_link", - execution_date=DEFAULT_DATE, - task_id=TASK_ID, - conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, - group_id=GROUP_ID, - dataset_id=DATASET_ID, - check_interval=1, - timeout=3, - ) - - ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) - url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") - EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( - "https://app.powerbi.com" # type: ignore[attr-defined] - f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] - "/details?experience=power-bi" - ) - - assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK + + ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) + url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") + EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( + "https://app.powerbi.com" # type: ignore[attr-defined] + f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] + "/details?experience=power-bi" + ) + + assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py index 5b44a84149501..b291b3e68d983 100644 --- a/tests/providers/microsoft/azure/triggers/test_powerbi.py +++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py @@ -19,11 +19,10 @@ import asyncio from unittest import mock -from unittest.mock import patch import pytest -from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.triggers.base import TriggerEvent from tests.providers.microsoft.conftest import get_airflow_connection @@ -54,19 +53,10 @@ def powerbi_trigger(): return trigger -@pytest.fixture -def mock_powerbi_hook(): - hook = PowerBIHook() - return hook - - -def test_powerbi_trigger_serialization(): - """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" - - with patch( - "airflow.hooks.base.BaseHook.get_connection", - side_effect=get_airflow_connection, - ): +class TestPowerBITrigger: + @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + def test_powerbi_trigger_serialization(self, connection): + """Asserts that the PowerBI Trigger correctly serializes its arguments and classpath.""" powerbi_trigger = PowerBITrigger( conn_id=POWERBI_CONN_ID, proxies=None, @@ -92,166 +82,168 @@ def test_powerbi_trigger_serialization(): } -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_inprogress( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert task isn't completed until timeout if dataset refresh is in progress.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - task = asyncio.create_task(powerbi_trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # Assert TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_failed( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert event is triggered upon failed dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "Failed", - "message": f"The dataset refresh {DATASET_REFRESH_ID} has " - f"{PowerBIDatasetRefreshStatus.FAILED}.", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert expected == actual - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_completed( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert event is triggered upon successful dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "Completed", - "message": f"The dataset refresh {DATASET_REFRESH_ID} has " - f"{PowerBIDatasetRefreshStatus.COMPLETED}.", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert expected == actual - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_during_refresh_check_loop( - mock_trigger_dataset_refresh, - mock_get_refresh_details_by_refresh_id, - mock_cancel_dataset_refresh, - powerbi_trigger, -): - """Assert that run catch exception if Power BI API throw exception""" - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred: Test exception", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert len(task) == 1 - assert response in task - mock_cancel_dataset_refresh.assert_called_once() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_during_refresh_cancellation( - mock_trigger_dataset_refresh, - mock_get_refresh_details_by_refresh_id, - mock_cancel_dataset_refresh, - powerbi_trigger, -): - """Assert that run catch exception if Power BI API throw exception""" - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") - mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh") - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_inprogress( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert task isn't completed until timeout if dataset refresh is in progress.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + task = asyncio.create_task(powerbi_trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # Assert TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_failed( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert event is triggered upon failed dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Failed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.FAILED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual - assert len(task) == 1 - assert response in task - mock_cancel_dataset_refresh.assert_called_once() - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_exception_without_refresh_id( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert handling of exception when there is no dataset_refresh_id""" - powerbi_trigger.dataset_refresh_id = None - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id") - mock_trigger_dataset_refresh.return_value = None - - task = [i async for i in powerbi_trigger.run()] - response = TriggerEvent( - { - "status": "error", - "message": "An error occurred: Test exception for no dataset_refresh_id", - "dataset_refresh_id": None, - } - ) - assert len(task) == 1 - assert response in task - - -@pytest.mark.asyncio -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") -@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") -async def test_powerbi_trigger_run_timeout( - mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger -): - """Assert that powerbi run timesout after end_time elapses""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} - mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID - - generator = powerbi_trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - { - "status": "error", - "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.", - "dataset_refresh_id": DATASET_REFRESH_ID, - } - ) - assert expected == actual + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_completed( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert event is triggered upon successful dataset refresh.""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "Completed", + "message": f"The dataset refresh {DATASET_REFRESH_ID} has " + f"{PowerBIDatasetRefreshStatus.COMPLETED}.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert expected == actual + + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_during_refresh_check_loop( + self, + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, + ): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() + + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_during_refresh_cancellation( + self, + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, + ): + """Assert that run catch exception if Power BI API throw exception""" + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception") + mock_cancel_dataset_refresh.side_effect = Exception("Exception caused by cancel_dataset_refresh") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred while canceling dataset: Exception caused by cancel_dataset_refresh", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + + assert len(task) == 1 + assert response in task + mock_cancel_dataset_refresh.assert_called_once() + + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_exception_without_refresh_id( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert handling of exception when there is no dataset_refresh_id""" + powerbi_trigger.dataset_refresh_id = None + mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id") + mock_trigger_dataset_refresh.return_value = None + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "message": "An error occurred: Test exception for no dataset_refresh_id", + "dataset_refresh_id": None, + } + ) + assert len(task) == 1 + assert response in task + + + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_timeout( + self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger + ): + """Assert that powerbi run timesout after end_time elapses""" + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + generator = powerbi_trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "error", + "message": f"Timeout occurred while waiting for dataset refresh to complete: The dataset refresh {DATASET_REFRESH_ID} has status In Progress.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + + assert expected == actual diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index c4e4e41953dea..a2fe55aa43cd2 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -32,6 +32,7 @@ from msgraph_core import APIVersion from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook from airflow.utils.context import Context if TYPE_CHECKING: @@ -213,3 +214,8 @@ def get_airflow_connection( "disable_instance_discovery": disable_instance_discovery, }, ) + + +@pytest.fixture +def powerbi_hook(): + return PowerBIHook(**{"conn_id": "powerbi_conn_id", "timeout": 3, "api_version": "v1.0"}) From 2255ad552568787dba9ca34bd31dc6771d280816 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Sep 2024 11:23:23 +0200 Subject: [PATCH 17/20] refactor: Added support for tenantId beside tenant_id as extra config for KiotaRequestAdapterHook --- .../microsoft/azure/hooks/msgraph.py | 2 +- .../microsoft/azure/hooks/test_msgraph.py | 31 +++++++++++++++++++ tests/providers/microsoft/conftest.py | 26 ++++++++++------ 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 6abc2ae004a9a..61e555f4caa78 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -190,7 +190,7 @@ def get_conn(self) -> RequestAdapter: client_id = connection.login client_secret = connection.password config = connection.extra_dejson if connection.extra else {} - tenant_id = config.get("tenant_id") + tenant_id = config.get("tenant_id") or config.get("tenantId") api_version = self.get_api_version(config) host = self.get_host(connection) base_url = config.get("base_url", urljoin(host, api_version)) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 45b33f528e438..57c09f9b83546 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -21,6 +21,7 @@ from unittest.mock import patch import pytest +from kiota_abstractions.request_adapter import RequestAdapter from kiota_http.httpx_request_adapter import HttpxRequestAdapter from msgraph_core import APIVersion, NationalClouds @@ -43,6 +44,12 @@ class TestKiotaRequestAdapterHook: def setup_method(self): KiotaRequestAdapterHook.cached_request_adapters.clear() + @staticmethod + def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str): + assert isinstance(request_adapter, HttpxRequestAdapter) + tenant_id = request_adapter._authentication_provider.access_token_provider._credentials._tenant_id + assert tenant_id == expected_tenant_id + def test_get_conn(self): with patch( "airflow.hooks.base.BaseHook.get_connection", @@ -132,6 +139,30 @@ def test_get_host_when_connection_has_no_scheme_or_host(self): assert actual == NationalClouds.Global.value + def test_tenant_id(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + self.assert_tenant_id(actual, "tenant-id") + + def test_azure_tenant_id(self): + airflow_connection = lambda conn_id: get_airflow_connection( + conn_id=conn_id, azure_tenant_id="azure-tenant-id", + ) + + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = hook.get_conn() + + self.assert_tenant_id(actual, "azure-tenant-id") + def test_encoded_query_parameters(self): actual = KiotaRequestAdapterHook.encoded_query_parameters( query_parameters={"$expand": "reports,users,datasets,dataflows,dashboards", "$top": 5000}, diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index a2fe55aa43cd2..4b7a93e6bce79 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -188,14 +188,30 @@ def get_airflow_connection( login: str = "client_id", password: str = "client_secret", tenant_id: str = "tenant-id", + azure_tenant_id: str | None = None, proxies: dict | None = None, scopes: list[str] | None = None, api_version: APIVersion | str | None = APIVersion.v1.value, authority: str | None = None, disable_instance_discovery: bool = False, + ): from airflow.models import Connection + extra = { + "api_version": api_version, + "proxies": proxies or {}, + "verify": False, + "scopes": scopes or [], + "authority": authority, + "disable_instance_discovery": disable_instance_discovery, + } + + if azure_tenant_id: + extra["tenantId"] = azure_tenant_id + else: + extra["tenant_id"] = tenant_id + return Connection( schema="https", conn_id=conn_id, @@ -204,15 +220,7 @@ def get_airflow_connection( port=80, login=login, password=password, - extra={ - "tenant_id": tenant_id, - "api_version": api_version, - "proxies": proxies or {}, - "verify": False, - "scopes": scopes or [], - "authority": authority, - "disable_instance_discovery": disable_instance_discovery, - }, + extra=extra, ) From 876da95fff18c6449c0f3a4220e74993b41fde6e Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Sep 2024 13:40:49 +0200 Subject: [PATCH 18/20] refactor: Reformatted test related files to conform static checks --- .../microsoft/azure/hooks/test_msgraph.py | 3 ++- .../microsoft/azure/operators/test_powerbi.py | 2 +- .../microsoft/azure/triggers/test_powerbi.py | 27 ++++++++++--------- tests/providers/microsoft/conftest.py | 1 - 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 57c09f9b83546..5e49280d75480 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -151,7 +151,8 @@ def test_tenant_id(self): def test_azure_tenant_id(self): airflow_connection = lambda conn_id: get_airflow_connection( - conn_id=conn_id, azure_tenant_id="azure-tenant-id", + conn_id=conn_id, + azure_tenant_id="azure-tenant-id", ) with patch( diff --git a/tests/providers/microsoft/azure/operators/test_powerbi.py b/tests/providers/microsoft/azure/operators/test_powerbi.py index ad4f3694f7731..35bb76f782ce3 100644 --- a/tests/providers/microsoft/azure/operators/test_powerbi.py +++ b/tests/providers/microsoft/azure/operators/test_powerbi.py @@ -31,7 +31,7 @@ from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.utils import timezone from tests.providers.microsoft.azure.base import Base -from tests.providers.microsoft.conftest import mock_context, get_airflow_connection +from tests.providers.microsoft.conftest import get_airflow_connection, mock_context DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operator" diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py index b291b3e68d983..30cb8097625cd 100644 --- a/tests/providers/microsoft/azure/triggers/test_powerbi.py +++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py @@ -81,7 +81,6 @@ def test_powerbi_trigger_serialization(self, connection): "wait_for_termination": True, } - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") @@ -89,7 +88,9 @@ async def test_powerbi_trigger_run_inprogress( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): """Assert task isn't completed until timeout if dataset refresh is in progress.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.IN_PROGRESS + } mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID task = asyncio.create_task(powerbi_trigger.run().__anext__()) await asyncio.sleep(0.5) @@ -98,7 +99,6 @@ async def test_powerbi_trigger_run_inprogress( assert task.done() is False asyncio.get_event_loop().stop() - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") @@ -106,7 +106,9 @@ async def test_powerbi_trigger_run_failed( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): """Assert event is triggered upon failed dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.FAILED + } mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID generator = powerbi_trigger.run() @@ -121,7 +123,6 @@ async def test_powerbi_trigger_run_failed( ) assert expected == actual - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") @@ -129,7 +130,9 @@ async def test_powerbi_trigger_run_completed( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): """Assert event is triggered upon successful dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.COMPLETED} + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.COMPLETED + } mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID generator = powerbi_trigger.run() @@ -144,7 +147,6 @@ async def test_powerbi_trigger_run_completed( ) assert expected == actual - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @@ -172,7 +174,6 @@ async def test_powerbi_trigger_run_exception_during_refresh_check_loop( assert response in task mock_cancel_dataset_refresh.assert_called_once() - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @@ -202,7 +203,6 @@ async def test_powerbi_trigger_run_exception_during_refresh_cancellation( assert response in task mock_cancel_dataset_refresh.assert_called_once() - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") @@ -211,7 +211,9 @@ async def test_powerbi_trigger_run_exception_without_refresh_id( ): """Assert handling of exception when there is no dataset_refresh_id""" powerbi_trigger.dataset_refresh_id = None - mock_get_refresh_details_by_refresh_id.side_effect = Exception("Test exception for no dataset_refresh_id") + mock_get_refresh_details_by_refresh_id.side_effect = Exception( + "Test exception for no dataset_refresh_id" + ) mock_trigger_dataset_refresh.return_value = None task = [i async for i in powerbi_trigger.run()] @@ -225,7 +227,6 @@ async def test_powerbi_trigger_run_exception_without_refresh_id( assert len(task) == 1 assert response in task - @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") @@ -233,7 +234,9 @@ async def test_powerbi_trigger_run_timeout( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): """Assert that powerbi run timesout after end_time elapses""" - mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.IN_PROGRESS} + mock_get_refresh_details_by_refresh_id.return_value = { + "status": PowerBIDatasetRefreshStatus.IN_PROGRESS + } mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID generator = powerbi_trigger.run() diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index 4b7a93e6bce79..de25d24fb05e1 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -194,7 +194,6 @@ def get_airflow_connection( api_version: APIVersion | str | None = APIVersion.v1.value, authority: str | None = None, disable_instance_discovery: bool = False, - ): from airflow.models import Connection From 6f1856d9126963834a9172a162488f067025fb71 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Sep 2024 15:48:51 +0200 Subject: [PATCH 19/20] refactor: Reformatted TestPowerBITrigger --- tests/providers/microsoft/azure/triggers/test_powerbi.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/providers/microsoft/azure/triggers/test_powerbi.py b/tests/providers/microsoft/azure/triggers/test_powerbi.py index 30cb8097625cd..c3276e258b3da 100644 --- a/tests/providers/microsoft/azure/triggers/test_powerbi.py +++ b/tests/providers/microsoft/azure/triggers/test_powerbi.py @@ -106,9 +106,7 @@ async def test_powerbi_trigger_run_failed( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): """Assert event is triggered upon failed dataset refresh.""" - mock_get_refresh_details_by_refresh_id.return_value = { - "status": PowerBIDatasetRefreshStatus.FAILED - } + mock_get_refresh_details_by_refresh_id.return_value = {"status": PowerBIDatasetRefreshStatus.FAILED} mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID generator = powerbi_trigger.run() From 7f345aedb7ac2c45a431fb1c9208b2dd43853794 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Sep 2024 16:03:44 +0200 Subject: [PATCH 20/20] refactor: Moved import of RequestAdapter into type checking block --- tests/providers/microsoft/azure/hooks/test_msgraph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 5e49280d75480..04e85525616bc 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -18,10 +18,10 @@ import asyncio from json import JSONDecodeError +from typing import TYPE_CHECKING from unittest.mock import patch import pytest -from kiota_abstractions.request_adapter import RequestAdapter from kiota_http.httpx_request_adapter import HttpxRequestAdapter from msgraph_core import APIVersion, NationalClouds @@ -39,6 +39,9 @@ mock_response, ) +if TYPE_CHECKING: + from kiota_abstractions.request_adapter import RequestAdapter + class TestKiotaRequestAdapterHook: def setup_method(self):