Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rust/arrow/examples/read_csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ fn main() -> Result<()> {

let file = File::open("test/data/uk_cities.csv").unwrap();

let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None);
let mut csv =
csv::Reader::new(file, Arc::new(schema), false, None, false, 1024, None, None);
let _batch = csv.next().unwrap().unwrap();
#[cfg(feature = "prettyprint")]
{
Expand Down
71 changes: 67 additions & 4 deletions rust/arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
//!
//! let file = File::open("test/data/uk_cities.csv").unwrap();
//!
//! let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None);
//! let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, false, 1024, None, None);
//! let batch = csv.next().unwrap().unwrap();
//! ```

Expand All @@ -56,7 +56,7 @@ use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;

use self::csv_crate::{ByteRecord, StringRecord};
use self::csv_crate::{ByteRecord, StringRecord, Trim};

lazy_static! {
static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d+\.\d+)$").unwrap();
Expand Down Expand Up @@ -102,11 +102,15 @@ fn infer_field_schema(string: &str) -> DataType {
fn infer_file_schema<R: Read + Seek>(
reader: &mut R,
delimiter: u8,
trim: bool,
max_read_records: Option<usize>,
has_header: bool,
) -> Result<(Schema, usize)> {
let t = if trim { Trim::All } else { Trim::None };

let mut csv_reader = csv_crate::ReaderBuilder::new()
.delimiter(delimiter)
.trim(t)
.from_reader(reader);

// get or create header names
Expand Down Expand Up @@ -199,6 +203,7 @@ fn infer_file_schema<R: Read + Seek>(
pub fn infer_schema_from_files(
files: &[String],
delimiter: u8,
trim: bool,
max_read_records: Option<usize>,
has_header: bool,
) -> Result<Schema> {
Expand All @@ -209,6 +214,7 @@ pub fn infer_schema_from_files(
let (schema, records_read) = infer_file_schema(
&mut File::open(fname)?,
delimiter,
trim,
Some(records_to_read),
has_header,
)?;
Expand Down Expand Up @@ -270,12 +276,13 @@ impl<R: Read> Reader<R> {
schema: SchemaRef,
has_header: bool,
delimiter: Option<u8>,
trim: bool,
batch_size: usize,
bounds: Bounds,
projection: Option<Vec<usize>>,
) -> Self {
Self::from_reader(
reader, schema, has_header, delimiter, batch_size, bounds, projection,
reader, schema, has_header, delimiter, trim, batch_size, bounds, projection,
)
}

Expand Down Expand Up @@ -303,6 +310,7 @@ impl<R: Read> Reader<R> {
schema: SchemaRef,
has_header: bool,
delimiter: Option<u8>,
trim: bool,
batch_size: usize,
bounds: Bounds,
projection: Option<Vec<usize>>,
Expand All @@ -314,6 +322,10 @@ impl<R: Read> Reader<R> {
reader_builder.delimiter(c);
}

if trim {
reader_builder.trim(Trim::All);
}

let mut csv_reader = reader_builder.from_reader(reader);

let (start, end) = match bounds {
Expand Down Expand Up @@ -635,6 +647,8 @@ pub struct ReaderBuilder {
has_header: bool,
/// An optional column delimiter. Defaults to `b','`
delimiter: Option<u8>,
/// Whether to trim strings before parsing. Defaults to false.
trim: bool,
/// Optional maximum number of records to read during schema inference
///
/// If a number is not provided, all the records are read.
Expand All @@ -655,6 +669,7 @@ impl Default for ReaderBuilder {
schema: None,
has_header: false,
delimiter: None,
trim: false,
max_records: None,
batch_size: 1024,
bounds: None,
Expand Down Expand Up @@ -729,6 +744,12 @@ impl ReaderBuilder {
self
}

/// Set the reader's trim setting.
pub fn with_trim(mut self, trim: bool) -> Self {
self.trim = trim;
self
}

/// Create a new `Reader` from the `ReaderBuilder`
pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<Reader<R>> {
// check if schema should be inferred
Expand All @@ -739,6 +760,7 @@ impl ReaderBuilder {
let (inferred_schema, _) = infer_file_schema(
&mut reader,
delimiter,
self.trim,
self.max_records,
self.has_header,
)?;
Expand All @@ -751,6 +773,7 @@ impl ReaderBuilder {
schema,
self.has_header,
self.delimiter,
self.trim,
self.batch_size,
None,
self.projection.clone(),
Expand Down Expand Up @@ -784,6 +807,7 @@ mod tests {
Arc::new(schema.clone()),
false,
None,
false,
1024,
None,
None,
Expand Down Expand Up @@ -830,6 +854,7 @@ mod tests {
Arc::new(schema),
true,
None,
false,
1024,
None,
None,
Expand Down Expand Up @@ -927,6 +952,7 @@ mod tests {
Arc::new(schema),
false,
None,
false,
1024,
None,
Some(vec![0, 1]),
Expand All @@ -952,7 +978,8 @@ mod tests {

let file = File::open("test/data/null_test.csv").unwrap();

let mut csv = Reader::new(file, Arc::new(schema), true, None, 1024, None, None);
let mut csv =
Reader::new(file, Arc::new(schema), true, None, false, 1024, None, None);
let batch = csv.next().unwrap().unwrap();

assert_eq!(false, batch.column(1).is_null(0));
Expand Down Expand Up @@ -1119,6 +1146,7 @@ mod tests {
csv4.path().to_str().unwrap().to_string(),
],
b',',
false,
Some(3), // only csv1 and csv2 should be read
true,
)?;
Expand Down Expand Up @@ -1164,6 +1192,7 @@ mod tests {
Arc::new(schema),
false,
None,
false,
2,
// starting at row 2 and up to row 6.
Some((2, 6)),
Expand Down Expand Up @@ -1222,4 +1251,38 @@ mod tests {
assert_eq!(None, parse_item::<Float64Type>("dd"));
assert_eq!(None, parse_item::<Float64Type>("12.34.56"));
}

#[test]
fn test_trim() -> Result<()> {
let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
// create data with deliberate spaces that will not parse without trim
let data = vec![vec!["0"], vec![" 1"], vec!["2 "], vec![" 3 "]];

let data = data
.iter()
.map(|x| x.join(","))
.collect::<Vec<_>>()
.join("\n");
let data = data.as_bytes();

let reader = std::io::Cursor::new(data);

let mut csv = Reader::new(
reader,
Arc::new(schema),
false,
None,
true,
1024,
None,
Some(vec![0]),
);

let batch = csv.next().unwrap().unwrap();
let a = batch.column(0);
let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
assert_eq!(a, &UInt32Array::from(vec![0, 1, 2, 3]));

Ok(())
}
}
3 changes: 3 additions & 0 deletions rust/datafusion/src/datasource/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct CsvFile {
schema: SchemaRef,
has_header: bool,
delimiter: u8,
trim: bool,
file_extension: String,
statistics: Statistics,
}
Expand All @@ -77,6 +78,7 @@ impl CsvFile {
schema,
has_header: options.has_header,
delimiter: options.delimiter,
trim: options.trim,
file_extension: String::from(options.file_extension),
statistics: Statistics::default(),
})
Expand Down Expand Up @@ -104,6 +106,7 @@ impl TableProvider for CsvFile {
.schema(&self.schema)
.has_header(self.has_header)
.delimiter(self.delimiter)
.trim(self.trim)
.file_extension(self.file_extension.as_str()),
projection.clone(),
batch_size,
Expand Down
17 changes: 17 additions & 0 deletions rust/datafusion/src/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pub struct CsvReadOptions<'a> {
/// File extension; only files with this extension are selected for data input.
/// Defaults to ".csv".
pub file_extension: &'a str,
/// Whether to trim string values before parsing into types.
/// Defaults to false.
pub trim: bool,
}

impl<'a> CsvReadOptions<'a> {
Expand All @@ -64,6 +67,7 @@ impl<'a> CsvReadOptions<'a> {
schema_infer_max_records: 1000,
delimiter: b',',
file_extension: ".csv",
trim: false,
}
}

Expand All @@ -79,6 +83,12 @@ impl<'a> CsvReadOptions<'a> {
self
}

/// Configure trim setting
pub fn trim(mut self, trim: bool) -> Self {
self.trim = trim;
self
}

/// Specify the file extension for CSV file selection
pub fn file_extension(mut self, file_extension: &'a str) -> Self {
self.file_extension = file_extension;
Expand Down Expand Up @@ -119,6 +129,8 @@ pub struct CsvExec {
has_header: bool,
/// An optional column delimiter. Defaults to `b','`
delimiter: Option<u8>,
/// Trim the string values before parsing into types?
trim: bool,
/// File extension
file_extension: String,
/// Optional projection for which columns to load
Expand Down Expand Up @@ -161,6 +173,7 @@ impl CsvExec {
schema: Arc::new(schema),
has_header: options.has_header,
delimiter: Some(options.delimiter),
trim: options.trim,
file_extension,
projection,
projected_schema: Arc::new(projected_schema),
Expand All @@ -176,6 +189,7 @@ impl CsvExec {
Ok(csv::infer_schema_from_files(
filenames,
options.delimiter,
options.trim,
Some(options.schema_infer_max_records),
options.has_header,
)?)
Expand Down Expand Up @@ -224,6 +238,7 @@ impl ExecutionPlan for CsvExec {
self.schema.clone(),
self.has_header,
self.delimiter,
self.trim,
&self.projection,
self.batch_size,
)?))
Expand All @@ -243,6 +258,7 @@ impl CsvStream {
schema: SchemaRef,
has_header: bool,
delimiter: Option<u8>,
trim: bool,
projection: &Option<Vec<usize>>,
batch_size: usize,
) -> Result<Self> {
Expand All @@ -252,6 +268,7 @@ impl CsvStream {
schema,
has_header,
delimiter,
trim,
batch_size,
None,
projection.clone(),
Expand Down