Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
victoria de sainte agathe committed Dec 5, 2023
1 parent 1707e35 commit 331368c
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 51 deletions.
18 changes: 8 additions & 10 deletions src/differential_privacy/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl PUPRelation {
pub fn differentially_private_aggregates(
self,
named_aggregates: Vec<(&str, AggregateColumn)>,
group_by: &[Expr],
group_by: &[Column],
epsilon: f64,
delta: f64,
) -> Result<DPRelation> {
Expand All @@ -135,11 +135,9 @@ impl PUPRelation {
(input_builder, group_by_names) =
group_by
.into_iter()
.fold((input_builder, group_by_names), |(mut b, mut v), x| {
if let Expr::Column(c) = x {
b = b.with((c.last().unwrap(), x.clone()));
v.push(c.last().unwrap());
}
.fold((input_builder, group_by_names), |(mut b, mut v), c| {
b = b.with((c.last().unwrap(), Expr::Column(c.clone())));
v.push(c.last().unwrap());
(b, v)
});

Expand Down Expand Up @@ -275,7 +273,7 @@ impl Reduce {
}

first_aggs.extend(
self.group_by_columns()
self.group_by()
.into_iter()
.map(|x| (x.to_string(), AggregateColumn::new(aggregate::Aggregate::First, x.clone())))
.collect::<Vec<_>>()
Expand Down Expand Up @@ -308,7 +306,7 @@ impl Reduce {
let builder = Relation::reduce()
.input(self.input().clone());
if let Some(identifier) = identifier {
let mut group_by = self.group_by_columns()
let mut group_by = self.group_by()
.into_iter()
.map(|c| c.clone())
.collect::<Vec<_>>();
Expand Down Expand Up @@ -475,7 +473,7 @@ mod tests {
let reduce = Reduce::new(
"my_reduce".to_string(),
vec![("my_sum_price".to_string(), AggregateColumn::sum("price"))],
vec![Expr::col("item")],
vec!["item".into()],
pup_table.deref().clone().into(),
);
let relation = Relation::from(reduce.clone());
Expand Down Expand Up @@ -627,7 +625,7 @@ mod tests {
("my_item".to_string(), AggregateColumn::first("item")),
("avg_price".to_string(), AggregateColumn::mean("price")),
],
vec![Expr::col("item")],
vec!["item".into()],
pup_table.deref().clone().into(),
);
let relation = Relation::from(reduce.clone());
Expand Down
2 changes: 1 addition & 1 deletion src/differential_privacy/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Reduce {
.with_iter(
self.group_by()
.into_iter()
.map(|x| (x.to_string(), x.clone()))
.map(|col| (col.to_string(), Expr::Column(col.clone())))
.collect::<Vec<_>>(),
)
.with((
Expand Down
75 changes: 62 additions & 13 deletions src/expr/split.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! The splits with some improvements
//! Each split has named Expr and anonymous exprs
use super::{
aggregate, function, visitor::Acceptor, Aggregate, AggregateColumn, Column, Expr, Function,
aggregate, function, visitor::Acceptor, AggregateColumn, Column, Expr, Function,
Identifier, Value, Visitor,
};
use crate::{
Expand Down Expand Up @@ -36,7 +36,18 @@ impl Split {
}

pub fn group_by(expr: Expr) -> Reduce {
Reduce::new(vec![], vec![expr], None)
println!("x = {:?}", expr);
match expr {
Expr::Column(c) => Reduce::new(vec![], vec![c], None),
Expr::Value(_) => todo!(),
Expr::Function(_) => {
let name = namer::name_from_content(FIELD, &expr);
let map = Map::new(vec![(name.clone(), expr)], None, vec![], None);
Reduce::new(vec![], vec![name.into()], Some(map))
},
Expr::Aggregate(_) => todo!(),
Expr::Struct(_) => todo!(),
}
}

pub fn into_map(self) -> Map {
Expand Down Expand Up @@ -416,17 +427,47 @@ impl And<Expr> for Map {
}
}

impl And<Column> for Map {
type Product = (Map, Column);

fn and(self, col: Column) -> Self::Product {
let Map {
named_exprs,
filter,
order_by,
reduce,
} = self;
// Add the expr to the next split if needed
let (reduce, col) = if let Some(r) = reduce {
let (r, expr) = r.and(col);
(Some(r), expr)
} else {
(None, col)
};
// Add matched sub-expressions
(
Map::new(
named_exprs.into_iter().chain(vec![(col, Expr::Column]).collect(),
filter,
order_by,
reduce,
),
expr,
)
}
}

#[derive(Clone, Default, Debug, Hash, PartialEq, Eq)]
pub struct Reduce {
pub named_aggregates: Vec<(String, AggregateColumn)>,
pub group_by: Vec<Expr>,
pub group_by: Vec<Column>,
pub map: Option<Box<Map>>,
}

impl Reduce {
pub fn new(
named_aggregates: Vec<(String, AggregateColumn)>,
group_by: Vec<Expr>,
group_by: Vec<Column>,
map: Option<Map>,
) -> Self {
Reduce {
Expand All @@ -440,7 +481,7 @@ impl Reduce {
&self.named_aggregates
}

pub fn group_by(&self) -> &[Expr] {
pub fn group_by(&self) -> &[Column] {
&self.group_by
}

Expand Down Expand Up @@ -581,9 +622,9 @@ impl And<Self> for Reduce {
let (map, group_by) =
self.group_by
.into_iter()
.fold((map, vec![]), |(map, mut group_by), expr| {
let (map, expr) = map.and(expr);
group_by.push(expr);
.fold((map, vec![]), |(map, mut group_by), col| {
let (map, col) = map.and(Expr::Column(col));
group_by.push(col);
(map, group_by)
});
Reduce::new(
Expand All @@ -608,9 +649,9 @@ impl And<Self> for Reduce {
other
.group_by
.into_iter()
.fold((map, vec![]), |(map, mut group_by), expr| {
let (map, expr) = map.and(expr);
group_by.push(expr);
.fold((map, vec![]), |(map, mut group_by), col| {
let (map, col) = map.and(Expr::Column(col));
group_by.push(col);
(map, group_by)
});
Reduce::new(
Expand Down Expand Up @@ -663,7 +704,7 @@ impl And<Expr> for Reduce {
group_by
.clone()
.into_iter()
.map(|e| (namer::name_from_content(FIELD, &e), e)),
.map(|col| (namer::name_from_content(FIELD, &col), Expr::Column(col))),
)
.unique()
.collect();
Expand Down Expand Up @@ -822,7 +863,15 @@ mod tests {
None,
);
println!("reduce = {reduce}");
let reduce = reduce.and(Reduce::new(vec![], vec![Expr::col("z")], None));

let reduce = reduce.and(Reduce::new(vec![], vec!["z".into()], None));
println!("reduce and group by = {}", reduce);
assert_eq!(reduce.len(), 1);
let map = reduce.clone().into_map();
println!("reduce into map = {}", map);
assert_eq!(map.len(), 2);

let reduce = reduce.and(Reduce::new(vec![], vec![expr!(3 * v)], None));
println!("reduce and group by = {}", reduce);
assert_eq!(reduce.len(), 1);
let map = reduce.into_map();
Expand Down
2 changes: 1 addition & 1 deletion src/relation/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ impl<RequireInput> ReduceBuilder<RequireInput> {
self
}

pub fn group_by_iter<I: IntoIterator<Item = Expr>>(self, iter: I) -> Self {
pub fn group_by_iter<E: Into<Expr>, I: IntoIterator<Item=E>>(self, iter: I) -> Self {
iter.into_iter().fold(self, |w, i| w.group_by(i))
}

Expand Down
37 changes: 17 additions & 20 deletions src/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ pub struct Reduce {
/// Aggregate expressions
aggregate: Vec<AggregateColumn>,
/// Grouping expressions
group_by: Vec<Expr>,
group_by: Vec<Column>,
/// The schema description of the output
schema: Schema,
/// The size of the Reduce
Expand All @@ -461,7 +461,7 @@ impl Reduce {
pub fn new(
name: String,
named_aggregate: Vec<(String, AggregateColumn)>,
group_by: Vec<Expr>,
group_by: Vec<Column>,
input: Arc<Relation>,
) -> Self {
// assert!(Split::from_iter(named_exprs.clone()).len()==1);
Expand Down Expand Up @@ -525,22 +525,22 @@ impl Reduce {
&self.aggregate
}
/// Get group_by
pub fn group_by(&self) -> &[Expr] {
pub fn group_by(&self) -> &[Column] {
&self.group_by
}
/// Get group_by columns
pub fn group_by_columns(&self) -> Vec<&Column> {
self.group_by
.iter()
.filter_map(|e| {
if let Expr::Column(column) = e {
Some(column)
} else {
None
}
})
.collect()
}
// /// Get group_by columns
// pub fn group_by_columns(&self) -> Vec<&Column> {
// self.group_by
// .iter()
// .filter_map(|e| {
// if let Expr::Column(column) = e {
// Some(column)
// } else {
// None
// }
// })
// .collect()
// }
/// Get the input
pub fn input(&self) -> &Relation {
&self.input
Expand All @@ -565,10 +565,7 @@ impl Reduce {
pub fn group_by_names(&self) -> Vec<&str> {
self.group_by
.iter()
.filter_map(|e| match e {
Expr::Column(col) => col.last().ok(),
_ => None,
})
.filter_map(|col| col.last().ok())
.collect()
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/relation/rewriting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ mod tests {
("sum_a".to_string(), AggregateColumn::sum("a")),
("b".to_string(), AggregateColumn::first("b")),
],
vec![Expr::col("b")],
vec!["b".into()],
Arc::new(table.clone()),
);
let red_with_grouping_columns = red.clone().with_grouping_columns();
Expand All @@ -1866,7 +1866,7 @@ mod tests {
let red = Reduce::new(
"reduce_relation".to_string(),
vec![("sum_a".to_string(), AggregateColumn::sum("a"))],
vec![Expr::col("b")],
vec!["b".into()],
Arc::new(table.clone()),
);
let red_with_grouping_columns = red.clone().with_grouping_columns();
Expand All @@ -1888,7 +1888,7 @@ mod tests {
("b".to_string(), AggregateColumn::first("b")),
("sum_a".to_string(), AggregateColumn::sum("a")),
],
vec![Expr::col("b"), Expr::col("c")],
vec!["b".into(), "c".into()],
Arc::new(table.clone()),
);
let red_with_grouping_columns = red.clone().with_grouping_columns();
Expand All @@ -1910,7 +1910,7 @@ mod tests {
("c".to_string(), AggregateColumn::first("c")),
("sum_a".to_string(), AggregateColumn::sum("a")),
],
vec![Expr::col("b"), Expr::col("c")],
vec!["b".into(), "c".into()],
Arc::new(table.clone()),
);
let red_with_grouping_columns = red.clone().with_grouping_columns();
Expand Down
2 changes: 1 addition & 1 deletion src/relation/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ impl<'a> Visitor<'a, ast::Query> for FromRelationVisitor {
table_with_joins(table_factor(reduce.input.as_ref().into(), None), vec![]),
None,
ast::GroupByExpr::Expressions(
reduce.group_by.iter().map(ast::Expr::from).collect(),
reduce.group_by.iter().map(|col| ast::Expr::from(&Expr::Column(col.clone()))).collect(),
),
vec![],
None,
Expand Down
36 changes: 35 additions & 1 deletion src/sql/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ mod tests {
builder::Ready,
data_type::{DataType, DataTyped, Variant},
display::Dot,
relation::{schema::Schema, Constraint},
relation::schema::Schema,
io::{Database, postgresql}
};

#[test]
Expand Down Expand Up @@ -1200,4 +1201,37 @@ mod tests {
DataType::structured(vec![("my_sum", DataType::float().try_empty().unwrap())])
);
}

#[test]
fn test_group_by_exprs() {
let mut database = postgresql::test_database();
let relations = database.relations();

let query = parse(
"SELECT CASE WHEN d < 5 THEN 5 ELSE 1 END AS case_d, COUNT(*) AS my_count FROM table_1 GROUP BY CASE WHEN d < 5 THEN 5 ELSE 1 END;"
).unwrap();
let relation = Relation::try_from(QueryWithRelations::new(
&query,
&relations
))
.unwrap();
relation.display_dot().unwrap();
println!("relation = {relation}");
assert_eq!(
relation.data_type(),
DataType::structured(vec![
("case_d", DataType::float_values([1., 5.])),
("my_count", DataType::integer_interval(0, 10)),
])
);
let query: &str = &ast::Query::from(&relation).to_string();
println!("{query}");
_ = database
.query(query)
.unwrap()
.iter()
.map(ToString::to_string);


}
}
2 changes: 2 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ const QUERIES: &[&str] = &[
"SELECT SUM(a), count(b) FROM table_1 WHERE a>4",
// Some GROUP BY
"SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d",
"SELECT count(b) FROM table_1 GROUP BY CEIL(d)",
"SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled",
// Some WHERE and GROUP BY
"SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d",
"SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d",
Expand Down

0 comments on commit 331368c

Please sign in to comment.