diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 86becd7c..113bc006 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -112,7 +112,7 @@ impl PUPRelation { pub fn differentially_private_aggregates( self, named_aggregates: Vec<(&str, AggregateColumn)>, - group_by: &[Expr], + group_by: &[Column], epsilon: f64, delta: f64, ) -> Result { @@ -135,11 +135,9 @@ impl PUPRelation { (input_builder, group_by_names) = group_by .into_iter() - .fold((input_builder, group_by_names), |(mut b, mut v), x| { - if let Expr::Column(c) = x { - b = b.with((c.last().unwrap(), x.clone())); - v.push(c.last().unwrap()); - } + .fold((input_builder, group_by_names), |(mut b, mut v), c| { + b = b.with((c.last().unwrap(), Expr::Column(c.clone()))); + v.push(c.last().unwrap()); (b, v) }); @@ -275,7 +273,7 @@ impl Reduce { } first_aggs.extend( - self.group_by_columns() + self.group_by() .into_iter() .map(|x| (x.to_string(), AggregateColumn::new(aggregate::Aggregate::First, x.clone()))) .collect::>() @@ -308,7 +306,7 @@ impl Reduce { let builder = Relation::reduce() .input(self.input().clone()); if let Some(identifier) = identifier { - let mut group_by = self.group_by_columns() + let mut group_by = self.group_by() .into_iter() .map(|c| c.clone()) .collect::>(); @@ -475,7 +473,7 @@ mod tests { let reduce = Reduce::new( "my_reduce".to_string(), vec![("my_sum_price".to_string(), AggregateColumn::sum("price"))], - vec![Expr::col("item")], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); @@ -627,7 +625,7 @@ mod tests { ("my_item".to_string(), AggregateColumn::first("item")), ("avg_price".to_string(), AggregateColumn::mean("price")), ], - vec![Expr::col("item")], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); diff --git a/src/differential_privacy/group_by.rs b/src/differential_privacy/group_by.rs index 01580047..c88c3189 100644 --- a/src/differential_privacy/group_by.rs +++ b/src/differential_privacy/group_by.rs @@ -22,7 +22,7 @@ impl Reduce { .with_iter( self.group_by() .into_iter() - .map(|x| (x.to_string(), x.clone())) + .map(|col| (col.to_string(), Expr::Column(col.clone()))) .collect::>(), ) .with(( diff --git a/src/expr/split.rs b/src/expr/split.rs index 66bcea60..69e3d8ff 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -1,7 +1,7 @@ //! The splits with some improvements //! Each split has named Expr and anonymous exprs use super::{ - aggregate, function, visitor::Acceptor, Aggregate, AggregateColumn, Column, Expr, Function, + aggregate, function, visitor::Acceptor, AggregateColumn, Column, Expr, Function, Identifier, Value, Visitor, }; use crate::{ @@ -36,7 +36,18 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - Reduce::new(vec![], vec![expr], None) + println!("x = {:?}", expr); + match expr { + Expr::Column(c) => Reduce::new(vec![], vec![c], None), + Expr::Value(_) => todo!(), + Expr::Function(_) => { + let name = namer::name_from_content(FIELD, &expr); + let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); + Reduce::new(vec![], vec![name.into()], Some(map)) + }, + Expr::Aggregate(_) => todo!(), + Expr::Struct(_) => todo!(), + } } pub fn into_map(self) -> Map { @@ -416,17 +427,47 @@ impl And for Map { } } +impl And for Map { + type Product = (Map, Column); + + fn and(self, col: Column) -> Self::Product { + let Map { + named_exprs, + filter, + order_by, + reduce, + } = self; + // Add the expr to the next split if needed + let (reduce, col) = if let Some(r) = reduce { + let (r, expr) = r.and(col); + (Some(r), expr) + } else { + (None, col) + }; + // Add matched sub-expressions + ( + Map::new( + named_exprs.into_iter().chain(vec![(col, Expr::Column]).collect(), + filter, + order_by, + reduce, + ), + expr, + ) + } +} + #[derive(Clone, Default, Debug, Hash, PartialEq, Eq)] pub struct Reduce { pub named_aggregates: Vec<(String, AggregateColumn)>, - pub group_by: Vec, + pub group_by: Vec, pub map: Option>, } impl Reduce { pub fn new( named_aggregates: Vec<(String, AggregateColumn)>, - group_by: Vec, + group_by: Vec, map: Option, ) -> Self { Reduce { @@ -440,7 +481,7 @@ impl Reduce { &self.named_aggregates } - pub fn group_by(&self) -> &[Expr] { + pub fn group_by(&self) -> &[Column] { &self.group_by } @@ -581,9 +622,9 @@ impl And for Reduce { let (map, group_by) = self.group_by .into_iter() - .fold((map, vec![]), |(map, mut group_by), expr| { - let (map, expr) = map.and(expr); - group_by.push(expr); + .fold((map, vec![]), |(map, mut group_by), col| { + let (map, col) = map.and(Expr::Column(col)); + group_by.push(col); (map, group_by) }); Reduce::new( @@ -608,9 +649,9 @@ impl And for Reduce { other .group_by .into_iter() - .fold((map, vec![]), |(map, mut group_by), expr| { - let (map, expr) = map.and(expr); - group_by.push(expr); + .fold((map, vec![]), |(map, mut group_by), col| { + let (map, col) = map.and(Expr::Column(col)); + group_by.push(col); (map, group_by) }); Reduce::new( @@ -663,7 +704,7 @@ impl And for Reduce { group_by .clone() .into_iter() - .map(|e| (namer::name_from_content(FIELD, &e), e)), + .map(|col| (namer::name_from_content(FIELD, &col), Expr::Column(col))), ) .unique() .collect(); @@ -822,7 +863,15 @@ mod tests { None, ); println!("reduce = {reduce}"); - let reduce = reduce.and(Reduce::new(vec![], vec![Expr::col("z")], None)); + + let reduce = reduce.and(Reduce::new(vec![], vec!["z".into()], None)); + println!("reduce and group by = {}", reduce); + assert_eq!(reduce.len(), 1); + let map = reduce.clone().into_map(); + println!("reduce into map = {}", map); + assert_eq!(map.len(), 2); + + let reduce = reduce.and(Reduce::new(vec![], vec![expr!(3 * v)], None)); println!("reduce and group by = {}", reduce); assert_eq!(reduce.len(), 1); let map = reduce.into_map(); diff --git a/src/relation/builder.rs b/src/relation/builder.rs index b5a2b2f4..07417a49 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -445,7 +445,7 @@ impl ReduceBuilder { self } - pub fn group_by_iter>(self, iter: I) -> Self { + pub fn group_by_iter, I: IntoIterator>(self, iter: I) -> Self { iter.into_iter().fold(self, |w, i| w.group_by(i)) } diff --git a/src/relation/mod.rs b/src/relation/mod.rs index e0f7bdc8..d8caed90 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -444,7 +444,7 @@ pub struct Reduce { /// Aggregate expressions aggregate: Vec, /// Grouping expressions - group_by: Vec, + group_by: Vec, /// The schema description of the output schema: Schema, /// The size of the Reduce @@ -461,7 +461,7 @@ impl Reduce { pub fn new( name: String, named_aggregate: Vec<(String, AggregateColumn)>, - group_by: Vec, + group_by: Vec, input: Arc, ) -> Self { // assert!(Split::from_iter(named_exprs.clone()).len()==1); @@ -525,22 +525,22 @@ impl Reduce { &self.aggregate } /// Get group_by - pub fn group_by(&self) -> &[Expr] { + pub fn group_by(&self) -> &[Column] { &self.group_by } - /// Get group_by columns - pub fn group_by_columns(&self) -> Vec<&Column> { - self.group_by - .iter() - .filter_map(|e| { - if let Expr::Column(column) = e { - Some(column) - } else { - None - } - }) - .collect() - } + // /// Get group_by columns + // pub fn group_by_columns(&self) -> Vec<&Column> { + // self.group_by + // .iter() + // .filter_map(|e| { + // if let Expr::Column(column) = e { + // Some(column) + // } else { + // None + // } + // }) + // .collect() + // } /// Get the input pub fn input(&self) -> &Relation { &self.input @@ -565,10 +565,7 @@ impl Reduce { pub fn group_by_names(&self) -> Vec<&str> { self.group_by .iter() - .filter_map(|e| match e { - Expr::Column(col) => col.last().ok(), - _ => None, - }) + .filter_map(|col| col.last().ok()) .collect() } } diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index c4a8c843..8e3c4b49 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -1847,7 +1847,7 @@ mod tests { ("sum_a".to_string(), AggregateColumn::sum("a")), ("b".to_string(), AggregateColumn::first("b")), ], - vec![Expr::col("b")], + vec!["b".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1866,7 +1866,7 @@ mod tests { let red = Reduce::new( "reduce_relation".to_string(), vec![("sum_a".to_string(), AggregateColumn::sum("a"))], - vec![Expr::col("b")], + vec!["b".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1888,7 +1888,7 @@ mod tests { ("b".to_string(), AggregateColumn::first("b")), ("sum_a".to_string(), AggregateColumn::sum("a")), ], - vec![Expr::col("b"), Expr::col("c")], + vec!["b".into(), "c".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1910,7 +1910,7 @@ mod tests { ("c".to_string(), AggregateColumn::first("c")), ("sum_a".to_string(), AggregateColumn::sum("a")), ], - vec![Expr::col("b"), Expr::col("c")], + vec!["b".into(), "c".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); diff --git a/src/relation/sql.rs b/src/relation/sql.rs index b996ab64..e668080d 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -324,7 +324,7 @@ impl<'a> Visitor<'a, ast::Query> for FromRelationVisitor { table_with_joins(table_factor(reduce.input.as_ref().into(), None), vec![]), None, ast::GroupByExpr::Expressions( - reduce.group_by.iter().map(ast::Expr::from).collect(), + reduce.group_by.iter().map(|col| ast::Expr::from(&Expr::Column(col.clone()))).collect(), ), vec![], None, diff --git a/src/sql/relation.rs b/src/sql/relation.rs index b03719a3..04d4a20e 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -610,7 +610,8 @@ mod tests { builder::Ready, data_type::{DataType, DataTyped, Variant}, display::Dot, - relation::{schema::Schema, Constraint}, + relation::schema::Schema, + io::{Database, postgresql} }; #[test] @@ -1200,4 +1201,37 @@ mod tests { DataType::structured(vec![("my_sum", DataType::float().try_empty().unwrap())]) ); } + + #[test] + fn test_group_by_exprs() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let query = parse( + "SELECT CASE WHEN d < 5 THEN 5 ELSE 1 END AS case_d, COUNT(*) AS my_count FROM table_1 GROUP BY CASE WHEN d < 5 THEN 5 ELSE 1 END;" + ).unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations + )) + .unwrap(); + relation.display_dot().unwrap(); + println!("relation = {relation}"); + assert_eq!( + relation.data_type(), + DataType::structured(vec![ + ("case_d", DataType::float_values([1., 5.])), + ("my_count", DataType::integer_interval(0, 10)), + ]) + ); + let query: &str = &ast::Query::from(&relation).to_string(); + println!("{query}"); + _ = database + .query(query) + .unwrap() + .iter() + .map(ToString::to_string); + + + } } diff --git a/tests/integration.rs b/tests/integration.rs index 9dc9d79e..192df726 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -56,6 +56,8 @@ const QUERIES: &[&str] = &[ "SELECT SUM(a), count(b) FROM table_1 WHERE a>4", // Some GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d", + "SELECT count(b) FROM table_1 GROUP BY CEIL(d)", + "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", // Some WHERE and GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d", "SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d",