diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py index cbb2119b83f6..c9745abf9499 100644 --- a/sdks/python/apache_beam/testing/util.py +++ b/sdks/python/apache_beam/testing/util.py @@ -50,6 +50,7 @@ 'matches_all', # open_shards is internal and has no backwards compatibility guarantees. 'open_shards', + 'row_namedtuple_equals_fn', 'TestWindowedValue', ] @@ -167,7 +168,7 @@ def _equal(actual, equals_fn=equals_fn): # collection. It can also raise false negatives for types that don't have # a deterministic sort order, like pyarrow Tables as of 0.14.1 if not equals_fn: - equals_fn = lambda e, a: e == a + equals_fn = row_namedtuple_equals_fn try: sorted_expected = sorted(expected) sorted_actual = sorted(actual) @@ -202,6 +203,33 @@ def _equal(actual, equals_fn=equals_fn): return _equal +def row_namedtuple_equals_fn(expected, actual, fallback_equals_fn=None): + """ + equals_fn which can be used by equal_to which treats Rows and + NamedTuples as equivalent types. This can be useful since Beam converts + Rows to NamedTuples when they are sent across portability layers, so a Row + may be converted to a NamedTuple automatically by Beam. + """ + if fallback_equals_fn is None: + fallback_equals_fn = lambda e, a: e == a + if type(expected) is not pvalue.Row and not _is_named_tuple(expected): + return fallback_equals_fn(expected, actual) + if type(actual) is not pvalue.Row and not _is_named_tuple(actual): + return fallback_equals_fn(expected, actual) + + expected_dict = expected._asdict() + actual_dict = actual._asdict() + if len(expected_dict) != len(actual_dict): + return False + for k, v in expected_dict.items(): + if k not in actual_dict: + return False + if not row_namedtuple_equals_fn(v, actual_dict[k]): + return False + + return True + + def matches_all(expected): """Matcher used by assert_that to check a set of matchers. @@ -386,5 +414,11 @@ def _sort_lists(result): return result +def _is_named_tuple(obj) -> bool: + return ( + isinstance(obj, tuple) and hasattr(obj, '_asdict') and + hasattr(obj, '_fields')) + + # A utility transform that recursively sorts lists for easier testing. SortLists = Map(_sort_lists) diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py index ba3c743c03f3..dbb5d0fd37a5 100644 --- a/sdks/python/apache_beam/testing/util_test.py +++ b/sdks/python/apache_beam/testing/util_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import unittest +from typing import NamedTuple import apache_beam as beam from apache_beam import Create @@ -32,6 +33,7 @@ from apache_beam.testing.util import equal_to_per_window from apache_beam.testing.util import is_empty from apache_beam.testing.util import is_not_empty +from apache_beam.testing.util import row_namedtuple_equals_fn from apache_beam.transforms import trigger from apache_beam.transforms import window from apache_beam.transforms.window import FixedWindows @@ -254,6 +256,54 @@ def test_equal_to_per_window_fail_unexpected_element(self): equal_to_per_window(expected), reify_windows=True) + def test_row_namedtuple_equals(self): + class RowTuple(NamedTuple): + a: str + b: int + + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), RowTuple(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + RowTuple(a='123', b=456), RowTuple(a='123', b=456))) + self.assertTrue( + row_namedtuple_equals_fn( + RowTuple(a='123', b=456), beam.Row(a='123', b=456))) + self.assertTrue(row_namedtuple_equals_fn('foo', 'foo')) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), beam.Row(a='123', b=456, c='a'))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456), RowTuple(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn( + beam.Row(a='123', b=456, c='foo'), RowTuple(a='123', b=4567))) + self.assertFalse( + row_namedtuple_equals_fn(beam.Row(a='123'), RowTuple(a='123', b=4567))) + self.assertFalse(row_namedtuple_equals_fn(beam.Row(a='123'), '123')) + self.assertFalse(row_namedtuple_equals_fn('123', RowTuple(a='123', b=456))) + + class NestedNamedTuple(NamedTuple): + a: str + b: RowTuple + + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='foo', b=beam.Row(a='123', b=456)), + NestedNamedTuple(a='foo', b=RowTuple(a='123', b=456)))) + self.assertTrue( + row_namedtuple_equals_fn( + beam.Row(a='foo', b=beam.Row(a='123', b=456)), + beam.Row(a='foo', b=RowTuple(a='123', b=456)))) + if __name__ == '__main__': unittest.main()