Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions rust/arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ impl<R: Read> Reader<R> {
)
}

/// Returns the schema of the reader, useful for getting the schema without reading
/// record batches
pub fn schema(&self) -> Arc<Schema> {
match &self.projection {
Some(projection) => {
let fields = self.schema.fields();
let projected_fields: Vec<Field> =
projection.iter().map(|i| fields[*i].clone()).collect();

Arc::new(Schema::new(projected_fields))
}
None => self.schema.clone(),
}
}

/// Create a new CsvReader from a `BufReader<R: Read>
///
/// This constructor allows you more flexibility in what records are processed by the
Expand Down Expand Up @@ -536,7 +551,8 @@ mod tests {

let file = File::open("test/data/uk_cities.csv").unwrap();

let mut csv = Reader::new(file, Arc::new(schema), false, 1024, None);
let mut csv = Reader::new(file, Arc::new(schema.clone()), false, 1024, None);
assert_eq!(Arc::new(schema), csv.schema());
let batch = csv.next().unwrap().unwrap();
assert_eq!(37, batch.num_rows());
assert_eq!(3, batch.num_columns());
Expand Down Expand Up @@ -594,6 +610,12 @@ mod tests {
let builder = ReaderBuilder::new().has_headers(true).infer_schema(None);

let mut csv = builder.build(file).unwrap();
let expected_schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]);
assert_eq!(Arc::new(expected_schema), csv.schema());
let batch = csv.next().unwrap().unwrap();
assert_eq!(37, batch.num_rows());
assert_eq!(3, batch.num_columns());
Expand Down Expand Up @@ -625,14 +647,16 @@ mod tests {
let builder = ReaderBuilder::new().infer_schema(None);

let mut csv = builder.build(file).unwrap();
let batch = csv.next().unwrap().unwrap();

// csv field names should be 'column_{number}'
let schema = batch.schema();
let schema = csv.schema();
assert_eq!("column_1", schema.field(0).name());
assert_eq!("column_2", schema.field(1).name());
assert_eq!("column_3", schema.field(2).name());
let batch = csv.next().unwrap().unwrap();
let batch_schema = batch.schema();

assert_eq!(&schema, batch_schema);
assert_eq!(37, batch.num_rows());
assert_eq!(3, batch.num_columns());

Expand Down Expand Up @@ -667,7 +691,13 @@ mod tests {
let file = File::open("test/data/uk_cities.csv").unwrap();

let mut csv = Reader::new(file, Arc::new(schema), false, 1024, Some(vec![0, 1]));
let projected_schema = Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
]));
assert_eq!(projected_schema.clone(), csv.schema());
let batch = csv.next().unwrap().unwrap();
assert_eq!(&projected_schema, batch.schema());
assert_eq!(37, batch.num_rows());
assert_eq!(2, batch.num_columns());
}
Expand Down
43 changes: 40 additions & 3 deletions rust/arrow/src/json/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,29 @@ impl<R: Read> Reader<R> {
}
}

/// Returns the schema of the reader, useful for getting the schema without reading
/// record batches
pub fn schema(&self) -> Arc<Schema> {
match &self.projection {
Some(projection) => {
let fields = self.schema.fields();
let projected_fields: Vec<Field> = fields
.iter()
.filter_map(|field| {
if projection.contains(field.name()) {
Some(field.clone())
} else {
None
}
})
.collect();

Arc::new(Schema::new(projected_fields))
}
None => self.schema.clone(),
}
}

/// Read the next batch of records
pub fn next(&mut self) -> Result<Option<RecordBatch>> {
let mut rows: Vec<Value> = Vec::with_capacity(self.batch_size);
Expand Down Expand Up @@ -742,7 +765,9 @@ mod tests {
assert_eq!(4, batch.num_columns());
assert_eq!(12, batch.num_rows());

let schema = batch.schema();
let schema = reader.schema();
let batch_schema = batch.schema();
assert_eq!(&schema, batch_schema);

let a = schema.column_with_name("a").unwrap();
assert_eq!(0, a.0);
Expand Down Expand Up @@ -798,7 +823,9 @@ mod tests {
assert_eq!(4, batch.num_columns());
assert_eq!(12, batch.num_rows());

let schema = batch.schema();
let schema = reader.schema();
let batch_schema = batch.schema();
assert_eq!(&schema, batch_schema);

let a = schema.column_with_name("a").unwrap();
assert_eq!(&DataType::Int64, a.1.data_type());
Expand Down Expand Up @@ -855,10 +882,12 @@ mod tests {

let mut reader: Reader<File> = Reader::new(
BufReader::new(File::open("test/data/basic.json").unwrap()),
Arc::new(schema),
Arc::new(schema.clone()),
1024,
None,
);
let reader_schema = reader.schema();
assert_eq!(reader_schema, Arc::new(schema));
let batch = reader.next().unwrap().unwrap();

assert_eq!(4, batch.num_columns());
Expand Down Expand Up @@ -909,13 +938,21 @@ mod tests {
1024,
Some(vec!["a".to_string(), "c".to_string()]),
);
let reader_schema = reader.schema();
let expected_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Boolean, false),
]));
assert_eq!(reader_schema.clone(), expected_schema);

let batch = reader.next().unwrap().unwrap();

assert_eq!(2, batch.num_columns());
assert_eq!(2, batch.schema().fields().len());
assert_eq!(12, batch.num_rows());

let schema = batch.schema();
assert_eq!(&reader_schema, schema);

let a = schema.column_with_name("a").unwrap();
assert_eq!(0, a.0);
Expand Down