diff --git a/sdks/python/apache_beam/yaml/tests/create.yaml b/sdks/python/apache_beam/yaml/tests/create.yaml index bed364c17143..723d8a888c26 100644 --- a/sdks/python/apache_beam/yaml/tests/create.yaml +++ b/sdks/python/apache_beam/yaml/tests/create.yaml @@ -30,7 +30,7 @@ pipelines: - {element: 2} - {element: 3} - {element: 4} - - {element: 5} + - {element: 5} # Simple Create with more complex beam row - pipeline: @@ -64,3 +64,20 @@ pipelines: - {first: 0, second: [1,2,3]} - {first: 1, second: [4,5,6]} - {first: 2, second: [7,8,9]} + + # Simple Create with element list + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: Flume} + - {sdk: MillWheel, year: 2008} + - type: AssertEqual + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: Flume} + - {sdk: MillWheel, year: 2008} diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 7c8114b57706..42087565013c 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -785,9 +785,14 @@ def __init__(self, elements: Iterable[Any]): self._elements = elements def expand(self, pcoll): + def to_dict(row): + # filter None when comparing + temp_dict = {k: v for k, v in row._asdict().items() if v is not None} + return dict(temp_dict.items()) + return assert_that( - pcoll | beam.Map(lambda row: beam.Row(**row._asdict())), - equal_to(dicts_to_rows(self._elements))) + pcoll | beam.Map(to_dict), + equal_to([to_dict(e) for e in dicts_to_rows(self._elements)])) @staticmethod def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): @@ -838,7 +843,32 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): # not the intent. if not isinstance(elements, Iterable) or isinstance(elements, (dict, str)): raise TypeError('elements must be a list of elements') - return beam.Create([element_to_rows(e) for e in elements], + + # Check if elements have different keys + updated_elements = elements + if elements and all(isinstance(e, dict) for e in elements): + keys = [set(e.keys()) for e in elements] + if len(set.union(*keys)) > min(len(k) for k in keys): + # Merge all dictionaries to get all possible keys + all_keys = set() + for element in elements: + if isinstance(element, dict): + all_keys.update(element.keys()) + + # Create a merged dictionary with all keys + merged_dict = {} + for key in all_keys: + merged_dict[key] = None # Use None as a default value + + # Update each element with the merged dictionary + updated_elements = [] + for e in elements: + if isinstance(e, dict): + updated_elements.append({**merged_dict, **e}) + else: + updated_elements.append(e) + + return beam.Create([element_to_rows(e) for e in updated_elements], reshuffle=reshuffle is not False) # Or should this be posargs, args?