From de6a4e0e68f5464ae2a921f8501dce2440f52beb Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Sat, 26 Aug 2023 16:54:34 +0300 Subject: [PATCH 1/2] Implement validation of missing template fields --- .pre-commit-config.yaml | 25 ++ STATIC_CODE_CHECKS.rst | 2 + .../src/airflow_breeze/pre_commit_ids.py | 1 + docs/apache-airflow/howto/custom-operator.rst | 82 ++++++ .../pre_commit_validate_operators_init.py | 236 ++++++++++++++++++ 5 files changed, 346 insertions(+) create mode 100755 scripts/ci/pre_commit/pre_commit_validate_operators_init.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 676e8b0e4f0a3..cf1f8218cfabd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -293,6 +293,31 @@ repos: # changes quickly - especially when we want the early modifications from the first local group # to be applied before the non-local pre-commits are run hooks: + - id: validate-operators-init + name: Prevent templated field logic checks in operators' __init__ + language: python + entry: ./scripts/ci/pre_commit/pre_commit_validate_operators_init.py + pass_filenames: true + files: ^airflow/providers/.*/(operators|transfers|sensors)/.*\.py$ + additional_dependencies: [ 'rich>=12.4.4' ] + # TODO: Handle the provider-specific exclusions and remove them from the list, see: + # https://github.com/apache/airflow/issues/36484 + exclude: | + (?x)^( + ^.*__init__\.py$| + ^airflow\/providers\/google\/cloud\/operators\/bigquery\.py$| + ^airflow\/providers\/amazon\/aws\/transfers\/gcs_to_s3\.py$| + ^airflow\/providers\/databricks\/operators\/databricks\.py$| + ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service\.py$| + ^airflow\/providers\/google\/cloud\/transfers\/bigquery_to_mysql\.py$| + ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$| + ^airflow\/providers\/amazon\/aws\/transfers\/redshift_to_s3\.py$| + ^airflow\/providers\/google\/cloud\/operators\/compute\.py$| + ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/custom_job\.py$| + ^airflow\/providers\/cncf\/kubernetes\/operators\/pod\.py$| + ^airflow\/providers\/amazon\/aws\/operators\/emr\.py$| + ^airflow\/providers\/amazon\/aws\/operators\/eks\.py$ + )$ - id: ruff name: Run 'ruff' for extremely fast Python linting description: "Run 'ruff' for extremely fast Python linting" diff --git a/STATIC_CODE_CHECKS.rst b/STATIC_CODE_CHECKS.rst index b03b14332120e..dd56039f2da3f 100644 --- a/STATIC_CODE_CHECKS.rst +++ b/STATIC_CODE_CHECKS.rst @@ -414,6 +414,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | update-version | Update version to the latest version in the documentation | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ +| validate-operators-init | Prevent templated field logic checks in operators' __init__ | | ++-----------------------------------------------------------+--------------------------------------------------------------+---------+ | yamllint | Check YAML files with yamllint | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py index e27068f5c988a..ad6614d4520a4 100644 --- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py +++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py @@ -133,5 +133,6 @@ "update-supported-versions", "update-vendored-in-k8s-json-schema", "update-version", + "validate-operators-init", "yamllint", ] diff --git a/docs/apache-airflow/howto/custom-operator.rst b/docs/apache-airflow/howto/custom-operator.rst index 012409ec74b84..ce32654a6b690 100644 --- a/docs/apache-airflow/howto/custom-operator.rst +++ b/docs/apache-airflow/howto/custom-operator.rst @@ -269,6 +269,88 @@ Currently available lexers: If you use a non-existing lexer then the value of the template field will be rendered as a pretty-printed object. +Limitations +^^^^^^^^^^^ +To prevent misuse, the following limitations must be observed when defining and assigning templated fields in the +operator's constructor (when such exists, otherwise - see below): + +1. Templated fields' corresponding parameters passed into the constructor must be named exactly +as the fields. The following example is invalid, as the parameter passed into the constructor is not the same as the +templated field: + +.. code-block:: python + + class HelloOperator(BaseOperator): + template_fields = "field_a" + + def __init__(field_a_id) -> None: # <- should be def __init__(field_a)-> None + self.field_a = field_a_id # <- should be self.field_a = field_a + +2. Templated fields' instance members must be assigned with their corresponding parameter from the constructor, +either by a direct assignment or by calling the parent's constructor (in which these fields are +defined as ``template_fields``) with explicit an assignment of the parameter. +The following example is invalid, as the instance member ``self.field_a`` is not assigned at all, despite being a +templated field: + +.. code-block:: python + + class HelloOperator(BaseOperator): + template_fields = ("field_a", "field_b") + + def __init__(field_a, field_b) -> None: + self.field_b = field_b + + +The following example is also invalid, as the instance member ``self.field_a`` of ``MyHelloOperator`` is initialized +implicitly as part of the ``kwargs`` passed to its parent constructor: + +.. code-block:: python + + class HelloOperator(BaseOperator): + template_fields = "field_a" + + def __init__(field_a) -> None: + self.field_a = field_a + + + class MyHelloOperator(HelloOperator): + template_fields = ("field_a", "field_b") + + def __init__(field_b, **kwargs) -> None: # <- should be def __init__(field_a, field_b, **kwargs) + super().__init__(**kwargs) # <- should be super().__init__(field_a=field_a, **kwargs) + self.field_b = field_b + +3. Applying actions on the parameter during the assignment in the constructor is not allowed. +Any action on the value should be applied in the ``execute()`` method. +Therefore, the following example is invalid: + +.. code-block:: python + + class HelloOperator(BaseOperator): + template_fields = "field_a" + + def __init__(field_a) -> None: + self.field_a = field_a.lower() # <- assignment should be only self.field_a = field_a + +When an operator inherits from a base operator and does not have a constructor defined on its own, the limitations above +do not apply. However, the templated fields must be set properly in the parent according to those limitations. + +Thus, the following example is valid: + +.. code-block:: python + + class HelloOperator(BaseOperator): + template_fields = "field_a" + + def __init__(field_a) -> None: + self.field_a = field_a + + + class MyHelloOperator(HelloOperator): + template_fields = "field_a" + +The limitations above are enforced by a pre-commit named 'validate-operators-init'. + Add template fields with subclassing ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ A common use case for creating a custom operator is for simply augmenting existing ``template_fields``. diff --git a/scripts/ci/pre_commit/pre_commit_validate_operators_init.py b/scripts/ci/pre_commit/pre_commit_validate_operators_init.py new file mode 100755 index 0000000000000..aab3a4d47bae4 --- /dev/null +++ b/scripts/ci/pre_commit/pre_commit_validate_operators_init.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import sys +from typing import Any + +from rich.console import Console + +console = Console(color_system="standard", width=200) +BASE_OPERATOR_CLASS_NAME = "BaseOperator" + + +def _is_operator(class_node: ast.ClassDef) -> bool: + """ + Check if a given class node is an operator, based of the string suffix of the base IDs + (ends with "BaseOperator"). + TODO: Enhance this function to work with nested inheritance trees through dynamic imports. + + :param class_node: The class node to check. + :return: True if the class definition is of an operator, False otherwise. + """ + for base in class_node.bases: + if isinstance(base, ast.Name) and base.id.endswith(BASE_OPERATOR_CLASS_NAME): + return True + return False + + +def _extract_template_fields(class_node: ast.ClassDef) -> list[str]: + """ + This method takes a class node as input and extracts the template fields from it. + Template fields are identified by an assignment statement where the target is a variable + named "template_fields" and the value is a tuple of constants. + + :param class_node: The class node representing the class for which template fields need to be extracted. + :return: A list of template fields extracted from the class node. + """ + for class_item in class_node.body: + if isinstance(class_item, ast.Assign): + for target in class_item.targets: + if ( + isinstance(target, ast.Name) + and target.id == "template_fields" + and isinstance(class_item.value, ast.Tuple) + ): + return [elt.value for elt in class_item.value.elts if isinstance(elt, ast.Constant)] + elif isinstance(class_item, ast.AnnAssign): + if ( + isinstance(class_item.target, ast.Name) + and class_item.target.id == "template_fields" + and isinstance(class_item.value, ast.Tuple) + ): + return [elt.value for elt in class_item.value.elts if isinstance(elt, ast.Constant)] + return [] + + +def _handle_parent_constructor_kwargs( + template_fields: list[str], + ctor_stmt: ast.stmt, + missing_assignments: list[str], + invalid_assignments: list[str], +) -> list[str]: + """ + This method checks if template fields are correctly assigned in a call to class parent's + constructor call. + It handles both the detection of missing assignments and invalid assignments. + It assumes that if the call is valid - the parent class will correctly assign the template + field. + TODO: Enhance this function to work with nested inheritance trees through dynamic imports. + + :param missing_assignments: List[str] - List of template fields that have not been assigned a value. + :param ctor_stmt: ast.Expr - AST node representing the constructor statement. + :param invalid_assignments: List[str] - List of template fields that have been assigned incorrectly. + :param template_fields: List[str] - List of template fields to be assigned. + + :return: List[str] - List of template fields that are still missing assignments. + """ + if isinstance(ctor_stmt, ast.Expr): + if ( + isinstance(ctor_stmt.value, ast.Call) + and isinstance(ctor_stmt.value.func, ast.Attribute) + and isinstance(ctor_stmt.value.func.value, ast.Call) + and isinstance(ctor_stmt.value.func.value.func, ast.Name) + and ctor_stmt.value.func.value.func.id == "super" + ): + for arg in ctor_stmt.value.keywords: + if arg.arg is not None and arg.arg in template_fields: + if not isinstance(arg.value, ast.Name) or arg.arg != arg.value.id: + invalid_assignments.append(arg.arg) + assigned_targets = [arg.arg for arg in ctor_stmt.value.keywords if arg.arg is not None] + return list(set(missing_assignments) - set(assigned_targets)) + return missing_assignments + + +def _handle_constructor_statement( + template_fields: list[str], + ctor_stmt: ast.stmt, + missing_assignments: list[str], + invalid_assignments: list[str], +) -> list[str]: + """ + This method handles a single constructor statement by doing the following actions: + 1. Removing assigned fields of template_fields from missing_assignments. + 2. Detecting invalid assignments of template fields and adding them to invalid_assignments. + + :param template_fields: Tuple of template fields. + :param ctor_stmt: Constructor statement (for example, self.field_name = param_name) + :param missing_assignments: List of missing assignments. + :param invalid_assignments: List of invalid assignments. + :return: List of missing assignments after handling the assigned targets. + """ + assigned_template_fields: list[str] = [] + if isinstance(ctor_stmt, ast.Assign): + if isinstance(ctor_stmt.targets[0], ast.Attribute): + for target in ctor_stmt.targets: + if isinstance(target, ast.Attribute) and target.attr in template_fields: + if isinstance(ctor_stmt.value, ast.BoolOp) and isinstance(ctor_stmt.value.op, ast.Or): + _handle_assigned_field( + assigned_template_fields, invalid_assignments, target, ctor_stmt.value.values[0] + ) + else: + _handle_assigned_field( + assigned_template_fields, invalid_assignments, target, ctor_stmt.value + ) + elif isinstance(ctor_stmt.targets[0], ast.Tuple) and isinstance(ctor_stmt.value, ast.Tuple): + for target, value in zip(ctor_stmt.targets[0].elts, ctor_stmt.value.elts): + if isinstance(target, ast.Attribute): + _handle_assigned_field(assigned_template_fields, invalid_assignments, target, value) + elif isinstance(ctor_stmt, ast.AnnAssign): + if isinstance(ctor_stmt.target, ast.Attribute) and ctor_stmt.target.attr in template_fields: + _handle_assigned_field( + assigned_template_fields, invalid_assignments, ctor_stmt.target, ctor_stmt.value + ) + return list(set(missing_assignments) - set(assigned_template_fields)) + + +def _handle_assigned_field( + assigned_template_fields: list[str], invalid_assignments: list[str], target: ast.Attribute, value: Any +) -> None: + """ + Handle an assigned field by its value. + + :param assigned_template_fields: A list to store the valid assigned fields. + :param invalid_assignments: A list to store the invalid assignments. + :param target: The target field. + :param value: The value of the field. + """ + if not isinstance(value, ast.Name): + invalid_assignments.append(target.attr) + else: + assigned_template_fields.append(target.attr) + + +def _check_constructor_template_fields(class_node: ast.ClassDef, template_fields: list[str]) -> int: + """ + This method checks a class's constructor for missing or invalid assignments of template fields. + When there isn't a constructor - it assumes that the template fields are defined in the parent's + constructor correctly. + TODO: Enhance this function to work with nested inheritance trees through dynamic imports. + + :param class_node: the AST node representing the class definition + :param template_fields: a tuple of template fields + :return: the number of invalid template fields found + """ + count = 0 + class_name = class_node.name + missing_assignments = template_fields.copy() + invalid_assignments: list[str] = [] + init_flag: bool = False + for class_item in class_node.body: + if isinstance(class_item, ast.FunctionDef) and class_item.name == "__init__": + init_flag = True + for ctor_stmt in class_item.body: + missing_assignments = _handle_parent_constructor_kwargs( + template_fields, ctor_stmt, missing_assignments, invalid_assignments + ) + missing_assignments = _handle_constructor_statement( + template_fields, ctor_stmt, missing_assignments, invalid_assignments + ) + + if init_flag and missing_assignments: + count += len(missing_assignments) + console.print( + f"{class_name}'s constructor lacks direct assignments for " + f"instance members corresponding to the following template fields " + f"(i.e., self.field_name = field_name or super.__init__(field_name=field_name, ...) ):" + ) + console.print(f"[red]{missing_assignments}[/red]") + + if invalid_assignments: + count += len(invalid_assignments) + console.print( + f"{class_name}'s constructor contains invalid assignments to the following instance " + f"members that should be corresponding to template fields " + f"(i.e., self.field_name = field_name):" + ) + console.print(f"[red]{[f'self.{entry}' for entry in invalid_assignments]}[/red]") + return count + + +def main(): + """ + Check missing or invalid template fields in constructors of providers' operators. + + :return: The total number of errors found. + """ + err = 0 + for path in sys.argv[1:]: + console.print(f"[yellow]{path}[/yellow]") + tree = ast.parse(open(path).read()) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and _is_operator(class_node=node): + template_fields = _extract_template_fields(node) or [] + err += _check_constructor_template_fields(node, template_fields) + return err + + +if __name__ == "__main__": + sys.exit(main()) From f51ebe175ead624ec73f71f5ce8be86177846662 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Sat, 20 Jan 2024 15:34:18 +0200 Subject: [PATCH 2/2] Update output static checks --- images/breeze/output_static-checks.svg | 74 ++++++++++++++------------ images/breeze/output_static-checks.txt | 2 +- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/images/breeze/output_static-checks.svg b/images/breeze/output_static-checks.svg index 68665eb5318cc..f05cd5435b353 100644 --- a/images/breeze/output_static-checks.svg +++ b/images/breeze/output_static-checks.svg @@ -1,4 +1,4 @@ - +