diff --git a/sdks/python/apache_beam/yaml/integration_tests.py b/sdks/python/apache_beam/yaml/integration_tests.py index f6c60ae5f121..9036e6a3d5c4 100644 --- a/sdks/python/apache_beam/yaml/integration_tests.py +++ b/sdks/python/apache_beam/yaml/integration_tests.py @@ -361,11 +361,11 @@ def temp_sqlserver_database(): class OracleTestContainer(DockerContainer): """ - OracleTestContainer is an updated version of OracleDBContainer that goes - ahead and sets the oracle password, waits for logs to establish that the - container is ready before calling get_exposed_port, and uses a more modern - oracle driver. - """ + OracleTestContainer is an updated version of OracleDBContainer that goes + ahead and sets the oracle password, waits for logs to establish that the + container is ready before calling get_exposed_port, and uses a more modern + oracle driver. + """ def __init__(self): super().__init__("gvenzl/oracle-xe:21-slim") self.with_env("ORACLE_PASSWORD", "oracle") @@ -483,6 +483,22 @@ def temp_pubsub_emulator(project_id="apache-beam-testing"): def replace_recursive(spec, vars): + """Recursively replaces string placeholders in a spec with values from vars. + + Traverses a nested structure (dicts, lists, or other types). If a string + is encountered and contains placeholders in the format '{key}', it attempts + to replace them using the `vars` dictionary. + + Args: + spec: The (potentially nested) structure to process. + vars: A dictionary of variable names to their replacement values. + + Returns: + The spec with placeholders replaced. + + Raises: + ValueError: If a string formatting error occurs. + """ if isinstance(spec, dict): return { key: replace_recursive(value, vars) @@ -503,6 +519,20 @@ def replace_recursive(spec, vars): def transform_types(spec): + """Recursively extracts all transform types from a pipeline specification. + + This generator function traverses a nested pipeline specification (likely + parsed from YAML). It identifies and yields the 'type' string for each + transform defined within the specification, including those within + 'composite' or 'chain' structures. + + Args: + spec (dict): A dictionary representing a pipeline or transform + specification. + + Yields: + str: The 'type' of each transform found in the specification. + """ if spec.get('type', None) in (None, 'composite', 'chain'): if 'source' in spec: yield from transform_types(spec['source']) @@ -515,8 +545,30 @@ def transform_types(spec): def provider_sets(spec, require_available=False): - """For transforms that are vended by multiple providers, yields all possible - combinations of providers to use. + """ + Generates all relevant combinations of providers for a given pipeline spec. + + This function analyzes a pipeline specification to identify transforms that + can be implemented by multiple underlying providers (e.g., a generic + transform vs. a SQL-backed one). It then yields different "provider sets," + each representing a unique combination of choices for these multi-provider + transforms. + + If no transforms have multiple available providers, it yields a single + provider set using the standard defaults. + + Args: + spec (dict): The pipeline specification, typically loaded from YAML. + require_available (bool): If True, raises an error if a provider + needed for a transform is not available. If False (default), + unavailable providers are skipped, potentially reducing the number + of yielded combinations. + + Yields: + tuple: A tuple where the first element is a string suffix uniquely + identifying the provider combination (e.g., "MyTransform_SqlProvider_0"), + and the second element is a dictionary mapping transform types to a list + containing the selected provider(s) for that combination. """ try: for p in spec['pipelines']: @@ -566,6 +618,39 @@ def filter_to_available(t, providers): def create_test_methods(spec): + """Dynamically creates test methods based on a YAML specification. + + This function takes a YAML specification (`spec`) which defines pipelines, + fixtures, and potentially options. It iterates through different + combinations of "providers" (which determine how YAML transforms are + implemented, e.g., using Python or SQL). + + For each combination of providers: + 1. It constructs a unique test method name (e.g., `test_only`). + 2. It defines a test method that: + a. Sets up any specified fixtures, making their values available as + variables. + b. Mocks the standard YAML providers to use the current combination + of providers for this test run. + c. For each pipeline defined in the `spec`: + i. Creates a `beam.Pipeline` instance with specified options. + ii. Expands the YAML pipeline definition using + `yaml_transform.expand_pipeline`, substituting any fixture + variables. + iii. Runs the Beam pipeline. + + The function yields tuples of (test_method_name, test_method_function), + which can then be used to populate a `unittest.TestCase` class. + + Args: + spec (dict): A dictionary parsed from a YAML test specification file. + It's expected to have keys like 'fixtures' (optional) and 'pipelines'. + + Yields: + tuple: A tuple containing: + - str: The generated name for the test method (e.g., "test_only"). + - function: The dynamically generated test method. + """ for suffix, providers in provider_sets(spec): def test(self, providers=providers): # default arg to capture loop value @@ -593,6 +678,23 @@ def test(self, providers=providers): # default arg to capture loop value def parse_test_files(filepattern): + """Parses YAML test files and dynamically creates test cases. + + This function iterates through all files matching the given glob pattern. + For each YAML file found, it: + 1. Reads the file content. + 2. Determines a test suite name based on the file name. + 3. Calls `create_test_methods` to generate test methods from the + YAML specification. + 4. Dynamically creates a new TestCase class (inheriting from + `unittest.TestCase`) and populates it with the generated test methods. + 5. Adds this newly created TestCase class to the global scope, making it + discoverable by the unittest framework. + + Args: + filepattern (str): A glob pattern specifying the YAML test files to parse. + For example, 'path/to/tests/*.yaml'. + """ for path in glob.glob(filepattern): with open(path) as fin: suite_name = os.path.splitext(os.path.basename(path))[0].title().replace(