Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions sdks/python/apache_beam/dataframe/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ def check_df_pcoll_equal(actual):
PD_VERSION = tuple(int(n) for n in pd.__version__.split('.'))


def test_name_func(testcase_func, param_num, params):
df_or_series, _, _ = params.args
if isinstance(df_or_series, pd.Series):
return f"{testcase_func.__name__}_Series[{df_or_series.dtype}]"
elif isinstance(df_or_series, pd.DataFrame):
return (
f"{testcase_func.__name__}_DataFrame"
f"[{','.join(str(dtype) for dtype in df_or_series.dtypes)}]")
else:
raise ValueError(
f"Encountered unsupported param in {testcase_func.__name__}. "
"Expected Series or DataFrame, got:\n" + str(df_or_series))


class SchemasTest(unittest.TestCase):
def test_simple_df(self):
expected = pd.DataFrame({
Expand Down Expand Up @@ -230,7 +244,8 @@ def assert_typehints_equal(self, left, right):
else:
self.assertEqual(left, right)

@parameterized.expand(SERIES_TESTS + NOINDEX_DF_TESTS)
@parameterized.expand(
SERIES_TESTS + NOINDEX_DF_TESTS, name_func=test_name_func)
def test_unbatch_no_index(self, df_or_series, rows, beam_type):
proxy = df_or_series[:0]

Expand All @@ -247,7 +262,7 @@ def test_unbatch_no_index(self, df_or_series, rows, beam_type):

assert_that(res, equal_to(rows))

@parameterized.expand(SERIES_TESTS + INDEX_DF_TESTS)
@parameterized.expand(SERIES_TESTS + INDEX_DF_TESTS, name_func=test_name_func)
def test_unbatch_with_index(self, df_or_series, rows, _):
proxy = df_or_series[:0]

Expand Down