From a94b806c2b1acf7a89d23291f1a52136f708548f Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 9 Jun 2025 16:18:19 -0400 Subject: [PATCH 1/2] Evaluate namedTuples as equivalent to rows --- sdks/python/apache_beam/testing/util.py | 36 +++++++++++++- sdks/python/apache_beam/testing/util_test.py | 50 ++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py index cbb2119b83f6..c9745abf9499 100644 --- a/sdks/python/apache_beam/testing/util.py +++ b/sdks/python/apache_beam/testing/util.py @@ -50,6 +50,7 @@ 'matches_all', # open_shards is internal and has no backwards compatibility guarantees. 'open_shards', + 'row_namedtuple_equals_fn', 'TestWindowedValue', ] @@ -167,7 +168,7 @@ def _equal(actual, equals_fn=equals_fn): # collection. It can also raise false negatives for types that don't have # a deterministic sort order, like pyarrow Tables as of 0.14.1 if not equals_fn: - equals_fn = lambda e, a: e == a + equals_fn = row_namedtuple_equals_fn try: sorted_expected = sorted(expected) sorted_actual = sorted(actual) @@ -202,6 +203,33 @@ def _equal(actual, equals_fn=equals_fn): return _equal +def row_namedtuple_equals_fn(expected, actual, fallback_equals_fn=None): + """ + equals_fn which can be used by equal_to which treats Rows and + NamedTuples as equivalent types. This can be useful since Beam converts + Rows to NamedTuples when they are sent across portability layers, so a Row + may be converted to a NamedTuple automatically by Beam. + """ + if fallback_equals_fn is None: + fallback_equals_fn = lambda e, a: e == a + if type(expected) is not pvalue.Row and not _is_named_tuple(expected): + return fallback_equals_fn(expected, actual) + if type(actual) is not pvalue.Row and not _is_named_tuple(actual): + return fallback_equals_fn(expected, actual) + + expected_dict = expected._asdict() + actual_dict = actual._asdict() + if len(expected_dict) != len(actual_dict): + return False + for k, v in expected_dict.items(): + if k not in actual_dict: + return False + if not row_namedtuple_equals_fn(v, actual_dict[k]): + return False + + return True + + def matches_all(expected): """Matcher used by assert_that to check a set of matchers. @@ -386,5 +414,11 @@ def _sort_lists(result): return result +def _is_named_tuple(obj) -> bool: + return ( + isinstance(obj, tuple) and hasattr(obj, '_asdict') and + hasattr(obj, '_fields')) + + # A utility transform that recursively sorts lists for easier testing. SortLists = Map(_sort_lists) diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py index ba3c743c03f3..a236543f6c97 100644 --- a/sdks/python/apache_beam/testing/util_test.py +++ b/sdks/python/apache_beam/testing/util_test.py @@ -32,12 +32,14 @@ from apache_beam.testing.util import equal_to_per_window from apache_beam.testing.util import is_empty from apache_beam.testing.util import is_not_empty +from apache_beam.testing.util import row_namedtuple_equals_fn from apache_beam.transforms import trigger from apache_beam.transforms import window from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.timestamp import MIN_TIMESTAMP +from typing import NamedTuple class UtilTest(unittest.TestCase): @@ -254,6 +256,54 @@ def test_equal_to_per_window_fail_unexpected_element(self): equal_to_per_window(expected), reify_windows=True) + def test_row_namedtuple_equals(self): + class RowTuple(NamedTuple): + a: str + b: int + + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), RowTuple(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + RowTuple(a='123', b=456), RowTuple(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + RowTuple(a='123', b=456), beam.Row(a='123', b=456))) + self.assertTrue(row_namedtuple_equals_fn('foo', 'foo')) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=456, c='a'))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), RowTuple(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456, c='foo'), RowTuple(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn(beam.Row(a='123'), RowTuple(a='123', b=4567))) + self.assertFalse(row_namedtuple_equals_fn(beam.Row(a='123'), '123')) + self.assertFalse(row_namedtuple_equals_fn('123', RowTuple(a='123', b=4567))) + + class NestedNamedTuple(NamedTuple): + a: str + b: RowTuple + + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='foo', b=beam.Row(a='123', b=456)), + NestedNamedTuple(a='foo', b=RowTuple(a='123', b=456)))) + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='foo', b=beam.Row(a='123', b=456)), + beam.Row(a='foo', b=RowTuple(a='123', b=456)))) + if __name__ == '__main__': unittest.main() From 8bba5b7410fdec5670e5aec14632e19aa358ae9e Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 9 Jun 2025 16:58:07 -0400 Subject: [PATCH 2/2] lint --- sdks/python/apache_beam/testing/util_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py index a236543f6c97..dbb5d0fd37a5 100644 --- a/sdks/python/apache_beam/testing/util_test.py +++ b/sdks/python/apache_beam/testing/util_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import unittest +from typing import NamedTuple import apache_beam as beam from apache_beam import Create @@ -39,7 +40,6 @@ from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.timestamp import MIN_TIMESTAMP -from typing import NamedTuple class UtilTest(unittest.TestCase): @@ -289,7 +289,7 @@ class RowTuple(NamedTuple): self.assertFalse( row_namedtuple_equals_fn(beam.Row(a='123'), RowTuple(a='123', b=4567))) self.assertFalse(row_namedtuple_equals_fn(beam.Row(a='123'), '123')) - self.assertFalse(row_namedtuple_equals_fn('123', RowTuple(a='123', b=4567))) + self.assertFalse(row_namedtuple_equals_fn('123', RowTuple(a='123', b=456))) class NestedNamedTuple(NamedTuple): a: str