From 04d6980f83ebf2fb7818e42b9be1102f831cfc8b Mon Sep 17 00:00:00 2001 From: kharoc Date: Mon, 2 Aug 2021 15:13:08 -0500 Subject: [PATCH 1/3] ARROW-13089:[Python]Allow creating RecordBatch from Python dict --- python/pyarrow/table.pxi | 47 ++++++++++++++++++++++++ python/pyarrow/tests/test_table.py | 57 ++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 65f1ba11dc9..a1dc80a8983 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -616,6 +616,53 @@ cdef class RecordBatch(_PandasConvertible): self.sp_batch = batch self.batch = batch.get() + @staticmethod + def from_pydict(mapping, schema=None, metadata=None): + """ + Construct a RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + RecordBatch + """ + arrays = [] + if schema is None: + names = [] + for k, v in mapping.items(): + names.append(k) + arrays.append(asarray(v)) + return RecordBatch.from_arrays(arrays, names, metadata=metadata) + elif isinstance(schema, Schema): + for field in schema: + try: + v = mapping[field.name] + except KeyError: + try: + v = mapping[tobytes(field.name)] + except KeyError: + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) + arrays.append(asarray(v, type=field.type)) + # Will raise if metadata is not None + return RecordBatch.from_arrays( + arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema') + def __reduce__(self): return _reconstruct_record_batch, (self.columns, self.schema) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 7ba844aa809..1290b99eacd 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1685,3 +1685,60 @@ def test_table_select(): result = table.select(['f2']) expected = pa.table([a2], ['f2']) assert result.equals(expected) + + +def test_recordbatch_from_pydict(): + table = pa.RecordBatch.from_pydict({}) + assert table.num_columns == 0 + assert table.num_rows == 0 + assert table.schema == pa.schema([]) + assert table.to_pydict() == {} + + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())]) + + # With lists as values + data = OrderedDict([('strs', ['', 'foo', 'bar']), + ('floats', [4.5, 5, None])]) + table = pa.RecordBatch.from_pydict(data) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + assert table.to_pydict() == data + + # With metadata and inferred schema + metadata = {b'foo': b'bar'} + schema = schema.with_metadata(metadata) + table = pa.RecordBatch.from_pydict(data, metadata=metadata) + assert table.schema == schema + assert table.schema.metadata == metadata + assert table.to_pydict() == data + + # With explicit schema + table = pa.RecordBatch.from_pydict(data, schema=schema) + assert table.schema == schema + assert table.schema.metadata == metadata + assert table.to_pydict() == data + + # Cannot pass both schema and metadata + with pytest.raises(ValueError): + pa.RecordBatch.from_pydict(data, schema=schema, metadata=metadata) + + # Non-convertible values given schema + with pytest.raises(TypeError): + pa.RecordBatch.from_pydict({'c0': [0, 1, 2]}, + schema=pa.schema([("c0", pa.string())])) + + # Missing schema fields from the passed mapping + with pytest.raises(KeyError, match="doesn\'t contain.* c, d"): + pa.RecordBatch.from_pydict( + {'a': [1, 2, 3], 'b': [3, 4, 5]}, + schema=pa.schema([ + ('a', pa.int64()), + ('c', pa.int32()), + ('d', pa.int16()) + ]) + ) + + # Passed wrong schema type + with pytest.raises(TypeError): + pa.RecordBatch.from_pydict({'a': [1, 2, 3]}, schema={}) From 163ff1159b46ebc65811b5ee2f137861f51071e2 Mon Sep 17 00:00:00 2001 From: kharoc Date: Tue, 3 Aug 2021 11:42:53 -0500 Subject: [PATCH 2/3] ARROW-13089:[Python]Allow creating RecordBatch from Python - review comments --- python/pyarrow/table.pxi | 17 +++--- python/pyarrow/tests/test_table.py | 84 ++++++------------------------ 2 files changed, 24 insertions(+), 77 deletions(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index a1dc80a8983..8e4b903f1dc 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -646,16 +646,13 @@ cdef class RecordBatch(_PandasConvertible): try: v = mapping[field.name] except KeyError: - try: - v = mapping[tobytes(field.name)] - except KeyError: - present = mapping.keys() - missing = [n for n in schema.names if n not in present] - raise KeyError( - "The passed mapping doesn't contain the " - "following field(s) of the schema: {}". - format(', '.join(missing)) - ) + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) arrays.append(asarray(v, type=field.type)) # Will raise if metadata is not None return RecordBatch.from_arrays( diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 1290b99eacd..72bb8ef2d99 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1339,8 +1339,15 @@ def test_from_arrays_schema(data, klass): pa.Table.from_arrays(data, schema=schema, metadata={b'foo': b'bar'}) -def test_table_from_pydict(): - table = pa.Table.from_pydict({}) +@pytest.mark.parametrize( + ('cls'), + [ + (pa.Table), + (pa.RecordBatch) + ] +) +def test_table_from_pydict(cls): + table = cls.from_pydict({}) assert table.num_columns == 0 assert table.num_rows == 0 assert table.schema == pa.schema([]) @@ -1351,7 +1358,7 @@ def test_table_from_pydict(): # With lists as values data = OrderedDict([('strs', ['', 'foo', 'bar']), ('floats', [4.5, 5, None])]) - table = pa.Table.from_pydict(data) + table = cls.from_pydict(data) assert table.num_columns == 2 assert table.num_rows == 3 assert table.schema == schema @@ -1360,29 +1367,29 @@ def test_table_from_pydict(): # With metadata and inferred schema metadata = {b'foo': b'bar'} schema = schema.with_metadata(metadata) - table = pa.Table.from_pydict(data, metadata=metadata) + table = cls.from_pydict(data, metadata=metadata) assert table.schema == schema assert table.schema.metadata == metadata assert table.to_pydict() == data # With explicit schema - table = pa.Table.from_pydict(data, schema=schema) + table = cls.from_pydict(data, schema=schema) assert table.schema == schema assert table.schema.metadata == metadata assert table.to_pydict() == data # Cannot pass both schema and metadata with pytest.raises(ValueError): - pa.Table.from_pydict(data, schema=schema, metadata=metadata) + cls.from_pydict(data, schema=schema, metadata=metadata) # Non-convertible values given schema with pytest.raises(TypeError): - pa.Table.from_pydict({'c0': [0, 1, 2]}, - schema=pa.schema([("c0", pa.string())])) + cls.from_pydict({'c0': [0, 1, 2]}, + schema=pa.schema([("c0", pa.string())])) # Missing schema fields from the passed mapping with pytest.raises(KeyError, match="doesn\'t contain.* c, d"): - pa.Table.from_pydict( + cls.from_pydict( {'a': [1, 2, 3], 'b': [3, 4, 5]}, schema=pa.schema([ ('a', pa.int64()), @@ -1393,7 +1400,7 @@ def test_table_from_pydict(): # Passed wrong schema type with pytest.raises(TypeError): - pa.Table.from_pydict({'a': [1, 2, 3]}, schema={}) + cls.from_pydict({'a': [1, 2, 3]}, schema={}) @pytest.mark.parametrize('data, klass', [ @@ -1685,60 +1692,3 @@ def test_table_select(): result = table.select(['f2']) expected = pa.table([a2], ['f2']) assert result.equals(expected) - - -def test_recordbatch_from_pydict(): - table = pa.RecordBatch.from_pydict({}) - assert table.num_columns == 0 - assert table.num_rows == 0 - assert table.schema == pa.schema([]) - assert table.to_pydict() == {} - - schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())]) - - # With lists as values - data = OrderedDict([('strs', ['', 'foo', 'bar']), - ('floats', [4.5, 5, None])]) - table = pa.RecordBatch.from_pydict(data) - assert table.num_columns == 2 - assert table.num_rows == 3 - assert table.schema == schema - assert table.to_pydict() == data - - # With metadata and inferred schema - metadata = {b'foo': b'bar'} - schema = schema.with_metadata(metadata) - table = pa.RecordBatch.from_pydict(data, metadata=metadata) - assert table.schema == schema - assert table.schema.metadata == metadata - assert table.to_pydict() == data - - # With explicit schema - table = pa.RecordBatch.from_pydict(data, schema=schema) - assert table.schema == schema - assert table.schema.metadata == metadata - assert table.to_pydict() == data - - # Cannot pass both schema and metadata - with pytest.raises(ValueError): - pa.RecordBatch.from_pydict(data, schema=schema, metadata=metadata) - - # Non-convertible values given schema - with pytest.raises(TypeError): - pa.RecordBatch.from_pydict({'c0': [0, 1, 2]}, - schema=pa.schema([("c0", pa.string())])) - - # Missing schema fields from the passed mapping - with pytest.raises(KeyError, match="doesn\'t contain.* c, d"): - pa.RecordBatch.from_pydict( - {'a': [1, 2, 3], 'b': [3, 4, 5]}, - schema=pa.schema([ - ('a', pa.int64()), - ('c', pa.int32()), - ('d', pa.int16()) - ]) - ) - - # Passed wrong schema type - with pytest.raises(TypeError): - pa.RecordBatch.from_pydict({'a': [1, 2, 3]}, schema={}) From 229a31a1e600653a9933224baeb618b283afdb51 Mon Sep 17 00:00:00 2001 From: kharoc Date: Wed, 4 Aug 2021 10:54:12 -0500 Subject: [PATCH 3/3] ARROW-13089:[Python]Allow creating RecordBatch from Python - refactor _from_pydict function --- python/pyarrow/table.pxi | 110 +++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8e4b903f1dc..d92bdb2efa3 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -634,31 +634,11 @@ cdef class RecordBatch(_PandasConvertible): ------- RecordBatch """ - arrays = [] - if schema is None: - names = [] - for k, v in mapping.items(): - names.append(k) - arrays.append(asarray(v)) - return RecordBatch.from_arrays(arrays, names, metadata=metadata) - elif isinstance(schema, Schema): - for field in schema: - try: - v = mapping[field.name] - except KeyError: - present = mapping.keys() - missing = [n for n in schema.names if n not in present] - raise KeyError( - "The passed mapping doesn't contain the " - "following field(s) of the schema: {}". - format(', '.join(missing)) - ) - arrays.append(asarray(v, type=field.type)) - # Will raise if metadata is not None - return RecordBatch.from_arrays( - arrays, schema=schema, metadata=metadata) - else: - raise TypeError('Schema must be an instance of pyarrow.Schema') + + return _from_pydict(cls=RecordBatch, + mapping=mapping, + schema=schema, + metadata=metadata) def __reduce__(self): return _reconstruct_record_batch, (self.columns, self.schema) @@ -1675,33 +1655,11 @@ cdef class Table(_PandasConvertible): ------- Table """ - arrays = [] - if schema is None: - names = [] - for k, v in mapping.items(): - names.append(k) - arrays.append(asarray(v)) - return Table.from_arrays(arrays, names, metadata=metadata) - elif isinstance(schema, Schema): - for field in schema: - try: - v = mapping[field.name] - except KeyError: - try: - v = mapping[tobytes(field.name)] - except KeyError: - present = mapping.keys() - missing = [n for n in schema.names if n not in present] - raise KeyError( - "The passed mapping doesn't contain the " - "following field(s) of the schema: {}". - format(', '.join(missing)) - ) - arrays.append(asarray(v, type=field.type)) - # Will raise if metadata is not None - return Table.from_arrays(arrays, schema=schema, metadata=metadata) - else: - raise TypeError('Schema must be an instance of pyarrow.Schema') + + return _from_pydict(cls=Table, + mapping=mapping, + schema=schema, + metadata=metadata) @staticmethod def from_batches(batches, Schema schema=None): @@ -2316,3 +2274,51 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): ConcatenateTables(c_tables, options, pool)) return pyarrow_wrap_table(c_result_table) + + +def _from_pydict(cls, mapping, schema, metadata): + """ + Construct a Table/RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + cls : Class Table/RecordBatch + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table/RecordBatch + """ + + arrays = [] + if schema is None: + names = [] + for k, v in mapping.items(): + names.append(k) + arrays.append(asarray(v)) + return cls.from_arrays(arrays, names, metadata=metadata) + elif isinstance(schema, Schema): + for field in schema: + try: + v = mapping[field.name] + except KeyError: + try: + v = mapping[tobytes(field.name)] + except KeyError: + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) + arrays.append(asarray(v, type=field.type)) + # Will raise if metadata is not None + return cls.from_arrays(arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema')