diff --git a/airflow/providers/microsoft/azure/operators/adls.py b/airflow/providers/microsoft/azure/operators/adls.py index 345336b2c4cb2..6afb495077963 100644 --- a/airflow/providers/microsoft/azure/operators/adls.py +++ b/airflow/providers/microsoft/azure/operators/adls.py @@ -16,22 +16,74 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import IO, TYPE_CHECKING, Any, AnyStr, Iterable, Sequence from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook, AzureDataLakeStorageV2Hook if TYPE_CHECKING: from airflow.utils.context import Context +DEFAULT_AZURE_DATA_LAKE_CONN_ID = "azure_data_lake_default" + + +class ADLSCreateObjectOperator(BaseOperator): + """ + Creates a new object from passed data to Azure Data Lake on specified file. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ADLSCreateObjectOperator` + + :param file_system_name: Name of the file system or instance of FileSystemProperties. + :param file_name: Name of the file which needs to be created in the file system. + :param data: The data that will be uploaded. + :param length: Size of the data in bytes (optional). + :param replace: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten. + :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. + """ + + template_fields: Sequence[str] = ("file_system_name", "file_name", "data") + ui_color = "#e4f0e8" + + def __init__( + self, + *, + file_system_name: str, + file_name: str, + data: bytes | str | Iterable[AnyStr] | IO[AnyStr], + length: int | None = None, + replace: bool = False, + azure_data_lake_conn_id: str = DEFAULT_AZURE_DATA_LAKE_CONN_ID, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.file_system_name = file_system_name + self.file_name = file_name + self.replace = replace + self.data = data # type: ignore[var-annotated] + self.length = length + self.azure_data_lake_conn_id = azure_data_lake_conn_id + + def execute(self, context: Context) -> dict[str, Any]: + self.log.debug("Uploading %s to %s", self.data, self.file_name) + hook = AzureDataLakeStorageV2Hook(adls_conn_id=self.azure_data_lake_conn_id) + return hook.create_file(file_system_name=self.file_system_name, file_name=self.file_name).upload_data( + data=self.data, length=self.length, overwrite=self.replace + ) + class ADLSDeleteOperator(BaseOperator): """ Delete files in the specified path. - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:ADLSDeleteOperator` + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ADLSDeleteOperator` :param path: A directory or file to remove :param recursive: Whether to loop into directories in the location and remove the files @@ -48,7 +100,7 @@ def __init__( path: str, recursive: bool = False, ignore_not_found: bool = True, - azure_data_lake_conn_id: str = "azure_data_lake_default", + azure_data_lake_conn_id: str = DEFAULT_AZURE_DATA_LAKE_CONN_ID, **kwargs, ) -> None: super().__init__(**kwargs) @@ -69,26 +121,19 @@ class ADLSListOperator(BaseOperator): This operator returns a python list with the names of files which can be used by `xcom` in the downstream tasks. - :param path: The Azure Data Lake path to find the objects. Supports glob - strings (templated) - :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ADLSListOperator` - **Example**: - The following Operator would list all the Parquet files from ``folder/output/`` - folder in the specified ADLS account :: - - adls_files = ADLSListOperator( - task_id="adls_files", - path="folder/output/*.parquet", - azure_data_lake_conn_id="azure_data_lake_default", - ) + :param path: The Azure Data Lake path to find the objects. Supports glob strings (templated) + :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. """ template_fields: Sequence[str] = ("path",) ui_color = "#901dd2" def __init__( - self, *, path: str, azure_data_lake_conn_id: str = "azure_data_lake_default", **kwargs + self, *, path: str, azure_data_lake_conn_id: str = DEFAULT_AZURE_DATA_LAKE_CONN_ID, **kwargs ) -> None: super().__init__(**kwargs) self.path = path diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/adls.rst b/docs/apache-airflow-providers-microsoft-azure/operators/adls.rst index 54578991458d0..24b6a3ac5603e 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/adls.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/adls.rst @@ -24,12 +24,29 @@ Prerequisite Tasks .. include:: /operators/_partials/prerequisite_tasks.rst +.. _howto/operator:ADLSCreateObjectOperator: + +ADLSCreateObjectOperator +---------------------------------- + +:class:`~airflow.providers.microsoft.azure.operators.adls.ADLSCreateObjectOperator` allows you to +upload data to Azure DataLake Storage + + +Below is an example of using this operator to upload data to ADL. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adls_create.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_adls_create] + :end-before: [END howto_operator_adls_create] + .. _howto/operator:ADLSDeleteOperator: ADLSDeleteOperator ---------------------------------- Use the -:class:`~airflow.providers.microsoft.azure.operators.adls_delete.ADLSDeleteOperator` to remove +:class:`~airflow.providers.microsoft.azure.operators.adls.ADLSDeleteOperator` to remove file(s) from Azure DataLake Storage @@ -41,6 +58,23 @@ Below is an example of using this operator to delete a file from ADL. :start-after: [START howto_operator_adls_delete] :end-before: [END howto_operator_adls_delete] +.. _howto/operator:ADLSListOperator: + +ADLSListOperator +---------------------------------- +Use the +:class:`~airflow.providers.microsoft.azure.operators.adls.ADLSListOperator` to list all +file(s) from Azure DataLake Storage + + +Below is an example of using this operator to list files from ADL. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adls_list.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_adls_list] + :end-before: [END howto_operator_adls_list] + Reference --------- diff --git a/tests/providers/microsoft/azure/operators/test_adls_create.py b/tests/providers/microsoft/azure/operators/test_adls_create.py new file mode 100644 index 0000000000000..90d14b229fc51 --- /dev/null +++ b/tests/providers/microsoft/azure/operators/test_adls_create.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock + +from airflow.providers.microsoft.azure.operators.adls import ADLSCreateObjectOperator + +TASK_ID = "test-adls-upload-operator" +FILE_SYSTEM_NAME = "Fabric" +REMOTE_PATH = "TEST-DIR" +DATA = json.dumps({"name": "David", "surname": "Blain", "gender": "M"}).encode("utf-8") + + +class TestADLSUploadOperator: + @mock.patch("airflow.providers.microsoft.azure.operators.adls.AzureDataLakeStorageV2Hook") + def test_execute_success_when_local_data(self, mock_hook): + operator = ADLSCreateObjectOperator( + task_id=TASK_ID, + file_system_name=FILE_SYSTEM_NAME, + file_name=REMOTE_PATH, + data=DATA, + replace=True, + ) + operator.execute(None) + data_lake_file_client_mock = mock_hook.return_value.create_file + data_lake_file_client_mock.assert_called_once_with( + file_system_name=FILE_SYSTEM_NAME, file_name=REMOTE_PATH + ) + upload_data_mock = data_lake_file_client_mock.return_value.upload_data + upload_data_mock.assert_called_once_with(data=DATA, length=None, overwrite=True) diff --git a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py index 50020b16920f9..1ef388eef7597 100644 --- a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py +++ b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py @@ -17,22 +17,17 @@ # under the License. from __future__ import annotations -import json from unittest import mock import pytest from airflow.exceptions import AirflowException -from airflow.providers.microsoft.azure.transfers.local_to_adls import ( - LocalFilesystemToADLSOperator, -) +from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalFilesystemToADLSOperator TASK_ID = "test-adls-upload-operator" -FILE_SYSTEM_NAME = "Fabric" LOCAL_PATH = "test/*" BAD_LOCAL_PATH = "test/**" REMOTE_PATH = "TEST-DIR" -DATA = json.dumps({"name": "David", "surname": "Blain", "gender": "M"}).encode("utf-8") class TestADLSUploadOperator: diff --git a/tests/system/providers/microsoft/azure/example_adls_create.py b/tests/system/providers/microsoft/azure/example_adls_create.py new file mode 100644 index 0000000000000..726e9eba76ae8 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_adls_create.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import models +from airflow.providers.microsoft.azure.operators.adls import ADLSCreateObjectOperator, ADLSDeleteOperator + +REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", "remote.txt") +DAG_ID = "example_adls_create" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + catchup=False, + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_adls_create] + upload_data = ADLSCreateObjectOperator( + task_id="upload_data", + file_system_name="Fabric", + file_name=REMOTE_FILE_PATH, + data="Hello world", + replace=True, + ) + # [END howto_operator_adls_create] + + delete_file = ADLSDeleteOperator(task_id="remove_task", path=REMOTE_FILE_PATH, recursive=True) + + upload_data >> delete_file + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/microsoft/azure/example_adls_list.py b/tests/system/providers/microsoft/azure/example_adls_list.py new file mode 100644 index 0000000000000..a6bd2d7bd6652 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_adls_list.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import models +from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator + +LOCAL_FILE_PATH = os.environ.get("LOCAL_FILE_PATH", "localfile.txt") +REMOTE_FILE_PATH = os.environ.get("REMOTE_LOCAL_PATH", "remote.txt") + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_adls_list" + +with models.DAG( + DAG_ID, + start_date=datetime(2021, 1, 1), + schedule=None, + tags=["example"], +) as dag: + # [START howto_operator_adls_list] + adls_files = ADLSListOperator( + task_id="adls_files", + path="folder/output/*.parquet", + azure_data_lake_conn_id="azure_data_lake_default", + ) + # [END howto_operator_adls_list] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)