diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 65f1ba11dc9..d92bdb2efa3 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -616,6 +616,30 @@ cdef class RecordBatch(_PandasConvertible): self.sp_batch = batch self.batch = batch.get() + @staticmethod + def from_pydict(mapping, schema=None, metadata=None): + """ + Construct a RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + RecordBatch + """ + + return _from_pydict(cls=RecordBatch, + mapping=mapping, + schema=schema, + metadata=metadata) + def __reduce__(self): return _reconstruct_record_batch, (self.columns, self.schema) @@ -1631,33 +1655,11 @@ cdef class Table(_PandasConvertible): ------- Table """ - arrays = [] - if schema is None: - names = [] - for k, v in mapping.items(): - names.append(k) - arrays.append(asarray(v)) - return Table.from_arrays(arrays, names, metadata=metadata) - elif isinstance(schema, Schema): - for field in schema: - try: - v = mapping[field.name] - except KeyError: - try: - v = mapping[tobytes(field.name)] - except KeyError: - present = mapping.keys() - missing = [n for n in schema.names if n not in present] - raise KeyError( - "The passed mapping doesn't contain the " - "following field(s) of the schema: {}". - format(', '.join(missing)) - ) - arrays.append(asarray(v, type=field.type)) - # Will raise if metadata is not None - return Table.from_arrays(arrays, schema=schema, metadata=metadata) - else: - raise TypeError('Schema must be an instance of pyarrow.Schema') + + return _from_pydict(cls=Table, + mapping=mapping, + schema=schema, + metadata=metadata) @staticmethod def from_batches(batches, Schema schema=None): @@ -2272,3 +2274,51 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): ConcatenateTables(c_tables, options, pool)) return pyarrow_wrap_table(c_result_table) + + +def _from_pydict(cls, mapping, schema, metadata): + """ + Construct a Table/RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + cls : Class Table/RecordBatch + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table/RecordBatch + """ + + arrays = [] + if schema is None: + names = [] + for k, v in mapping.items(): + names.append(k) + arrays.append(asarray(v)) + return cls.from_arrays(arrays, names, metadata=metadata) + elif isinstance(schema, Schema): + for field in schema: + try: + v = mapping[field.name] + except KeyError: + try: + v = mapping[tobytes(field.name)] + except KeyError: + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) + arrays.append(asarray(v, type=field.type)) + # Will raise if metadata is not None + return cls.from_arrays(arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema') diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 7ba844aa809..72bb8ef2d99 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1339,8 +1339,15 @@ def test_from_arrays_schema(data, klass): pa.Table.from_arrays(data, schema=schema, metadata={b'foo': b'bar'}) -def test_table_from_pydict(): - table = pa.Table.from_pydict({}) +@pytest.mark.parametrize( + ('cls'), + [ + (pa.Table), + (pa.RecordBatch) + ] +) +def test_table_from_pydict(cls): + table = cls.from_pydict({}) assert table.num_columns == 0 assert table.num_rows == 0 assert table.schema == pa.schema([]) @@ -1351,7 +1358,7 @@ def test_table_from_pydict(): # With lists as values data = OrderedDict([('strs', ['', 'foo', 'bar']), ('floats', [4.5, 5, None])]) - table = pa.Table.from_pydict(data) + table = cls.from_pydict(data) assert table.num_columns == 2 assert table.num_rows == 3 assert table.schema == schema @@ -1360,29 +1367,29 @@ def test_table_from_pydict(): # With metadata and inferred schema metadata = {b'foo': b'bar'} schema = schema.with_metadata(metadata) - table = pa.Table.from_pydict(data, metadata=metadata) + table = cls.from_pydict(data, metadata=metadata) assert table.schema == schema assert table.schema.metadata == metadata assert table.to_pydict() == data # With explicit schema - table = pa.Table.from_pydict(data, schema=schema) + table = cls.from_pydict(data, schema=schema) assert table.schema == schema assert table.schema.metadata == metadata assert table.to_pydict() == data # Cannot pass both schema and metadata with pytest.raises(ValueError): - pa.Table.from_pydict(data, schema=schema, metadata=metadata) + cls.from_pydict(data, schema=schema, metadata=metadata) # Non-convertible values given schema with pytest.raises(TypeError): - pa.Table.from_pydict({'c0': [0, 1, 2]}, - schema=pa.schema([("c0", pa.string())])) + cls.from_pydict({'c0': [0, 1, 2]}, + schema=pa.schema([("c0", pa.string())])) # Missing schema fields from the passed mapping with pytest.raises(KeyError, match="doesn\'t contain.* c, d"): - pa.Table.from_pydict( + cls.from_pydict( {'a': [1, 2, 3], 'b': [3, 4, 5]}, schema=pa.schema([ ('a', pa.int64()), @@ -1393,7 +1400,7 @@ def test_table_from_pydict(): # Passed wrong schema type with pytest.raises(TypeError): - pa.Table.from_pydict({'a': [1, 2, 3]}, schema={}) + cls.from_pydict({'a': [1, 2, 3]}, schema={}) @pytest.mark.parametrize('data, klass', [