diff --git a/python/pyarrow/feather.py b/python/pyarrow/feather.py index 3754aec7372..34783a71e3d 100644 --- a/python/pyarrow/feather.py +++ b/python/pyarrow/feather.py @@ -37,7 +37,7 @@ def __init__(self, source): self.source = source self.open(source) - def read(self, columns=None): + def read(self, columns=None, nthreads=1): if columns is not None: column_set = set(columns) else: @@ -53,7 +53,7 @@ def read(self, columns=None): names.append(name) table = Table.from_arrays(columns, names=names) - return table.to_pandas() + return table.to_pandas(nthreads=nthreads) class FeatherWriter(object): @@ -118,7 +118,7 @@ def write_feather(df, dest): raise -def read_feather(source, columns=None): +def read_feather(source, columns=None, nthreads=1): """ Read a pandas.DataFrame from Feather format @@ -128,10 +128,12 @@ def read_feather(source, columns=None): columns : sequence, optional Only read a specific set of columns. If not provided, all columns are read + nthreads : int, default 1 + Number of CPU threads to use when reading to pandas.DataFrame Returns ------- df : pandas.DataFrame """ reader = FeatherReader(source) - return reader.read(columns=columns) + return reader.read(columns=columns, nthreads=nthreads) diff --git a/python/pyarrow/tests/test_feather.py b/python/pyarrow/tests/test_feather.py index 69c32be5f3d..287e0da2f55 100644 --- a/python/pyarrow/tests/test_feather.py +++ b/python/pyarrow/tests/test_feather.py @@ -61,7 +61,8 @@ def _get_null_counts(self, path, columns=None): return counts def _check_pandas_roundtrip(self, df, expected=None, path=None, - columns=None, null_counts=None): + columns=None, null_counts=None, + nthreads=1): if path is None: path = random_path() @@ -70,7 +71,7 @@ def _check_pandas_roundtrip(self, df, expected=None, path=None, if not os.path.exists(path): raise Exception('file not written') - result = read_feather(path, columns) + result = read_feather(path, columns, nthreads=nthreads) if expected is None: expected = df @@ -293,6 +294,12 @@ def test_empty_strings(self): df = pd.DataFrame({'strings': [''] * 10}) self._check_pandas_roundtrip(df) + def test_multithreaded_read(self): + data = {'c{0}'.format(i): [''] * 10 + for i in range(100)} + df = pd.DataFrame(data) + self._check_pandas_roundtrip(df, nthreads=4) + def test_nan_as_null(self): # Create a nan that is not numpy.nan values = np.array(['foo', np.nan, np.nan * 2, 'bar'] * 10)