diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs index ffeffddc78c..cf28e38406b 100644 --- a/rust/arrow/src/csv/reader.rs +++ b/rust/arrow/src/csv/reader.rs @@ -217,6 +217,21 @@ impl Reader { ) } + /// Returns the schema of the reader, useful for getting the schema without reading + /// record batches + pub fn schema(&self) -> Arc { + match &self.projection { + Some(projection) => { + let fields = self.schema.fields(); + let projected_fields: Vec = + projection.iter().map(|i| fields[*i].clone()).collect(); + + Arc::new(Schema::new(projected_fields)) + } + None => self.schema.clone(), + } + } + /// Create a new CsvReader from a `BufReader /// /// This constructor allows you more flexibility in what records are processed by the @@ -536,7 +551,8 @@ mod tests { let file = File::open("test/data/uk_cities.csv").unwrap(); - let mut csv = Reader::new(file, Arc::new(schema), false, 1024, None); + let mut csv = Reader::new(file, Arc::new(schema.clone()), false, 1024, None); + assert_eq!(Arc::new(schema), csv.schema()); let batch = csv.next().unwrap().unwrap(); assert_eq!(37, batch.num_rows()); assert_eq!(3, batch.num_columns()); @@ -594,6 +610,12 @@ mod tests { let builder = ReaderBuilder::new().has_headers(true).infer_schema(None); let mut csv = builder.build(file).unwrap(); + let expected_schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); + assert_eq!(Arc::new(expected_schema), csv.schema()); let batch = csv.next().unwrap().unwrap(); assert_eq!(37, batch.num_rows()); assert_eq!(3, batch.num_columns()); @@ -625,14 +647,16 @@ mod tests { let builder = ReaderBuilder::new().infer_schema(None); let mut csv = builder.build(file).unwrap(); - let batch = csv.next().unwrap().unwrap(); // csv field names should be 'column_{number}' - let schema = batch.schema(); + let schema = csv.schema(); assert_eq!("column_1", schema.field(0).name()); assert_eq!("column_2", schema.field(1).name()); assert_eq!("column_3", schema.field(2).name()); + let batch = csv.next().unwrap().unwrap(); + let batch_schema = batch.schema(); + assert_eq!(&schema, batch_schema); assert_eq!(37, batch.num_rows()); assert_eq!(3, batch.num_columns()); @@ -667,7 +691,13 @@ mod tests { let file = File::open("test/data/uk_cities.csv").unwrap(); let mut csv = Reader::new(file, Arc::new(schema), false, 1024, Some(vec![0, 1])); + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema.clone(), csv.schema()); let batch = csv.next().unwrap().unwrap(); + assert_eq!(&projected_schema, batch.schema()); assert_eq!(37, batch.num_rows()); assert_eq!(2, batch.num_columns()); } diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs index 8bdbf89ea15..467a89a92b9 100644 --- a/rust/arrow/src/json/reader.rs +++ b/rust/arrow/src/json/reader.rs @@ -345,6 +345,29 @@ impl Reader { } } + /// Returns the schema of the reader, useful for getting the schema without reading + /// record batches + pub fn schema(&self) -> Arc { + match &self.projection { + Some(projection) => { + let fields = self.schema.fields(); + let projected_fields: Vec = fields + .iter() + .filter_map(|field| { + if projection.contains(field.name()) { + Some(field.clone()) + } else { + None + } + }) + .collect(); + + Arc::new(Schema::new(projected_fields)) + } + None => self.schema.clone(), + } + } + /// Read the next batch of records pub fn next(&mut self) -> Result> { let mut rows: Vec = Vec::with_capacity(self.batch_size); @@ -742,7 +765,9 @@ mod tests { assert_eq!(4, batch.num_columns()); assert_eq!(12, batch.num_rows()); - let schema = batch.schema(); + let schema = reader.schema(); + let batch_schema = batch.schema(); + assert_eq!(&schema, batch_schema); let a = schema.column_with_name("a").unwrap(); assert_eq!(0, a.0); @@ -798,7 +823,9 @@ mod tests { assert_eq!(4, batch.num_columns()); assert_eq!(12, batch.num_rows()); - let schema = batch.schema(); + let schema = reader.schema(); + let batch_schema = batch.schema(); + assert_eq!(&schema, batch_schema); let a = schema.column_with_name("a").unwrap(); assert_eq!(&DataType::Int64, a.1.data_type()); @@ -855,10 +882,12 @@ mod tests { let mut reader: Reader = Reader::new( BufReader::new(File::open("test/data/basic.json").unwrap()), - Arc::new(schema), + Arc::new(schema.clone()), 1024, None, ); + let reader_schema = reader.schema(); + assert_eq!(reader_schema, Arc::new(schema)); let batch = reader.next().unwrap().unwrap(); assert_eq!(4, batch.num_columns()); @@ -909,6 +938,13 @@ mod tests { 1024, Some(vec!["a".to_string(), "c".to_string()]), ); + let reader_schema = reader.schema(); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("c", DataType::Boolean, false), + ])); + assert_eq!(reader_schema.clone(), expected_schema); + let batch = reader.next().unwrap().unwrap(); assert_eq!(2, batch.num_columns()); @@ -916,6 +952,7 @@ mod tests { assert_eq!(12, batch.num_rows()); let schema = batch.schema(); + assert_eq!(&reader_schema, schema); let a = schema.column_with_name("a").unwrap(); assert_eq!(0, a.0);