diff --git a/sdks/python/apache_beam/coders/row_coder_test.py b/sdks/python/apache_beam/coders/row_coder_test.py index 843d2a32ec10..b2da71b2a20e 100644 --- a/sdks/python/apache_beam/coders/row_coder_test.py +++ b/sdks/python/apache_beam/coders/row_coder_test.py @@ -58,7 +58,14 @@ ("favorite_time", typing.Optional[Timestamp]), ("one_more_field", typing.Optional[str])]) + +class People(typing.NamedTuple): + primary: Person + partner: typing.Optional[Person] + + coders_registry.register_coder(Person, RowCoder) +coders_registry.register_coder(People, RowCoder) class RowCoderTest(unittest.TestCase): @@ -121,6 +128,19 @@ def test_create_row_coder_from_named_tuple(self): self.assertEqual( test_case, real_coder.decode(real_coder.encode(test_case))) + def test_create_row_coder_from_nested_named_tuple(self): + expected_coder = RowCoder(typing_to_runner_api(People).row_type.schema) + real_coder = coders_registry.get_coder(People) + + for primary in self.PEOPLE: + for other in self.PEOPLE + [None]: + test_case = People(primary=primary, partner=other) + self.assertEqual( + expected_coder.encode(test_case), real_coder.encode(test_case)) + + self.assertEqual( + test_case, real_coder.decode(real_coder.encode(test_case))) + def test_create_row_coder_from_schema(self): schema = schema_pb2.Schema( id="person", diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index b1f6fd99d979..0a01fbc35f56 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -123,6 +123,14 @@ def from_fields( field_options=field_options, schema_registry=schema_registry) + def __call__(self, *args, **kwargs): + # We make RowTypeConstraint callable (defers to constructing the user type) + # so that Python will recognize it as a type. This allows RowTypeConstraint + # to be used in conjunction with native typehints, like Optional. + # CPython (prior to 3.11) considers anything callable to be a type: + # https://github.com/python/cpython/blob/d348afa15d5a997e7a8e51c0f789f41cb15cc651/Lib/typing.py#L137-L167 + return self._user_type(*args, **kwargs) + @property def user_type(self): return self._user_type diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 370b9c92cde7..5d9434345cbc 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -267,6 +267,43 @@ def get_test_beam_fieldtype_protos(): ]) for i, typ in enumerate(all_primitives) ]))), + schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + id='a-schema-with-optional-nested-struct', + fields=[ + schema_pb2.Field( + name='id', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='nested_row', + type=schema_pb2.FieldType( + nullable=True, + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + id='the-nested-schema', + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING) + ), + schema_pb2.Field( + name='optional_map', + type=schema_pb2.FieldType( + nullable=True, + map_type=schema_pb2.MapType( + key_type=schema_pb2. + FieldType( + atomic_type=schema_pb2 + .STRING), + value_type=schema_pb2. + FieldType( + atomic_type=schema_pb2 + .DOUBLE)))), + ])))) + ]))) ] return all_primitives + \ @@ -562,6 +599,13 @@ def test_schema_with_bad_field_raises_helpful_error(self): # bypass schema cache schema_registry=SchemaTypeRegistry())) + def test_row_type_is_callable(self): + simple_row_type = row_type.RowTypeConstraint.from_fields([('foo', np.int64), + ('bar', str)]) + instance = simple_row_type(np.int64(35), 'baz') + self.assertIsInstance(instance, simple_row_type.user_type) + self.assertEqual(instance, (np.int64(35), 'baz')) + @parameterized_class([ {