Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ class DatabricksSqlOperator(BaseOperator):
:param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must
be provided as described above.
:param sql: the SQL code to be executed as a single string, or
a list of str (sql statements), or a reference to a template file.
a list of str (sql statements), or a reference to a template file. (templated)
Template references are recognized by str ending in '.sql'
:param parameters: (optional) the parameters to render the SQL query with.
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
If not specified, it could be specified in the Databricks connection's extra parameters.
:param output_path: optional string specifying the file to which write selected data.
:param output_path: optional string specifying the file to which write selected data. (templated)
:param output_format: format of output data if ``output_path` is specified.
Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
:param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
"""

template_fields: Sequence[str] = ('sql',)
template_fields: Sequence[str] = ('sql', '_output_path')
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {'sql': 'sql'}

Expand Down Expand Up @@ -146,14 +146,16 @@ def execute(self, context: 'Context') -> Any:

class DatabricksCopyIntoOperator(BaseOperator):
"""
Executes COPY INTO command in a Databricks SQL endpoint or a Databricks cluster
Executes COPY INTO command in a Databricks SQL endpoint or a Databricks cluster.
COPY INTO command is constructed from individual pieces, that are described in
`documentation <https://docs.databricks.com/sql/language-manual/delta-copy-into.html>`_.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DatabricksSqlCopyIntoOperator`

:param table_name: Required name of the table.
:param file_location: Required location of files to import.
:param table_name: Required name of the table. (templated)
:param file_location: Required location of files to import. (templated)
:param file_format: Required file format. Supported formats are
``CSV``, ``JSON``, ``AVRO``, ``ORC``, ``PARQUET``, ``TEXT``, ``BINARYFILE``.
:param databricks_conn_id: Reference to
Expand All @@ -163,18 +165,23 @@ class DatabricksCopyIntoOperator(BaseOperator):
or ``sql_endpoint_name`` must be specified.
:param sql_endpoint_name: Optional name of Databricks SQL Endpoint.
If not specified, ``http_path`` must be provided as described above.
:param files: optional list of files to import. Can't be specified together with ``pattern``.
:param files: optional list of files to import. Can't be specified together with ``pattern``. (templated)
:param pattern: optional regex string to match file names to import.
Can't be specified together with ``files``.
:param expression_list: optional string that will be used in the ``SELECT`` expression.
:param credential: optional credential configuration for authentication against a specified location.
:param encryption: optional encryption configuration for a specified location.
:param format_options: optional dictionary with options specific for a given file format.
:param force_copy: optional bool to control forcing of data import
(could be also specified in ``copy_options``).
:param validate: optional configuration for schema & data validation. ``True`` forces validation
of all rows, integer number - validate only N first rows
:param copy_options: optional dictionary of copy options. Right now only ``force`` option is supported.
"""

template_fields: Sequence[str] = (
'_file_location',
'_files',
'_table_name',
)

Expand All @@ -191,9 +198,12 @@ def __init__(
files: Optional[List[str]] = None,
pattern: Optional[str] = None,
expression_list: Optional[str] = None,
credential: Optional[Dict[str, str]] = None,
encryption: Optional[Dict[str, str]] = None,
format_options: Optional[Dict[str, str]] = None,
force_copy: Optional[bool] = None,
copy_options: Optional[Dict[str, str]] = None,
validate: Optional[Union[bool, int]] = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSqlOperator``."""
Expand All @@ -216,8 +226,11 @@ def __init__(
self._table_name = table_name
self._file_location = file_location
self._expression_list = expression_list
self._credential = credential
self._encryption = encryption
self._format_options = format_options
self._copy_options = copy_options or {}
self._validate = validate
if force_copy is not None:
self._copy_options["force"] = 'true' if force_copy else 'false'

Expand All @@ -231,32 +244,60 @@ def _get_hook(self) -> DatabricksSqlHook:

@staticmethod
def _generate_options(
name: str, escaper: ParamEscaper, opts: Optional[Dict[str, str]] = None
) -> Optional[str]:
name: str,
escaper: ParamEscaper,
opts: Optional[Dict[str, str]] = None,
escape_key: bool = True,
) -> str:
formatted_opts = ""
if opts is not None and len(opts) > 0:
pairs = [f"{escaper.escape_item(k)} = {escaper.escape_item(v)}" for k, v in opts.items()]
formatted_opts = f"{name} ({', '.join(pairs)})\n"
pairs = [
f"{escaper.escape_item(k) if escape_key else k} = {escaper.escape_item(v)}"
for k, v in opts.items()
]
formatted_opts = f"{name} ({', '.join(pairs)})"

return formatted_opts

def _create_sql_query(self) -> str:
escaper = ParamEscaper()
location = escaper.escape_item(self._file_location)
maybe_with = ""
if self._encryption is not None or self._credential is not None:
maybe_encryption = ""
if self._encryption is not None:
maybe_encryption = self._generate_options("ENCRYPTION", escaper, self._encryption, False)
maybe_credential = ""
if self._credential is not None:
maybe_credential = self._generate_options("CREDENTIAL", escaper, self._credential, False)
maybe_with = f" WITH ({maybe_credential} {maybe_encryption})"
location = escaper.escape_item(self._file_location) + maybe_with
if self._expression_list is not None:
location = f"(SELECT {self._expression_list} FROM {location})"
files_or_pattern = ""
if self._pattern is not None:
files_or_pattern = f"PATTERN = {escaper.escape_item(self._pattern)}\n"
elif self._files is not None:
files_or_pattern = f"FILES = {escaper.escape_item(self._files)}\n"
format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options)
copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options)
format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options) + "\n"
copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options) + "\n"
validation = ""
if self._validate is not None:
if isinstance(self._validate, bool):
if self._validate:
validation = "VALIDATE ALL\n"
elif isinstance(self._validate, int):
if self._validate < 0:
raise AirflowException(
"Number of rows for validation should be positive, got: " + str(self._validate)
)
validation = f"VALIDATE {self._validate} ROWS\n"
else:
raise AirflowException("Incorrect data type for validate parameter: " + type(self._validate))
# TODO: think on how to make sure that table_name and expression_list aren't used for SQL injection
sql = f"""COPY INTO {self._table_name}
FROM {location}
FILEFORMAT = {self._file_format}
{files_or_pattern}{format_options}{copy_options}
{validation}{files_or_pattern}{format_options}{copy_options}
"""
return sql.strip()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,18 @@ Operator loads data from a specified location into a table using a configured en
- optional regex string to match file names to import. Can't be specified together with ``files``.
* - expression_list: Optional[str]
- optional string that will be used in the ``SELECT`` expression.
* - credential: Optional[Dict[str, str]]
- optional credential configuration for authentication against a specified location
* - encryption: Optional[Dict[str, str]]
- optional encryption configuration for a specified location
* - format_options: Optional[Dict[str, str]]
- optional dictionary with options specific for a given file format.
* - force_copy: Optional[bool]
- optional bool to control forcing of data import (could be also specified in ``copy_options``).
* - copy_options: Optional[Dict[str, str]]
- optional dictionary of copy options. Right now only ``force`` option is supported.
* - validate: Optional[Union[bool, int]]
- optional validation configuration. ``True`` forces validation of all rows, positive number - only N first rows. (requires Preview channel)

Examples
--------
Expand Down
86 changes: 86 additions & 0 deletions tests/providers/databricks/operators/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,92 @@ def test_copy_with_expression(self):
PATTERN = 'folder1/file_[a-g].csv'
FORMAT_OPTIONS ('header' = 'true')
COPY_OPTIONS ('force' = 'true')
""".strip()
)

def test_copy_with_credential(self):
expression = "col1, col2"
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format='CSV',
table_name='test',
task_id=TASK_ID,
expression_list=expression,
credential={'AZURE_SAS_TOKEN': 'abc'},
)
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ))
FILEFORMAT = CSV
""".strip()
)

def test_copy_with_encryption(self):
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format='CSV',
table_name='test',
task_id=TASK_ID,
encryption={'TYPE': 'AWS_SSE_C', 'MASTER_KEY': 'abc'},
)
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM '{COPY_FILE_LOCATION}' WITH ( ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
FILEFORMAT = CSV
""".strip()
)

def test_copy_with_encryption_and_credential(self):
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format='CSV',
table_name='test',
task_id=TASK_ID,
encryption={'TYPE': 'AWS_SSE_C', 'MASTER_KEY': 'abc'},
credential={'AZURE_SAS_TOKEN': 'abc'},
)
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """
"""ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
FILEFORMAT = CSV
""".strip()
)

def test_copy_with_validate_all(self):
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format='JSON',
table_name='test',
task_id=TASK_ID,
validate=True,
)
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM '{COPY_FILE_LOCATION}'
FILEFORMAT = JSON
VALIDATE ALL
""".strip()
)

def test_copy_with_validate_N_rows(self):
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format='JSON',
table_name='test',
task_id=TASK_ID,
validate=10,
)
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM '{COPY_FILE_LOCATION}'
FILEFORMAT = JSON
VALIDATE 10 ROWS
""".strip()
)

Expand Down