diff --git a/rust/arrow/src/csv/mod.rs b/rust/arrow/src/csv/mod.rs index ffe82f33580..d35ec69d6e4 100644 --- a/rust/arrow/src/csv/mod.rs +++ b/rust/arrow/src/csv/mod.rs @@ -20,8 +20,8 @@ pub mod reader; pub mod writer; -pub use self::reader::infer_schema_from_files; pub use self::reader::Reader; pub use self::reader::ReaderBuilder; +pub use self::reader::{build_array, build_batch, infer_schema_from_files}; pub use self::writer::Writer; pub use self::writer::WriterBuilder; diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs index c9f97cdc6d0..a5d84d6005d 100644 --- a/rust/arrow/src/csv/reader.rs +++ b/rust/arrow/src/csv/reader.rs @@ -380,8 +380,7 @@ impl Iterator for Reader { return None; } - // parse the batches into a RecordBatch - let result = parse( + let result = build_batch( &self.batch_records[..read_records], &self.schema.fields(), &self.projection, @@ -394,8 +393,92 @@ impl Iterator for Reader { } } -/// parses a slice of [csv_crate::StringRecord] into a [array::record_batch::RecordBatch]. -fn parse( +/// Tries to create an [array::Array] from a slice of [csv_crate::StringRecord] by interpreting its +/// values at column `column_index` to be of `data_type`. +/// `line_number` is where the set of rows starts at, and is only used to report the line number in case of errors. +/// # Error +/// This function errors iff: +/// * _any_ entry from `rows` at `column_index` cannot be parsed into the DataType. +/// * The [array::datatypes::DataType] is not supported. +pub fn build_array( + rows: &[StringRecord], + data_type: &DataType, + line_number: usize, + column_index: usize, +) -> Result { + match data_type { + DataType::Boolean => build_boolean_array(line_number, rows, column_index), + DataType::Int8 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Int16 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Int32 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Int64 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::UInt8 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::UInt16 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::UInt32 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::UInt64 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Float32 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Float64 => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Date32(_) => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Date64(_) => { + build_primitive_array::(line_number, rows, column_index) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + build_primitive_array::( + line_number, + rows, + column_index, + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => build_primitive_array::< + TimestampNanosecondType, + >( + line_number, rows, column_index + ), + DataType::Utf8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(column_index)) + .collect::(), + ) as ArrayRef), + other => Err(ArrowError::ParseError(format!( + "Unsupported data type {:?}", + other + ))), + } +} + +/// Tries to create an [array::record_batch::RecordBatch] from a slice of [csv_crate::StringRecord] by interpreting +/// each of its columns according to `fields`. When `projection` is not None, it is used to select a subset of `fields` to +/// parse. +/// `line_number` is where the set of rows starts at, and is only used to report the line number in case of errors. +/// # Error +/// This function errors iff: +/// * _any_ entry from `rows` cannot be parsed into its corresponding field's `DataType`. +/// * Any of the fields' [array::datatypes::DataType] is not supported. +/// # Panic +/// This function panics if any index in `projection` is larger than `fields.len()`. +pub fn build_batch( rows: &[StringRecord], fields: &[Field], projection: &Option>, @@ -403,79 +486,23 @@ fn parse( ) -> Result { let projection: Vec = match projection { Some(ref v) => v.clone(), - None => fields.iter().enumerate().map(|(i, _)| i).collect(), + None => (0..fields.len()).collect(), }; - let arrays: Result> = projection + let columns = projection .iter() .map(|i| { let i = *i; - let field = &fields[i]; - match field.data_type() { - &DataType::Boolean => build_boolean_array(line_number, rows, i), - &DataType::Int8 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Int16 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Int32 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Int64 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::UInt8 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::UInt16 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::UInt32 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::UInt64 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Float32 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Float64 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Date32(_) => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Date64(_) => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Timestamp(TimeUnit::Microsecond, _) => { - build_primitive_array::( - line_number, - rows, - i, - ) - } - &DataType::Timestamp(TimeUnit::Nanosecond, _) => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Utf8 => Ok(Arc::new( - rows.iter().map(|row| row.get(i)).collect::(), - ) as ArrayRef), - other => Err(ArrowError::ParseError(format!( - "Unsupported data type {:?}", - other - ))), - } + build_array(rows, fields[i].data_type(), line_number, i) }) - .collect(); + .collect::>()?; let projected_fields: Vec = projection.iter().map(|i| fields[*i].clone()).collect(); let projected_schema = Arc::new(Schema::new(projected_fields)); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr)) + RecordBatch::try_new(projected_schema, columns) } /// Specialized parsing implementations