Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion sdks/python/apache_beam/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
'matches_all',
# open_shards is internal and has no backwards compatibility guarantees.
'open_shards',
'row_namedtuple_equals_fn',
'TestWindowedValue',
]

Expand Down Expand Up @@ -167,7 +168,7 @@ def _equal(actual, equals_fn=equals_fn):
# collection. It can also raise false negatives for types that don't have
# a deterministic sort order, like pyarrow Tables as of 0.14.1
if not equals_fn:
equals_fn = lambda e, a: e == a
equals_fn = row_namedtuple_equals_fn
try:
sorted_expected = sorted(expected)
sorted_actual = sorted(actual)
Expand Down Expand Up @@ -202,6 +203,33 @@ def _equal(actual, equals_fn=equals_fn):
return _equal


def row_namedtuple_equals_fn(expected, actual, fallback_equals_fn=None):
"""
equals_fn which can be used by equal_to which treats Rows and
NamedTuples as equivalent types. This can be useful since Beam converts
Rows to NamedTuples when they are sent across portability layers, so a Row
may be converted to a NamedTuple automatically by Beam.
"""
if fallback_equals_fn is None:
fallback_equals_fn = lambda e, a: e == a
if type(expected) is not pvalue.Row and not _is_named_tuple(expected):
return fallback_equals_fn(expected, actual)
if type(actual) is not pvalue.Row and not _is_named_tuple(actual):
return fallback_equals_fn(expected, actual)

expected_dict = expected._asdict()
actual_dict = actual._asdict()
if len(expected_dict) != len(actual_dict):
return False
for k, v in expected_dict.items():
if k not in actual_dict:
return False
if not row_namedtuple_equals_fn(v, actual_dict[k]):
return False

return True


def matches_all(expected):
"""Matcher used by assert_that to check a set of matchers.

Expand Down Expand Up @@ -386,5 +414,11 @@ def _sort_lists(result):
return result


def _is_named_tuple(obj) -> bool:
return (
isinstance(obj, tuple) and hasattr(obj, '_asdict') and
hasattr(obj, '_fields'))


# A utility transform that recursively sorts lists for easier testing.
SortLists = Map(_sort_lists)
50 changes: 50 additions & 0 deletions sdks/python/apache_beam/testing/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# pytype: skip-file

import unittest
from typing import NamedTuple

import apache_beam as beam
from apache_beam import Create
Expand All @@ -32,6 +33,7 @@
from apache_beam.testing.util import equal_to_per_window
from apache_beam.testing.util import is_empty
from apache_beam.testing.util import is_not_empty
from apache_beam.testing.util import row_namedtuple_equals_fn
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.window import FixedWindows
Expand Down Expand Up @@ -254,6 +256,54 @@ def test_equal_to_per_window_fail_unexpected_element(self):
equal_to_per_window(expected),
reify_windows=True)

def test_row_namedtuple_equals(self):
class RowTuple(NamedTuple):
a: str
b: int

self.assertTrue(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456), beam.Row(a='123', b=456)))
self.assertTrue(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456), RowTuple(a='123', b=456)))
self.assertTrue(
row_namedtuple_equals_fn(
RowTuple(a='123', b=456), RowTuple(a='123', b=456)))
self.assertTrue(
row_namedtuple_equals_fn(
RowTuple(a='123', b=456), beam.Row(a='123', b=456)))
self.assertTrue(row_namedtuple_equals_fn('foo', 'foo'))
self.assertFalse(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456), beam.Row(a='123', b=4567)))
self.assertFalse(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456), beam.Row(a='123', b=456, c='a')))
self.assertFalse(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456), RowTuple(a='123', b=4567)))
self.assertFalse(
row_namedtuple_equals_fn(
beam.Row(a='123', b=456, c='foo'), RowTuple(a='123', b=4567)))
self.assertFalse(
row_namedtuple_equals_fn(beam.Row(a='123'), RowTuple(a='123', b=4567)))
self.assertFalse(row_namedtuple_equals_fn(beam.Row(a='123'), '123'))
self.assertFalse(row_namedtuple_equals_fn('123', RowTuple(a='123', b=456)))

class NestedNamedTuple(NamedTuple):
a: str
b: RowTuple

self.assertTrue(
row_namedtuple_equals_fn(
beam.Row(a='foo', b=beam.Row(a='123', b=456)),
NestedNamedTuple(a='foo', b=RowTuple(a='123', b=456))))
self.assertTrue(
row_namedtuple_equals_fn(
beam.Row(a='foo', b=beam.Row(a='123', b=456)),
beam.Row(a='foo', b=RowTuple(a='123', b=456))))


if __name__ == '__main__':
unittest.main()
Loading