From 38759eb2c66debb1c8100db5d763e69c4885c123 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 5 Sep 2022 18:50:31 +0000 Subject: [PATCH 1/5] stricter casting for table with new schema --- python/pyarrow/table.pxi | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 931677f9848..9b4f9a0d330 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3395,11 +3395,15 @@ cdef class Table(_PandasConvertible): Field field list newcols = [] - if self.schema.names != target_schema.names: + field_names = self.schema.names + if field_names != target_schema.names: raise ValueError("Target schema's field names are not matching " "the table's field names: {!r}, {!r}" .format(self.schema.names, target_schema.names)) + for name in field_names: + if self.schema.field(name).nullable and not target_schema.field(name).nullable: + raise RuntimeError("Casting a nullable field {!r} to non-nullable".format(name)) for column, field in zip(self.itercolumns(), target_schema): casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted) From a641a8348746fdd80a49a07791df5f5c26f9f1ed Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 5 Sep 2022 18:55:04 +0000 Subject: [PATCH 2/5] add test and run linter --- python/pyarrow/table.pxi | 3 ++- python/pyarrow/tests/test_table.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 9b4f9a0d330..693c6e20aba 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3403,7 +3403,8 @@ cdef class Table(_PandasConvertible): for name in field_names: if self.schema.field(name).nullable and not target_schema.field(name).nullable: - raise RuntimeError("Casting a nullable field {!r} to non-nullable".format(name)) + raise RuntimeError( + "Casting a nullable field {!r} to non-nullable".format(name)) for column, field in zip(self.itercolumns(), target_schema): casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index c0c60da6272..33328917b91 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2192,3 +2192,12 @@ def test_table_join_many_columns(): "col6": ["A", "B", None, "Z"], "col7": ["A", "B", None, "Z"], }) + + +def test_table_cast_invalid(): + # Casting a nullable field to non-nullable should be invalid! + table = pa.table({'a': [None, 1], 'b': [None, True]}) + new_schema = pa.schema([pa.field("a", "int64", nullable=True), + pa.field("b", "bool", nullable=False)]) + with pytest.raises(RuntimeError): + table.cast(new_schema) From f8f86359e7729c5734687d5bcab7c064c1723dce Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 6 Sep 2022 16:01:44 +0000 Subject: [PATCH 3/5] check for null_count on the chunked array --- python/pyarrow/table.pxi | 10 ++++------ python/pyarrow/tests/test_table.py | 3 +++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 693c6e20aba..f0087ac002b 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3395,17 +3395,15 @@ cdef class Table(_PandasConvertible): Field field list newcols = [] - field_names = self.schema.names - if field_names != target_schema.names: + if self.schema.names != target_schema.names: raise ValueError("Target schema's field names are not matching " "the table's field names: {!r}, {!r}" .format(self.schema.names, target_schema.names)) - for name in field_names: - if self.schema.field(name).nullable and not target_schema.field(name).nullable: - raise RuntimeError( - "Casting a nullable field {!r} to non-nullable".format(name)) for column, field in zip(self.itercolumns(), target_schema): + if column.null_count > 0 and not field.nullable: + raise RuntimeError("Casting field {!r} with null values to non-nullable" + .format(field.name)) casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 33328917b91..e1a3603096d 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2201,3 +2201,6 @@ def test_table_cast_invalid(): pa.field("b", "bool", nullable=False)]) with pytest.raises(RuntimeError): table.cast(new_schema) + + table = pa.table({'a': [None, 1], 'b': [False, True]}) + assert table.cast(new_schema).schema == new_schema From 4049b9d8d29b0cf7a3b9d6db8168f9e6eaa8700b Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 8 Sep 2022 07:56:58 +0000 Subject: [PATCH 4/5] address review --- python/pyarrow/table.pxi | 4 ++-- python/pyarrow/tests/test_table.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index f0087ac002b..0f8a2942868 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3401,8 +3401,8 @@ cdef class Table(_PandasConvertible): .format(self.schema.names, target_schema.names)) for column, field in zip(self.itercolumns(), target_schema): - if column.null_count > 0 and not field.nullable: - raise RuntimeError("Casting field {!r} with null values to non-nullable" + if not field.nullable and column.null_count > 0: + raise ValueError("Casting field {!r} with null values to non-nullable" .format(field.name)) casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index e1a3603096d..fad1c0acb24 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2199,7 +2199,7 @@ def test_table_cast_invalid(): table = pa.table({'a': [None, 1], 'b': [None, True]}) new_schema = pa.schema([pa.field("a", "int64", nullable=True), pa.field("b", "bool", nullable=False)]) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): table.cast(new_schema) table = pa.table({'a': [None, 1], 'b': [False, True]}) From bd9a60fe2744942ff9d926371602fc3731a29d62 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 8 Sep 2022 09:10:28 +0000 Subject: [PATCH 5/5] make linter happy --- python/pyarrow/table.pxi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 0f8a2942868..30352bf3950 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3403,7 +3403,7 @@ cdef class Table(_PandasConvertible): for column, field in zip(self.itercolumns(), target_schema): if not field.nullable and column.null_count > 0: raise ValueError("Casting field {!r} with null values to non-nullable" - .format(field.name)) + .format(field.name)) casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted)