diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 30982ec2226a..55c38b8e7d47 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -362,6 +362,26 @@ Result> Table::RenameColumns( return Table::Make(::arrow::schema(std::move(fields)), std::move(columns), num_rows()); } +Result> Table::SelectColumns( + const std::vector& indices) const { + int n = static_cast(indices.size()); + + std::vector> columns(n); + std::vector> fields(n); + for (int i = 0; i < n; i++) { + int pos = indices[i]; + if (pos < 0 || pos > num_columns() - 1) { + return Status::Invalid("Invalid column index ", pos, " to select columns."); + } + columns[i] = column(pos); + fields[i] = field(pos); + } + + auto new_schema = + std::make_shared(std::move(fields), schema()->metadata()); + return Table::Make(new_schema, std::move(columns), num_rows()); +} + std::string Table::ToString() const { std::stringstream ss; ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss)); diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index 20f5c684a717..c547019c9899 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -148,6 +148,9 @@ class ARROW_EXPORT Table { Result> RenameColumns( const std::vector& names) const; + /// \brief Return new table with specified columns + Result> SelectColumns(const std::vector& indices) const; + /// \brief Replace schema key-value metadata with new metadata (EXPERIMENTAL) /// \since 0.5.0 /// diff --git a/cpp/src/arrow/table_test.cc b/cpp/src/arrow/table_test.cc index 38bc20597556..f4deac4934ee 100644 --- a/cpp/src/arrow/table_test.cc +++ b/cpp/src/arrow/table_test.cc @@ -559,6 +559,22 @@ TEST_F(TestTable, RenameColumns) { ASSERT_RAISES(Invalid, table->RenameColumns({"hello", "world"})); } +TEST_F(TestTable, SelectColumns) { + MakeExample1(10); + auto table = Table::Make(schema_, columns_); + + ASSERT_OK_AND_ASSIGN(auto subset, table->SelectColumns({0, 2})); + ASSERT_OK(subset->ValidateFull()); + + auto expexted_schema = ::arrow::schema({schema_->field(0), schema_->field(2)}); + auto expected = Table::Make(expexted_schema, {table->column(0), table->column(2)}); + ASSERT_TRUE(subset->Equals(*expected)); + + // Out of bounds indices + ASSERT_RAISES(Invalid, table->SelectColumns({0, 3})); + ASSERT_RAISES(Invalid, table->SelectColumns({-1})); +} + TEST_F(TestTable, RemoveColumnEmpty) { // ARROW-1865 const int64_t length = 10; diff --git a/python/pyarrow/feather.py b/python/pyarrow/feather.py index 7b813af38689..8830838be371 100644 --- a/python/pyarrow/feather.py +++ b/python/pyarrow/feather.py @@ -258,7 +258,4 @@ def read_table(source, columns=None, memory_map=True): return table else: # follow exact order / selection of names - new_fields = [table.schema.field(c) for c in columns] - new_schema = schema(new_fields, metadata=table.schema.metadata) - new_columns = [table.column(c) for c in columns] - return Table.from_arrays(new_columns, schema=new_schema) + return table.select(columns) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 41b1abe1fb90..d704af400234 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -766,6 +766,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: vector[c_string] ColumnNames() CResult[shared_ptr[CTable]] RenameColumns(const vector[c_string]&) + CResult[shared_ptr[CTable]] SelectColumns(const vector[int]&) CResult[shared_ptr[CTable]] Flatten(CMemoryPool* pool) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 688d668cd75b..37064a51dc55 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1144,6 +1144,37 @@ cdef class Table(_PandasConvertible): """ return _pc().take(self, indices) + def select(self, object columns): + """ + Select columns of the Table. + + Returns a new Table with the specified columns, and metadata + preserved. + + Parameters + ---------- + columns : list-like + The column names or integer indices to select. + + Returns + ------- + Table + + """ + cdef: + shared_ptr[CTable] c_table + vector[int] c_indices + + for idx in columns: + idx = self._ensure_integer_index(idx) + idx = _normalize_index(idx, self.num_columns) + c_indices.push_back( idx) + + with nogil: + c_table = GetResultValue(self.table.SelectColumns(move(c_indices))) + + return pyarrow_wrap_table(c_table) + def replace_schema_metadata(self, metadata=None): """ EXPERIMENTAL: Create shallow copy of table by replacing schema @@ -1583,18 +1614,9 @@ cdef class Table(_PandasConvertible): """ return self.schema.field(i) - def column(self, i): + def _ensure_integer_index(self, i): """ - Select a column by its column name, or numeric index. - - Parameters - ---------- - i : int or string - The index or name of the column to retrieve. - - Returns - ------- - pyarrow.ChunkedArray + Ensure integer index (convert string column name to integer if needed). """ if isinstance(i, (bytes, str)): field_indices = self.schema.get_all_field_indices(i) @@ -1606,12 +1628,27 @@ cdef class Table(_PandasConvertible): raise KeyError("Field \"{}\" exists {} times in table schema" .format(i, len(field_indices))) else: - return self._column(field_indices[0]) + return field_indices[0] elif isinstance(i, int): - return self._column(i) + return i else: raise TypeError("Index must either be string or integer") + def column(self, i): + """ + Select a column by its column name, or numeric index. + + Parameters + ---------- + i : int or string + The index or name of the column to retrieve. + + Returns + ------- + pyarrow.ChunkedArray + """ + return self._column(self._ensure_integer_index(i)) + def _column(self, int i): """ Select a column by its numeric index. diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 2e8174baee88..75b6d5dbf22f 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -758,8 +758,8 @@ def assert_yields_projected(fragment, row_slice, column_names = columns if columns else table.column_names assert actual.column_names == column_names - expected = table.slice(*row_slice).to_pandas()[[*column_names]] - assert actual.equals(pa.Table.from_pandas(expected)) + expected = table.slice(*row_slice).select(column_names) + assert actual.equals(expected) fragment = list(dataset.get_fragments())[0] parquet_format = fragment.format diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 387bbe35ce15..58747b031cc2 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1413,3 +1413,54 @@ def test_table_take_non_consecutive(): ['f1', 'f2']) assert table.take(pa.array([1, 3])).equals(result_non_consecutive) + + +def test_table_select(): + a1 = pa.array([1, 2, 3, None, 5]) + a2 = pa.array(['a', 'b', 'c', 'd', 'e']) + a3 = pa.array([[1, 2], [3, 4], [5, 6], None, [9, 10]]) + table = pa.table([a1, a2, a3], ['f1', 'f2', 'f3']) + + # selecting with string names + result = table.select(['f1']) + expected = pa.table([a1], ['f1']) + assert result.equals(expected) + + result = table.select(['f3', 'f2']) + expected = pa.table([a3, a2], ['f3', 'f2']) + assert result.equals(expected) + + # selecting with integer indices + result = table.select([0]) + expected = pa.table([a1], ['f1']) + assert result.equals(expected) + + result = table.select([2, 1]) + expected = pa.table([a3, a2], ['f3', 'f2']) + assert result.equals(expected) + + # preserve metadata + table2 = table.replace_schema_metadata({"a": "test"}) + result = table2.select(["f1", "f2"]) + assert b"a" in result.schema.metadata + + # selecting non-existing column raises + with pytest.raises(KeyError, match='Field "f5" does not exist'): + table.select(['f5']) + + with pytest.raises(IndexError, match="index out of bounds"): + table.select([5]) + + # duplicate selection gives duplicated names in resulting table + result = table.select(['f2', 'f2']) + expected = pa.table([a2, a2], ['f2', 'f2']) + assert result.equals(expected) + + # selection duplicated column raises + table = pa.table([a1, a2, a3], ['f1', 'f2', 'f1']) + with pytest.raises(KeyError, match='Field "f1" exists 2 times'): + table.select(['f1']) + + result = table.select(['f2']) + expected = pa.table([a2], ['f2']) + assert result.equals(expected)