Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 77 additions & 27 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,30 @@ cdef class RecordBatch(_PandasConvertible):
self.sp_batch = batch
self.batch = batch.get()

@staticmethod
def from_pydict(mapping, schema=None, metadata=None):
"""
Construct a RecordBatch from Arrow arrays or columns.

Parameters
----------
mapping : dict or Mapping
A mapping of strings to Arrays or Python lists.
schema : Schema, default None
If not passed, will be inferred from the Mapping values.
metadata : dict or Mapping, default None
Optional metadata for the schema (if inferred).

Returns
-------
RecordBatch
"""

return _from_pydict(cls=RecordBatch,
mapping=mapping,
schema=schema,
metadata=metadata)

def __reduce__(self):
return _reconstruct_record_batch, (self.columns, self.schema)

Expand Down Expand Up @@ -1631,33 +1655,11 @@ cdef class Table(_PandasConvertible):
-------
Table
"""
arrays = []
if schema is None:
names = []
for k, v in mapping.items():
names.append(k)
arrays.append(asarray(v))
return Table.from_arrays(arrays, names, metadata=metadata)
elif isinstance(schema, Schema):
for field in schema:
try:
v = mapping[field.name]
except KeyError:
try:
v = mapping[tobytes(field.name)]
except KeyError:
present = mapping.keys()
missing = [n for n in schema.names if n not in present]
raise KeyError(
"The passed mapping doesn't contain the "
"following field(s) of the schema: {}".
format(', '.join(missing))
)
arrays.append(asarray(v, type=field.type))
# Will raise if metadata is not None
return Table.from_arrays(arrays, schema=schema, metadata=metadata)
else:
raise TypeError('Schema must be an instance of pyarrow.Schema')

return _from_pydict(cls=Table,
mapping=mapping,
schema=schema,
metadata=metadata)

@staticmethod
def from_batches(batches, Schema schema=None):
Expand Down Expand Up @@ -2272,3 +2274,51 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None):
ConcatenateTables(c_tables, options, pool))

return pyarrow_wrap_table(c_result_table)


def _from_pydict(cls, mapping, schema, metadata):
"""
Construct a Table/RecordBatch from Arrow arrays or columns.

Parameters
----------
cls : Class Table/RecordBatch
mapping : dict or Mapping
A mapping of strings to Arrays or Python lists.
schema : Schema, default None
If not passed, will be inferred from the Mapping values.
metadata : dict or Mapping, default None
Optional metadata for the schema (if inferred).

Returns
-------
Table/RecordBatch
"""

arrays = []
if schema is None:
names = []
for k, v in mapping.items():
names.append(k)
arrays.append(asarray(v))
return cls.from_arrays(arrays, names, metadata=metadata)
elif isinstance(schema, Schema):
for field in schema:
try:
v = mapping[field.name]
except KeyError:
try:
v = mapping[tobytes(field.name)]
except KeyError:
present = mapping.keys()
missing = [n for n in schema.names if n not in present]
raise KeyError(
"The passed mapping doesn't contain the "
"following field(s) of the schema: {}".
format(', '.join(missing))
)
arrays.append(asarray(v, type=field.type))
# Will raise if metadata is not None
return cls.from_arrays(arrays, schema=schema, metadata=metadata)
else:
raise TypeError('Schema must be an instance of pyarrow.Schema')
27 changes: 17 additions & 10 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,15 @@ def test_from_arrays_schema(data, klass):
pa.Table.from_arrays(data, schema=schema, metadata={b'foo': b'bar'})


def test_table_from_pydict():
table = pa.Table.from_pydict({})
@pytest.mark.parametrize(
('cls'),
[
(pa.Table),
(pa.RecordBatch)
]
)
def test_table_from_pydict(cls):
table = cls.from_pydict({})
assert table.num_columns == 0
assert table.num_rows == 0
assert table.schema == pa.schema([])
Expand All @@ -1351,7 +1358,7 @@ def test_table_from_pydict():
# With lists as values
data = OrderedDict([('strs', ['', 'foo', 'bar']),
('floats', [4.5, 5, None])])
table = pa.Table.from_pydict(data)
table = cls.from_pydict(data)
assert table.num_columns == 2
assert table.num_rows == 3
assert table.schema == schema
Expand All @@ -1360,29 +1367,29 @@ def test_table_from_pydict():
# With metadata and inferred schema
metadata = {b'foo': b'bar'}
schema = schema.with_metadata(metadata)
table = pa.Table.from_pydict(data, metadata=metadata)
table = cls.from_pydict(data, metadata=metadata)
assert table.schema == schema
assert table.schema.metadata == metadata
assert table.to_pydict() == data

# With explicit schema
table = pa.Table.from_pydict(data, schema=schema)
table = cls.from_pydict(data, schema=schema)
assert table.schema == schema
assert table.schema.metadata == metadata
assert table.to_pydict() == data

# Cannot pass both schema and metadata
with pytest.raises(ValueError):
pa.Table.from_pydict(data, schema=schema, metadata=metadata)
cls.from_pydict(data, schema=schema, metadata=metadata)

# Non-convertible values given schema
with pytest.raises(TypeError):
pa.Table.from_pydict({'c0': [0, 1, 2]},
schema=pa.schema([("c0", pa.string())]))
cls.from_pydict({'c0': [0, 1, 2]},
schema=pa.schema([("c0", pa.string())]))

# Missing schema fields from the passed mapping
with pytest.raises(KeyError, match="doesn\'t contain.* c, d"):
pa.Table.from_pydict(
cls.from_pydict(
{'a': [1, 2, 3], 'b': [3, 4, 5]},
schema=pa.schema([
('a', pa.int64()),
Expand All @@ -1393,7 +1400,7 @@ def test_table_from_pydict():

# Passed wrong schema type
with pytest.raises(TypeError):
pa.Table.from_pydict({'a': [1, 2, 3]}, schema={})
cls.from_pydict({'a': [1, 2, 3]}, schema={})


@pytest.mark.parametrize('data, klass', [
Expand Down