From eed8c8d0b8f09f1ae744a7a64c88775040a15a71 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 8 Mar 2022 11:20:00 +0100 Subject: [PATCH 1/2] Add new options to DatabricksCopyIntoOperator This includes: * `encryption` - to specify encryption options for a given location * `credential` - to specify authentication options for a given location * `validate` - to control validation of schema & data also, make `files` templatized --- .../databricks/operators/databricks_sql.py | 67 ++++++++++++--- .../operators/copy_into.rst | 6 ++ .../operators/test_databricks_sql.py | 86 +++++++++++++++++++ 3 files changed, 145 insertions(+), 14 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 33652f21ce1ea..2ae1a2b1a8d4a 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -48,18 +48,18 @@ class DatabricksSqlOperator(BaseOperator): :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must be provided as described above. :param sql: the SQL code to be executed as a single string, or - a list of str (sql statements), or a reference to a template file. + a list of str (sql statements), or a reference to a template file. (templated) Template references are recognized by str ending in '.sql' :param parameters: (optional) the parameters to render the SQL query with. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. - :param output_path: optional string specifying the file to which write selected data. + :param output_path: optional string specifying the file to which write selected data. (templated) :param output_format: format of output data if ``output_path` is specified. Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``. :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data. """ - template_fields: Sequence[str] = ('sql',) + template_fields: Sequence[str] = ('sql', '_output_path') template_ext: Sequence[str] = ('.sql',) template_fields_renderers = {'sql': 'sql'} @@ -152,8 +152,8 @@ class DatabricksCopyIntoOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DatabricksSqlCopyIntoOperator` - :param table_name: Required name of the table. - :param file_location: Required location of files to import. + :param table_name: Required name of the table. (templated) + :param file_location: Required location of files to import. (templated) :param file_format: Required file format. Supported formats are ``CSV``, ``JSON``, ``AVRO``, ``ORC``, ``PARQUET``, ``TEXT``, ``BINARYFILE``. :param databricks_conn_id: Reference to @@ -163,18 +163,23 @@ class DatabricksCopyIntoOperator(BaseOperator): or ``sql_endpoint_name`` must be specified. :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must be provided as described above. - :param files: optional list of files to import. Can't be specified together with ``pattern``. + :param files: optional list of files to import. Can't be specified together with ``pattern``. (templated) :param pattern: optional regex string to match file names to import. Can't be specified together with ``files``. :param expression_list: optional string that will be used in the ``SELECT`` expression. + :param credential: optional credential configuration for authentication against a specified location. + :param encryption: optional encryption configuration for a specified location. :param format_options: optional dictionary with options specific for a given file format. :param force_copy: optional bool to control forcing of data import (could be also specified in ``copy_options``). + :param validate: optional configuration for schema & data validation. ``True`` forces validation + of all rows, integer number - validate only N first rows :param copy_options: optional dictionary of copy options. Right now only ``force`` option is supported. """ template_fields: Sequence[str] = ( '_file_location', + '_files', '_table_name', ) @@ -191,9 +196,12 @@ def __init__( files: Optional[List[str]] = None, pattern: Optional[str] = None, expression_list: Optional[str] = None, + credential: Optional[Dict[str, str]] = None, + encryption: Optional[Dict[str, str]] = None, format_options: Optional[Dict[str, str]] = None, force_copy: Optional[bool] = None, copy_options: Optional[Dict[str, str]] = None, + validate: Optional[Union[bool, int]] = None, **kwargs, ) -> None: """Creates a new ``DatabricksSqlOperator``.""" @@ -216,8 +224,11 @@ def __init__( self._table_name = table_name self._file_location = file_location self._expression_list = expression_list + self._credential = credential + self._encryption = encryption self._format_options = format_options self._copy_options = copy_options or {} + self._validate = validate if force_copy is not None: self._copy_options["force"] = 'true' if force_copy else 'false' @@ -231,18 +242,33 @@ def _get_hook(self) -> DatabricksSqlHook: @staticmethod def _generate_options( - name: str, escaper: ParamEscaper, opts: Optional[Dict[str, str]] = None - ) -> Optional[str]: + name: str, + escaper: ParamEscaper, + opts: Optional[Dict[str, str]] = None, + escape_key: bool = True, + ) -> str: formatted_opts = "" if opts is not None and len(opts) > 0: - pairs = [f"{escaper.escape_item(k)} = {escaper.escape_item(v)}" for k, v in opts.items()] - formatted_opts = f"{name} ({', '.join(pairs)})\n" + pairs = [ + f"{escaper.escape_item(k) if escape_key else k} = {escaper.escape_item(v)}" + for k, v in opts.items() + ] + formatted_opts = f"{name} ({', '.join(pairs)})" return formatted_opts def _create_sql_query(self) -> str: escaper = ParamEscaper() - location = escaper.escape_item(self._file_location) + maybe_with = "" + if self._encryption is not None or self._credential is not None: + maybe_encryption = "" + if self._encryption is not None: + maybe_encryption = self._generate_options("ENCRYPTION", escaper, self._encryption, False) + maybe_credential = "" + if self._credential is not None: + maybe_credential = self._generate_options("CREDENTIAL", escaper, self._credential, False) + maybe_with = f" WITH ({maybe_credential} {maybe_encryption})" + location = escaper.escape_item(self._file_location) + maybe_with if self._expression_list is not None: location = f"(SELECT {self._expression_list} FROM {location})" files_or_pattern = "" @@ -250,13 +276,26 @@ def _create_sql_query(self) -> str: files_or_pattern = f"PATTERN = {escaper.escape_item(self._pattern)}\n" elif self._files is not None: files_or_pattern = f"FILES = {escaper.escape_item(self._files)}\n" - format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options) - copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options) + format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options) + "\n" + copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options) + "\n" + validation = "" + if self._validate is not None: + if isinstance(self._validate, bool): + if self._validate: + validation = "VALIDATE ALL\n" + elif isinstance(self._validate, int): + if self._validate < 0: + raise AirflowException( + "Number of rows for validation should be positive, got: " + str(self._validate) + ) + validation = f"VALIDATE {self._validate} ROWS\n" + else: + raise AirflowException("Incorrect data type for validate parameter: " + type(self._validate)) # TODO: think on how to make sure that table_name and expression_list aren't used for SQL injection sql = f"""COPY INTO {self._table_name} FROM {location} FILEFORMAT = {self._file_format} -{files_or_pattern}{format_options}{copy_options} +{validation}{files_or_pattern}{format_options}{copy_options} """ return sql.strip() diff --git a/docs/apache-airflow-providers-databricks/operators/copy_into.rst b/docs/apache-airflow-providers-databricks/operators/copy_into.rst index 58c97be7852a3..c2db25992fb8b 100644 --- a/docs/apache-airflow-providers-databricks/operators/copy_into.rst +++ b/docs/apache-airflow-providers-databricks/operators/copy_into.rst @@ -55,12 +55,18 @@ Operator loads data from a specified location into a table using a configured en - optional regex string to match file names to import. Can't be specified together with ``files``. * - expression_list: Optional[str] - optional string that will be used in the ``SELECT`` expression. + * - credential: Optional[Dict[str, str]] + - optional credential configuration for authentication against a specified location + * - encryption: Optional[Dict[str, str]] + - optional encryption configuration for a specified location * - format_options: Optional[Dict[str, str]] - optional dictionary with options specific for a given file format. * - force_copy: Optional[bool] - optional bool to control forcing of data import (could be also specified in ``copy_options``). * - copy_options: Optional[Dict[str, str]] - optional dictionary of copy options. Right now only ``force`` option is supported. + * - validate: Optional[Union[bool, int]] + - optional validation configuration. ``True`` forces validation of all rows, positive number - only N first rows. (requires Preview channel) Examples -------- diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index bebf50bf97481..6b9fb43701ae3 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -123,6 +123,92 @@ def test_copy_with_expression(self): PATTERN = 'folder1/file_[a-g].csv' FORMAT_OPTIONS ('header' = 'true') COPY_OPTIONS ('force' = 'true') +""".strip() + ) + + def test_copy_with_credential(self): + expression = "col1, col2" + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='CSV', + table_name='test', + task_id=TASK_ID, + expression_list=expression, + credential={'AZURE_SAS_TOKEN': 'abc'}, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test +FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') )) +FILEFORMAT = CSV +""".strip() + ) + + def test_copy_with_encryption(self): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='CSV', + table_name='test', + task_id=TASK_ID, + encryption={'TYPE': 'AWS_SSE_C', 'MASTER_KEY': 'abc'}, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test +FROM '{COPY_FILE_LOCATION}' WITH ( ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) +FILEFORMAT = CSV +""".strip() + ) + + def test_copy_with_encryption_and_credential(self): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='CSV', + table_name='test', + task_id=TASK_ID, + encryption={'TYPE': 'AWS_SSE_C', 'MASTER_KEY': 'abc'}, + credential={'AZURE_SAS_TOKEN': 'abc'}, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test +FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """ + """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) +FILEFORMAT = CSV +""".strip() + ) + + def test_copy_with_validate_all(self): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='JSON', + table_name='test', + task_id=TASK_ID, + validate=True, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test +FROM '{COPY_FILE_LOCATION}' +FILEFORMAT = JSON +VALIDATE ALL +""".strip() + ) + + def test_copy_with_validate_N_rows(self): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='JSON', + table_name='test', + task_id=TASK_ID, + validate=10, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test +FROM '{COPY_FILE_LOCATION}' +FILEFORMAT = JSON +VALIDATE 10 ROWS """.strip() ) From c3a49f927a887621c2e06e358c90760a45f956ff Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 8 Mar 2022 13:12:44 +0100 Subject: [PATCH 2/2] Add a link to documentation --- airflow/providers/databricks/operators/databricks_sql.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 2ae1a2b1a8d4a..48a6713517b46 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -146,7 +146,9 @@ def execute(self, context: 'Context') -> Any: class DatabricksCopyIntoOperator(BaseOperator): """ - Executes COPY INTO command in a Databricks SQL endpoint or a Databricks cluster + Executes COPY INTO command in a Databricks SQL endpoint or a Databricks cluster. + COPY INTO command is constructed from individual pieces, that are described in + `documentation `_. .. seealso:: For more information on how to use this operator, take a look at the guide: