Skip to content

Commit

Permalink
Add Schema::project and RecordBatch::project functions (#1033) (#1077)
Browse files Browse the repository at this point in the history
* Allow Schema and RecordBatch to project schemas on specific columns returning a new schema with those columns only

* Addressing PR updates and adding a test for out of range projection

* switch to &[usize]

* fix: clippy and fmt

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

Co-authored-by: Stephen Carman <hntd187@users.noreply.github.com>
  • Loading branch information
alamb and hntd187 authored Dec 22, 2021
1 parent 31911a4 commit e0abdb9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
65 changes: 65 additions & 0 deletions arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ impl Schema {
Self { fields, metadata }
}

/// Returns a new schema with only the specified columns in the new schema
/// This carries metadata from the parent schema over as well
pub fn project(&self, indices: &[usize]) -> Result<Schema> {
let new_fields = indices
.iter()
.map(|i| {
self.fields.get(*i).cloned().ok_or_else(|| {
ArrowError::SchemaError(format!(
"project index {} out of bounds, max field {}",
i,
self.fields().len()
))
})
})
.collect::<Result<Vec<_>>>()?;
Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
}

/// Merge schema into self if it is compatible. Struct fields will be merged recursively.
///
/// Example:
Expand Down Expand Up @@ -369,4 +387,51 @@ mod tests {

assert_eq!(schema, de_schema);
}

#[test]
fn test_projection() {
let mut metadata = HashMap::new();
metadata.insert("meta".to_string(), "data".to_string());

let schema = Schema::new_with_metadata(
vec![
Field::new("name", DataType::Utf8, false),
Field::new("address", DataType::Utf8, false),
Field::new("priority", DataType::UInt8, false),
],
metadata,
);

let projected: Schema = schema.project(&[0, 2]).unwrap();

assert_eq!(projected.fields().len(), 2);
assert_eq!(projected.fields()[0].name(), "name");
assert_eq!(projected.fields()[1].name(), "priority");
assert_eq!(projected.metadata.get("meta").unwrap(), "data")
}

#[test]
fn test_oob_projection() {
let mut metadata = HashMap::new();
metadata.insert("meta".to_string(), "data".to_string());

let schema = Schema::new_with_metadata(
vec![
Field::new("name", DataType::Utf8, false),
Field::new("address", DataType::Utf8, false),
Field::new("priority", DataType::UInt8, false),
],
metadata,
);

let projected: Result<Schema> = schema.project(&[0, 3]);

assert!(projected.is_err());
if let Err(e) = projected {
assert_eq!(
e.to_string(),
"Schema error: project index 3 out of bounds, max field 3".to_string()
)
}
}
}
38 changes: 38 additions & 0 deletions arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ impl RecordBatch {
self.schema.clone()
}

/// Projects the schema onto the specified columns
pub fn project(&self, indices: &[usize]) -> Result<RecordBatch> {
let projected_schema = self.schema.project(indices)?;
let batch_fields = indices
.iter()
.map(|f| {
self.columns.get(*f).cloned().ok_or_else(|| {
ArrowError::SchemaError(format!(
"project index {} out of bounds, max field {}",
f,
self.columns.len()
))
})
})
.collect::<Result<Vec<_>>>()?;

RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields)
}

/// Returns the number of columns in the record batch.
///
/// # Example
Expand Down Expand Up @@ -900,4 +919,23 @@ mod tests {

assert_ne!(batch1, batch2);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

let record_batch = RecordBatch::try_from_iter(vec![
("a", a.clone()),
("b", b.clone()),
("c", c.clone()),
])
.expect("valid conversion");

let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)])
.expect("valid conversion");

assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
}
}

0 comments on commit e0abdb9

Please sign in to comment.