From 942d5f71abedcf42de00de504f866586f38d0d16 Mon Sep 17 00:00:00 2001 From: Nate Clark Date: Wed, 9 Jun 2021 14:31:31 -0400 Subject: [PATCH] ARROW-12995: [C++] Add validation to CSV options --- cpp/src/arrow/csv/options.cc | 43 ++++++++++++++++ cpp/src/arrow/csv/options.h | 14 ++++++ cpp/src/arrow/csv/reader.cc | 8 +++ cpp/src/arrow/csv/writer.cc | 2 + python/pyarrow/_csv.pyx | 13 +++++ python/pyarrow/includes/libarrow.pxd | 8 +++ python/pyarrow/tests/test_csv.py | 74 ++++++++++++++++++++++++++++ 7 files changed, 162 insertions(+) diff --git a/cpp/src/arrow/csv/options.cc b/cpp/src/arrow/csv/options.cc index a515abf2cf4..c71cfdaf295 100644 --- a/cpp/src/arrow/csv/options.cc +++ b/cpp/src/arrow/csv/options.cc @@ -22,6 +22,19 @@ namespace csv { ParseOptions ParseOptions::Defaults() { return ParseOptions(); } +Status ParseOptions::Validate() const { + if (ARROW_PREDICT_FALSE(delimiter == '\n' || delimiter == '\r')) { + return Status::Invalid("ParseOptions: delimiter cannot be \\r or \\n"); + } + if (ARROW_PREDICT_FALSE(quoting && (quote_char == '\n' || quote_char == '\r'))) { + return Status::Invalid("ParseOptions: quote_char cannot be \\r or \\n"); + } + if (ARROW_PREDICT_FALSE(escaping && (escape_char == '\n' || escape_char == '\r'))) { + return Status::Invalid("ParseOptions: escape_char cannot be \\r or \\n"); + } + return Status::OK(); +} + ConvertOptions ConvertOptions::Defaults() { auto options = ConvertOptions(); // Same default null / true / false spellings as in Pandas. @@ -33,8 +46,38 @@ ConvertOptions ConvertOptions::Defaults() { return options; } +Status ConvertOptions::Validate() const { return Status::OK(); } + ReadOptions ReadOptions::Defaults() { return ReadOptions(); } + +Status ReadOptions::Validate() const { + if (ARROW_PREDICT_FALSE(block_size < 1)) { + // Min is 1 because some tests use really small block sizes + return Status::Invalid("ReadOptions: block_size must be at least 1: ", block_size); + } + if (ARROW_PREDICT_FALSE(skip_rows < 0)) { + return Status::Invalid("ReadOptions: skip_rows cannot be negative: ", skip_rows); + } + if (ARROW_PREDICT_FALSE(skip_rows_after_names < 0)) { + return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ", + skip_rows_after_names); + } + if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) { + return Status::Invalid( + "ReadOptions: autogenerate_column_names cannot be true when column_names are " + "provided"); + } + return Status::OK(); +} + WriteOptions WriteOptions::Defaults() { return WriteOptions(); } +Status WriteOptions::Validate() const { + if (ARROW_PREDICT_FALSE(batch_size < 1)) { + return Status::Invalid("WriteOptions: batch_size must be at least 1: ", batch_size); + } + return Status::OK(); +} + } // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/csv/options.h b/cpp/src/arrow/csv/options.h index d9c94a03f86..790c47fc3f4 100644 --- a/cpp/src/arrow/csv/options.h +++ b/cpp/src/arrow/csv/options.h @@ -24,6 +24,7 @@ #include #include "arrow/csv/type_fwd.h" +#include "arrow/status.h" #include "arrow/util/visibility.h" namespace arrow { @@ -59,6 +60,9 @@ struct ARROW_EXPORT ParseOptions { /// Create parsing options with default values static ParseOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; }; struct ARROW_EXPORT ConvertOptions { @@ -112,6 +116,9 @@ struct ARROW_EXPORT ConvertOptions { /// Create conversion options with default values, including conventional /// values for `null_values`, `true_values` and `false_values` static ConvertOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; }; struct ARROW_EXPORT ReadOptions { @@ -124,6 +131,7 @@ struct ARROW_EXPORT ReadOptions { /// /// This will determine multi-threading granularity as well as /// the size of individual record batches. + /// Minimum valid value for block size is 1 int32_t block_size = 1 << 20; // 1 MB /// Number of header rows to skip (not including the row of column names, if any) @@ -143,6 +151,9 @@ struct ARROW_EXPORT ReadOptions { /// Create read options with default values static ReadOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; }; /// Experimental @@ -158,6 +169,9 @@ struct ARROW_EXPORT WriteOptions { /// Create write options with default values static WriteOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; }; } // namespace csv diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index 068e06178c8..f221ffcadd9 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -1033,6 +1033,9 @@ Result> MakeTableReader( MemoryPool* pool, io::IOContext io_context, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, const ConvertOptions& convert_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); + RETURN_NOT_OK(convert_options.Validate()); std::shared_ptr reader; if (read_options.use_threads) { auto cpu_executor = internal::GetCpuThreadPool(); @@ -1051,6 +1054,9 @@ Future> MakeStreamingReader( io::IOContext io_context, std::shared_ptr input, internal::Executor* cpu_executor, const ReadOptions& read_options, const ParseOptions& parse_options, const ConvertOptions& convert_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); + RETURN_NOT_OK(convert_options.Validate()); std::shared_ptr reader; reader = std::make_shared( io_context, cpu_executor, input, read_options, parse_options, convert_options, @@ -1182,6 +1188,8 @@ Future CountRowsAsync(io::IOContext io_context, internal::Executor* cpu_executor, const ReadOptions& read_options, const ParseOptions& parse_options) { + RETURN_NOT_OK(parse_options.Validate()); + RETURN_NOT_OK(read_options.Validate()); auto counter = std::make_shared( io_context, cpu_executor, std::move(input), read_options, parse_options); return counter->Count(); diff --git a/cpp/src/arrow/csv/writer.cc b/cpp/src/arrow/csv/writer.cc index ddd59b46fc1..e1c34a77ae9 100644 --- a/cpp/src/arrow/csv/writer.cc +++ b/cpp/src/arrow/csv/writer.cc @@ -414,6 +414,7 @@ class CSVConverter { Status WriteCSV(const Table& table, const WriteOptions& options, MemoryPool* pool, arrow::io::OutputStream* output) { + RETURN_NOT_OK(options.Validate()); if (pool == nullptr) { pool = default_memory_pool(); } @@ -424,6 +425,7 @@ Status WriteCSV(const Table& table, const WriteOptions& options, MemoryPool* poo Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, MemoryPool* pool, arrow::io::OutputStream* output) { + RETURN_NOT_OK(options.Validate()); if (pool == nullptr) { pool = default_memory_pool(); } diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index e7dda3fb953..8ede8272c07 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -58,6 +58,7 @@ cdef class ReadOptions(_Weakrefable): How much bytes to process at a time from the input stream. This will determine multi-threading granularity as well as the size of individual record batches or table chunks. + Minimum valid value for block size is 1 skip_rows: int, optional (default 0) The number of rows to skip before the column names (if any) and the CSV data. @@ -189,6 +190,9 @@ cdef class ReadOptions(_Weakrefable): def skip_rows_after_names(self, value): deref(self.options).skip_rows_after_names = value + def validate(self): + check_status(deref(self.options).Validate()) + def equals(self, ReadOptions other): return ( self.use_threads == other.use_threads and @@ -359,6 +363,9 @@ cdef class ParseOptions(_Weakrefable): def ignore_empty_lines(self, value): deref(self.options).ignore_empty_lines = value + def validate(self): + check_status(deref(self.options).Validate()) + def equals(self, ParseOptions other): return ( self.delimiter == other.delimiter and @@ -680,6 +687,9 @@ cdef class ConvertOptions(_Weakrefable): out.options.reset(new CCSVConvertOptions(move(options))) return out + def validate(self): + check_status(deref(self.options).Validate()) + def equals(self, ConvertOptions other): return ( self.check_utf8 == other.check_utf8 and @@ -941,6 +951,9 @@ cdef class WriteOptions(_Weakrefable): def batch_size(self, value): self.options.batch_size = value + def validate(self): + check_status(self.options.Validate()) + cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out): if write_options is None: diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 35a2034eba4..b1fb04a1f8e 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1592,6 +1592,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: @staticmethod CCSVParseOptions Defaults() + CStatus Validate() + cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions": c_bool check_utf8 unordered_map[c_string, shared_ptr[CDataType]] column_types @@ -1613,6 +1615,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: @staticmethod CCSVConvertOptions Defaults() + CStatus Validate() + cdef cppclass CCSVReadOptions" arrow::csv::ReadOptions": c_bool use_threads int32_t block_size @@ -1627,6 +1631,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: @staticmethod CCSVReadOptions Defaults() + CStatus Validate() + cdef cppclass CCSVWriteOptions" arrow::csv::WriteOptions": c_bool include_header int32_t batch_size @@ -1634,6 +1640,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: @staticmethod CCSVWriteOptions Defaults() + CStatus Validate() + cdef cppclass CCSVReader" arrow::csv::TableReader": @staticmethod CResult[shared_ptr[CCSVReader]] Make( diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index 32c0353fada..48cdff75f97 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -132,6 +132,34 @@ def test_read_options(): opts = cls(block_size=1234) assert opts.block_size == 1234 + opts.validate() + + match = "ReadOptions: block_size must be at least 1: 0" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.block_size = 0 + opts.validate() + + match = "ReadOptions: skip_rows cannot be negative: -1" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.skip_rows = -1 + opts.validate() + + match = "ReadOptions: skip_rows_after_names cannot be negative: -1" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.skip_rows_after_names = -1 + opts.validate() + + match = "ReadOptions: autogenerate_column_names cannot be true when" \ + " column_names are provided" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.autogenerate_column_names = True + opts.column_names = ('a', 'b') + opts.validate() + def test_parse_options(): cls = ParseOptions @@ -150,6 +178,44 @@ def test_parse_options(): newlines_in_values=True, ignore_empty_lines=False) + cls().validate() + opts = cls() + opts.delimiter = "\t" + opts.validate() + + match = "ParseOptions: delimiter cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.delimiter = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.delimiter = "\r" + opts.validate() + + match = "ParseOptions: quote_char cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.quote_char = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.quote_char = "\r" + opts.validate() + + match = "ParseOptions: escape_char cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.escape_char = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.escape_char = "\r" + opts.validate() + def test_convert_options(): cls = ConvertOptions @@ -238,6 +304,14 @@ def test_write_options(): opts = cls(batch_size=9876) assert opts.batch_size == 9876 + opts.validate() + + match = "WriteOptions: batch_size must be at least 1: 0" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.batch_size = 0 + opts.validate() + class BaseTestCSVRead: