diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index d4e0af7808c..780f845429b 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -183,9 +183,15 @@ static inline Future> OpenReaderAsync( auto tracer = arrow::internal::tracing::GetTracer(); auto span = tracer->StartSpan("arrow::dataset::CsvFileFormat::OpenReaderAsync"); #endif + ARROW_ASSIGN_OR_RAISE( + auto fragment_scan_options, + GetFragmentScanOptions( + kCsvTypeName, scan_options.get(), format.default_fragment_scan_options)); ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options)); - ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed()); + if (fragment_scan_options->stream_transform_func) { + ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input)); + } const auto& path = source.path(); ARROW_ASSIGN_OR_RAISE( input, io::BufferedInputStream::Create(reader_options.block_size, @@ -289,8 +295,15 @@ Future> CsvFileFormat::CountRows( return Future>::MakeFinished(util::nullopt); } auto self = checked_pointer_cast(shared_from_this()); - ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed()); + ARROW_ASSIGN_OR_RAISE( + auto fragment_scan_options, + GetFragmentScanOptions( + kCsvTypeName, options.get(), self->default_fragment_scan_options)); ARROW_ASSIGN_OR_RAISE(auto read_options, GetReadOptions(*self, options)); + ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed()); + if (fragment_scan_options->stream_transform_func) { + ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input)); + } return csv::CountRowsAsync(options->io_context, std::move(input), ::arrow::internal::GetCpuThreadPool(), read_options, self->parse_options) diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index 83dbb88b85f..84bcf94abe3 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -73,6 +73,9 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { std::string type_name() const override { return kCsvTypeName; } + using StreamWrapFunc = std::function>( + std::shared_ptr)>; + /// CSV conversion options csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults(); @@ -80,6 +83,13 @@ struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { /// /// Note that use_threads is always ignored. csv::ReadOptions read_options = csv::ReadOptions::Defaults(); + + /// Optional stream wrapping function + /// + /// If defined, all open dataset file fragments will be passed + /// through this function. One possible use case is to transparently + /// transcode all input files from a given character set to utf8. + StreamWrapFunc stream_transform_func{}; }; class ARROW_DS_EXPORT CsvFileWriteOptions : public FileWriteOptions { diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 68833a5350e..c22d992ff18 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -21,6 +21,7 @@ from cython.operator cimport dereference as deref +import codecs import collections import os import warnings @@ -831,8 +832,14 @@ cdef class FileFormat(_Weakrefable): @property def default_fragment_scan_options(self): - return FragmentScanOptions.wrap( + dfso = FragmentScanOptions.wrap( self.wrapped.get().default_fragment_scan_options) + # CsvFileFormat stores a Python-specific encoding field that needs + # to be restored because it does not exist in the C++ struct + if isinstance(self, CsvFileFormat): + if self._read_options_py is not None: + dfso.read_options = self._read_options_py + return dfso @default_fragment_scan_options.setter def default_fragment_scan_options(self, FragmentScanOptions options): @@ -1171,6 +1178,10 @@ cdef class CsvFileFormat(FileFormat): """ cdef: CCsvFileFormat* csv_format + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading + # default_fragment_scan_options.read_options + public ReadOptions _read_options_py # Avoid mistakingly creating attributes __slots__ = () @@ -1198,6 +1209,8 @@ cdef class CsvFileFormat(FileFormat): raise TypeError('`default_fragment_scan_options` must be either ' 'a dictionary or an instance of ' 'CsvFragmentScanOptions') + if read_options is not None: + self._read_options_py = read_options cdef void init(self, const shared_ptr[CFileFormat]& sp): FileFormat.init(self, sp) @@ -1220,6 +1233,8 @@ cdef class CsvFileFormat(FileFormat): cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): if options.type_name == 'csv': self.csv_format.default_fragment_scan_options = options.wrapped + self.default_fragment_scan_options.read_options = options.read_options + self._read_options_py = options.read_options else: super()._set_default_fragment_scan_options(options) @@ -1251,6 +1266,9 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): cdef: CCsvFragmentScanOptions* csv_options + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading read_options + ReadOptions _read_options_py # Avoid mistakingly creating attributes __slots__ = () @@ -1263,6 +1281,7 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): self.convert_options = convert_options if read_options is not None: self.read_options = read_options + self._read_options_py = read_options cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): FragmentScanOptions.init(self, sp) @@ -1278,11 +1297,18 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): @property def read_options(self): - return ReadOptions.wrap(self.csv_options.read_options) + read_options = ReadOptions.wrap(self.csv_options.read_options) + if self._read_options_py is not None: + read_options.encoding = self._read_options_py.encoding + return read_options @read_options.setter def read_options(self, ReadOptions read_options not None): self.csv_options.read_options = deref(read_options.options) + self._read_options_py = read_options + if codecs.lookup(read_options.encoding).name != 'utf-8': + self.csv_options.stream_transform_func = deref( + make_streamwrap_func(read_options.encoding, 'utf-8')) def equals(self, CsvFragmentScanOptions other): return ( diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 781d2ce7ad6..b3fada56680 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1212,6 +1212,9 @@ cdef extern from "arrow/builder.h" namespace "arrow" nogil: ctypedef void CallbackTransform(object, const shared_ptr[CBuffer]& src, shared_ptr[CBuffer]* dest) +ctypedef CResult[shared_ptr[CInputStream]] StreamWrapFunc( + shared_ptr[CInputStream]) + cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil: cdef cppclass CStopToken "arrow::StopToken": @@ -1379,6 +1382,11 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: shared_ptr[CInputStream] wrapped, CTransformInputStreamVTable vtable, object method_arg) + shared_ptr[function[StreamWrapFunc]] MakeStreamTransformFunc \ + "arrow::py::MakeStreamTransformFunc"( + CTransformInputStreamVTable vtable, + object method_arg) + # ---------------------------------------------------------------------- # HDFS diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index bd8fbd1b56a..ad1bbbc5442 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -277,6 +277,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: "arrow::dataset::CsvFragmentScanOptions"(CFragmentScanOptions): CCSVConvertOptions convert_options CCSVReadOptions read_options + function[StreamWrapFunc] stream_transform_func cdef cppclass CPartitioning "arrow::dataset::Partitioning": c_string type_name() const diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi index d1d3feb3c17..d7f5acfff76 100644 --- a/python/pyarrow/io.pxi +++ b/python/pyarrow/io.pxi @@ -1583,6 +1583,33 @@ class Transcoder: return self._encoder.encode(self._decoder.decode(buf, final), final) +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except *: + """ + Create a function that will add a transcoding transformation to a stream. + Data from that stream will be decoded according to ``src_encoding`` and + then re-encoded according to ``dest_encoding``. + The created function can be used to wrap streams. + + Parameters + ---------- + src_encoding : str + The codec to use when reading data. + dest_encoding : str + The codec to use for emitted data. + """ + cdef: + shared_ptr[function[StreamWrapFunc]] empty_func + CTransformInputStreamVTable vtable + + vtable.transform = _cb_transform + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + return MakeStreamTransformFunc(move(vtable), + Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + def transcoding_input_stream(stream, src_encoding, dest_encoding): """ Add a transcoding transformation to the stream. @@ -1594,7 +1621,7 @@ def transcoding_input_stream(stream, src_encoding, dest_encoding): stream : NativeFile The stream to which the transformation should be applied. src_encoding : str - The codec to use when reading data data. + The codec to use when reading data. dest_encoding : str The codec to use for emitted data. """ diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 953b0e7b518..67db3d2ffb8 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -536,6 +536,9 @@ cdef shared_ptr[CInputStream] native_transcoding_input_stream( shared_ptr[CInputStream] stream, src_encoding, dest_encoding) except * +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except * + # Default is allow_none=False cpdef DataType ensure_type(object type, bint allow_none=*) diff --git a/python/pyarrow/src/io.cc b/python/pyarrow/src/io.cc index 173d84ff567..0aa2c85939f 100644 --- a/python/pyarrow/src/io.cc +++ b/python/pyarrow/src/io.cc @@ -370,5 +370,15 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream( return std::make_shared(std::move(wrapped), std::move(transform)); } +std::shared_ptr MakeStreamTransformFunc(TransformInputStreamVTable vtable, + PyObject* handler) { + TransformInputStream::TransformFunc transform( + TransformFunctionWrapper{std::move(vtable.transform), handler}); + StreamWrapFunc func = [transform](std::shared_ptr<::arrow::io::InputStream> wrapped) { + return std::make_shared(wrapped, transform); + }; + return std::make_shared(func); +} + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/io.h b/python/pyarrow/src/io.h index 53b15434ea6..9d79d566efe 100644 --- a/python/pyarrow/src/io.h +++ b/python/pyarrow/src/io.h @@ -112,5 +112,10 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream( std::shared_ptr<::arrow::io::InputStream> wrapped, TransformInputStreamVTable vtable, PyObject* arg); +using StreamWrapFunc = std::function>( + std::shared_ptr)>; +ARROW_PYTHON_EXPORT +std::shared_ptr MakeStreamTransformFunc(TransformInputStreamVTable vtable, + PyObject* handler); } // namespace py } // namespace arrow diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 3dc9c3beb6e..e6aa789e792 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -3130,6 +3130,55 @@ def test_csv_fragment_options(tempdir, dataset_reader): pa.table({'col0': pa.array(['foo', 'spam', 'MYNULL'])})) +def test_encoding(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + + for encoding, input_rows in [ + ('latin-1', b"a,b\nun,\xe9l\xe9phant"), + ('utf16', b'\xff\xfea\x00,\x00b\x00\n\x00u\x00n\x00,' + b'\x00\xe9\x00l\x00\xe9\x00p\x00h\x00a\x00n\x00t\x00'), + ]: + + with open(path, 'wb') as sink: + sink.write(input_rows) + + # Interpret as utf8: + expected_schema = pa.schema([("a", pa.string()), ("b", pa.string())]) + expected_table = pa.table({'a': ["un"], + 'b': ["éléphant"]}, schema=expected_schema) + + read_options = pa.csv.ReadOptions(encoding=encoding) + file_format = ds.CsvFileFormat(read_options=read_options) + dataset_transcoded = ds.dataset(path, format=file_format) + assert dataset_transcoded.schema.equals(expected_schema) + assert dataset_transcoded.to_table().equals(expected_table) + + +# Test if a dataset with non-utf8 chars in the column names is properly handled +def test_column_names_encoding(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + + with open(path, 'wb') as sink: + sink.write(b"\xe9,b\nun,\xe9l\xe9phant") + + # Interpret as utf8: + expected_schema = pa.schema([("é", pa.string()), ("b", pa.string())]) + expected_table = pa.table({'é': ["un"], + 'b': ["éléphant"]}, schema=expected_schema) + + # Reading as string without specifying encoding should produce an error + dataset = ds.dataset(path, format='csv', schema=expected_schema) + with pytest.raises(pyarrow.lib.ArrowInvalid, match="invalid UTF8"): + dataset_reader.to_table(dataset) + + # Setting the encoding in the read_options should transcode the data + read_options = pa.csv.ReadOptions(encoding='latin-1') + file_format = ds.CsvFileFormat(read_options=read_options) + dataset_transcoded = ds.dataset(path, format=file_format) + assert dataset_transcoded.schema.equals(expected_schema) + assert dataset_transcoded.to_table().equals(expected_table) + + def test_feather_format(tempdir, dataset_reader): from pyarrow.feather import write_feather