Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 109 additions & 7 deletions sdks/python/apache_beam/yaml/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ def temp_sqlserver_database():

class OracleTestContainer(DockerContainer):
"""
OracleTestContainer is an updated version of OracleDBContainer that goes
ahead and sets the oracle password, waits for logs to establish that the
container is ready before calling get_exposed_port, and uses a more modern
oracle driver.
"""
OracleTestContainer is an updated version of OracleDBContainer that goes
ahead and sets the oracle password, waits for logs to establish that the
container is ready before calling get_exposed_port, and uses a more modern
oracle driver.
"""
def __init__(self):
super().__init__("gvenzl/oracle-xe:21-slim")
self.with_env("ORACLE_PASSWORD", "oracle")
Expand Down Expand Up @@ -483,6 +483,22 @@ def temp_pubsub_emulator(project_id="apache-beam-testing"):


def replace_recursive(spec, vars):
"""Recursively replaces string placeholders in a spec with values from vars.

Traverses a nested structure (dicts, lists, or other types). If a string
is encountered and contains placeholders in the format '{key}', it attempts
to replace them using the `vars` dictionary.

Args:
spec: The (potentially nested) structure to process.
vars: A dictionary of variable names to their replacement values.

Returns:
The spec with placeholders replaced.

Raises:
ValueError: If a string formatting error occurs.
"""
if isinstance(spec, dict):
return {
key: replace_recursive(value, vars)
Expand All @@ -503,6 +519,20 @@ def replace_recursive(spec, vars):


def transform_types(spec):
"""Recursively extracts all transform types from a pipeline specification.

This generator function traverses a nested pipeline specification (likely
parsed from YAML). It identifies and yields the 'type' string for each
transform defined within the specification, including those within
'composite' or 'chain' structures.

Args:
spec (dict): A dictionary representing a pipeline or transform
specification.

Yields:
str: The 'type' of each transform found in the specification.
"""
if spec.get('type', None) in (None, 'composite', 'chain'):
if 'source' in spec:
yield from transform_types(spec['source'])
Expand All @@ -515,8 +545,30 @@ def transform_types(spec):


def provider_sets(spec, require_available=False):
"""For transforms that are vended by multiple providers, yields all possible
combinations of providers to use.
"""
Generates all relevant combinations of providers for a given pipeline spec.

This function analyzes a pipeline specification to identify transforms that
can be implemented by multiple underlying providers (e.g., a generic
transform vs. a SQL-backed one). It then yields different "provider sets,"
each representing a unique combination of choices for these multi-provider
transforms.

If no transforms have multiple available providers, it yields a single
provider set using the standard defaults.

Args:
spec (dict): The pipeline specification, typically loaded from YAML.
require_available (bool): If True, raises an error if a provider
needed for a transform is not available. If False (default),
unavailable providers are skipped, potentially reducing the number
of yielded combinations.

Yields:
tuple: A tuple where the first element is a string suffix uniquely
identifying the provider combination (e.g., "MyTransform_SqlProvider_0"),
and the second element is a dictionary mapping transform types to a list
containing the selected provider(s) for that combination.
"""
try:
for p in spec['pipelines']:
Expand Down Expand Up @@ -566,6 +618,39 @@ def filter_to_available(t, providers):


def create_test_methods(spec):
"""Dynamically creates test methods based on a YAML specification.

This function takes a YAML specification (`spec`) which defines pipelines,
fixtures, and potentially options. It iterates through different
combinations of "providers" (which determine how YAML transforms are
implemented, e.g., using Python or SQL).

For each combination of providers:
1. It constructs a unique test method name (e.g., `test_only`).
2. It defines a test method that:
a. Sets up any specified fixtures, making their values available as
variables.
b. Mocks the standard YAML providers to use the current combination
of providers for this test run.
c. For each pipeline defined in the `spec`:
i. Creates a `beam.Pipeline` instance with specified options.
ii. Expands the YAML pipeline definition using
`yaml_transform.expand_pipeline`, substituting any fixture
variables.
iii. Runs the Beam pipeline.

The function yields tuples of (test_method_name, test_method_function),
which can then be used to populate a `unittest.TestCase` class.

Args:
spec (dict): A dictionary parsed from a YAML test specification file.
It's expected to have keys like 'fixtures' (optional) and 'pipelines'.

Yields:
tuple: A tuple containing:
- str: The generated name for the test method (e.g., "test_only").
- function: The dynamically generated test method.
"""
for suffix, providers in provider_sets(spec):

def test(self, providers=providers): # default arg to capture loop value
Expand Down Expand Up @@ -593,6 +678,23 @@ def test(self, providers=providers): # default arg to capture loop value


def parse_test_files(filepattern):
"""Parses YAML test files and dynamically creates test cases.

This function iterates through all files matching the given glob pattern.
For each YAML file found, it:
1. Reads the file content.
2. Determines a test suite name based on the file name.
3. Calls `create_test_methods` to generate test methods from the
YAML specification.
4. Dynamically creates a new TestCase class (inheriting from
`unittest.TestCase`) and populates it with the generated test methods.
5. Adds this newly created TestCase class to the global scope, making it
discoverable by the unittest framework.

Args:
filepattern (str): A glob pattern specifying the YAML test files to parse.
For example, 'path/to/tests/*.yaml'.
"""
for path in glob.glob(filepattern):
with open(path) as fin:
suite_name = os.path.splitext(os.path.basename(path))[0].title().replace(
Expand Down
Loading