From e1a2beb0b9883f64c8655958895ed82270731fb8 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 21 Sep 2021 11:47:00 -0400 Subject: [PATCH 1/2] Add two new Snowflake operators and update old ones Add SnowflakeThresholdCheckOperator and BranchSnowflakeOperator based on corresponding SQL operators. These operators round out the functionality provided by SQL operators for Snowflake. Additionally, the SnowflakeIntervalCheckOperator parameter list is updated. --- .../snowflake/operators/snowflake.py | 199 +++++++++++++++++- 1 file changed, 196 insertions(+), 3 deletions(-) diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 290c323119ab7..0218b50f4c2c4 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -15,10 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional, SupportsAbs +from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union from airflow.models import BaseOperator -from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator +from airflow.operators.sql import ( + BranchSQLOperator, + SQLCheckOperator, + SQLIntervalCheckOperator, + SQLThresholdCheckOperator, + SQLValueCheckOperator, +) from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -332,6 +338,19 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator): :param days_back: number of days between ds and the ds we want to check against. Defaults to 7 days :type days_back: int + :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds' + :type date_filter_column: Optional[str] + :param ratio_formula: which formula to use to compute the ratio between + the two metrics. Assuming cur is the metric of today and ref is + the metric to today - days_back. + + max_over_min: computes max(cur, ref) / min(cur, ref) + relative_diff: computes abs(cur-ref) / ref + + Default: 'max_over_min' + :type ratio_formula: str + :param ignore_zero: whether we should ignore zero metrics + :type ignore_zero: bool :param metrics_thresholds: a dictionary of ratios indexed by metrics, for example 'COUNT(*)': 1.5 would require a 50 percent or less difference between the current day, and the prior days_back. @@ -373,9 +392,11 @@ def __init__( self, *, table: str, - metrics_thresholds: dict, + metrics_thresholds: Dict[str, int], date_filter_column: str = 'ds', days_back: SupportsAbs[int] = -7, + ratio_formula: Optional[str] = 'max_over_min', + ignore_zero: bool = True, snowflake_conn_id: str = 'snowflake_default', parameters: Optional[dict] = None, autocommit: bool = True, @@ -409,3 +430,175 @@ def __init__( def get_db_hook(self) -> SnowflakeHook: return get_db_hook(self) + + +class SnowflakeThresholdCheckOperator(SQLThresholdCheckOperator): + """ + Performs a value check using sql code against a minimum threshold + and a maximum threshold. Thresholds can be in the form of a numeric + value OR a sql statement that results a numeric. + + :param sql: the sql to be executed. (templated) + :type sql: str + :param snowflake_conn_id: Reference to + :ref:`Snowflake connection id` + :type snowflake_conn_id: str + :param min_threshold: numerical value or min threshold sql to be executed (templated) + :type min_threshold: numeric or str + :param max_threshold: numerical value or max threshold sql to be executed (templated) + :type max_threshold: numeric or str + :param autocommit: if True, each command is automatically committed. + (default value: True) + :type autocommit: bool + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: name of database (will overwrite database defined + in connection) + :type database: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + def __init__( + self, + *, + sql: str, + min_threshold: Any, + max_threshold: Any, + snowflake_conn_id: str = 'snowflake_default', + parameters: Optional[dict] = None, + autocommit: bool = True, + do_xcom_push: bool = True, + warehouse: Optional[str] = None, + database: Optional[str] = None, + role: Optional[str] = None, + schema: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + **kwargs, + ) -> None: + super().__init__( + sql=sql, + min_threshold=min_threshold, + max_threshold=max_threshold, + **kwargs, + ) + + self.snowflake_conn_id = snowflake_conn_id + self.autocommit = autocommit + self.do_xcom_push = do_xcom_push + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.role = role + self.schema = schema + self.authenticator = authenticator + self.session_parameters = session_parameters + self.query_ids = [] + + def get_db_hook(self) -> SnowflakeHook: + return get_db_hook(self) + + +class BranchSnowflakeOperator(BranchSQLOperator): + """ + Executes sql code in a specific database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement or reference to a template file. + Template reference are recognized by str ending in '.sql'. + Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) + or string (true/y/yes/1/on/false/n/no/0/off). + :param follow_task_ids_if_true: task id or task ids to follow if query return true + :type follow_task_ids_if_true: str or list + :param follow_task_ids_if_false: task id or task ids to follow if query return true + :type follow_task_ids_if_false: str or list + :param snowflake_conn_id: Reference to + :ref:`Snowflake connection id` + :type snowflake_conn_id: str + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: mapping or iterable + :param autocommit: if True, each command is automatically committed. + (default value: True) + :type autocommit: bool + :param warehouse: name of warehouse (will overwrite any warehouse + defined in the connection's extra JSON) + :type warehouse: str + :param database: name of database (will overwrite database defined + in connection) + :type database: str + :param schema: name of schema (will overwrite schema defined in + connection) + :type schema: str + :param role: name of role (will overwrite any role defined in + connection's extra JSON) + :type role: str + :param authenticator: authenticator for Snowflake. + 'snowflake' (default) to use the internal Snowflake authenticator + 'externalbrowser' to authenticate using your web browser and + Okta, ADFS or any other SAML 2.0-compliant identify provider + (IdP) that has been defined for your account + 'https://.okta.com' to authenticate + through native Okta. + :type authenticator: str + :param session_parameters: You can set session-level parameters at + the time you connect to Snowflake + :type session_parameters: dict + """ + + def __init__( + self, + *, + sql: str, + follow_task_ids_if_true: List[str], + follow_task_ids_if_false: List[str], + snowflake_conn_id: str = 'snowflake_default', + autocommit: bool = True, + do_xcom_push: bool = True, + warehouse: Optional[str] = None, + database: Optional[str] = None, + role: Optional[str] = None, + schema: Optional[str] = None, + authenticator: Optional[str] = None, + session_parameters: Optional[dict] = None, + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: + super().__init__( + sql=sql, + follow_task_ids_if_true=follow_task_ids_if_true, + follow_task_ids_if_false=follow_task_ids_if_false, + parameters=parameters, + **kwargs, + ) + + self.snowflake_conn_id = snowflake_conn_id + self.autocommit = autocommit + self.do_xcom_push = do_xcom_push + self.parameters = parameters + self.warehouse = warehouse + self.database = database + self.role = role + self.schema = schema + self.authenticator = authenticator + self.session_parameters = session_parameters + self.query_ids = [] + + def get_db_hook(self) -> SnowflakeHook: + return get_db_hook(self) From fe30abd5cc074585ea88ec45aead77c1e50cbee5 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 21 Sep 2021 14:28:45 -0400 Subject: [PATCH 2/2] Add tests for new Snowflake Operators --- .../snowflake/operators/test_snowflake.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 8cdcc0936d898..8ee3f093f81be 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -23,9 +23,11 @@ from airflow.models.dag import DAG from airflow.providers.snowflake.operators.snowflake import ( + BranchSnowflakeOperator, SnowflakeCheckOperator, SnowflakeIntervalCheckOperator, SnowflakeOperator, + SnowflakeThresholdCheckOperator, SnowflakeValueCheckOperator, ) from airflow.utils import timezone @@ -57,12 +59,38 @@ def test_snowflake_operator(self, mock_get_db_hook): operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) +@pytest.mark.parametrize( + "operator_class, kwargs", + [ + ( + BranchSnowflakeOperator, + dict(sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2"), + ) + ], +) +class TestBranchSnowflakeOperator: + @mock.patch("airflow.providers.snowflake.operators.snowflake.get_db_hook") + def test_get_db_hook( + self, + mock_get_db_hook, + operator_class, + kwargs, + ): + operator = operator_class(task_id='branch_snowflake', snowflake_conn_id='snowflake_default', **kwargs) + operator.get_db_hook() + mock_get_db_hook.assert_called_once() + + @pytest.mark.parametrize( "operator_class, kwargs", [ (SnowflakeCheckOperator, dict(sql='Select * from test_table')), (SnowflakeValueCheckOperator, dict(sql='Select * from test_table', pass_value=95)), (SnowflakeIntervalCheckOperator, dict(table='test-table-id', metrics_thresholds={'COUNT(*)': 1.5})), + ( + SnowflakeThresholdCheckOperator, + dict(sql='Select * from test_table', min_threshold=0, max_threshold=10), + ), ], ) class TestSnowflakeCheckOperators: