diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index 42922fe31ae2..b240dec59190 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -68,6 +68,7 @@ from apache_beam.typehints.schemas import named_fields_to_schema from apache_beam.typehints.schemas import named_tuple_from_schema from apache_beam.typehints.schemas import named_tuple_to_schema +from apache_beam.typehints.typehints import normalize from apache_beam.utils import proto_utils __all__ = ( @@ -119,6 +120,12 @@ BEAM_TO_PANDAS[bytes] = 'bytes' +# Add shunts for normalized (Beam) typehints as well +BEAM_TO_PANDAS.update({ + normalize(typehint): pandas_dtype + for (typehint, pandas_dtype) in BEAM_TO_PANDAS.items() +}) + @typehints.with_input_types(T) @typehints.with_output_types(pd.DataFrame) diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py index f019f82ddb3c..ec0c466fa859 100644 --- a/sdks/python/apache_beam/dataframe/schemas_test.py +++ b/sdks/python/apache_beam/dataframe/schemas_test.py @@ -180,6 +180,13 @@ def test_generate_proxy(self): pd.testing.assert_frame_equal(schemas.generate_proxy(Animal), expected) + def test_generate_proxy_beam_typehint(self): + expected = pd.Series(dtype=pd.Int32Dtype()) + + actual = schemas.generate_proxy(typehints.Optional[np.int32]) + + pd.testing.assert_series_equal(actual, expected) + def test_nice_types_proxy_roundtrip(self): roundtripped = schemas.generate_proxy( schemas.element_type_from_dataframe(NICE_TYPES_PROXY))