Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sdks/python/apache_beam/coders/row_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@
("favorite_time", typing.Optional[Timestamp]),
("one_more_field", typing.Optional[str])])


class People(typing.NamedTuple):
primary: Person
partner: typing.Optional[Person]


coders_registry.register_coder(Person, RowCoder)
coders_registry.register_coder(People, RowCoder)


class RowCoderTest(unittest.TestCase):
Expand Down Expand Up @@ -121,6 +128,19 @@ def test_create_row_coder_from_named_tuple(self):
self.assertEqual(
test_case, real_coder.decode(real_coder.encode(test_case)))

def test_create_row_coder_from_nested_named_tuple(self):
expected_coder = RowCoder(typing_to_runner_api(People).row_type.schema)
real_coder = coders_registry.get_coder(People)

for primary in self.PEOPLE:
for other in self.PEOPLE + [None]:
test_case = People(primary=primary, partner=other)
self.assertEqual(
expected_coder.encode(test_case), real_coder.encode(test_case))

self.assertEqual(
test_case, real_coder.decode(real_coder.encode(test_case)))

def test_create_row_coder_from_schema(self):
schema = schema_pb2.Schema(
id="person",
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def from_fields(
field_options=field_options,
schema_registry=schema_registry)

def __call__(self, *args, **kwargs):
# We make RowTypeConstraint callable (defers to constructing the user type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify why we need to make this a callable for Python to recognize it as a type (may be also update the comment here) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there's a better reference on this, but the CPython implementation explicitly considers anything that's callable a type:
https://github.com/python/cpython/blob/d348afa15d5a997e7a8e51c0f789f41cb15cc651/Lib/typing.py#L137-L167

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is actually being relaxed in Python 3.11: python/cpython@870b22b#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Optionally, add a comment that this can be removed after we are fully in Python 3.11 (whenever that is).

# so that Python will recognize it as a type. This allows RowTypeConstraint
# to be used in conjunction with native typehints, like Optional.
# CPython (prior to 3.11) considers anything callable to be a type:
# https://github.com/python/cpython/blob/d348afa15d5a997e7a8e51c0f789f41cb15cc651/Lib/typing.py#L137-L167
return self._user_type(*args, **kwargs)

@property
def user_type(self):
return self._user_type
Expand Down
44 changes: 44 additions & 0 deletions sdks/python/apache_beam/typehints/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,43 @@ def get_test_beam_fieldtype_protos():
]) for i,
typ in enumerate(all_primitives)
]))),
schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
id='a-schema-with-optional-nested-struct',
fields=[
schema_pb2.Field(
name='id',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name='nested_row',
type=schema_pb2.FieldType(
nullable=True,
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
id='the-nested-schema',
fields=[
schema_pb2.Field(
name='name',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)
),
schema_pb2.Field(
name='optional_map',
type=schema_pb2.FieldType(
nullable=True,
map_type=schema_pb2.MapType(
key_type=schema_pb2.
FieldType(
atomic_type=schema_pb2
.STRING),
value_type=schema_pb2.
FieldType(
atomic_type=schema_pb2
.DOUBLE)))),
]))))
])))
]

return all_primitives + \
Expand Down Expand Up @@ -562,6 +599,13 @@ def test_schema_with_bad_field_raises_helpful_error(self):
# bypass schema cache
schema_registry=SchemaTypeRegistry()))

def test_row_type_is_callable(self):
simple_row_type = row_type.RowTypeConstraint.from_fields([('foo', np.int64),
('bar', str)])
instance = simple_row_type(np.int64(35), 'baz')
self.assertIsInstance(instance, simple_row_type.user_type)
self.assertEqual(instance, (np.int64(35), 'baz'))


@parameterized_class([
{
Expand Down