diff --git a/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py
index d62931a76510d..9af93e212b394 100644
--- a/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py
@@ -23,7 +23,13 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+
+try:
+ from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/airflow/providers/google/cloud/transfers/adls_to_gcs.py b/airflow/providers/google/cloud/transfers/adls_to_gcs.py
index 7abbd9a9c3142..f11b6aa881b80 100644
--- a/airflow/providers/google/cloud/transfers/adls_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/adls_to_gcs.py
@@ -24,8 +24,14 @@
from typing import TYPE_CHECKING, Sequence
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
-from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
-from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator
+
+try:
+ from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+ from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator
+except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
index 8ba6f2d6eb079..1da9e82c09247 100644
--- a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
@@ -22,7 +22,13 @@
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
-from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+
+try:
+ from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
+except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
index 9ba612979164c..cca318001c779 100644
--- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
@@ -24,7 +24,13 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory
-from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
+
+try:
+ from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
+except ModuleNotFoundError as e:
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py
new file mode 100644
index 0000000000000..7fcc328f8670a
--- /dev/null
+++ b/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -0,0 +1,208 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from urllib.parse import urljoin, urlparse
+
+import httpx
+from azure.identity import ClientSecretCredential
+from httpx import Timeout
+from kiota_authentication_azure.azure_identity_authentication_provider import (
+ AzureIdentityAuthenticationProvider,
+)
+from kiota_http.httpx_request_adapter import HttpxRequestAdapter
+from msgraph_core import GraphClientFactory
+from msgraph_core._enums import APIVersion, NationalClouds
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+
+if TYPE_CHECKING:
+ from kiota_abstractions.request_adapter import RequestAdapter
+
+ from airflow.models import Connection
+
+
+class KiotaRequestAdapterHook(BaseHook):
+ """
+ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
+
+ https://github.com/microsoftgraph/msgraph-sdk-python-core
+
+ :param conn_id: The HTTP Connection ID to run the trigger against.
+ :param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
+ When no timeout is specified or set to None then no HTTP timeout is applied on each request.
+ :param proxies: A Dict defining the HTTP proxies to be used (default is None).
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
+ or you can pass a string as "v1.0" or "beta".
+ """
+
+ cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
+ default_conn_name: str = "msgraph_default"
+
+ def __init__(
+ self,
+ conn_id: str = default_conn_name,
+ timeout: float | None = None,
+ proxies: dict | None = None,
+ api_version: APIVersion | str | None = None,
+ ):
+ super().__init__()
+ self.conn_id = conn_id
+ self.timeout = timeout
+ self.proxies = proxies
+ self._api_version = self.resolve_api_version_from_value(api_version)
+
+ @property
+ def api_version(self) -> APIVersion:
+ self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
+ return self._api_version
+
+ @staticmethod
+ def resolve_api_version_from_value(
+ api_version: APIVersion | str, default: APIVersion | None = None
+ ) -> APIVersion:
+ if isinstance(api_version, APIVersion):
+ return api_version
+ return next(
+ filter(lambda version: version.value == api_version, APIVersion),
+ default,
+ )
+
+ def get_api_version(self, config: dict) -> APIVersion:
+ if self._api_version is None:
+ return self.resolve_api_version_from_value(
+ api_version=config.get("api_version"), default=APIVersion.v1
+ )
+ return self._api_version
+
+ @staticmethod
+ def get_host(connection: Connection) -> str:
+ if connection.schema and connection.host:
+ return f"{connection.schema}://{connection.host}"
+ return NationalClouds.Global.value
+
+ @staticmethod
+ def format_no_proxy_url(url: str) -> str:
+ if "://" not in url:
+ url = f"all://{url}"
+ return url
+
+ @classmethod
+ def to_httpx_proxies(cls, proxies: dict) -> dict:
+ proxies = proxies.copy()
+ if proxies.get("http"):
+ proxies["http://"] = proxies.pop("http")
+ if proxies.get("https"):
+ proxies["https://"] = proxies.pop("https")
+ if proxies.get("no"):
+ for url in proxies.pop("no", "").split(","):
+ proxies[cls.format_no_proxy_url(url.strip())] = None
+ return proxies
+
+ @classmethod
+ def to_msal_proxies(cls, authority: str | None, proxies: dict):
+ if authority:
+ no_proxies = proxies.get("no")
+ if no_proxies:
+ for url in no_proxies.split(","):
+ domain_name = urlparse(url).path.replace("*", "")
+ if authority.endswith(domain_name):
+ return None
+ return proxies
+
+ def get_conn(self) -> RequestAdapter:
+ if not self.conn_id:
+ raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")
+
+ api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))
+
+ if not request_adapter:
+ connection = self.get_connection(conn_id=self.conn_id)
+ client_id = connection.login
+ client_secret = connection.password
+ config = connection.extra_dejson if connection.extra else {}
+ tenant_id = config.get("tenant_id")
+ api_version = self.get_api_version(config)
+ host = self.get_host(connection)
+ base_url = config.get("base_url", urljoin(host, api_version.value))
+ authority = config.get("authority")
+ proxies = self.proxies or config.get("proxies", {})
+ msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
+ httpx_proxies = self.to_httpx_proxies(proxies=proxies)
+ scopes = config.get("scopes", ["https://graph.microsoft.com/.default"])
+ verify = config.get("verify", True)
+ trust_env = config.get("trust_env", False)
+ disable_instance_discovery = config.get("disable_instance_discovery", False)
+ allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")
+
+ self.log.info(
+ "Creating Microsoft Graph SDK client %s for conn_id: %s",
+ api_version.value,
+ self.conn_id,
+ )
+ self.log.info("Host: %s", host)
+ self.log.info("Base URL: %s", base_url)
+ self.log.info("Tenant id: %s", tenant_id)
+ self.log.info("Client id: %s", client_id)
+ self.log.info("Client secret: %s", client_secret)
+ self.log.info("API version: %s", api_version.value)
+ self.log.info("Scope: %s", scopes)
+ self.log.info("Verify: %s", verify)
+ self.log.info("Timeout: %s", self.timeout)
+ self.log.info("Trust env: %s", trust_env)
+ self.log.info("Authority: %s", authority)
+ self.log.info("Disable instance discovery: %s", disable_instance_discovery)
+ self.log.info("Allowed hosts: %s", allowed_hosts)
+ self.log.info("Proxies: %s", proxies)
+ self.log.info("MSAL Proxies: %s", msal_proxies)
+ self.log.info("HTTPX Proxies: %s", httpx_proxies)
+ credentials = ClientSecretCredential(
+ tenant_id=tenant_id, # type: ignore
+ client_id=connection.login,
+ client_secret=connection.password,
+ authority=authority,
+ proxies=msal_proxies,
+ disable_instance_discovery=disable_instance_discovery,
+ connection_verify=verify,
+ )
+ http_client = GraphClientFactory.create_with_default_middleware(
+ api_version=api_version,
+ client=httpx.AsyncClient(
+ proxies=httpx_proxies,
+ timeout=Timeout(timeout=self.timeout),
+ verify=verify,
+ trust_env=trust_env,
+ ),
+ host=host,
+ )
+ auth_provider = AzureIdentityAuthenticationProvider(
+ credentials=credentials,
+ scopes=scopes,
+ allowed_hosts=allowed_hosts,
+ )
+ request_adapter = HttpxRequestAdapter(
+ authentication_provider=auth_provider,
+ http_client=http_client,
+ base_url=base_url,
+ )
+ self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
+ self._api_version = api_version
+ return request_adapter
diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py
new file mode 100644
index 0000000000000..6411f9cc4ac2d
--- /dev/null
+++ b/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -0,0 +1,292 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from copy import deepcopy
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Sequence,
+)
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import BaseOperator
+from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
+from airflow.providers.microsoft.azure.triggers.msgraph import (
+ MSGraphTrigger,
+ ResponseSerializer,
+)
+from airflow.utils.xcom import XCOM_RETURN_KEY
+
+if TYPE_CHECKING:
+ from io import BytesIO
+
+ from kiota_abstractions.request_adapter import ResponseType
+ from kiota_abstractions.request_information import QueryParams
+ from kiota_abstractions.response_handler import NativeResponseType
+ from kiota_abstractions.serialization import ParsableFactory
+ from msgraph_core import APIVersion
+
+ from airflow.utils.context import Context
+
+
+class MSGraphAsyncOperator(BaseOperator):
+ """
+ A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API.
+
+ https://learn.microsoft.com/en-us/graph/use-the-api
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:MSGraphAsyncOperator`
+
+ :param url: The url being executed on the Microsoft Graph API (templated).
+ :param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
+ `str`, `int`, `float`, `bool` and `datetime` (default is None).
+ :param response_handler: Function to convert the native HTTPX response returned by the hook (default is
+ lambda response, error_map: response.json()). The default expression will convert the native response
+ to JSON. If response_type parameter is specified, then the response_handler will be ignored.
+ :param method: The HTTP method being used to do the REST call (default is GET).
+ :param conn_id: The HTTP Connection ID to run the operator against (templated).
+ :param key: The key that will be used to store `XCom's` ("return_value" is default).
+ :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None).
+ When no timeout is specified or set to None then there is no HTTP timeout on each request.
+ :param proxies: A dict defining the HTTP proxies to be used (default is None).
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
+ or you can pass a string as `v1.0` or `beta`.
+ :param result_processor: Function to further process the response from MS Graph API
+ (default is lambda: context, response: response). When the response returned by the
+ `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string.
+ :param serializer: Class which handles response serialization (default is ResponseSerializer).
+ Bytes will be base64 encoded into a string, so it can be stored as an XCom.
+ """
+
+ template_fields: Sequence[str] = (
+ "url",
+ "response_type",
+ "path_parameters",
+ "url_template",
+ "query_parameters",
+ "headers",
+ "data",
+ "conn_id",
+ )
+
+ def __init__(
+ self,
+ *,
+ url: str,
+ response_type: ResponseType | None = None,
+ response_handler: Callable[
+ [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
+ ] = lambda response, error_map: response.json(),
+ path_parameters: dict[str, Any] | None = None,
+ url_template: str | None = None,
+ method: str = "GET",
+ query_parameters: dict[str, QueryParams] | None = None,
+ headers: dict[str, str] | None = None,
+ data: dict[str, Any] | str | BytesIO | None = None,
+ conn_id: str = KiotaRequestAdapterHook.default_conn_name,
+ key: str = XCOM_RETURN_KEY,
+ timeout: float | None = None,
+ proxies: dict | None = None,
+ api_version: APIVersion | None = None,
+ pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None,
+ result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
+ serializer: type[ResponseSerializer] = ResponseSerializer,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.url = url
+ self.response_type = response_type
+ self.response_handler = response_handler
+ self.path_parameters = path_parameters
+ self.url_template = url_template
+ self.method = method
+ self.query_parameters = query_parameters
+ self.headers = headers
+ self.data = data
+ self.conn_id = conn_id
+ self.key = key
+ self.timeout = timeout
+ self.proxies = proxies
+ self.api_version = api_version
+ self.pagination_function = pagination_function or self.paginate
+ self.result_processor = result_processor
+ self.serializer: ResponseSerializer = serializer()
+ self.results: list[Any] | None = None
+
+ def execute(self, context: Context) -> None:
+ self.log.info("Executing url '%s' as '%s'", self.url, self.method)
+ self.defer(
+ trigger=MSGraphTrigger(
+ url=self.url,
+ response_type=self.response_type,
+ path_parameters=self.path_parameters,
+ url_template=self.url_template,
+ method=self.method,
+ query_parameters=self.query_parameters,
+ headers=self.headers,
+ data=self.data,
+ conn_id=self.conn_id,
+ timeout=self.timeout,
+ proxies=self.proxies,
+ api_version=self.api_version,
+ serializer=type(self.serializer),
+ ),
+ method_name=self.execute_complete.__name__,
+ )
+
+ def execute_complete(
+ self,
+ context: Context,
+ event: dict[Any, Any] | None = None,
+ ) -> Any:
+ """
+ Execute callback when MSGraphTrigger finishes execution.
+
+ This method gets executed automatically when MSGraphTrigger completes its execution.
+ """
+ self.log.debug("context: %s", context)
+
+ if event:
+ self.log.info("%s completed with %s: %s", self.task_id, event.get("status"), event)
+
+ if event.get("status") == "failure":
+ raise AirflowException(event.get("message"))
+
+ response = event.get("response")
+
+ self.log.info("response: %s", response)
+
+ if response:
+ response = self.serializer.deserialize(response)
+
+ self.log.debug("deserialize response: %s", response)
+
+ result = self.result_processor(context, response)
+
+ self.log.debug("processed response: %s", result)
+
+ event["response"] = result
+
+ try:
+ self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__)
+ except TaskDeferred as exception:
+ self.append_result(
+ result=result,
+ append_result_as_list_if_absent=True,
+ )
+ self.push_xcom(context=context, value=self.results)
+ raise exception
+
+ self.append_result(result=result)
+ self.log.debug("results: %s", self.results)
+
+ return self.results
+ return None
+
+ def append_result(
+ self,
+ result: Any,
+ append_result_as_list_if_absent: bool = False,
+ ):
+ self.log.debug("value: %s", result)
+
+ if isinstance(self.results, list):
+ if isinstance(result, list):
+ self.results.extend(result)
+ else:
+ self.results.append(result)
+ else:
+ if append_result_as_list_if_absent:
+ if isinstance(result, list):
+ self.results = result
+ else:
+ self.results = [result]
+ else:
+ self.results = result
+
+ def push_xcom(self, context: Context, value) -> None:
+ self.log.debug("do_xcom_push: %s", self.do_xcom_push)
+ if self.do_xcom_push:
+ self.log.info("Pushing XCom with key '%s': %s", self.key, value)
+ self.xcom_push(context=context, key=self.key, value=value)
+
+ def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any:
+ self.results = list(
+ self.xcom_pull(
+ context=context,
+ task_ids=self.task_id,
+ dag_id=self.dag_id,
+ key=self.key,
+ )
+ or []
+ )
+ self.log.info(
+ "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s",
+ self.task_id,
+ self.dag_id,
+ self.key,
+ self.results,
+ )
+ return self.execute_complete(context, event)
+
+ @staticmethod
+ def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]:
+ odata_count = response.get("@odata.count")
+ if odata_count and operator.query_parameters:
+ query_parameters = deepcopy(operator.query_parameters)
+ top = query_parameters.get("$top")
+ odata_count = response.get("@odata.count")
+
+ if top and odata_count:
+ if len(response.get("value", [])) == top:
+ skip = (
+ sum(map(lambda result: len(result["value"]), operator.results)) + top
+ if operator.results
+ else top
+ )
+ query_parameters["$skip"] = skip
+ return operator.url, query_parameters
+ return response.get("@odata.nextLink"), operator.query_parameters
+
+ def trigger_next_link(self, response, method_name="execute_complete") -> None:
+ if isinstance(response, dict):
+ url, query_parameters = self.pagination_function(self, response)
+
+ self.log.debug("url: %s", url)
+ self.log.debug("query_parameters: %s", query_parameters)
+
+ if url:
+ self.defer(
+ trigger=MSGraphTrigger(
+ url=url,
+ query_parameters=query_parameters,
+ response_type=self.response_type,
+ response_handler=self.response_handler,
+ conn_id=self.conn_id,
+ timeout=self.timeout,
+ proxies=self.proxies,
+ api_version=self.api_version,
+ serializer=type(self.serializer),
+ ),
+ method_name=method_name,
+ )
diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml
index 2ddc479b63c0a..f1fa058f86618 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -76,7 +76,7 @@ versions:
- 1.0.0
dependencies:
- - apache-airflow>=2.6.0
+ - apache-airflow>=2.7.0
- adlfs>=2023.10.0
- azure-batch>=8.0.0
- azure-cosmos>=4.6.0
@@ -98,6 +98,7 @@ dependencies:
- azure-mgmt-datafactory>=2.0.0
- azure-mgmt-containerregistry>=8.0.0
- azure-mgmt-containerinstance>=9.0.0
+ - msgraph-core>=1.0.0
devel-dependencies:
- pywinrm
@@ -164,6 +165,12 @@ integrations:
external-doc-url: https://azure.microsoft.com/en-us/products/storage/data-lake-storage/
logo: /integration-logos/azure/Data Lake Storage.svg
tags: [azure]
+ - integration-name: Microsoft Graph API
+ external-doc-url: https://learn.microsoft.com/en-us/graph/use-the-api/
+ logo: /integration-logos/azure/Microsoft-Graph-API.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
+ tags: [azure]
operators:
- integration-name: Microsoft Azure Data Lake Storage
@@ -193,6 +200,9 @@ operators:
- integration-name: Microsoft Azure Synapse
python-modules:
- airflow.providers.microsoft.azure.operators.synapse
+ - integration-name: Microsoft Graph API
+ python-modules:
+ - airflow.providers.microsoft.azure.operators.msgraph
sensors:
- integration-name: Microsoft Azure Cosmos DB
@@ -204,6 +214,9 @@ sensors:
- integration-name: Microsoft Azure Data Factory
python-modules:
- airflow.providers.microsoft.azure.sensors.data_factory
+ - integration-name: Microsoft Graph API
+ python-modules:
+ - airflow.providers.microsoft.azure.sensors.msgraph
filesystems:
- airflow.providers.microsoft.azure.fs.adls
@@ -247,6 +260,9 @@ hooks:
- integration-name: Microsoft Azure Synapse
python-modules:
- airflow.providers.microsoft.azure.hooks.synapse
+ - integration-name: Microsoft Graph API
+ python-modules:
+ - airflow.providers.microsoft.azure.hooks.msgraph
triggers:
- integration-name: Microsoft Azure Data Factory
@@ -255,6 +271,9 @@ triggers:
- integration-name: Microsoft Azure Blob Storage
python-modules:
- airflow.providers.microsoft.azure.triggers.wasb
+ - integration-name: Microsoft Graph API
+ python-modules:
+ - airflow.providers.microsoft.azure.triggers.msgraph
transfers:
- source-integration-name: Local
diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py
new file mode 100644
index 0000000000000..ffbf244dbe88c
--- /dev/null
+++ b/airflow/providers/microsoft/azure/sensors/msgraph.py
@@ -0,0 +1,163 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
+from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
+from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
+
+if TYPE_CHECKING:
+ from io import BytesIO
+
+ from kiota_abstractions.request_information import QueryParams
+ from kiota_abstractions.response_handler import NativeResponseType
+ from kiota_abstractions.serialization import ParsableFactory
+ from kiota_http.httpx_request_adapter import ResponseType
+ from msgraph_core import APIVersion
+
+ from airflow.triggers.base import TriggerEvent
+ from airflow.utils.context import Context
+
+
+def default_event_processor(context: Context, event: TriggerEvent) -> bool:
+ if event.payload["status"] == "success":
+ return json.loads(event.payload["response"])["status"] == "Succeeded"
+ return False
+
+
+class MSGraphSensor(BaseSensorOperator):
+ """
+ A Microsoft Graph API sensor which allows you to poll an async REST call to the Microsoft Graph API.
+
+ :param url: The url being executed on the Microsoft Graph API (templated).
+ :param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
+ `str`, `int`, `float`, `bool` and `datetime` (default is None).
+ :param response_handler: Function to convert the native HTTPX response returned by the hook (default is
+ lambda response, error_map: response.json()). The default expression will convert the native response
+ to JSON. If response_type parameter is specified, then the response_handler will be ignored.
+ :param method: The HTTP method being used to do the REST call (default is GET).
+ :param conn_id: The HTTP Connection ID to run the operator against (templated).
+ :param proxies: A dict defining the HTTP proxies to be used (default is None).
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
+ or you can pass a string as `v1.0` or `beta`.
+ :param event_processor: Function which checks the response from MS Graph API (default is the
+ `default_event_processor` method) and returns a boolean. When the result is True, the sensor
+ will stop poking, otherwise it will continue until it's True or times out.
+ :param result_processor: Function to further process the response from MS Graph API
+ (default is lambda: context, response: response). When the response returned by the
+ `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string.
+ :param serializer: Class which handles response serialization (default is ResponseSerializer).
+ Bytes will be base64 encoded into a string, so it can be stored as an XCom.
+ """
+
+ template_fields: Sequence[str] = (
+ "url",
+ "response_type",
+ "path_parameters",
+ "url_template",
+ "query_parameters",
+ "headers",
+ "data",
+ "conn_id",
+ )
+
+ def __init__(
+ self,
+ url: str,
+ response_type: ResponseType | None = None,
+ response_handler: Callable[
+ [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
+ ] = lambda response, error_map: response.json(),
+ path_parameters: dict[str, Any] | None = None,
+ url_template: str | None = None,
+ method: str = "GET",
+ query_parameters: dict[str, QueryParams] | None = None,
+ headers: dict[str, str] | None = None,
+ data: dict[str, Any] | str | BytesIO | None = None,
+ conn_id: str = KiotaRequestAdapterHook.default_conn_name,
+ proxies: dict | None = None,
+ api_version: APIVersion | None = None,
+ event_processor: Callable[[Context, TriggerEvent], bool] = default_event_processor,
+ result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
+ serializer: type[ResponseSerializer] = ResponseSerializer,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.url = url
+ self.response_type = response_type
+ self.response_handler = response_handler
+ self.path_parameters = path_parameters
+ self.url_template = url_template
+ self.method = method
+ self.query_parameters = query_parameters
+ self.headers = headers
+ self.data = data
+ self.conn_id = conn_id
+ self.proxies = proxies
+ self.api_version = api_version
+ self.event_processor = event_processor
+ self.result_processor = result_processor
+ self.serializer = serializer()
+
+ @property
+ def trigger(self):
+ return MSGraphTrigger(
+ url=self.url,
+ response_type=self.response_type,
+ response_handler=self.response_handler,
+ path_parameters=self.path_parameters,
+ url_template=self.url_template,
+ method=self.method,
+ query_parameters=self.query_parameters,
+ headers=self.headers,
+ data=self.data,
+ conn_id=self.conn_id,
+ timeout=self.timeout,
+ proxies=self.proxies,
+ api_version=self.api_version,
+ serializer=type(self.serializer),
+ )
+
+ async def async_poke(self, context: Context) -> bool | PokeReturnValue:
+ self.log.info("Sensor triggered")
+
+ async for event in self.trigger.run():
+ self.log.debug("event: %s", event)
+
+ is_done = self.event_processor(context, event)
+
+ self.log.debug("is_done: %s", is_done)
+
+ response = self.serializer.deserialize(event.payload["response"])
+
+ self.log.debug("deserialize event: %s", response)
+
+ result = self.result_processor(context, response)
+
+ self.log.debug("result: %s", result)
+
+ return PokeReturnValue(is_done=is_done, xcom_value=result)
+ return PokeReturnValue(is_done=True)
+
+ def poke(self, context) -> bool | PokeReturnValue:
+ return asyncio.run(self.async_poke(context))
diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py
new file mode 100644
index 0000000000000..c0e5ee85a0c4c
--- /dev/null
+++ b/airflow/providers/microsoft/azure/triggers/msgraph.py
@@ -0,0 +1,316 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import locale
+from base64 import b64encode
+from contextlib import suppress
+from datetime import datetime
+from io import BytesIO
+from json import JSONDecodeError
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncIterator,
+ Callable,
+ Sequence,
+)
+from urllib.parse import quote
+from uuid import UUID
+
+import pendulum
+from kiota_abstractions.api_error import APIError
+from kiota_abstractions.method import Method
+from kiota_abstractions.request_information import RequestInformation
+from kiota_abstractions.response_handler import ResponseHandler
+from kiota_http.middleware.options import ResponseHandlerOption
+
+from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.module_loading import import_string
+
+if TYPE_CHECKING:
+ from kiota_abstractions.request_adapter import RequestAdapter
+ from kiota_abstractions.request_information import QueryParams
+ from kiota_abstractions.response_handler import NativeResponseType
+ from kiota_abstractions.serialization import ParsableFactory
+ from kiota_http.httpx_request_adapter import ResponseType
+ from msgraph_core import APIVersion
+
+
+class ResponseSerializer:
+ """ResponseSerializer serializes the response as a string."""
+
+ def __init__(self, encoding: str | None = None):
+ self.encoding = encoding or locale.getpreferredencoding()
+
+ def serialize(self, response) -> str | None:
+ def convert(value) -> str | None:
+ if value is not None:
+ if isinstance(value, UUID):
+ return str(value)
+ if isinstance(value, datetime):
+ return value.isoformat()
+ if isinstance(value, pendulum.DateTime):
+ return value.to_iso8601_string() # Adjust the format as needed
+ raise TypeError(f"Object of type {type(value)} is not JSON serializable!")
+ return None
+
+ if response is not None:
+ if isinstance(response, bytes):
+ return b64encode(response).decode(self.encoding)
+ with suppress(JSONDecodeError):
+ return json.dumps(response, default=convert)
+ return response
+ return None
+
+ def deserialize(self, response) -> Any:
+ if isinstance(response, str):
+ with suppress(JSONDecodeError):
+ response = json.loads(response)
+ return response
+
+
+class CallableResponseHandler(ResponseHandler):
+ """
+ CallableResponseHandler executes the passed callable_function with response as parameter.
+
+ param callable_function: Function that is applied to the response.
+ """
+
+ def __init__(
+ self,
+ callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
+ ):
+ self.callable_function = callable_function
+
+ async def handle_response_async(
+ self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
+ ) -> Any:
+ """
+ Invoke this callback method when a response is received.
+
+ param response: The type of the native response object.
+ param error_map: The error dict to use in case of a failed request.
+ """
+ return self.callable_function(response, error_map)
+
+
+class MSGraphTrigger(BaseTrigger):
+ """
+ A Microsoft Graph API trigger which allows you to execute an async REST call to the Microsoft Graph API.
+
+ :param url: The url being executed on the Microsoft Graph API (templated).
+ :param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
+ `str`, `int`, `float`, `bool` and `datetime` (default is None).
+ :param response_handler: Function to convert the native HTTPX response returned by the hook (default is
+ lambda response, error_map: response.json()). The default expression will convert the native response
+ to JSON. If response_type parameter is specified, then the response_handler will be ignored.
+ :param method: The HTTP method being used to do the REST call (default is GET).
+ :param conn_id: The HTTP Connection ID to run the operator against (templated).
+ :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None).
+ When no timeout is specified or set to None then there is no HTTP timeout on each request.
+ :param proxies: A dict defining the HTTP proxies to be used (default is None).
+ :param api_version: The API version of the Microsoft Graph API to be used (default is v1).
+ You can pass an enum named APIVersion which has 2 possible members v1 and beta,
+ or you can pass a string as `v1.0` or `beta`.
+ :param serializer: Class which handles response serialization (default is ResponseSerializer).
+ Bytes will be base64 encoded into a string, so it can be stored as an XCom.
+ """
+
+ DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
+ template_fields: Sequence[str] = (
+ "url",
+ "response_type",
+ "path_parameters",
+ "url_template",
+ "query_parameters",
+ "headers",
+ "data",
+ "conn_id",
+ )
+
+ def __init__(
+ self,
+ url: str,
+ response_type: ResponseType | None = None,
+ response_handler: Callable[
+ [NativeResponseType, dict[str, ParsableFactory | None] | None], Any
+ ] = lambda response, error_map: response.json(),
+ path_parameters: dict[str, Any] | None = None,
+ url_template: str | None = None,
+ method: str = "GET",
+ query_parameters: dict[str, QueryParams] | None = None,
+ headers: dict[str, str] | None = None,
+ data: dict[str, Any] | str | BytesIO | None = None,
+ conn_id: str = KiotaRequestAdapterHook.default_conn_name,
+ timeout: float | None = None,
+ proxies: dict | None = None,
+ api_version: APIVersion | None = None,
+ serializer: type[ResponseSerializer] = ResponseSerializer,
+ ):
+ super().__init__()
+ self.hook = KiotaRequestAdapterHook(
+ conn_id=conn_id,
+ timeout=timeout,
+ proxies=proxies,
+ api_version=api_version,
+ )
+ self.url = url
+ self.response_type = response_type
+ self.response_handler = response_handler
+ self.path_parameters = path_parameters
+ self.url_template = url_template
+ self.method = method
+ self.query_parameters = query_parameters
+ self.headers = headers
+ self.data = data
+ self.serializer: ResponseSerializer = self.resolve_type(serializer, default=ResponseSerializer)()
+
+ @classmethod
+ def resolve_type(cls, value: str | type, default) -> type:
+ if isinstance(value, str):
+ with suppress(ImportError):
+ return import_string(value)
+ return default
+ return value or default
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize the HttpTrigger arguments and classpath."""
+ api_version = self.api_version.value if self.api_version else None
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__name__}",
+ {
+ "conn_id": self.conn_id,
+ "timeout": self.timeout,
+ "proxies": self.proxies,
+ "api_version": api_version,
+ "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}",
+ "url": self.url,
+ "path_parameters": self.path_parameters,
+ "url_template": self.url_template,
+ "method": self.method,
+ "query_parameters": self.query_parameters,
+ "headers": self.headers,
+ "data": self.data,
+ "response_type": self.response_type,
+ },
+ )
+
+ def get_conn(self) -> RequestAdapter:
+ return self.hook.get_conn()
+
+ @property
+ def conn_id(self) -> str:
+ return self.hook.conn_id
+
+ @property
+ def timeout(self) -> float | None:
+ return self.hook.timeout
+
+ @property
+ def proxies(self) -> dict | None:
+ return self.hook.proxies
+
+ @property
+ def api_version(self) -> APIVersion:
+ return self.hook.api_version
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook."""
+ try:
+ response = await self.execute()
+
+ self.log.debug("response: %s", response)
+
+ if response:
+ response_type = type(response)
+
+ self.log.debug("response type: %s", response_type)
+
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "type": f"{response_type.__module__}.{response_type.__name__}",
+ "response": self.serializer.serialize(response),
+ }
+ )
+ else:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "type": None,
+ "response": None,
+ }
+ )
+ except Exception as e:
+ self.log.exception("An error occurred: %s", e)
+ yield TriggerEvent({"status": "failure", "message": str(e)})
+
+ def normalize_url(self) -> str | None:
+ if self.url.startswith("/"):
+ return self.url.replace("/", "", 1)
+ return self.url
+
+ def encoded_query_parameters(self) -> dict:
+ if self.query_parameters:
+ return {quote(key): quote(str(value)) for key, value in self.query_parameters.items()}
+ return {}
+
+ def request_information(self) -> RequestInformation:
+ request_information = RequestInformation()
+ request_information.path_parameters = self.path_parameters or {}
+ request_information.http_method = Method(self.method.strip().upper())
+ request_information.query_parameters = self.encoded_query_parameters()
+ if self.url.startswith("http"):
+ request_information.url = self.url
+ elif request_information.query_parameters.keys():
+ query = ",".join(request_information.query_parameters.keys())
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}{{?{query}}}"
+ else:
+ request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}"
+ if not self.response_type:
+ request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
+ response_handler=CallableResponseHandler(self.response_handler)
+ )
+ headers = {**self.DEFAULT_HEADERS, **self.headers} if self.headers else self.DEFAULT_HEADERS
+ for header_name, header_value in headers.items():
+ request_information.headers.try_add(header_name=header_name, header_value=header_value)
+ if isinstance(self.data, BytesIO) or isinstance(self.data, bytes) or isinstance(self.data, str):
+ request_information.content = self.data
+ elif self.data:
+ request_information.headers.try_add(
+ header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
+ )
+ request_information.content = json.dumps(self.data).encode("utf-8")
+ return request_information
+
+ @staticmethod
+ def error_mapping() -> dict[str, ParsableFactory | None]:
+ return {
+ "4XX": APIError,
+ "5XX": APIError,
+ }
+
+ async def execute(self) -> AsyncIterator[TriggerEvent]:
+ return await self.get_conn().send_primitive_async(
+ request_info=self.request_information(),
+ response_type=self.response_type,
+ error_map=self.error_mapping(),
+ )
diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py
index b527cafe3cd6e..efc01b58858cf 100644
--- a/dev/breeze/src/airflow_breeze/global_constants.py
+++ b/dev/breeze/src/airflow_breeze/global_constants.py
@@ -473,7 +473,9 @@ def _exclusion(providers: Iterable[str]) -> str:
{
"python-version": "3.8",
"airflow-version": "2.6.0",
- "remove-providers": _exclusion(["openlineage", "common.io", "cohere", "fab", "qdrant"]),
+ "remove-providers": _exclusion(
+ ["openlineage", "common.io", "cohere", "fab", "qdrant", "microsoft.azure"]
+ ),
},
{
"python-version": "3.8",
diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
new file mode 100644
index 0000000000000..817b14f783142
--- /dev/null
+++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst
@@ -0,0 +1,74 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Microsoft Graph API Operators
+=============================
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+.. _howto/operator:MSGraphAsyncOperator:
+
+MSGraphAsyncOperator
+----------------------------------
+Use the
+:class:`~airflow.providers.microsoft.azure.operators.msgraph.MSGraphAsyncOperator` to call Microsoft Graph API.
+
+
+Below is an example of using this operator to get a Sharepoint site.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msgraph.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_graph_site]
+ :end-before: [END howto_operator_graph_site]
+
+Below is an example of using this operator to get a Sharepoint site pages.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msgraph.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_graph_site_pages]
+ :end-before: [END howto_operator_graph_site_pages]
+
+Below is an example of using this operator to get PowerBI workspaces.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_powerbi_workspaces]
+ :end-before: [END howto_operator_powerbi_workspaces]
+
+Below is an example of using this operator to get PowerBI workspaces info.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_powerbi_workspaces_info]
+ :end-before: [END howto_operator_powerbi_workspaces_info]
+
+
+Reference
+---------
+
+For further information, look at:
+
+* `Use the Microsoft Graph API `__
+* `Using the Power BI REST APIs `__
diff --git a/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst
new file mode 100644
index 0000000000000..4ddad88f19fa1
--- /dev/null
+++ b/docs/apache-airflow-providers-microsoft-azure/sensors/msgraph.rst
@@ -0,0 +1,42 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Microsoft Graph API Sensors
+=============================
+
+MSGraphSensor
+-------------
+Use the
+:class:`~airflow.providers.microsoft.azure.sensors.msgraph.MSGraphSensor` to poll a Power BI API.
+
+
+Below is an example of using this sensor to poll the status of a PowerBI workspace.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_sensor_powerbi_scan_status]
+ :end-before: [END howto_sensor_powerbi_scan_status]
+
+
+Reference
+---------
+
+For further information, look at:
+
+* `Using the Power BI REST APIs `__
diff --git a/docs/integration-logos/azure/Microsoft-Graph-API.png b/docs/integration-logos/azure/Microsoft-Graph-API.png
new file mode 100644
index 0000000000000..0724a1e09b495
Binary files /dev/null and b/docs/integration-logos/azure/Microsoft-Graph-API.png differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 48b4189e1a43d..dcd8641d8061b 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1444,6 +1444,7 @@ setted
sftp
SFTPClient
sharded
+Sharepoint
shellcheck
shellcmd
shm
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 9315766f81926..841f3764674e3 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -680,7 +680,7 @@
"deps": [
"adal>=1.2.7",
"adlfs>=2023.10.0",
- "apache-airflow>=2.6.0",
+ "apache-airflow>=2.7.0",
"azure-batch>=8.0.0",
"azure-cosmos>=4.6.0",
"azure-datalake-store>=0.0.45",
@@ -699,7 +699,8 @@
"azure-storage-file-datalake>=12.9.1",
"azure-storage-file-share",
"azure-synapse-artifacts>=0.17.0",
- "azure-synapse-spark"
+ "azure-synapse-spark",
+ "msgraph-core>=1.0.0"
],
"devel-deps": [
"pywinrm"
diff --git a/tests/providers/microsoft/azure/base.py b/tests/providers/microsoft/azure/base.py
new file mode 100644
index 0000000000000..4cda62858e815
--- /dev/null
+++ b/tests/providers/microsoft/azure/base.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from contextlib import contextmanager
+from copy import deepcopy
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Iterable
+from unittest.mock import patch
+
+from kiota_http.httpx_request_adapter import HttpxRequestAdapter
+
+from airflow.exceptions import TaskDeferred
+from airflow.models import Operator, TaskInstance
+from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
+from airflow.utils.session import NEW_SESSION
+from airflow.utils.xcom import XCOM_RETURN_KEY
+from tests.providers.microsoft.conftest import get_airflow_connection, mock_context
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+ from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class MockedTaskInstance(TaskInstance):
+ values = {}
+
+ def xcom_pull(
+ self,
+ task_ids: Iterable[str] | str | None = None,
+ dag_id: str | None = None,
+ key: str = XCOM_RETURN_KEY,
+ include_prior_dates: bool = False,
+ session: Session = NEW_SESSION,
+ *,
+ map_indexes: Iterable[int] | int | None = None,
+ default: Any | None = None,
+ ) -> Any:
+ self.task_id = task_ids
+ self.dag_id = dag_id
+ return self.values.get(f"{task_ids}_{dag_id}_{key}")
+
+ def xcom_push(
+ self,
+ key: str,
+ value: Any,
+ execution_date: datetime | None = None,
+ session: Session = NEW_SESSION,
+ ) -> None:
+ self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+
+
+class Base:
+ def teardown_method(self, method):
+ KiotaRequestAdapterHook.cached_request_adapters.clear()
+ MockedTaskInstance.values.clear()
+
+ @contextmanager
+ def patch_hook_and_request_adapter(self, response):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection
+ ), patch.object(HttpxRequestAdapter, "get_http_response_message") as mock_get_http_response:
+ if isinstance(response, Exception):
+ mock_get_http_response.side_effect = response
+ else:
+ mock_get_http_response.return_value = response
+ yield
+
+ @staticmethod
+ async def _run_tigger(trigger: BaseTrigger) -> list[TriggerEvent]:
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+ return events
+
+ def run_trigger(self, trigger: BaseTrigger) -> list[TriggerEvent]:
+ return asyncio.run(self._run_tigger(trigger))
+
+ def execute_operator(self, operator: Operator) -> tuple[Any, Any]:
+ context = mock_context(task=operator)
+ return asyncio.run(self.deferrable_operator(context, operator))
+
+ async def deferrable_operator(self, context, operator):
+ result = None
+ triggered_events = []
+ try:
+ result = operator.execute(context=context)
+ except TaskDeferred as deferred:
+ task = deferred
+
+ while task:
+ events = await self._run_tigger(task.trigger)
+
+ if not events:
+ break
+
+ triggered_events.extend(deepcopy(events))
+
+ try:
+ method = getattr(operator, task.method_name)
+ result = method(context=context, event=next(iter(events)).payload)
+ task = None
+ except TaskDeferred as exception:
+ task = exception
+ return result, triggered_events
diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py
new file mode 100644
index 0000000000000..1c1046e1fa4f3
--- /dev/null
+++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+from kiota_http.httpx_request_adapter import HttpxRequestAdapter
+from msgraph_core import APIVersion, NationalClouds
+
+from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
+from tests.providers.microsoft.conftest import get_airflow_connection, mock_connection
+
+
+class TestKiotaRequestAdapterHook:
+ def test_get_conn(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_conn()
+
+ assert isinstance(actual, HttpxRequestAdapter)
+ assert actual.base_url == "https://graph.microsoft.com/v1.0"
+
+ def test_api_version(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+
+ assert hook.api_version == APIVersion.v1
+
+ def test_get_api_version_when_empty_config_dict(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_api_version({})
+
+ assert actual == APIVersion.v1
+
+ def test_get_api_version_when_api_version_in_config_dict(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+ actual = hook.get_api_version({"api_version": "beta"})
+
+ assert actual == APIVersion.beta
+
+ def test_get_host_when_connection_has_scheme_and_host(self):
+ connection = mock_connection(schema="https", host="graph.microsoft.de")
+ actual = KiotaRequestAdapterHook.get_host(connection)
+
+ assert actual == NationalClouds.Germany.value
+
+ def test_get_host_when_connection_has_no_scheme_or_host(self):
+ connection = mock_connection()
+ actual = KiotaRequestAdapterHook.get_host(connection)
+
+ assert actual == NationalClouds.Global.value
diff --git a/tests/providers/microsoft/azure/operators/test_msgraph.py b/tests/providers/microsoft/azure/operators/test_msgraph.py
new file mode 100644
index 0000000000000..b7520d731544c
--- /dev/null
+++ b/tests/providers/microsoft/azure/operators/test_msgraph.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import locale
+from base64 import b64encode
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
+from airflow.triggers.base import TriggerEvent
+from tests.providers.microsoft.azure.base import Base
+from tests.providers.microsoft.conftest import load_file, load_json, mock_json_response, mock_response
+
+
+class TestMSGraphAsyncOperator(Base):
+ @pytest.mark.db_test
+ def test_execute(self):
+ users = load_json("resources", "users.json")
+ next_users = load_json("resources", "next_users.json")
+ response = mock_json_response(200, users, next_users)
+
+ with self.patch_hook_and_request_adapter(response):
+ operator = MSGraphAsyncOperator(
+ task_id="users_delta",
+ conn_id="msgraph_api",
+ url="users",
+ result_processor=lambda context, result: result.get("value"),
+ )
+
+ results, events = self.execute_operator(operator)
+
+ assert len(results) == 30
+ assert results == users.get("value") + next_users.get("value")
+ assert len(events) == 2
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["type"] == "builtins.dict"
+ assert events[0].payload["response"] == json.dumps(users)
+ assert isinstance(events[1], TriggerEvent)
+ assert events[1].payload["status"] == "success"
+ assert events[1].payload["type"] == "builtins.dict"
+ assert events[1].payload["response"] == json.dumps(next_users)
+
+ @pytest.mark.db_test
+ def test_execute_when_do_xcom_push_is_false(self):
+ users = load_json("resources", "users.json")
+ users.pop("@odata.nextLink")
+ response = mock_json_response(200, users)
+
+ with self.patch_hook_and_request_adapter(response):
+ operator = MSGraphAsyncOperator(
+ task_id="users_delta",
+ conn_id="msgraph_api",
+ url="users/delta",
+ do_xcom_push=False,
+ )
+
+ results, events = self.execute_operator(operator)
+
+ assert isinstance(results, dict)
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["type"] == "builtins.dict"
+ assert events[0].payload["response"] == json.dumps(users)
+
+ @pytest.mark.db_test
+ def test_execute_when_an_exception_occurs(self):
+ with self.patch_hook_and_request_adapter(AirflowException()):
+ operator = MSGraphAsyncOperator(
+ task_id="users_delta",
+ conn_id="msgraph_api",
+ url="users/delta",
+ do_xcom_push=False,
+ )
+
+ with pytest.raises(AirflowException):
+ self.execute_operator(operator)
+
+ @pytest.mark.db_test
+ def test_execute_when_response_is_bytes(self):
+ content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding())
+ drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
+ response = mock_response(200, content)
+
+ with self.patch_hook_and_request_adapter(response):
+ operator = MSGraphAsyncOperator(
+ task_id="drive_item_content",
+ conn_id="msgraph_api",
+ response_type="bytes",
+ url=f"/drives/{drive_id}/root/content",
+ )
+
+ results, events = self.execute_operator(operator)
+
+ assert results == base64_encoded_content
+ assert len(events) == 1
+ assert isinstance(events[0], TriggerEvent)
+ assert events[0].payload["status"] == "success"
+ assert events[0].payload["type"] == "builtins.bytes"
+ assert events[0].payload["response"] == base64_encoded_content
+
+ def test_template_fields(self):
+ operator = MSGraphAsyncOperator(
+ task_id="drive_item_content",
+ conn_id="msgraph_api",
+ url="users/delta",
+ )
+
+ for template_field in MSGraphAsyncOperator.template_fields:
+ getattr(operator, template_field)
diff --git a/airflow/providers/microsoft/azure/serialization/__init__.py b/tests/providers/microsoft/azure/resources/__init__.py
similarity index 100%
rename from airflow/providers/microsoft/azure/serialization/__init__.py
rename to tests/providers/microsoft/azure/resources/__init__.py
diff --git a/tests/providers/microsoft/azure/resources/dummy.pdf b/tests/providers/microsoft/azure/resources/dummy.pdf
new file mode 100644
index 0000000000000..774c2ea70c551
Binary files /dev/null and b/tests/providers/microsoft/azure/resources/dummy.pdf differ
diff --git a/tests/providers/microsoft/azure/resources/next_users.json b/tests/providers/microsoft/azure/resources/next_users.json
new file mode 100644
index 0000000000000..3a88cf08b2c34
--- /dev/null
+++ b/tests/providers/microsoft/azure/resources/next_users.json
@@ -0,0 +1 @@
+{"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users(displayName,description,mailNickname)", "value": [{"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/c52a9941-e5cb-49cb-9972-6f45cd7cd447/Microsoft.DirectoryServices.User", "displayName": "Leonardo DiCaprio", "mailNickname": "LeoD"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/abee3661-4525-4491-8e0f-e1589383aea7/Microsoft.DirectoryServices.User", "displayName": "Meryl Streep", "mailNickname": "MerylS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/6794f45f-949e-4e8b-b01c-03d777e1cbf8/Microsoft.DirectoryServices.User", "displayName": "Tom Hanks", "mailNickname": "TomH"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/ad31af2e-b69a-4ff7-8aee-b962cc739210/Microsoft.DirectoryServices.User", "displayName": "Jennifer Lawrence", "mailNickname": "JenniferL"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/6aed18fc-f39b-4762-85d1-1525ccdf4823/Microsoft.DirectoryServices.User", "displayName": "Denzel Washington", "mailNickname": "DenzelW"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/44c76cb8-7118-4211-99d0-dd6651ce2fe6/Microsoft.DirectoryServices.User", "displayName": "Angelina Jolie", "mailNickname": "AngelinaJ"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/d1db76e3-06e7-430a-a97e-7b6572f5fb15/Microsoft.DirectoryServices.User", "displayName": "Johnny Depp", "mailNickname": "JohnnyD"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/2b62c4a9-3c0c-4d60-8e3f-cb698d8ba9fc/Microsoft.DirectoryServices.User", "displayName": "Scarlett Johansson", "mailNickname": "ScarlettJ"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/d982d3c5-4ea8-47bb-81b1-77ed0b107f30/Microsoft.DirectoryServices.User", "displayName": "Daniel Craig", "mailNickname": "DanielC"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/0388216c-76ca-4031-a620-3af3df529485/Microsoft.DirectoryServices.User", "displayName": "Charlize Theron", "mailNickname": "CharlizeT"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/bbde9cbb-9688-4f07-af76-660244830541/Microsoft.DirectoryServices.User", "displayName": "George Clooney", "mailNickname": "GeorgeC"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/149136c3-62f0-4d27-94f1-8a27bfc4cf73/Microsoft.DirectoryServices.User", "displayName": "Julia Roberts", "mailNickname": "JuliaR"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/08f9d41c-a653-4de3-85c2-324eb53bcbff/Microsoft.DirectoryServices.User", "displayName": "Ryan Reynolds", "mailNickname": "RyanR"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/cf5ce2ea-132b-4699-bd8f-6bfcac619fc3/Microsoft.DirectoryServices.User", "displayName": "Nicole Kidman", "mailNickname": "NicoleK"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/e255041c-28e5-4abc-bb03-09205096cd87/Microsoft.DirectoryServices.User", "displayName": "Sean Connery", "mailNickname": "SeanC"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/52f8faad-6b39-4cbb-8614-e968f5af9e9e/Microsoft.DirectoryServices.User", "displayName": "Emma Stone", "mailNickname": "EmmaS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/e9bf2db3-4149-4396-aafd-278a0212179a/Microsoft.DirectoryServices.User", "displayName": "Robert Downey Jr.", "mailNickname": "RobertDJr"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/f66aba7d-5284-418d-a025-6d6f0639350b/Microsoft.DirectoryServices.User", "displayName": "Cate Blanchett", "mailNickname": "CateB"}]}
diff --git a/tests/providers/microsoft/azure/resources/status.json b/tests/providers/microsoft/azure/resources/status.json
new file mode 100644
index 0000000000000..6bff9e29afb41
--- /dev/null
+++ b/tests/providers/microsoft/azure/resources/status.json
@@ -0,0 +1 @@
+{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}
diff --git a/tests/providers/microsoft/azure/resources/users.json b/tests/providers/microsoft/azure/resources/users.json
new file mode 100644
index 0000000000000..617e6f8420c62
--- /dev/null
+++ b/tests/providers/microsoft/azure/resources/users.json
@@ -0,0 +1 @@
+{"@odata.nextLink": "https://graph.microsoft.com/v1.0/users/delta()?$skiptoken=qLMhmnoTon81CQ1VVWyx9MNESqEIMKNpEWZWAfnn5F7tBNFuSgWh_pXZOweu67nEThGR0yQewi_a3Ixe75S6PoB8pdllphCEev0fMe5Uc1lWMtn3byOS8_OPTzPGZIZ17x-dVyxaE_4I55YyLJ0cgBxg8wsBrkYgaNE9vy5Su2HeCKxJODDQk4zRgP8QGo0pZatReTpqisVbrW5Gl1H_Xgy4lhenv1SmoRcBQtWBa5iAh-MURoaTo7i0kQjFhH6SCrkjBkfkRFVy9dafOOt2Owbxfn5hKGfEnfmG0RBmgdUsZPgX-ap0mjjf7PjExoxMek4CDnb8Yv737oGkh9C_G0XTJGeGxPBbkD-w4SaQookde4yxOzceAw1MuamBy63uJdbXt1ul61tDvfPwrJVHq99FxGU1n-i_RfHh65nRCHju3N3ApOFKrAi6933l3VupyaXsvmm5pCPh0T70dYK01CYktBce8Mc1HaVqB7SR-R9-X4PHYfozPWeqv4hng4YEvqA41XjRPUzOaS1VTH08k-HhR9ENpqw5UTFnimAu5RbMT5fTbUAMTQC2XcWF_5aDgjuw8D2VQvr5VsB7qmu4mgGb_dNrHM47QyJCKY2QcgLmmsTnr5Z-7Qe2AGGy1b5DREBVoLnSL-aJ_6m9TAlYD9oGityZbDJ1ssVxS0XsGYxAwSa5z_E_lgedr_ikHc0zzDAdj3TgLGD5gWIjwvxnkNEXa0onk-jqGfANEDFh0vGuDz1mlgwiHGIKN-QjmfTX-Equ_uY2lbuPcIXTVpdrgQob5BaXONZ78uVh-AIlm96PPqbwNj1QMf3EE0DfgGaMIloFdByjSD5FjwQ_6COxueUJw2iUtKh8l4fpja7yOs4QCP5_tPvUUZT26ylLHjbPZO7I1TOaJb9OTBM-kplBJitW3SAm7DJM9jocA89Iqbm_ap6mlvEET2cyw6Dn7j2AP_0NsdoP1MJn6H0JE6HOJddbUVuzhba58QBujgHsEHo.WPJMTkIZstDbZSnvkZm1eX2ASvooRj4BgY-GW_pQEZo", "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users(displayName,description,mailNickname)", "value": [{"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/c52a9941-e5cb-49cb-9972-6f45cd7cd447/Microsoft.DirectoryServices.User", "displayName": "Sylvester Stallone", "mailNickname": "SylvesterS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/abee3661-4525-4491-8e0f-e1589383aea7/Microsoft.DirectoryServices.User", "displayName": "Jason Statham", "mailNickname": "JasonS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/6794f45f-949e-4e8b-b01c-03d777e1cbf8/Microsoft.DirectoryServices.User", "displayName": "Jet Li", "mailNickname": "JetL"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/ad31af2e-b69a-4ff7-8aee-b962cc739210/Microsoft.DirectoryServices.User", "displayName": "Dolph Lundgren", "mailNickname": "DolphL"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/6aed18fc-f39b-4762-85d1-1525ccdf4823/Microsoft.DirectoryServices.User", "displayName": "Randy Couture", "mailNickname": "RandyC"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/44c76cb8-7118-4211-99d0-dd6651ce2fe6/Microsoft.DirectoryServices.User", "displayName": "Terry Crews", "mailNickname": "TerryC"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/d1db76e3-06e7-430a-a97e-7b6572f5fb15/Microsoft.DirectoryServices.User", "displayName": "Arnold Schwarzenegger", "mailNickname": "ArnoldS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/2b62c4a9-3c0c-4d60-8e3f-cb698d8ba9fc/Microsoft.DirectoryServices.User", "displayName": "Wesley Snipes", "mailNickname": "WesleyS"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/d982d3c5-4ea8-47bb-81b1-77ed0b107f30/Microsoft.DirectoryServices.User", "displayName": "Mel Gibson", "mailNickname": "MelG"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/0388216c-76ca-4031-a620-3af3df529485/Microsoft.DirectoryServices.User", "displayName": "Harrison Ford", "mailNickname": "HarrisonF"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/bbde9cbb-9688-4f07-af76-660244830541/Microsoft.DirectoryServices.User", "displayName": "Antonio Banderas", "mailNickname": "AntonioB"}, {"@odata.type": "#microsoft.graph.user", "@odata.id": "https://graph.microsoft.com/v2/c34aa217-8e78-4c5f-91a7-29e01e0d97d8/directoryObjects/34c76cb8-7118-4211-99d0-dd6651ce2fe6/Microsoft.DirectoryServices.User", "displayName": "Chuck Norris", "mailNickname": "ChuckN"}]}
diff --git a/tests/providers/microsoft/azure/sensors/test_msgraph.py b/tests/providers/microsoft/azure/sensors/test_msgraph.py
new file mode 100644
index 0000000000000..50fd2474ab454
--- /dev/null
+++ b/tests/providers/microsoft/azure/sensors/test_msgraph.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
+from tests.providers.microsoft.azure.base import Base
+from tests.providers.microsoft.conftest import load_json, mock_context, mock_json_response
+
+
+class TestMSGraphSensor(Base):
+ def test_execute(self):
+ status = load_json("resources", "status.json")
+ response = mock_json_response(200, status)
+
+ with self.patch_hook_and_request_adapter(response):
+ sensor = MSGraphSensor(
+ task_id="check_workspaces_status",
+ conn_id="powerbi",
+ url="myorg/admin/workspaces/scanStatus/{scanId}",
+ path_parameters={"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
+ result_processor=lambda context, result: result["id"],
+ timeout=350.0,
+ )
+ actual = sensor.execute(context=mock_context(task=sensor))
+
+ assert isinstance(actual, str)
+ assert actual == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
+
+ def test_template_fields(self):
+ sensor = MSGraphSensor(
+ task_id="check_workspaces_status",
+ conn_id="powerbi",
+ url="myorg/admin/workspaces/scanStatus/{scanId}",
+ )
+
+ for template_field in MSGraphSensor.template_fields:
+ getattr(sensor, template_field)
diff --git a/tests/providers/microsoft/azure/triggers/test_msgraph.py b/tests/providers/microsoft/azure/triggers/test_msgraph.py
new file mode 100644
index 0000000000000..900d0875cd0f2
--- /dev/null
+++ b/tests/providers/microsoft/azure/triggers/test_msgraph.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+import locale
+from base64 import b64decode, b64encode
+from datetime import datetime
+from unittest.mock import patch
+from uuid import uuid4
+
+import pendulum
+
+from airflow.exceptions import AirflowException
+from airflow.providers.microsoft.azure.triggers.msgraph import (
+ CallableResponseHandler,
+ MSGraphTrigger,
+ ResponseSerializer,
+)
+from airflow.triggers.base import TriggerEvent
+from tests.providers.microsoft.azure.base import Base
+from tests.providers.microsoft.conftest import (
+ get_airflow_connection,
+ load_file,
+ load_json,
+ mock_json_response,
+ mock_response,
+)
+
+
+class TestMSGraphTrigger(Base):
+ def test_run_when_valid_response(self):
+ users = load_json("resources", "users.json")
+ response = mock_json_response(200, users)
+
+ with self.patch_hook_and_request_adapter(response):
+ trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api")
+ actual = self.run_trigger(trigger)
+
+ assert len(actual) == 1
+ assert isinstance(actual[0], TriggerEvent)
+ assert actual[0].payload["status"] == "success"
+ assert actual[0].payload["type"] == "builtins.dict"
+ assert actual[0].payload["response"] == json.dumps(users)
+
+ def test_run_when_response_is_none(self):
+ response = mock_json_response(200)
+
+ with self.patch_hook_and_request_adapter(response):
+ trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api")
+ actual = self.run_trigger(trigger)
+
+ assert len(actual) == 1
+ assert isinstance(actual[0], TriggerEvent)
+ assert actual[0].payload["status"] == "success"
+ assert actual[0].payload["type"] is None
+ assert actual[0].payload["response"] is None
+
+ def test_run_when_response_cannot_be_converted_to_json(self):
+ with self.patch_hook_and_request_adapter(AirflowException()):
+ trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api")
+ actual = next(iter(self.run_trigger(trigger)))
+
+ assert isinstance(actual, TriggerEvent)
+ assert actual.payload["status"] == "failure"
+ assert actual.payload["message"] == ""
+
+ def test_run_when_response_is_bytes(self):
+ content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ base64_encoded_content = b64encode(content).decode(locale.getpreferredencoding())
+ response = mock_response(200, content)
+
+ with self.patch_hook_and_request_adapter(response):
+ url = (
+ "https://graph.microsoft.com/v1.0/me/drive/items/1b30fecf-4330-4899-b249-104c2afaf9ed/content"
+ )
+ trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api")
+ actual = next(iter(self.run_trigger(trigger)))
+
+ assert isinstance(actual, TriggerEvent)
+ assert actual.payload["status"] == "success"
+ assert actual.payload["type"] == "builtins.bytes"
+ assert isinstance(actual.payload["response"], str)
+ assert actual.payload["response"] == base64_encoded_content
+
+ def test_serialize(self):
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_airflow_connection,
+ ):
+ url = "https://graph.microsoft.com/v1.0/me/drive/items"
+ trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api")
+
+ actual = trigger.serialize()
+
+ assert isinstance(actual, tuple)
+ assert actual[0] == "airflow.providers.microsoft.azure.triggers.msgraph.MSGraphTrigger"
+ assert actual[1] == {
+ "url": "https://graph.microsoft.com/v1.0/me/drive/items",
+ "path_parameters": None,
+ "url_template": None,
+ "method": "GET",
+ "query_parameters": None,
+ "headers": None,
+ "data": None,
+ "response_type": "bytes",
+ "conn_id": "msgraph_api",
+ "timeout": None,
+ "proxies": None,
+ "api_version": "v1.0",
+ "serializer": "airflow.providers.microsoft.azure.triggers.msgraph.ResponseSerializer",
+ }
+
+ def test_template_fields(self):
+ trigger = MSGraphTrigger("users/delta", response_type="bytes", conn_id="msgraph_api")
+
+ for template_field in MSGraphTrigger.template_fields:
+ getattr(trigger, template_field)
+
+
+class TestResponseHandler:
+ def test_handle_response_async(self):
+ users = load_json("resources", "users.json")
+ response = mock_json_response(200, users)
+
+ actual = asyncio.run(
+ CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async(
+ response, None
+ )
+ )
+
+ assert isinstance(actual, dict)
+ assert actual == users
+
+
+class TestResponseSerializer:
+ def test_serialize_when_bytes_then_base64_encoded(self):
+ response = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ content = b64encode(response).decode(locale.getpreferredencoding())
+
+ actual = ResponseSerializer().serialize(response)
+
+ assert isinstance(actual, str)
+ assert actual == content
+
+ def test_serialize_when_dict_with_uuid_datatime_and_pendulum_then_json(self):
+ id = uuid4()
+ response = {
+ "id": id,
+ "creationDate": datetime(2024, 2, 5),
+ "modificationTime": pendulum.datetime(2024, 2, 5),
+ }
+
+ actual = ResponseSerializer().serialize(response)
+
+ assert isinstance(actual, str)
+ assert (
+ actual
+ == f'{{"id": "{id}", "creationDate": "2024-02-05T00:00:00", "modificationTime": "2024-02-05T00:00:00+00:00"}}'
+ )
+
+ def test_deserialize_when_json(self):
+ response = load_file("resources", "users.json")
+
+ actual = ResponseSerializer().deserialize(response)
+
+ assert isinstance(actual, dict)
+ assert actual == load_json("resources", "users.json")
+
+ def test_deserialize_when_base64_encoded_string(self):
+ content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ response = b64encode(content).decode(locale.getpreferredencoding())
+
+ actual = ResponseSerializer().deserialize(response)
+
+ assert actual == response
+ assert b64decode(actual) == content
diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py
index bcf5aa65fe6eb..78d8748a89e04 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -17,14 +17,22 @@
from __future__ import annotations
+import json
import random
import string
-from typing import TypeVar
+from os.path import dirname, join
+from typing import TYPE_CHECKING, Any, Iterable, TypeVar
+from unittest.mock import MagicMock
import pytest
+from httpx import Response
+from msgraph_core import APIVersion
from airflow.models import Connection
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
T = TypeVar("T", dict, str, Connection)
@@ -68,3 +76,100 @@ def wrapper(*conns: T):
def mocked_connection(request, create_mock_connection):
"""Helper indirect fixture for create test connection."""
return create_mock_connection(request.param)
+
+
+def mock_connection(schema: str | None = None, host: str | None = None) -> Connection:
+ connection = MagicMock(spec=Connection)
+ connection.schema = schema
+ connection.host = host
+ return connection
+
+
+def mock_json_response(status_code, *contents) -> Response:
+ response = MagicMock(spec=Response)
+ response.status_code = status_code
+ if contents:
+ contents = list(contents)
+ response.json.side_effect = lambda: contents.pop(0)
+ else:
+ response.json.return_value = None
+ return response
+
+
+def mock_response(status_code, content: Any = None) -> Response:
+ response = MagicMock(spec=Response)
+ response.status_code = status_code
+ response.content = content
+ return response
+
+
+def mock_context(task):
+ from datetime import datetime
+
+ from airflow.models import TaskInstance
+ from airflow.utils.session import NEW_SESSION
+ from airflow.utils.state import TaskInstanceState
+ from airflow.utils.xcom import XCOM_RETURN_KEY
+
+ class MockedTaskInstance(TaskInstance):
+ def __init__(self):
+ super().__init__(task=task, run_id="run_id", state=TaskInstanceState.RUNNING)
+ self.values = {}
+
+ def xcom_pull(
+ self,
+ task_ids: Iterable[str] | str | None = None,
+ dag_id: str | None = None,
+ key: str = XCOM_RETURN_KEY,
+ include_prior_dates: bool = False,
+ session: Session = NEW_SESSION,
+ *,
+ map_indexes: Iterable[int] | int | None = None,
+ default: Any | None = None,
+ ) -> Any:
+ self.task_id = task_ids
+ self.dag_id = dag_id
+ return self.values.get(f"{task_ids}_{dag_id}_{key}")
+
+ def xcom_push(
+ self,
+ key: str,
+ value: Any,
+ execution_date: datetime | None = None,
+ session: Session = NEW_SESSION,
+ ) -> None:
+ self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+
+ return {"ti": MockedTaskInstance()}
+
+
+def load_json(*locations: Iterable[str]):
+ with open(join(dirname(__file__), "azure", join(*locations)), encoding="utf-8") as file:
+ return json.load(file)
+
+
+def load_file(*locations: Iterable[str], mode="r", encoding="utf-8"):
+ with open(join(dirname(__file__), "azure", join(*locations)), mode=mode, encoding=encoding) as file:
+ return file.read()
+
+
+def get_airflow_connection(
+ conn_id: str,
+ login: str = "client_id",
+ password: str = "client_secret",
+ tenant_id: str = "tenant-id",
+ proxies: (dict, None) = None,
+ api_version: APIVersion = APIVersion.v1,
+):
+ from airflow.models import Connection
+
+ return Connection(
+ schema="https",
+ conn_id=conn_id,
+ conn_type="http",
+ host="graph.microsoft.com",
+ port="80",
+ login=login,
+ password=password,
+ extra={"tenant_id": tenant_id, "api_version": api_version.value, "proxies": proxies or {}},
+ )
diff --git a/tests/system/providers/microsoft/azure/example_msgraph.py b/tests/system/providers/microsoft/azure/example_msgraph.py
new file mode 100644
index 0000000000000..5ff7ba6f88835
--- /dev/null
+++ b/tests/system/providers/microsoft/azure/example_msgraph.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
+
+DAG_ID = "example_sharepoint_site"
+
+with models.DAG(
+ DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ schedule=None,
+ tags=["example"],
+) as dag:
+ # [START howto_operator_graph_site]
+ site_task = MSGraphAsyncOperator(
+ task_id="news_site",
+ conn_id="msgraph_api",
+ url="sites/850v1v.sharepoint.com:/sites/news",
+ result_processor=lambda context, response: response["id"].split(",")[1], # only keep site_id
+ )
+ # [END howto_operator_graph_site]
+
+ # [START howto_operator_graph_site_pages]
+ site_pages_task = MSGraphAsyncOperator(
+ task_id="news_pages",
+ conn_id="msgraph_api",
+ api_version="beta",
+ url=("sites/%s/pages" % "{{ ti.xcom_pull(task_ids='news_site') }}"), # noqa: UP031
+ )
+ # [END howto_operator_graph_site_pages]
+
+ site_task >> site_pages_task
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/microsoft/azure/example_powerbi.py b/tests/system/providers/microsoft/azure/example_powerbi.py
new file mode 100644
index 0000000000000..cbee9a62af0c4
--- /dev/null
+++ b/tests/system/providers/microsoft/azure/example_powerbi.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
+from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
+
+DAG_ID = "example_powerbi"
+
+with models.DAG(
+ DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ schedule=None,
+ tags=["example"],
+) as dag:
+ # [START howto_operator_powerbi_workspaces]
+ workspaces_task = MSGraphAsyncOperator(
+ task_id="workspaces",
+ conn_id="powerbi",
+ url="myorg/admin/workspaces/modified",
+ result_processor=lambda context, response: list(map(lambda workspace: workspace["id"], response)),
+ )
+ # [END howto_operator_powerbi_workspaces]
+
+ # [START howto_operator_powerbi_workspaces_info]
+ workspaces_info_task = MSGraphAsyncOperator(
+ task_id="get_workspace_info",
+ conn_id="powerbi",
+ url="myorg/admin/workspaces/getInfo",
+ method="POST",
+ query_parameters={
+ "lineage": True,
+ "datasourceDetails": True,
+ "datasetSchema": True,
+ "datasetExpressions": True,
+ "getArtifactUsers": True,
+ },
+ data={"workspaces": workspaces_task.output},
+ result_processor=lambda context, response: {"scanId": response["id"]},
+ )
+ # [END howto_operator_powerbi_workspaces_info]
+
+ # [START howto_sensor_powerbi_scan_status]
+ check_workspace_status_task = MSGraphSensor.partial(
+ task_id="check_workspaces_status",
+ conn_id="powerbi_api",
+ url="myorg/admin/workspaces/scanStatus/{scanId}",
+ timeout=350.0,
+ ).expand(path_parameters=workspaces_info_task.output)
+ # [END howto_sensor_powerbi_scan_status]
+
+ workspaces_task >> workspaces_info_task >> check_workspace_status_task
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)