From 1707e35e3785983568afecd705a4c29d8e814aa6 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 5 Dec 2023 14:40:48 +0100 Subject: [PATCH 01/27] ok --- tests/integration.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/integration.rs b/tests/integration.rs index 206a1336..9dc9d79e 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -8,10 +8,8 @@ use itertools::Itertools; use qrlew::io::sqlite; use qrlew::{ ast, - display::Dot, expr, io::{postgresql, Database}, - privacy_unit_tracking::PrivacyUnit, relation::Variant as _, sql::parse, Relation, With, @@ -93,6 +91,11 @@ const QUERIES: &[&str] = &[ // Some string functions "SELECT UPPER(z) FROM table_2 LIMIT 5", "SELECT LOWER(z) FROM table_2 LIMIT 5", + // ORDER BY + "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d", + "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d DESC", + "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", + "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", ]; #[cfg(feature = "sqlite")] From 331368c36b15333935d62e5c45699a55f2e3d2ca Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 5 Dec 2023 16:07:37 +0100 Subject: [PATCH 02/27] ok --- src/differential_privacy/aggregates.rs | 18 +++---- src/differential_privacy/group_by.rs | 2 +- src/expr/split.rs | 75 +++++++++++++++++++++----- src/relation/builder.rs | 2 +- src/relation/mod.rs | 37 ++++++------- src/relation/rewriting.rs | 8 +-- src/relation/sql.rs | 2 +- src/sql/relation.rs | 36 ++++++++++++- tests/integration.rs | 2 + 9 files changed, 131 insertions(+), 51 deletions(-) diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 86becd7c..113bc006 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -112,7 +112,7 @@ impl PUPRelation { pub fn differentially_private_aggregates( self, named_aggregates: Vec<(&str, AggregateColumn)>, - group_by: &[Expr], + group_by: &[Column], epsilon: f64, delta: f64, ) -> Result { @@ -135,11 +135,9 @@ impl PUPRelation { (input_builder, group_by_names) = group_by .into_iter() - .fold((input_builder, group_by_names), |(mut b, mut v), x| { - if let Expr::Column(c) = x { - b = b.with((c.last().unwrap(), x.clone())); - v.push(c.last().unwrap()); - } + .fold((input_builder, group_by_names), |(mut b, mut v), c| { + b = b.with((c.last().unwrap(), Expr::Column(c.clone()))); + v.push(c.last().unwrap()); (b, v) }); @@ -275,7 +273,7 @@ impl Reduce { } first_aggs.extend( - self.group_by_columns() + self.group_by() .into_iter() .map(|x| (x.to_string(), AggregateColumn::new(aggregate::Aggregate::First, x.clone()))) .collect::>() @@ -308,7 +306,7 @@ impl Reduce { let builder = Relation::reduce() .input(self.input().clone()); if let Some(identifier) = identifier { - let mut group_by = self.group_by_columns() + let mut group_by = self.group_by() .into_iter() .map(|c| c.clone()) .collect::>(); @@ -475,7 +473,7 @@ mod tests { let reduce = Reduce::new( "my_reduce".to_string(), vec![("my_sum_price".to_string(), AggregateColumn::sum("price"))], - vec![Expr::col("item")], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); @@ -627,7 +625,7 @@ mod tests { ("my_item".to_string(), AggregateColumn::first("item")), ("avg_price".to_string(), AggregateColumn::mean("price")), ], - vec![Expr::col("item")], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); diff --git a/src/differential_privacy/group_by.rs b/src/differential_privacy/group_by.rs index 01580047..c88c3189 100644 --- a/src/differential_privacy/group_by.rs +++ b/src/differential_privacy/group_by.rs @@ -22,7 +22,7 @@ impl Reduce { .with_iter( self.group_by() .into_iter() - .map(|x| (x.to_string(), x.clone())) + .map(|col| (col.to_string(), Expr::Column(col.clone()))) .collect::>(), ) .with(( diff --git a/src/expr/split.rs b/src/expr/split.rs index 66bcea60..69e3d8ff 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -1,7 +1,7 @@ //! The splits with some improvements //! Each split has named Expr and anonymous exprs use super::{ - aggregate, function, visitor::Acceptor, Aggregate, AggregateColumn, Column, Expr, Function, + aggregate, function, visitor::Acceptor, AggregateColumn, Column, Expr, Function, Identifier, Value, Visitor, }; use crate::{ @@ -36,7 +36,18 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - Reduce::new(vec![], vec![expr], None) + println!("x = {:?}", expr); + match expr { + Expr::Column(c) => Reduce::new(vec![], vec![c], None), + Expr::Value(_) => todo!(), + Expr::Function(_) => { + let name = namer::name_from_content(FIELD, &expr); + let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); + Reduce::new(vec![], vec![name.into()], Some(map)) + }, + Expr::Aggregate(_) => todo!(), + Expr::Struct(_) => todo!(), + } } pub fn into_map(self) -> Map { @@ -416,17 +427,47 @@ impl And for Map { } } +impl And for Map { + type Product = (Map, Column); + + fn and(self, col: Column) -> Self::Product { + let Map { + named_exprs, + filter, + order_by, + reduce, + } = self; + // Add the expr to the next split if needed + let (reduce, col) = if let Some(r) = reduce { + let (r, expr) = r.and(col); + (Some(r), expr) + } else { + (None, col) + }; + // Add matched sub-expressions + ( + Map::new( + named_exprs.into_iter().chain(vec![(col, Expr::Column]).collect(), + filter, + order_by, + reduce, + ), + expr, + ) + } +} + #[derive(Clone, Default, Debug, Hash, PartialEq, Eq)] pub struct Reduce { pub named_aggregates: Vec<(String, AggregateColumn)>, - pub group_by: Vec, + pub group_by: Vec, pub map: Option>, } impl Reduce { pub fn new( named_aggregates: Vec<(String, AggregateColumn)>, - group_by: Vec, + group_by: Vec, map: Option, ) -> Self { Reduce { @@ -440,7 +481,7 @@ impl Reduce { &self.named_aggregates } - pub fn group_by(&self) -> &[Expr] { + pub fn group_by(&self) -> &[Column] { &self.group_by } @@ -581,9 +622,9 @@ impl And for Reduce { let (map, group_by) = self.group_by .into_iter() - .fold((map, vec![]), |(map, mut group_by), expr| { - let (map, expr) = map.and(expr); - group_by.push(expr); + .fold((map, vec![]), |(map, mut group_by), col| { + let (map, col) = map.and(Expr::Column(col)); + group_by.push(col); (map, group_by) }); Reduce::new( @@ -608,9 +649,9 @@ impl And for Reduce { other .group_by .into_iter() - .fold((map, vec![]), |(map, mut group_by), expr| { - let (map, expr) = map.and(expr); - group_by.push(expr); + .fold((map, vec![]), |(map, mut group_by), col| { + let (map, col) = map.and(Expr::Column(col)); + group_by.push(col); (map, group_by) }); Reduce::new( @@ -663,7 +704,7 @@ impl And for Reduce { group_by .clone() .into_iter() - .map(|e| (namer::name_from_content(FIELD, &e), e)), + .map(|col| (namer::name_from_content(FIELD, &col), Expr::Column(col))), ) .unique() .collect(); @@ -822,7 +863,15 @@ mod tests { None, ); println!("reduce = {reduce}"); - let reduce = reduce.and(Reduce::new(vec![], vec![Expr::col("z")], None)); + + let reduce = reduce.and(Reduce::new(vec![], vec!["z".into()], None)); + println!("reduce and group by = {}", reduce); + assert_eq!(reduce.len(), 1); + let map = reduce.clone().into_map(); + println!("reduce into map = {}", map); + assert_eq!(map.len(), 2); + + let reduce = reduce.and(Reduce::new(vec![], vec![expr!(3 * v)], None)); println!("reduce and group by = {}", reduce); assert_eq!(reduce.len(), 1); let map = reduce.into_map(); diff --git a/src/relation/builder.rs b/src/relation/builder.rs index b5a2b2f4..07417a49 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -445,7 +445,7 @@ impl ReduceBuilder { self } - pub fn group_by_iter>(self, iter: I) -> Self { + pub fn group_by_iter, I: IntoIterator>(self, iter: I) -> Self { iter.into_iter().fold(self, |w, i| w.group_by(i)) } diff --git a/src/relation/mod.rs b/src/relation/mod.rs index e0f7bdc8..d8caed90 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -444,7 +444,7 @@ pub struct Reduce { /// Aggregate expressions aggregate: Vec, /// Grouping expressions - group_by: Vec, + group_by: Vec, /// The schema description of the output schema: Schema, /// The size of the Reduce @@ -461,7 +461,7 @@ impl Reduce { pub fn new( name: String, named_aggregate: Vec<(String, AggregateColumn)>, - group_by: Vec, + group_by: Vec, input: Arc, ) -> Self { // assert!(Split::from_iter(named_exprs.clone()).len()==1); @@ -525,22 +525,22 @@ impl Reduce { &self.aggregate } /// Get group_by - pub fn group_by(&self) -> &[Expr] { + pub fn group_by(&self) -> &[Column] { &self.group_by } - /// Get group_by columns - pub fn group_by_columns(&self) -> Vec<&Column> { - self.group_by - .iter() - .filter_map(|e| { - if let Expr::Column(column) = e { - Some(column) - } else { - None - } - }) - .collect() - } + // /// Get group_by columns + // pub fn group_by_columns(&self) -> Vec<&Column> { + // self.group_by + // .iter() + // .filter_map(|e| { + // if let Expr::Column(column) = e { + // Some(column) + // } else { + // None + // } + // }) + // .collect() + // } /// Get the input pub fn input(&self) -> &Relation { &self.input @@ -565,10 +565,7 @@ impl Reduce { pub fn group_by_names(&self) -> Vec<&str> { self.group_by .iter() - .filter_map(|e| match e { - Expr::Column(col) => col.last().ok(), - _ => None, - }) + .filter_map(|col| col.last().ok()) .collect() } } diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index c4a8c843..8e3c4b49 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -1847,7 +1847,7 @@ mod tests { ("sum_a".to_string(), AggregateColumn::sum("a")), ("b".to_string(), AggregateColumn::first("b")), ], - vec![Expr::col("b")], + vec!["b".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1866,7 +1866,7 @@ mod tests { let red = Reduce::new( "reduce_relation".to_string(), vec![("sum_a".to_string(), AggregateColumn::sum("a"))], - vec![Expr::col("b")], + vec!["b".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1888,7 +1888,7 @@ mod tests { ("b".to_string(), AggregateColumn::first("b")), ("sum_a".to_string(), AggregateColumn::sum("a")), ], - vec![Expr::col("b"), Expr::col("c")], + vec!["b".into(), "c".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); @@ -1910,7 +1910,7 @@ mod tests { ("c".to_string(), AggregateColumn::first("c")), ("sum_a".to_string(), AggregateColumn::sum("a")), ], - vec![Expr::col("b"), Expr::col("c")], + vec!["b".into(), "c".into()], Arc::new(table.clone()), ); let red_with_grouping_columns = red.clone().with_grouping_columns(); diff --git a/src/relation/sql.rs b/src/relation/sql.rs index b996ab64..e668080d 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -324,7 +324,7 @@ impl<'a> Visitor<'a, ast::Query> for FromRelationVisitor { table_with_joins(table_factor(reduce.input.as_ref().into(), None), vec![]), None, ast::GroupByExpr::Expressions( - reduce.group_by.iter().map(ast::Expr::from).collect(), + reduce.group_by.iter().map(|col| ast::Expr::from(&Expr::Column(col.clone()))).collect(), ), vec![], None, diff --git a/src/sql/relation.rs b/src/sql/relation.rs index b03719a3..04d4a20e 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -610,7 +610,8 @@ mod tests { builder::Ready, data_type::{DataType, DataTyped, Variant}, display::Dot, - relation::{schema::Schema, Constraint}, + relation::schema::Schema, + io::{Database, postgresql} }; #[test] @@ -1200,4 +1201,37 @@ mod tests { DataType::structured(vec![("my_sum", DataType::float().try_empty().unwrap())]) ); } + + #[test] + fn test_group_by_exprs() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let query = parse( + "SELECT CASE WHEN d < 5 THEN 5 ELSE 1 END AS case_d, COUNT(*) AS my_count FROM table_1 GROUP BY CASE WHEN d < 5 THEN 5 ELSE 1 END;" + ).unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations + )) + .unwrap(); + relation.display_dot().unwrap(); + println!("relation = {relation}"); + assert_eq!( + relation.data_type(), + DataType::structured(vec![ + ("case_d", DataType::float_values([1., 5.])), + ("my_count", DataType::integer_interval(0, 10)), + ]) + ); + let query: &str = &ast::Query::from(&relation).to_string(); + println!("{query}"); + _ = database + .query(query) + .unwrap() + .iter() + .map(ToString::to_string); + + + } } diff --git a/tests/integration.rs b/tests/integration.rs index 9dc9d79e..192df726 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -56,6 +56,8 @@ const QUERIES: &[&str] = &[ "SELECT SUM(a), count(b) FROM table_1 WHERE a>4", // Some GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d", + "SELECT count(b) FROM table_1 GROUP BY CEIL(d)", + "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", // Some WHERE and GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d", "SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d", From f773d3217bca79cdee48d2b77029687d26aede34 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Dec 2023 08:33:45 +0100 Subject: [PATCH 03/27] ok --- src/expr/split.rs | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/expr/split.rs b/src/expr/split.rs index 69e3d8ff..8dd7b264 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -447,12 +447,12 @@ impl And for Map { // Add matched sub-expressions ( Map::new( - named_exprs.into_iter().chain(vec![(col, Expr::Column]).collect(), + named_exprs.into_iter().chain(vec![(col.last().unwrap().to_string(), Expr::Column(col.clone()))]).collect(), filter, order_by, reduce, ), - expr, + col, ) } } @@ -623,7 +623,7 @@ impl And for Reduce { self.group_by .into_iter() .fold((map, vec![]), |(map, mut group_by), col| { - let (map, col) = map.and(Expr::Column(col)); + let (map, col) = map.and(col); group_by.push(col); (map, group_by) }); @@ -650,7 +650,7 @@ impl And for Reduce { .group_by .into_iter() .fold((map, vec![]), |(map, mut group_by), col| { - let (map, col) = map.and(Expr::Column(col)); + let (map, col) = map.and(col); group_by.push(col); (map, group_by) }); @@ -727,6 +727,33 @@ impl And for Reduce { } } +impl And for Reduce { + type Product = (Reduce, Column); + + fn and(self, col: Column) -> Self::Product { + let Reduce { + named_aggregates, + group_by, + map, + } = self; + // Add the expr to the next split if needed + let (map, col) = if let Some(m) = map { + let (m, expr) = m.and(col); + (Some(m), expr) + } else { + (None, col) + }; + ( + Reduce::new( + named_aggregates.into_iter().chain(vec![(col.last().unwrap().to_string(), AggregateColumn::first(col.last().unwrap().to_string()))]).collect(), + group_by.into_iter().chain(vec![col.clone()]).collect(), + map, + ), + col, + ) + } +} + #[derive(Clone, Debug)] pub struct SplitVisitor(String); From 7ccd5b9147bd5458668085f9d4644d43f7c8ce15 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Dec 2023 10:20:33 +0100 Subject: [PATCH 04/27] ok --- .vscode/settings.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/settings.json b/.vscode/settings.json index 53e9e3ca..03e9856e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,4 +8,5 @@ ], "editor.codeActionsOnSave": {}, "rust-analyzer.cargo.buildScripts.overrideCommand": null, + "rust-analyzer.procMacro.enable": true } \ No newline at end of file From cd38ca167503e62a092d3fa1c33468598f5f0ad3 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Dec 2023 10:21:02 +0100 Subject: [PATCH 05/27] ok --- .vscode/settings.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 03e9856e..53e9e3ca 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,4 @@ ], "editor.codeActionsOnSave": {}, "rust-analyzer.cargo.buildScripts.overrideCommand": null, - "rust-analyzer.procMacro.enable": true } \ No newline at end of file From d4f1286f516f475e6e4131bf8bff08c7ce5b3e11 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Dec 2023 12:42:43 +0100 Subject: [PATCH 06/27] ok --- src/differential_privacy/aggregates.rs | 4 ++-- src/differential_privacy/mod.rs | 6 +++--- src/expr/split.rs | 8 -------- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 113bc006..8c5de43b 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -589,7 +589,7 @@ mod tests { ("sum_price".to_string(), AggregateColumn::sum("price")), ("avg_price".to_string(), AggregateColumn::mean("price")), ], - vec![expr!(item)], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); @@ -688,7 +688,7 @@ mod tests { ("sum_distinct_price".to_string(), AggregateColumn::sum_distinct("price")), ("item".to_string(), AggregateColumn::first("item")), ], - vec![expr!(item)], + vec!["item".into()], pup_table.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 5b0032e6..d4c679a2 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -291,7 +291,7 @@ mod tests { let reduce = Reduce::new( "my_reduce".to_string(), vec![("sum_price".to_string(), AggregateColumn::sum("price"))], - vec![expr!(order_id)], + vec!["order_id".into()], pup_map.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); @@ -365,7 +365,7 @@ mod tests { let reduce = Reduce::new( "my_reduce".to_string(), vec![("sum_price".to_string(), AggregateColumn::sum("price"))], - vec![expr!(order_id)], + vec!["order_id".into()], pup_map.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); @@ -448,7 +448,7 @@ mod tests { ("order_id".to_string(), AggregateColumn::first("order_id")), ("sum_price".to_string(), AggregateColumn::sum("price")), ], - vec![expr!(order_id), expr!(item)], + vec!["order_id".into(), "item".into()], pup_map.deref().clone().into(), ); let relation = Relation::from(reduce.clone()); diff --git a/src/expr/split.rs b/src/expr/split.rs index 8dd7b264..6787002b 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -36,7 +36,6 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - println!("x = {:?}", expr); match expr { Expr::Column(c) => Reduce::new(vec![], vec![c], None), Expr::Value(_) => todo!(), @@ -897,13 +896,6 @@ mod tests { let map = reduce.clone().into_map(); println!("reduce into map = {}", map); assert_eq!(map.len(), 2); - - let reduce = reduce.and(Reduce::new(vec![], vec![expr!(3 * v)], None)); - println!("reduce and group by = {}", reduce); - assert_eq!(reduce.len(), 1); - let map = reduce.into_map(); - println!("reduce into map = {}", map); - assert_eq!(map.len(), 2); } #[test] From fb2018a161aec6c7ba2f3ba0ab4e86a6ed1787b2 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Dec 2023 15:38:44 +0100 Subject: [PATCH 07/27] ok --- src/expr/split.rs | 20 +++++--------- src/relation/builder.rs | 22 +++++++++++++--- src/relation/mod.rs | 35 +++++++++++++++++++------ src/sql/relation.rs | 58 ++++++++++++++++++++++++++++------------- 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/src/expr/split.rs b/src/expr/split.rs index 6787002b..ceea8ed7 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -36,17 +36,9 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - match expr { - Expr::Column(c) => Reduce::new(vec![], vec![c], None), - Expr::Value(_) => todo!(), - Expr::Function(_) => { - let name = namer::name_from_content(FIELD, &expr); - let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); - Reduce::new(vec![], vec![name.into()], Some(map)) - }, - Expr::Aggregate(_) => todo!(), - Expr::Struct(_) => todo!(), - } + let name = namer::name_from_content(FIELD, &expr); + let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); + Reduce::new(vec![], vec![name.into()], Some(map)) } pub fn into_map(self) -> Map { @@ -744,7 +736,9 @@ impl And for Reduce { }; ( Reduce::new( - named_aggregates.into_iter().chain(vec![(col.last().unwrap().to_string(), AggregateColumn::first(col.last().unwrap().to_string()))]).collect(), + named_aggregates + .into_iter() + .chain(vec![(col.last().unwrap().to_string(), AggregateColumn::first(col.last().unwrap().to_string()))]).collect(), group_by.into_iter().chain(vec![col.clone()]).collect(), map, ), @@ -856,8 +850,6 @@ impl> FromIterator<(S, Expr)> for Split { #[cfg(test)] mod tests { - use crate::expr::implementation::aggregate; - use super::*; #[test] diff --git a/src/relation/builder.rs b/src/relation/builder.rs index 07417a49..05d4a8cd 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -12,7 +12,7 @@ use crate::{ expr::{self, AggregateColumn, Expr, Identifier, Split}, hierarchy::Hierarchy, namer::{self, FIELD, JOIN, MAP, REDUCE, SET}, - And, + And, display::Dot, }; // A Table builder @@ -157,7 +157,8 @@ impl MapBuilder { /// Add a group by pub fn group_by(mut self, expr: Expr) -> Self { - self.split = self.split.and(Split::group_by(expr).into()); + let s = Split::group_by(expr.into()); + self.split = self.split.and(s.into()); self } @@ -441,7 +442,8 @@ impl ReduceBuilder { } pub fn group_by>(mut self, expr: E) -> Self { - self.split = self.split.and(Split::group_by(expr.into()).into()); + let s = Split::group_by(expr.into()); + self.split = self.split.and(s.into()); self } @@ -618,6 +620,20 @@ impl Ready for ReduceBuilder { ), None => self.input.0, }; + input.display_dot().unwrap(); + println!("{:?}", reduce.group_by); + // Check that the First aggregate columns are in the GROUP BY + reduce.named_aggregates.iter() + .filter_map(|(_, agg)| matches!(agg.aggregate(), expr::aggregate::Aggregate::First).then_some(agg.column())) + .map(|col: &Identifier| if !reduce.group_by.contains(col) { + Err(Error::InvalidRelation(format!( + "First aggregate columns must be in the GROUP BY. Got: {}", + col + ))) + } else { + Ok(col) + }) + .collect::>>()?; // Build the Relation Ok(Reduce::new( name, diff --git a/src/relation/mod.rs b/src/relation/mod.rs index d8caed90..63124cc6 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -1779,23 +1779,42 @@ mod tests { #[test] fn test_reduce_builder() { let schema: Schema = vec![ - ("a", DataType::float()), - ("b", DataType::float_interval(-2., 2.)), + ("a", DataType::integer_interval(0, 10)), + ("b", DataType::float_interval(0., 1.)), ("c", DataType::float()), ("d", DataType::float_interval(0., 1.)), ] .into_iter() .collect(); - let table: Relation = Relation::table().schema(schema).build(); + let table: Relation = Relation::table().schema(schema).size(100).build(); + let reduce: Relation = Relation::reduce() - .with(Expr::sum(Expr::col("a"))) + .with(("my_sum", Expr::sum(Expr::col("b")))) + .with(("a", Expr::first(Expr::col("a")))) .group_by(Expr::col("a")) + .input(table.clone()) + .build(); + assert_eq!( + reduce.data_type(), + DataType::structured([ + ("my_sum", DataType::float_interval(0., 100.)), + ("a", DataType::integer_interval(0, 10)), + ]) + ); + + let reduce: Relation = Relation::reduce() + .with(("my_sum", Expr::sum(Expr::col("b")))) + .with(("my_a", Expr::first(expr!(3 * a)))) + .group_by(expr!(3 * a)) .input(table) - // .with(Expr::count(Expr::col("b"))) .build(); - println!("reduce = {}", reduce); - println!("reduce.data_type() = {}", reduce.data_type()); - println!("reduce.schema() = {}", reduce.schema()); + assert_eq!( + reduce.data_type(), + DataType::structured([ + ("my_sum", DataType::float_interval(0., 100.)), + ("my_a", DataType::integer_interval(0, 30)), + ]) + ); } #[test] diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 04d4a20e..97ad58ee 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -21,6 +21,7 @@ use crate::{ }, tokenizer::Tokenizer, visitor::{Acceptor, Dependencies, Visited}, + types::And }; use itertools::Itertools; use std::{ @@ -28,7 +29,7 @@ use std::{ iter::{once, Iterator}, result, str::FromStr, - sync::Arc, + sync::Arc, collections::HashMap, }; /* @@ -306,13 +307,31 @@ impl<'a> VisitedQueryRelations<'a> { } } // Prepare the GROUP BY - let group_by: Result> = match group_by { + let group_by = match group_by { ast::GroupByExpr::All => todo!(), ast::GroupByExpr::Expressions(group_by_exprs) => group_by_exprs .iter() .map(|e| e.with(columns).try_into()) - .collect(), + .collect::>>()?, }; + let exprs:HashMap = named_exprs + .iter() + .cloned() + .map(|(name, x)| (x, name)) + .collect(); + // let mut named_exprs = named_exprs.into_iter().chain( + // group_by.iter() + // .cloned() + // .map(|gx| ( + // exprs + // .get(&gx) + // .unwrap_or(&namer::name_from_content(FIELD, &gx)) + // .to_string(), + // gx) + // ) + // .collect::>() + // ).collect::>(); + // Add the having in named_exprs let having = if let Some(expr) = having { @@ -323,22 +342,26 @@ impl<'a> VisitedQueryRelations<'a> { .map(|(s, x)| (Expr::col(s.to_string()), x.clone())) .collect(); expr = expr.replace(columns).0; - if let Ok(g) = &group_by { - let columns = g - .iter() - .filter_map(|x| { - matches!(x, Expr::Column(_)).then_some((x.clone(), Expr::first(x.clone()))) - }) - .collect(); - expr = expr.replace(columns).0; - } + let columns = group_by + .iter() + .filter_map(|x| { + matches!(x, Expr::Column(_)).then_some((x.clone(), Expr::first(x.clone()))) + }) + .collect(); + expr = expr.replace(columns).0; named_exprs.push((having_name.clone(), expr)); Some(having_name) } else { None }; + // Build the Map or Reduce based on the type of split - let split = Split::from_iter(named_exprs); + let split = group_by.into_iter() + .fold( + Split::from_iter(named_exprs), + |s, expr| s.and(Split::Reduce(Split::group_by(expr))) + ); + println!("split = {}", split); // Prepare the WHERE let filter: Option = selection .as_ref() @@ -349,13 +372,13 @@ impl<'a> VisitedQueryRelations<'a> { Split::Map(map) => { let builder = Relation::map().split(map); let builder = filter.into_iter().fold(builder, |b, e| b.filter(e)); - let builder = group_by?.into_iter().fold(builder, |b, e| b.group_by(e)); + //let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); builder.input(from).build() } Split::Reduce(reduce) => { let builder = Relation::reduce().split(reduce); let builder = filter.into_iter().fold(builder, |b, e| b.filter(e)); - let builder = group_by?.into_iter().fold(builder, |b, e| b.group_by(e)); + //let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); builder.input(from).build() } }; @@ -1207,9 +1230,8 @@ mod tests { let mut database = postgresql::test_database(); let relations = database.relations(); - let query = parse( - "SELECT CASE WHEN d < 5 THEN 5 ELSE 1 END AS case_d, COUNT(*) AS my_count FROM table_1 GROUP BY CASE WHEN d < 5 THEN 5 ELSE 1 END;" - ).unwrap(); + let query_str = "SELECT 3*d, COUNT(*) AS my_count FROM table_1 GROUP BY 3*d;"; + let query = parse(query_str).unwrap(); let relation = Relation::try_from(QueryWithRelations::new( &query, &relations From c7b79bc912fd285e8f8319560ccfa0f2a0a555f1 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Mon, 11 Dec 2023 14:42:50 +0100 Subject: [PATCH 08/27] ok --- CHANGELOG.md | 2 ++ src/relation/rewriting.rs | 19 ++++++++++------ src/rewriting/mod.rs | 47 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12f675cd..6059c21a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- When the clipping factor is zero, multiply by zero instead of dividing by 1 / clipping_factor [#217](https://github.com/Qrlew/qrlew/issues/217) ## [0.5.5] - 2023-12-09 ## Added diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index bc56d123..a69520cc 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -453,20 +453,25 @@ impl Relation { let value_clippings: HashMap<&str, f64> = value_clippings.into_iter().collect(); // Compute the norm let norms = self.clone().l2_norms( - entities.clone(), + entities, groups.clone(), value_clippings.keys().cloned().collect(), ); // Compute the scaling factors let scaling_factors = norms.map_fields(|field_name, expr| { if value_clippings.contains_key(&field_name) { - Expr::divide( - Expr::val(1), - Expr::greatest( + let value_clipping = value_clippings[&field_name]; + if value_clipping == 0.0 { + Expr::val(value_clipping) + } else { + Expr::divide( Expr::val(1), - Expr::divide(expr.clone(), Expr::val(value_clippings[&field_name])), - ), - ) + Expr::greatest( + Expr::val(1), + Expr::divide(expr.clone(), Expr::val(value_clipping)), + ), + ) + } } else { expr } diff --git a/src/rewriting/mod.rs b/src/rewriting/mod.rs index cd0225f1..114ed076 100644 --- a/src/rewriting/mod.rs +++ b/src/rewriting/mod.rs @@ -405,4 +405,51 @@ mod tests { } } + + #[test] + fn test_census() { + let census: Relation = Relation::table() + .name("census") + .schema( + vec![ + ("capital_loss", DataType::integer()), + ("age", DataType::integer()), + ] + .into_iter() + .collect::() + ) + .size(1000) + .build(); + let relations: Hierarchy> = vec![census] + .iter() + .map(|t| (Identifier::from(t.name()), Arc::new(t.clone().into()))) + .collect(); + let synthetic_data = SyntheticData::new(Hierarchy::from([ + (vec!["census"], Identifier::from("census")), + ])); + let privacy_unit = PrivacyUnit::from(vec![ + ("census", vec![], "_PRIVACY_UNIT_ROW_"), + ]); + let budget = Budget::new(1., 1e-3); + + let queries = [ + "SELECT SUM(CAST(capital_loss AS float) / 100000.) AS my_sum FROM census WHERE capital_loss > 2231. AND capital_loss < 4356.;", + "SELECT SUM(capital_loss / 100000) AS my_sum FROM census WHERE capital_loss > 2231. AND capital_loss < 4356.;", + "SELECT SUM(CASE WHEN age > 90 THEN 1 ELSE 0 END) AS s1 FROM census WHERE age > 20 AND age < 90;" + ]; + for query_str in queries { + println!("\n{query_str}"); + let query = parse(query_str).unwrap(); + let relation = Relation::try_from(query.with(&relations)).unwrap(); + relation.display_dot().unwrap(); + let dp_relation = relation.rewrite_with_differential_privacy( + &relations, + synthetic_data.clone(), + privacy_unit.clone(), + budget.clone() + ).unwrap(); + dp_relation.relation().display_dot().unwrap(); + } + + } } From e667529bbf0ed7b00739bdad88944dc4b140ff96 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Mon, 11 Dec 2023 14:43:45 +0100 Subject: [PATCH 09/27] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6059c21a..a0eaca9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Fixed -- When the clipping factor is zero, multiply by zero instead of dividing by 1 / clipping_factor [#217](https://github.com/Qrlew/qrlew/issues/217) +- When the clipping factor is zero, multiply by zero instead of dividing by 1 / clipping_factor [#218](https://github.com/Qrlew/qrlew/issues/218) ## [0.5.5] - 2023-12-09 ## Added From 9f358f2648d55d124fba7d188beb805f5138d379 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 16:40:37 +0100 Subject: [PATCH 10/27] ok --- src/relation/mod.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 63124cc6..7dbe6813 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -528,19 +528,6 @@ impl Reduce { pub fn group_by(&self) -> &[Column] { &self.group_by } - // /// Get group_by columns - // pub fn group_by_columns(&self) -> Vec<&Column> { - // self.group_by - // .iter() - // .filter_map(|e| { - // if let Expr::Column(column) = e { - // Some(column) - // } else { - // None - // } - // }) - // .collect() - // } /// Get the input pub fn input(&self) -> &Relation { &self.input From 72657ac1564c2fc933b502f1874c251ce3963e1f Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 18:14:57 +0100 Subject: [PATCH 11/27] Change how group bys are taken into account --- src/expr/split.rs | 8 ++++++++ src/sql/relation.rs | 39 +++++++++++++-------------------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/expr/split.rs b/src/expr/split.rs index ceea8ed7..00880a12 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -985,4 +985,12 @@ mod tests { ]); println!("split = {split}"); } + + #[test] + fn test_split_map_reduce_map_group_by_expr() { + let split = Split::from(("b", expr!(2*count(1 + y)))); + let split = split.and(Split::group_by(expr!(x-y)).into()); + let split = split.and(Split::from(("a", expr!(x-y)))); + println!("split = {split}"); + } } diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 97f78d93..ac91ecd7 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -11,7 +11,7 @@ use crate::{ ast, builder::{Ready, With, WithIterator, WithoutContext}, dialect::{Dialect, GenericDialect}, - expr::{Expr, Identifier, Split}, + expr::{Expr, Identifier, Split, Reduce}, hierarchy::{Hierarchy, Path}, namer::{self, FIELD}, parser::Parser, @@ -315,25 +315,6 @@ impl<'a> VisitedQueryRelations<'a> { .map(|e| e.with(columns).try_into()) .collect::>>()?, }; - let exprs:HashMap = named_exprs - .iter() - .cloned() - .map(|(name, x)| (x, name)) - .collect(); - // let mut named_exprs = named_exprs.into_iter().chain( - // group_by.iter() - // .cloned() - // .map(|gx| ( - // exprs - // .get(&gx) - // .unwrap_or(&namer::name_from_content(FIELD, &gx)) - // .to_string(), - // gx) - // ) - // .collect::>() - // ).collect::>(); - - // Add the having in named_exprs let having = if let Some(expr) = having { let having_name = namer::name_from_content(FIELD, &expr); @@ -355,14 +336,20 @@ impl<'a> VisitedQueryRelations<'a> { } else { None }; - // Build the Map or Reduce based on the type of split - let split = group_by.into_iter() - .fold( - Split::from_iter(named_exprs), - |s, expr| s.and(Split::Reduce(Split::group_by(expr))) + // If group_by is non-empty, start with them so that aggregations can take them into account + let split = if group_by.is_empty() { + Split::from_iter(named_exprs) + } else { + let group_by = group_by.into_iter() + .fold(Split::Reduce(Reduce::default()), + |s, expr| s.and(Split::Reduce(Split::group_by(expr))) ); - println!("split = {}", split); + named_exprs.into_iter() + .fold(group_by, + |s, named_expr| s.and(named_expr.into()) + ) + }; // Prepare the WHERE let filter: Option = selection .as_ref() From 96356693ce5debe34e5fc8724d7346bf6e39c36f Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 19:06:18 +0100 Subject: [PATCH 12/27] Fixed test --- src/relation/rewriting.rs | 2 +- src/sql/relation.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index cd27ba05..22bebd1a 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -1984,7 +1984,7 @@ mod tests { .with(expr!(count(a))) //.with_group_by_column("c") .with(("twice_c", expr!(first(2*c)))) - .group_by(expr!(c)) + .group_by(expr!(2*c)) .build(); let distinct_relation = relation.clone().distinct(); distinct_relation.display_dot(); diff --git a/src/sql/relation.rs b/src/sql/relation.rs index ac91ecd7..1943d089 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -1234,7 +1234,7 @@ mod tests { assert_eq!( relation.data_type(), DataType::structured(vec![ - ("case_d", DataType::float_values([1., 5.])), + ("field_fp0x", DataType::integer_interval(0, 30)), ("my_count", DataType::integer_interval(0, 10)), ]) ); From 0381bad421f550d0fb0c0fc92f82a94afb6c8bbd Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 22:25:46 +0100 Subject: [PATCH 13/27] Tests to be fixed --- src/expr/mod.rs | 1 + src/relation/builder.rs | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index e8048e9f..e41f3840 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -1026,6 +1026,7 @@ pub struct SuperImageVisitor<'a>(&'a DataType); impl<'a> Visitor<'a, Result> for SuperImageVisitor<'a> { fn column(&self, column: &'a Column) -> Result { + println!("DEBUG COLUMN {:?}", column); Ok(self.0[column.clone()].clone()) } diff --git a/src/relation/builder.rs b/src/relation/builder.rs index 05d4a8cd..c38d975f 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -621,7 +621,6 @@ impl Ready for ReduceBuilder { None => self.input.0, }; input.display_dot().unwrap(); - println!("{:?}", reduce.group_by); // Check that the First aggregate columns are in the GROUP BY reduce.named_aggregates.iter() .filter_map(|(_, agg)| matches!(agg.aggregate(), expr::aggregate::Aggregate::First).then_some(agg.column())) From e7352636152b098371fae5f9a0d1416a22ebdf73 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 23:14:12 +0100 Subject: [PATCH 14/27] ok --- src/relation/mod.rs | 2 +- src/rewriting/rewriting_rule.rs | 2 +- src/sql/relation.rs | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 4b4aa9a4..9eed67c5 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -555,7 +555,7 @@ impl Reduce { pub fn group_by_names(&self) -> Vec<&str> { self.group_by .iter() - .filter_map(|col| col.last().ok()) + .filter_map(|col| col.last().ok())// We should fail if there is an ambiguity .collect() } } diff --git a/src/rewriting/rewriting_rule.rs b/src/rewriting/rewriting_rule.rs index 625d3a14..93daf963 100644 --- a/src/rewriting/rewriting_rule.rs +++ b/src/rewriting/rewriting_rule.rs @@ -14,7 +14,7 @@ use crate::{ relation::{Join, Map, Reduce, Relation, Set, Table, Values, Variant as _}, rewriting::relation_with_attributes::RelationWithAttributes, synthetic_data::SyntheticData, - visitor::{Acceptor, Visited, Visitor}, + visitor::{Acceptor, Visited, Visitor}, display::Dot, }; /// A simple Property object to tag Relations properties diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 1943d089..023ed678 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -341,7 +341,7 @@ impl<'a> VisitedQueryRelations<'a> { let split = if group_by.is_empty() { Split::from_iter(named_exprs) } else { - let group_by = group_by.into_iter() + let group_by = group_by.clone().into_iter() .fold(Split::Reduce(Reduce::default()), |s, expr| s.and(Split::Reduce(Split::group_by(expr))) ); @@ -360,13 +360,13 @@ impl<'a> VisitedQueryRelations<'a> { Split::Map(map) => { let builder = Relation::map().split(map); let builder = filter.into_iter().fold(builder, |b, e| b.filter(e)); - //let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); + let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); builder.input(from).build() } Split::Reduce(reduce) => { let builder = Relation::reduce().split(reduce); let builder = filter.into_iter().fold(builder, |b, e| b.filter(e)); - //let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); + let builder = group_by.into_iter().fold(builder, |b, e| b.group_by(e)); builder.input(from).build() } }; From c409a14c431bf3a77b43a804a61fd2e16894ba3b Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 23:30:05 +0100 Subject: [PATCH 15/27] ok --- src/expr/split.rs | 67 +++-------------------------------------------- 1 file changed, 4 insertions(+), 63 deletions(-) diff --git a/src/expr/split.rs b/src/expr/split.rs index 00880a12..d71985bc 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -418,36 +418,6 @@ impl And for Map { } } -impl And for Map { - type Product = (Map, Column); - - fn and(self, col: Column) -> Self::Product { - let Map { - named_exprs, - filter, - order_by, - reduce, - } = self; - // Add the expr to the next split if needed - let (reduce, col) = if let Some(r) = reduce { - let (r, expr) = r.and(col); - (Some(r), expr) - } else { - (None, col) - }; - // Add matched sub-expressions - ( - Map::new( - named_exprs.into_iter().chain(vec![(col.last().unwrap().to_string(), Expr::Column(col.clone()))]).collect(), - filter, - order_by, - reduce, - ), - col, - ) - } -} - #[derive(Clone, Default, Debug, Hash, PartialEq, Eq)] pub struct Reduce { pub named_aggregates: Vec<(String, AggregateColumn)>, @@ -614,8 +584,8 @@ impl And for Reduce { self.group_by .into_iter() .fold((map, vec![]), |(map, mut group_by), col| { - let (map, col) = map.and(col); - group_by.push(col); + let (map, col) = map.and(Expr::from(col)); + group_by.push(col.try_into().unwrap()); (map, group_by) }); Reduce::new( @@ -641,8 +611,8 @@ impl And for Reduce { .group_by .into_iter() .fold((map, vec![]), |(map, mut group_by), col| { - let (map, col) = map.and(col); - group_by.push(col); + let (map, col) = map.and(Expr::from(col)); + group_by.push(col.try_into().unwrap()); (map, group_by) }); Reduce::new( @@ -718,35 +688,6 @@ impl And for Reduce { } } -impl And for Reduce { - type Product = (Reduce, Column); - - fn and(self, col: Column) -> Self::Product { - let Reduce { - named_aggregates, - group_by, - map, - } = self; - // Add the expr to the next split if needed - let (map, col) = if let Some(m) = map { - let (m, expr) = m.and(col); - (Some(m), expr) - } else { - (None, col) - }; - ( - Reduce::new( - named_aggregates - .into_iter() - .chain(vec![(col.last().unwrap().to_string(), AggregateColumn::first(col.last().unwrap().to_string()))]).collect(), - group_by.into_iter().chain(vec![col.clone()]).collect(), - map, - ), - col, - ) - } -} - #[derive(Clone, Debug)] pub struct SplitVisitor(String); From f27855d99bb0047716d00ef2e7757c22d49e92b6 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Dec 2023 23:48:30 +0100 Subject: [PATCH 16/27] ok --- src/differential_privacy/budget.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/differential_privacy/budget.rs b/src/differential_privacy/budget.rs index b930d100..c7aac990 100644 --- a/src/differential_privacy/budget.rs +++ b/src/differential_privacy/budget.rs @@ -1,5 +1,5 @@ use super::{DPRelation, Ready, Reduce, Relation, Result, With}; -use crate::privacy_unit_tracking::PUPRelation; +use crate::{privacy_unit_tracking::PUPRelation, relation::Variant}; use std::{cmp::Eq, hash::Hash}; /// Represent a simple privacy budget @@ -34,11 +34,12 @@ impl Budget { impl Budget { pub fn reduce(&self, reduce: &Reduce, input: PUPRelation) -> Result { + print!("DEBUG input {}", Relation::from(input.clone()).schema()); let reduce: Reduce = Relation::reduce() .with(reduce.clone()) .input(Relation::from(input)) .build(); - + print!("DEBUG reduce input {}", reduce.input().schema()); let (epsilon, delta, epsilon_tau_thresholding, delta_tau_thresholding) = if reduce.group_by().is_empty() { (self.epsilon, self.delta, 0., 0.) From 6dc6256ba99dd41f6dfa8b8345320a259663bf23 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 12 Dec 2023 11:20:36 +0100 Subject: [PATCH 17/27] Fixed problems --- src/differential_privacy/budget.rs | 2 -- src/expr/mod.rs | 1 - src/expr/split.rs | 10 +++++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/differential_privacy/budget.rs b/src/differential_privacy/budget.rs index c7aac990..37450246 100644 --- a/src/differential_privacy/budget.rs +++ b/src/differential_privacy/budget.rs @@ -34,12 +34,10 @@ impl Budget { impl Budget { pub fn reduce(&self, reduce: &Reduce, input: PUPRelation) -> Result { - print!("DEBUG input {}", Relation::from(input.clone()).schema()); let reduce: Reduce = Relation::reduce() .with(reduce.clone()) .input(Relation::from(input)) .build(); - print!("DEBUG reduce input {}", reduce.input().schema()); let (epsilon, delta, epsilon_tau_thresholding, delta_tau_thresholding) = if reduce.group_by().is_empty() { (self.epsilon, self.delta, 0., 0.) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index e41f3840..e8048e9f 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -1026,7 +1026,6 @@ pub struct SuperImageVisitor<'a>(&'a DataType); impl<'a> Visitor<'a, Result> for SuperImageVisitor<'a> { fn column(&self, column: &'a Column) -> Result { - println!("DEBUG COLUMN {:?}", column); Ok(self.0[column.clone()].clone()) } diff --git a/src/expr/split.rs b/src/expr/split.rs index d71985bc..d073e876 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -36,9 +36,13 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - let name = namer::name_from_content(FIELD, &expr); - let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); - Reduce::new(vec![], vec![name.into()], Some(map)) + if let Expr::Column(col) = expr {// If the expression is a column + Reduce::new(vec![], vec![col], None) + } else {// If not + let name = namer::name_from_content(FIELD, &expr); + let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); + Reduce::new(vec![], vec![name.into()], Some(map)) + } } pub fn into_map(self) -> Map { From 6f9bc5ee0d4f04556883866804d4e2538c9ec9dc Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 12 Dec 2023 11:31:05 +0100 Subject: [PATCH 18/27] Fixed --- tests/integration.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration.rs b/tests/integration.rs index a5969649..11e06ab9 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -57,7 +57,8 @@ const QUERIES: &[&str] = &[ // Some GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d", "SELECT count(b) FROM table_1 GROUP BY CEIL(d)", - "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", + "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY CEIL(d)", + // "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", // Some WHERE and GROUP BY "SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d", "SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d", From 1b1c539d0bdb63e0d320259d2a865f55f9eba05b Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 12:20:13 +0100 Subject: [PATCH 19/27] ok --- src/relation/builder.rs | 6 +- src/relation/dot.rs | 4 +- src/relation/mod.rs | 164 ++++++++++++++++++++++++++++++++++------ src/rewriting/mod.rs | 43 +++++++---- tests/integration.rs | 12 ++- 5 files changed, 184 insertions(+), 45 deletions(-) diff --git a/src/relation/builder.rs b/src/relation/builder.rs index b5a2b2f4..1dd877c4 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -852,6 +852,9 @@ impl Ready for JoinBuilder { .name .clone() .unwrap_or(namer::name_from_content(JOIN, &self)); + let operator = self + .operator + .unwrap_or(JoinOperator::Inner(JoinConstraint::Natural)); let left_names = self .left .0 @@ -884,9 +887,6 @@ impl Ready for JoinBuilder { .to_string() }) .collect(); - let operator = self - .operator - .unwrap_or(JoinOperator::Inner(JoinConstraint::Natural)); Ok(Join::new( name, left_names, diff --git a/src/relation/dot.rs b/src/relation/dot.rs index 6b16b7d8..98840a59 100644 --- a/src/relation/dot.rs +++ b/src/relation/dot.rs @@ -113,7 +113,9 @@ impl<'a> Visitor<'a, FieldDataTypes> for DotVisitor { join.right() .schema() .iter() - .map(|f| vec![Join::right_name(), f.name()]), + .filter_map(|f| (!join.operator().is_natural() || join.left().schema().field(&f.name()).is_err()) + .then_some(vec![Join::right_name(), f.name()]) + ) ) .zip(join.schema().iter()) .map(|(p, field)| { diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 2bfe0f1c..fd83aa21 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -701,6 +701,17 @@ impl JoinOperator { JoinOperator::Cross => (false, false), } } + + // Returns true is the contraint is Natural + pub fn is_natural(&self) -> bool { + match self { + JoinOperator::Inner(c) + | JoinOperator::LeftOuter(c) + | JoinOperator::RightOuter(c) + | JoinOperator::FullOuter(c) if matches!(c, JoinConstraint::Natural) => true, + _ => false + } + } } impl DataType { @@ -951,36 +962,49 @@ impl Join { .into_iter() .zip(left_schema.iter()) .map(|(name, field)| { - Field::new( - name, - if transform_datatype_in_optional_left { - DataType::optional(field.data_type()) - } else { - field.data_type() - }, - if right_is_unique { - field.constraint() - } else { + let (data_type, constraint) = if ( + matches!(operator, JoinOperator::RightOuter(JoinConstraint::Natural)) + && right_schema.field(&field.name()).is_ok() + ) { // if the join is of type NATURAL RIGHT then the duplicate fields are the fields from the right Relation + let right_field = right_schema.field(&field.name()).unwrap(); + (right_field.data_type(), right_field.constraint()) + } else if ( + matches!(operator, JoinOperator::FullOuter(JoinConstraint::Natural)) + && right_schema.field(&field.name()).is_ok() + ) { // if the join is of type NATURAL FULL then the datatype of the duplicate fields is the union of the left and right datatypes + let right_field = right_schema.field(&field.name()).unwrap(); + ( + field.data_type().super_union(&right_field.data_type()).unwrap(), None - }, - ) + ) + } else { + ( + transform_datatype_in_optional_left.then_some(DataType::optional(field.data_type())) + .unwrap_or(field.data_type()), + right_is_unique.then_some(field.constraint()).unwrap_or(None) + ) + }; + Field::new(name, data_type, constraint) }); let right_fields = right_names .into_iter() .zip(right_schema.iter()) - .map(|(name, field)| { - Field::new( - name, - if transform_datatype_in_optional_right { - DataType::optional(field.data_type()) - } else { - field.data_type() - }, - if left_is_unique { - field.constraint() - } else { - None - }, + .filter_map(|(name, field)| { + (!operator.is_natural() || left_schema.field(&field.name()).is_err()) // we filter the duplicate fields + .then_some( + Field::new( + name, + if transform_datatype_in_optional_right { + DataType::optional(field.data_type()) + } else { + field.data_type() + }, + if left_is_unique { + field.constraint() + } else { + None + }, + ) ) }); left_fields.chain(right_fields).collect() @@ -2610,4 +2634,94 @@ mod tests { .collect(); assert_eq!(join.schema(), &correct_schema); } + + #[test] + fn test_natural_join_builder() { + let schema1: Schema = vec![ + ("a", DataType::integer_interval(-5, 5)), + ("b", DataType::integer_interval(-2, 2)), + ] + .into_iter() + .collect(); + let table1: Relation = Relation::table().name("table1").schema(schema1).build(); + let schema2: Schema = vec![ + ("a", DataType::integer_interval(0, 10)), + ("c", DataType::integer_interval(0, 20)), + ] + .into_iter() + .collect(); + let table2: Relation = Relation::table().name("table1").schema(schema2).build(); + + // natural inner join + let relation: Relation = Relation::join() + .left(table1.clone()) + .right(table2.clone()) + .left_names(vec!["a", "b"]) + .right_names(vec!["my_a", "c"]) + .inner() + .build(); + relation.display_dot().unwrap(); + assert_eq!( + relation.data_type(), + DataType::structured([ + ("a", DataType::integer_interval(0, 5)), + ("b", DataType::integer_interval(-2, 2)), + ("c", DataType::integer_interval(0, 20)) + ]) + ); + + // natural left join + let relation: Relation = Relation::join() + .left(table1.clone()) + .right(table2.clone()) + .left_names(vec!["a", "b"]) + .right_names(vec!["my_a", "c"]) + .left_outer() + .build(); + relation.display_dot().unwrap(); + assert_eq!( + relation.data_type(), + DataType::structured([ + ("a", DataType::integer_interval(-5, 5)), + ("b", DataType::integer_interval(-2, 2)), + ("c", DataType::optional(DataType::integer_interval(0, 20))) + ]) + ); + + // natural right join + let relation: Relation = Relation::join() + .left(table1.clone()) + .right(table2.clone()) + .left_names(vec!["a", "b"]) + .right_names(vec!["my_a", "c"]) + .right_outer() + .build(); + relation.display_dot().unwrap(); + assert_eq!( + relation.data_type(), + DataType::structured([ + ("a", DataType::integer_interval(0, 10)), + ("b", DataType::optional(DataType::integer_interval(-2, 2))), + ("c", DataType::integer_interval(0, 20)) + ]) + ); + + // natural full join + let relation: Relation = Relation::join() + .left(table1.clone()) + .right(table2.clone()) + .left_names(vec!["a", "b"]) + .right_names(vec!["my_a", "c"]) + .full_outer() + .build(); + relation.display_dot().unwrap(); + assert_eq!( + relation.data_type(), + DataType::structured([ + ("a", DataType::integer_interval(-5, 10)), + ("b", DataType::optional(DataType::integer_interval(-2, 2))), + ("c", DataType::optional(DataType::integer_interval(0, 20))) + ]) + ); + } } diff --git a/src/rewriting/mod.rs b/src/rewriting/mod.rs index 114ed076..caebf59d 100644 --- a/src/rewriting/mod.rs +++ b/src/rewriting/mod.rs @@ -122,6 +122,7 @@ mod tests { use super::*; use crate::{ + ast, builder::{Ready, With}, display::Dot, expr::Identifier, @@ -159,9 +160,9 @@ mod tests { #[test] fn test_rewrite_with_differential_privacy() { - let database = postgresql::test_database(); + let mut database = postgresql::test_database(); let relations = database.relations(); - let query = parse("SELECT order_id, sum(price) FROM item_table GROUP BY order_id").unwrap(); + let synthetic_data = SyntheticData::new(Hierarchy::from([ (vec!["item_table"], Identifier::from("item_table")), (vec!["order_table"], Identifier::from("order_table")), @@ -180,18 +181,32 @@ mod tests { ("user_table", vec![], "name"), ]); let budget = Budget::new(1., 1e-3); - let relation = Relation::try_from(query.with(&relations)).unwrap(); - let relation_with_private_query = relation - .rewrite_with_differential_privacy(&relations, synthetic_data, privacy_unit, budget) - .unwrap(); - relation_with_private_query - .relation() - .display_dot() - .unwrap(); - println!( - "PrivateQuery = {}", - relation_with_private_query.private_query() - ); + + let queries = [ + "SELECT order_id, sum(price) FROM item_table GROUP BY order_id", + "SELECT sum(distinct price) FROM item_table GROUP BY order_id HAVING count(*) > 2", + ]; + + for q in queries { + println!("=================================\n{q}"); + let query = parse(q).unwrap(); + let relation = Relation::try_from(query.with(&relations)).unwrap(); + let relation_with_private_query = relation + .rewrite_with_differential_privacy(&relations, synthetic_data.clone(), privacy_unit.clone(), budget.clone()) + .unwrap(); + relation_with_private_query + .relation() + .display_dot() + .unwrap(); + let dp_query = ast::Query::from(&relation_with_private_query.relation().clone()).to_string(); + println!("\n{dp_query}"); + _ = database + .query(dp_query.as_str()) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + } } #[test] diff --git a/tests/integration.rs b/tests/integration.rs index 9cebaf7a..9992ad39 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -95,8 +95,16 @@ const QUERIES: &[&str] = &[ "SELECT LOWER(z) FROM table_2 LIMIT 5", // distinct "SELECT DISTINCT COUNT(*) FROM table_1 GROUP BY d", - "SELECT DISTINCt c, d FROM table_1", - "SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c" + "SELECT DISTINCT c, d FROM table_1", + "SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c", + "SELECT SUM(DISTINCT a) AS s1 FROM table_1 GROUP BY c HAVING COUNT(*) > 5;", + // natural joins + "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL INNER JOIN t2", + "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL LEFT JOIN t2", + "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL RIGHT JOIN t2", + "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL FULL JOIN t2", + + ]; #[cfg(feature = "sqlite")] From 76d5d56d4fddbb2361efda2085363dc269d27dc8 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 12:21:42 +0100 Subject: [PATCH 20/27] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0eaca9a..1a1a0a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Fixed +- Natural joins [#220](https://github.com/Qrlew/qrlew/issues/220) - When the clipping factor is zero, multiply by zero instead of dividing by 1 / clipping_factor [#218](https://github.com/Qrlew/qrlew/issues/218) ## [0.5.5] - 2023-12-09 From a0cd1d9884bb3f89ada7816be8182d173ba81c9b Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 12:22:14 +0100 Subject: [PATCH 21/27] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a1a0a95..fa786aae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Fixed -- Natural joins [#220](https://github.com/Qrlew/qrlew/issues/220) +- Natural joins [#221](https://github.com/Qrlew/qrlew/issues/221) - When the clipping factor is zero, multiply by zero instead of dividing by 1 / clipping_factor [#218](https://github.com/Qrlew/qrlew/issues/218) ## [0.5.5] - 2023-12-09 From 6ba6b7be43eea0b18735a16b8b8fe5f85b580f7b Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 13:19:18 +0100 Subject: [PATCH 22/27] ok --- src/relation/mod.rs | 46 +++++++++++++++------------------------------ 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index bbea90c6..f24b08aa 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -946,26 +946,16 @@ impl Join { .into_iter() .zip(left_schema.iter()) .map(|(name, field)| { - let (data_type, constraint) = if ( - matches!(operator, JoinOperator::RightOuter(JoinConstraint::Natural)) - && right_schema.field(&field.name()).is_ok() - ) { // if the join is of type NATURAL RIGHT then the duplicate fields are the fields from the right Relation + let (data_type, constraint) = if matches!(operator, JoinOperator::RightOuter(JoinConstraint::Natural)) && right_schema.field(&field.name()).is_ok() { let right_field = right_schema.field(&field.name()).unwrap(); (right_field.data_type(), right_field.constraint()) - } else if ( - matches!(operator, JoinOperator::FullOuter(JoinConstraint::Natural)) - && right_schema.field(&field.name()).is_ok() - ) { // if the join is of type NATURAL FULL then the datatype of the duplicate fields is the union of the left and right datatypes + } else if matches!(operator, JoinOperator::FullOuter(JoinConstraint::Natural)) && right_schema.field(&field.name()).is_ok() { let right_field = right_schema.field(&field.name()).unwrap(); - ( - field.data_type().super_union(&right_field.data_type()).unwrap(), - None - ) + (field.data_type().super_union(&right_field.data_type()).unwrap(), None) } else { ( - transform_datatype_in_optional_left.then_some(DataType::optional(field.data_type())) - .unwrap_or(field.data_type()), - right_is_unique.then_some(field.constraint()).unwrap_or(None) + transform_datatype_in_optional_left.then_some(DataType::optional(field.data_type())).unwrap_or(field.data_type()), + right_is_unique.then_some(field.constraint()).unwrap_or(None), ) }; Field::new(name, data_type, constraint) @@ -974,22 +964,16 @@ impl Join { .into_iter() .zip(right_schema.iter()) .filter_map(|(name, field)| { - (!operator.is_natural() || left_schema.field(&field.name()).is_err()) // we filter the duplicate fields - .then_some( - Field::new( - name, - if transform_datatype_in_optional_right { - DataType::optional(field.data_type()) - } else { - field.data_type() - }, - if left_is_unique { - field.constraint() - } else { - None - }, - ) - ) + (!operator.is_natural() || left_schema.field(&field.name()).is_err()) + .then_some(Field::new( + name, + if transform_datatype_in_optional_right { + DataType::optional(field.data_type()) + } else { + field.data_type() + }, + if left_is_unique { field.constraint() } else { None }, + )) }); left_fields.chain(right_fields).collect() } From 6c1d675ac666e9611d19bfdf9a61a1880470ea6f Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 13:49:28 +0100 Subject: [PATCH 23/27] clean the code for building the Join Schema --- src/relation/mod.rs | 65 ++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index f24b08aa..ccbf87cd 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -934,46 +934,55 @@ impl Join { let (left_schema, right_schema) = operator.filtered_schemas(left, right); let (left_is_unique, right_is_unique) = operator.has_unique_constraint(left.schema(), right.schema()); - let transform_datatype_in_optional_left: bool = match operator { - JoinOperator::LeftOuter(_) | JoinOperator::Inner(_) | JoinOperator::Cross => false, - _ => true, - }; - let transform_datatype_in_optional_right = match operator { - JoinOperator::RightOuter(_) | JoinOperator::Inner(_) | JoinOperator::Cross => false, - _ => true, - }; let left_fields = left_names .into_iter() .zip(left_schema.iter()) .map(|(name, field)| { - let (data_type, constraint) = if matches!(operator, JoinOperator::RightOuter(JoinConstraint::Natural)) && right_schema.field(&field.name()).is_ok() { - let right_field = right_schema.field(&field.name()).unwrap(); - (right_field.data_type(), right_field.constraint()) - } else if matches!(operator, JoinOperator::FullOuter(JoinConstraint::Natural)) && right_schema.field(&field.name()).is_ok() { - let right_field = right_schema.field(&field.name()).unwrap(); - (field.data_type().super_union(&right_field.data_type()).unwrap(), None) - } else { - ( - transform_datatype_in_optional_left.then_some(DataType::optional(field.data_type())).unwrap_or(field.data_type()), - right_is_unique.then_some(field.constraint()).unwrap_or(None), + let (data_type, constraint) = match operator { + JoinOperator::RightOuter(JoinConstraint::Natural) if right_schema.field(&field.name()).is_ok() => { + // if `field` is present in both `left` and `right` and the operator is of type NATURAL RIGHT OUTER, the datatype of the field is the datatype of the right field + let right_field = right_schema.field(&field.name()).unwrap(); + (right_field.data_type(), right_field.constraint()) + }, + JoinOperator::FullOuter(JoinConstraint::Natural) if right_schema.field(&field.name()).is_ok() => { + // if `field` is present in both `left` and `right` and the operator is of type NATURAL FULL OUTER, the datatype of the field is the super union of the datatypes of the right and left field datatypes + let right_field = right_schema.field(&field.name()).unwrap(); + (field.data_type().super_union(&right_field.data_type()).unwrap(), None) + }, + JoinOperator::RightOuter(_) | JoinOperator::FullOuter(_) => ( + // if the operator if of type RIGHT or FULL OUTER (without NATURAL constraint), the current (left) field is an optional + DataType::optional(field.data_type()), + right_is_unique.then_some(field.constraint()).unwrap_or(None) + ), + _ => ( + field.data_type(), + right_is_unique.then_some(field.constraint()).unwrap_or(None) ) }; Field::new(name, data_type, constraint) }); + let right_fields = right_names .into_iter() .zip(right_schema.iter()) .filter_map(|(name, field)| { - (!operator.is_natural() || left_schema.field(&field.name()).is_err()) - .then_some(Field::new( - name, - if transform_datatype_in_optional_right { - DataType::optional(field.data_type()) - } else { - field.data_type() - }, - if left_is_unique { field.constraint() } else { None }, - )) + let data_type_constraint = match operator { + JoinOperator::Inner(JoinConstraint::Natural) + | JoinOperator::LeftOuter(JoinConstraint::Natural) + | JoinOperator::RightOuter(JoinConstraint::Natural) + | JoinOperator::FullOuter(JoinConstraint::Natural) if left_schema.field(&field.name()).is_ok() => None, // remove the duplicates when JoinConstaint is Natural + JoinOperator::LeftOuter(_) | JoinOperator::FullOuter(_) => Some(( + // if the operator if of type LEFT or FULL OUTER (without NATURAL constraint), the current (right) field is an optional + DataType::optional(field.data_type()), + left_is_unique.then_some(field.constraint()).unwrap_or(None) + )), + _ => Some(( + // if the operator if of type RIGHT or FULL OUTER (without NATURAL constraint), the current (right) field is an optional + field.data_type(), + left_is_unique.then_some(field.constraint()).unwrap_or(None) + )) + }; + data_type_constraint.map(|(data_type, constraint)| Field::new(name, data_type, constraint)) }); left_fields.chain(right_fields).collect() } From c22e3418649e0653f2508faca39c9ae6234d7b89 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 12 Dec 2023 14:20:09 +0100 Subject: [PATCH 24/27] clean the code for building the Join Schema --- src/relation/mod.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index ccbf87cd..1b2ca50b 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -938,18 +938,22 @@ impl Join { .into_iter() .zip(left_schema.iter()) .map(|(name, field)| { - let (data_type, constraint) = match operator { - JoinOperator::RightOuter(JoinConstraint::Natural) if right_schema.field(&field.name()).is_ok() => { - // if `field` is present in both `left` and `right` and the operator is of type NATURAL RIGHT OUTER, the datatype of the field is the datatype of the right field - let right_field = right_schema.field(&field.name()).unwrap(); + let (data_type, constraint) = match (operator, right_schema.field(&field.name())) { + (JoinOperator::RightOuter(JoinConstraint::Natural), Ok(right_field)) => { + // if + // - operator is of type NATURAL RIGHT OUTER and + // - `field` is present in both the `left` and `right` relations + // then the datatype of the corresponding field in the JOIN is the datatype of the right field (right_field.data_type(), right_field.constraint()) }, - JoinOperator::FullOuter(JoinConstraint::Natural) if right_schema.field(&field.name()).is_ok() => { - // if `field` is present in both `left` and `right` and the operator is of type NATURAL FULL OUTER, the datatype of the field is the super union of the datatypes of the right and left field datatypes - let right_field = right_schema.field(&field.name()).unwrap(); + (JoinOperator::FullOuter(JoinConstraint::Natural), Ok(right_field)) => { + // if + // - operator is of type NATURAL RIGHT OUTER and + // - `field` is present in both the `left` and `right` relations + // then the datatype of the corresponding field in the JOIN is the super union of the datatypes of the right and left field datatypes (field.data_type().super_union(&right_field.data_type()).unwrap(), None) }, - JoinOperator::RightOuter(_) | JoinOperator::FullOuter(_) => ( + (JoinOperator::RightOuter(_) | JoinOperator::FullOuter(_), _) => ( // if the operator if of type RIGHT or FULL OUTER (without NATURAL constraint), the current (left) field is an optional DataType::optional(field.data_type()), right_is_unique.then_some(field.constraint()).unwrap_or(None) @@ -972,12 +976,10 @@ impl Join { | JoinOperator::RightOuter(JoinConstraint::Natural) | JoinOperator::FullOuter(JoinConstraint::Natural) if left_schema.field(&field.name()).is_ok() => None, // remove the duplicates when JoinConstaint is Natural JoinOperator::LeftOuter(_) | JoinOperator::FullOuter(_) => Some(( - // if the operator if of type LEFT or FULL OUTER (without NATURAL constraint), the current (right) field is an optional DataType::optional(field.data_type()), left_is_unique.then_some(field.constraint()).unwrap_or(None) )), _ => Some(( - // if the operator if of type RIGHT or FULL OUTER (without NATURAL constraint), the current (right) field is an optional field.data_type(), left_is_unique.then_some(field.constraint()).unwrap_or(None) )) From 329df3f2f9983382f61a3943b5efaeb3b0fea050 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 14 Dec 2023 16:44:19 +0100 Subject: [PATCH 25/27] ok --- src/differential_privacy/group_by.rs | 4 +- src/privacy_unit_tracking/mod.rs | 7 +- src/relation/builder.rs | 94 +---- src/relation/dot.rs | 34 +- src/relation/mod.rs | 601 +++++---------------------- src/relation/rewriting.rs | 58 ++- src/relation/sql.rs | 36 +- src/rewriting/mod.rs | 2 +- src/sql/relation.rs | 293 ++++++++++++- tests/integration.rs | 17 +- 10 files changed, 476 insertions(+), 670 deletions(-) diff --git a/src/differential_privacy/group_by.rs b/src/differential_privacy/group_by.rs index c88c3189..d4117ea2 100644 --- a/src/differential_privacy/group_by.rs +++ b/src/differential_privacy/group_by.rs @@ -189,7 +189,7 @@ impl Relation { .right_names(right_names.clone()) .left(left) .left_names(left_names.clone()) - .left_outer() + .left_outer(Expr::val(true)) .on_iter(on) .build(); @@ -561,7 +561,7 @@ mod tests { .deref() .clone(); let join: Join = Join::builder() - .inner() + .inner(Expr::val(true)) .on_eq("order_id", "id") .left(left.clone()) .right(right.clone()) diff --git a/src/privacy_unit_tracking/mod.rs b/src/privacy_unit_tracking/mod.rs index d58dbb5f..08bed803 100644 --- a/src/privacy_unit_tracking/mod.rs +++ b/src/privacy_unit_tracking/mod.rs @@ -174,8 +174,7 @@ impl Relation { referred_relation }; let join: Relation = Relation::join() - .inner() - .on(Expr::eq( + .inner(Expr::eq( Expr::qcol(Join::right_name(), &referring_id), Expr::qcol(Join::left_name(), &referred_id), )) @@ -670,7 +669,7 @@ mod tests { .deref() .clone(); let join: Join = Join::builder() - .inner() + .inner(Expr::val(true)) .on_eq("order_id", "id") .left(left.clone()) .right(right.clone()) @@ -739,7 +738,7 @@ mod tests { .deref() .clone(); let join: Join = Join::builder() - .inner() + .inner(Expr::val(true)) .on_eq("item", "item") .left(table.clone()) .right(table.clone()) diff --git a/src/relation/builder.rs b/src/relation/builder.rs index 36f7bf3a..ae803b6d 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -3,7 +3,7 @@ use std::{hash::Hash, sync::Arc}; use itertools::Itertools; use super::{ - Error, Join, JoinConstraint, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Schema, Set, + Error, Join, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Schema, Set, SetOperator, SetQuantifier, Table, Values, Variant, }; use crate::{ @@ -692,23 +692,23 @@ impl JoinBuilder Self { - self.operator = Some(JoinOperator::Inner(JoinConstraint::Natural)); + pub fn inner(mut self, expr: Expr) -> Self { + self.operator = Some(JoinOperator::Inner(expr)); self } - pub fn left_outer(mut self) -> Self { - self.operator = Some(JoinOperator::LeftOuter(JoinConstraint::Natural)); + pub fn left_outer(mut self, expr: Expr) -> Self { + self.operator = Some(JoinOperator::LeftOuter(expr)); self } - pub fn right_outer(mut self) -> Self { - self.operator = Some(JoinOperator::RightOuter(JoinConstraint::Natural)); + pub fn right_outer(mut self, expr: Expr) -> Self { + self.operator = Some(JoinOperator::RightOuter(expr)); self } - pub fn full_outer(mut self) -> Self { - self.operator = Some(JoinOperator::FullOuter(JoinConstraint::Natural)); + pub fn full_outer(mut self, expr: Expr) -> Self { + self.operator = Some(JoinOperator::FullOuter(expr)); self } @@ -719,18 +719,12 @@ impl JoinBuilder Self { self.operator = match self.operator { - Some(JoinOperator::Inner(_)) => Some(JoinOperator::Inner(JoinConstraint::On(expr))), - Some(JoinOperator::LeftOuter(_)) => { - Some(JoinOperator::LeftOuter(JoinConstraint::On(expr))) - } - Some(JoinOperator::RightOuter(_)) => { - Some(JoinOperator::RightOuter(JoinConstraint::On(expr))) - } - Some(JoinOperator::FullOuter(_)) => { - Some(JoinOperator::FullOuter(JoinConstraint::On(expr))) - } + Some(JoinOperator::Inner(_)) => Some(JoinOperator::Inner(expr)), + Some(JoinOperator::LeftOuter(_)) => Some(JoinOperator::LeftOuter(expr)), + Some(JoinOperator::RightOuter(_)) => Some(JoinOperator::RightOuter(expr)), + Some(JoinOperator::FullOuter(_)) => Some(JoinOperator::FullOuter(expr)), Some(JoinOperator::Cross) => Some(JoinOperator::Cross), - None => Some(JoinOperator::Inner(JoinConstraint::On(expr))), + None => Some(JoinOperator::Inner(expr)), }; self } @@ -751,59 +745,14 @@ impl JoinBuilder Self { self.operator = match self.operator { - Some(JoinOperator::Inner(JoinConstraint::On(on))) => { - Some(JoinOperator::Inner(JoinConstraint::On(Expr::and(expr, on)))) - } - Some(JoinOperator::LeftOuter(JoinConstraint::On(on))) => Some(JoinOperator::LeftOuter( - JoinConstraint::On(Expr::and(expr, on)), - )), - Some(JoinOperator::RightOuter(JoinConstraint::On(on))) => Some( - JoinOperator::RightOuter(JoinConstraint::On(Expr::and(expr, on))), - ), - Some(JoinOperator::FullOuter(JoinConstraint::On(on))) => Some(JoinOperator::FullOuter( - JoinConstraint::On(Expr::and(expr, on)), - )), + Some(JoinOperator::Inner(x)) => Some(JoinOperator::Inner(Expr::and(expr, x))), + Some(JoinOperator::LeftOuter(x)) => Some(JoinOperator::LeftOuter(Expr::and(expr, x))), + Some(JoinOperator::RightOuter(x)) => Some(JoinOperator::RightOuter(Expr::and(expr, x))), + Some(JoinOperator::FullOuter(x)) => Some(JoinOperator::FullOuter(Expr::and(expr, x))), op => op, }; self } - /// Add a using condition - pub fn using>(mut self, using: I) -> Self { - let using: Identifier = using.into(); - self.operator = match self.operator { - Some(JoinOperator::Inner(JoinConstraint::Using(mut identifiers))) => { - identifiers.push(using); - Some(JoinOperator::Inner(JoinConstraint::Using(identifiers))) - } - Some(JoinOperator::LeftOuter(JoinConstraint::Using(mut identifiers))) => { - identifiers.push(using); - Some(JoinOperator::LeftOuter(JoinConstraint::Using(identifiers))) - } - Some(JoinOperator::RightOuter(JoinConstraint::Using(mut identifiers))) => { - identifiers.push(using); - Some(JoinOperator::RightOuter(JoinConstraint::Using(identifiers))) - } - Some(JoinOperator::FullOuter(JoinConstraint::Using(mut identifiers))) => { - identifiers.push(using); - Some(JoinOperator::FullOuter(JoinConstraint::Using(identifiers))) - } - Some(JoinOperator::Inner(_)) => { - Some(JoinOperator::Inner(JoinConstraint::Using(vec![using]))) - } - Some(JoinOperator::LeftOuter(_)) => { - Some(JoinOperator::LeftOuter(JoinConstraint::Using(vec![using]))) - } - Some(JoinOperator::RightOuter(_)) => { - Some(JoinOperator::RightOuter(JoinConstraint::Using(vec![using]))) - } - Some(JoinOperator::FullOuter(_)) => { - Some(JoinOperator::FullOuter(JoinConstraint::Using(vec![using]))) - } - Some(JoinOperator::Cross) => Some(JoinOperator::Cross), - None => Some(JoinOperator::Inner(JoinConstraint::Using(vec![using]))), - }; - self - } /// Set directly the full JOIN operator pub fn operator(mut self, operator: JoinOperator) -> Self { @@ -869,7 +818,7 @@ impl Ready for JoinBuilder { .unwrap_or(namer::name_from_content(JOIN, &self)); let operator = self .operator - .unwrap_or(JoinOperator::Inner(JoinConstraint::Natural)); + .unwrap_or(JoinOperator::Cross); let left_names = self .left .0 @@ -1262,8 +1211,7 @@ mod tests { let join: Relation = Relation::join() .left(table1) .right(table2) - .left_outer() - .on_iter(vec![Expr::eq(Expr::col("a"), Expr::col("c"))]) + .left_outer(Expr::eq(Expr::col("a"), Expr::col("c"))) .left_names(vec!["a1", "b1"]) //.on_iter(vec![Expr::eq(Expr::col("a"), Expr::col("c")), Expr::eq(Expr::col("b"), Expr::col("d"))]) .build(); @@ -1492,7 +1440,7 @@ mod tests { let join: Relation = Relation::join() .left(table1.clone()) .right(table1.clone()) - .inner() + .inner(Expr::val(true)) .on_eq("d", "d") .names(Hierarchy::::from_iter(vec![ ([Join::left_name(), "a"], "a1".to_string()), diff --git a/src/relation/dot.rs b/src/relation/dot.rs index 98840a59..01267e2e 100644 --- a/src/relation/dot.rs +++ b/src/relation/dot.rs @@ -1,4 +1,4 @@ -use super::{Error, Field, JoinConstraint, JoinOperator, Relation, Variant as _, Visitor}; +use super::{Error, Field, JoinOperator, Relation, Variant as _, Visitor}; use crate::{ data_type::DataTyped, display::{self, colors}, @@ -113,9 +113,7 @@ impl<'a> Visitor<'a, FieldDataTypes> for DotVisitor { join.right() .schema() .iter() - .filter_map(|f| (!join.operator().is_natural() || join.left().schema().field(&f.name()).is_err()) - .then_some(vec![Join::right_name(), f.name()]) - ) + .map(|f| vec![Join::right_name(), f.name()]) ) .zip(join.schema().iter()) .map(|(p, field)| { @@ -244,31 +242,13 @@ impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T>> dot::Labeller<'a, Node<'a, } Relation::Join(join) => { let operator = match &join.operator { - JoinOperator::Inner(JoinConstraint::On(expr)) - | JoinOperator::LeftOuter(JoinConstraint::On(expr)) - | JoinOperator::RightOuter(JoinConstraint::On(expr)) - | JoinOperator::FullOuter(JoinConstraint::On(expr)) => { + JoinOperator::Inner(expr) + | JoinOperator::LeftOuter(expr) + | JoinOperator::RightOuter(expr) + | JoinOperator::FullOuter(expr) => { format!("
{} ON {}", join.operator.to_string(), expr) } - JoinOperator::Inner(JoinConstraint::Using(identifiers)) - | JoinOperator::LeftOuter(JoinConstraint::Using(identifiers)) - | JoinOperator::RightOuter(JoinConstraint::Using(identifiers)) - | JoinOperator::FullOuter(JoinConstraint::Using(identifiers)) => format!( - "
{} USING ({})", - join.operator.to_string(), - identifiers.iter().join(", ") - ), - JoinOperator::Inner(JoinConstraint::Natural) - | JoinOperator::LeftOuter(JoinConstraint::Natural) - | JoinOperator::RightOuter(JoinConstraint::Natural) - | JoinOperator::FullOuter(JoinConstraint::Natural) => { - format!("
NATURAL {}", join.operator.to_string()) - } - JoinOperator::Inner(JoinConstraint::None) - | JoinOperator::LeftOuter(JoinConstraint::None) - | JoinOperator::RightOuter(JoinConstraint::None) - | JoinOperator::FullOuter(JoinConstraint::None) - | JoinOperator::Cross => format!("
{}", join.operator.to_string()), + JoinOperator::Cross => format!("
{}", join.operator.to_string()), }; format!( "{} size ∈ {}
{}{}", diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 1b2ca50b..46a92f35 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -623,10 +623,10 @@ impl Variant for Reduce { /// Join type #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum JoinOperator { - Inner(JoinConstraint), - LeftOuter(JoinConstraint), - RightOuter(JoinConstraint), - FullOuter(JoinConstraint), + Inner(Expr), + LeftOuter(Expr), + RightOuter(Expr), + FullOuter(Expr), Cross, } @@ -678,22 +678,55 @@ impl JoinOperator { fn has_unique_constraint(&self, left_schema: &Schema, right_schema: &Schema) -> (bool, bool) { match self { - JoinOperator::Inner(c) - | JoinOperator::LeftOuter(c) - | JoinOperator::RightOuter(c) - | JoinOperator::FullOuter(c) => c.has_unique_constraint(left_schema, right_schema), - JoinOperator::Cross => (false, false), - } - } - - // Returns true is the contraint is Natural - pub fn is_natural(&self) -> bool { - match self { - JoinOperator::Inner(c) - | JoinOperator::LeftOuter(c) - | JoinOperator::RightOuter(c) - | JoinOperator::FullOuter(c) if matches!(c, JoinConstraint::Natural) => true, - _ => false + JoinOperator::Inner(Expr::Function(f)) + | JoinOperator::LeftOuter(Expr::Function(f)) + | JoinOperator::RightOuter(Expr::Function(f)) + | JoinOperator::FullOuter(Expr::Function(f)) if f.function() == function::Function::Eq => { + let fields_with_unique_or_primary_key_constraint = Hierarchy::from_iter( + left_schema + .iter() + .map(|f| { + ( + vec![Join::left_name(), f.name()], + f.has_unique_or_primary_key_constraint(), + ) + }) + .chain(right_schema.iter().map(|f| { + ( + vec![Join::right_name(), f.name()], + f.has_unique_or_primary_key_constraint(), + ) + })), + ); + let mut left = false; + let mut right = false; + if let Expr::Column(c) = &f.arguments()[0] { + if fields_with_unique_or_primary_key_constraint + .get_key_value(c) + .unwrap() + .0[0] + == Join::left_name() + { + left = fields_with_unique_or_primary_key_constraint[c.as_slice()] + } else { + right = fields_with_unique_or_primary_key_constraint[c.as_slice()] + } + } + if let Expr::Column(c) = &f.arguments()[1] { + if fields_with_unique_or_primary_key_constraint + .get_key_value(c) + .unwrap() + .0[0] + == Join::left_name() + { + left = fields_with_unique_or_primary_key_constraint[c.as_slice()] + } else { + right = fields_with_unique_or_primary_key_constraint[c.as_slice()] + } + } + (left, right) + }, + _ => (false, false), } } } @@ -703,12 +736,8 @@ impl DataType { /// filtered by the `Expr` equivalent to the `JoinOperator` fn filter_by_join_operator(&self, join_op: &JoinOperator) -> DataType { match join_op { - JoinOperator::Inner(c) => { - let x = Expr::from((c, self)); - self.filter(&x) - } - JoinOperator::LeftOuter(c) => { - let x = Expr::from((c, self)); + JoinOperator::Inner(x) => self.filter(&x), + JoinOperator::LeftOuter(x) => { let filtered_data_type = self.filter(&x); DataType::structured([ (Join::left_name(), self[Join::left_name()].clone()), @@ -718,8 +747,7 @@ impl DataType { ), ]) } - JoinOperator::RightOuter(c) => { - let x = Expr::from((c, self)); + JoinOperator::RightOuter(x) => { let filtered_data_type = self.filter(&x); DataType::structured([ ( @@ -750,135 +778,6 @@ impl fmt::Display for JoinOperator { } } -/// Join constraint -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub enum JoinConstraint { - On(Expr), - Using(Vec), - Natural, - None, -} - -impl JoinConstraint { - /// Rename all exprs in the constraint - pub fn rename<'a>(&'a self, columns: &'a Hierarchy) -> Self { - match self { - JoinConstraint::On(expr) => JoinConstraint::On(expr.rename(columns)), - JoinConstraint::Using(identifiers) => JoinConstraint::Using( - identifiers - .iter() - .map(|i| columns.get(i).unwrap().clone()) - .collect(), - ), - JoinConstraint::Natural => JoinConstraint::Natural, - JoinConstraint::None => JoinConstraint::None, - } - } - - /// Returns a tuple of bool where - /// the first (resp. second) item is `true` if - /// - the current `JoinConstraint` is an `On` - /// - the wrapped expression if of type `(Column(_) = Column(_))` AND - /// - the field of column belonging to the left (resp. right) relation has a `Unique` or `PrimaryKey` constraint - /// - the current `JoinConstraint`` is a `Using`: the field of column belonging to the left (resp. right) relation has a `Unique` or `PrimaryKey` constraint - pub fn has_unique_constraint( - &self, - left_schema: &Schema, - right_schema: &Schema, - ) -> (bool, bool) { - match self { - JoinConstraint::On(x) => match x { - Expr::Function(f) if f.function() == function::Function::Eq => { - let fields_with_unique_or_primary_key_constraint = Hierarchy::from_iter( - left_schema - .iter() - .map(|f| { - ( - vec![Join::left_name(), f.name()], - f.has_unique_or_primary_key_constraint(), - ) - }) - .chain(right_schema.iter().map(|f| { - ( - vec![Join::right_name(), f.name()], - f.has_unique_or_primary_key_constraint(), - ) - })), - ); - let mut left = false; - let mut right = false; - if let Expr::Column(c) = &f.arguments()[0] { - if fields_with_unique_or_primary_key_constraint - .get_key_value(c) - .unwrap() - .0[0] - == Join::left_name() - { - left = fields_with_unique_or_primary_key_constraint[c.as_slice()] - } else { - right = fields_with_unique_or_primary_key_constraint[c.as_slice()] - } - } - if let Expr::Column(c) = &f.arguments()[1] { - if fields_with_unique_or_primary_key_constraint - .get_key_value(c) - .unwrap() - .0[0] - == Join::left_name() - { - left = fields_with_unique_or_primary_key_constraint[c.as_slice()] - } else { - right = fields_with_unique_or_primary_key_constraint[c.as_slice()] - } - } - (left, right) - } - _ => (false, false), - }, - JoinConstraint::Using(v) if v.len() == 1 => { - let left = left_schema - .field(v[0].last().unwrap()) - .map(|f| f.has_unique_or_primary_key_constraint()) - .unwrap_or(false); - let right = right_schema - .field(v[0].last().unwrap()) - .map(|f| f.has_unique_or_primary_key_constraint()) - .unwrap_or(false); - (left, right) - } - _ => (false, false), - } - } -} - -impl From<(&JoinConstraint, &DataType)> for Expr { - fn from(value: (&JoinConstraint, &DataType)) -> Self { - let (constraint, dt) = value; - match constraint { - JoinConstraint::On(x) => x.clone(), - JoinConstraint::Using(x) => x.iter().fold(Expr::val(true), |f, v| { - Expr::and( - f, - Expr::eq( - Expr::qcol(Join::left_name(), v.head().unwrap()), - Expr::qcol(Join::right_name(), v.head().unwrap()), - ), - ) - }), - JoinConstraint::Natural => { - let h = dt[Join::right_name()].hierarchy(); - let v = dt[Join::left_name()] - .hierarchy() - .into_iter() - .filter_map(|(s, _)| h.get(&s).map(|_| Identifier::from(s))) - .collect::>(); - (&JoinConstraint::Using(v), dt).into() - } - JoinConstraint::None => Expr::val(true), - } - } -} - /// Join two relations on one or more join columns #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Join { @@ -934,57 +833,36 @@ impl Join { let (left_schema, right_schema) = operator.filtered_schemas(left, right); let (left_is_unique, right_is_unique) = operator.has_unique_constraint(left.schema(), right.schema()); + let transform_datatype_in_optional_left: bool = match operator { + JoinOperator::LeftOuter(_) | JoinOperator::Inner(_) | JoinOperator::Cross => false, + _ => true, + }; + let transform_datatype_in_optional_right = match operator { + JoinOperator::RightOuter(_) | JoinOperator::Inner(_) | JoinOperator::Cross => false, + _ => true, + }; let left_fields = left_names .into_iter() .zip(left_schema.iter()) - .map(|(name, field)| { - let (data_type, constraint) = match (operator, right_schema.field(&field.name())) { - (JoinOperator::RightOuter(JoinConstraint::Natural), Ok(right_field)) => { - // if - // - operator is of type NATURAL RIGHT OUTER and - // - `field` is present in both the `left` and `right` relations - // then the datatype of the corresponding field in the JOIN is the datatype of the right field - (right_field.data_type(), right_field.constraint()) - }, - (JoinOperator::FullOuter(JoinConstraint::Natural), Ok(right_field)) => { - // if - // - operator is of type NATURAL RIGHT OUTER and - // - `field` is present in both the `left` and `right` relations - // then the datatype of the corresponding field in the JOIN is the super union of the datatypes of the right and left field datatypes - (field.data_type().super_union(&right_field.data_type()).unwrap(), None) - }, - (JoinOperator::RightOuter(_) | JoinOperator::FullOuter(_), _) => ( - // if the operator if of type RIGHT or FULL OUTER (without NATURAL constraint), the current (left) field is an optional - DataType::optional(field.data_type()), - right_is_unique.then_some(field.constraint()).unwrap_or(None) - ), - _ => ( - field.data_type(), - right_is_unique.then_some(field.constraint()).unwrap_or(None) - ) - }; - Field::new(name, data_type, constraint) - }); + .map(|(name, field)| + Field::new( + name, + transform_datatype_in_optional_left.then_some(DataType::optional(field.data_type())) + .unwrap_or(field.data_type()), + right_is_unique.then_some(field.constraint()).unwrap_or(None) + ) + ); let right_fields = right_names .into_iter() .zip(right_schema.iter()) - .filter_map(|(name, field)| { - let data_type_constraint = match operator { - JoinOperator::Inner(JoinConstraint::Natural) - | JoinOperator::LeftOuter(JoinConstraint::Natural) - | JoinOperator::RightOuter(JoinConstraint::Natural) - | JoinOperator::FullOuter(JoinConstraint::Natural) if left_schema.field(&field.name()).is_ok() => None, // remove the duplicates when JoinConstaint is Natural - JoinOperator::LeftOuter(_) | JoinOperator::FullOuter(_) => Some(( - DataType::optional(field.data_type()), - left_is_unique.then_some(field.constraint()).unwrap_or(None) - )), - _ => Some(( - field.data_type(), - left_is_unique.then_some(field.constraint()).unwrap_or(None) - )) - }; - data_type_constraint.map(|(data_type, constraint)| Field::new(name, data_type, constraint)) + .map(|(name, field)| { + Field::new( + name, + transform_datatype_in_optional_right.then_some(DataType::optional(field.data_type())) + .unwrap_or(field.data_type()), + left_is_unique.then_some(field.constraint()).unwrap_or(None) + ) }); left_fields.chain(right_fields).collect() } @@ -1067,19 +945,10 @@ impl fmt::Display for Join { .collect(); let operator = format!("{} {}", self.operator, "JOIN".to_string().bold().blue()); let constraint = match &self.operator { - JoinOperator::Inner(constraint) - | JoinOperator::LeftOuter(constraint) - | JoinOperator::RightOuter(constraint) - | JoinOperator::FullOuter(constraint) => match constraint { - JoinConstraint::On(expr) => format!("{} {}", "ON".to_string().bold().blue(), expr), - JoinConstraint::Using(identifiers) => format!( - "{} {}", - "USING".to_string().bold().blue(), - identifiers.iter().join(", ") - ), - JoinConstraint::Natural => todo!(), - JoinConstraint::None => todo!(), - }, + JoinOperator::Inner(expr) + | JoinOperator::LeftOuter(expr) + | JoinOperator::RightOuter(expr) + | JoinOperator::FullOuter(expr) => format!("{} {}", "ON".to_string().bold().blue(), expr), JoinOperator::Cross => format!(""), }; write!( @@ -1847,7 +1716,7 @@ mod tests { let left: Relation = Relation::table().name("left").schema(left_schema).build(); let right: Relation = Relation::table().name("right").schema(right_schema).build(); let join: Join = Relation::join() - .inner() + .inner(Expr::val(true)) .on(Expr::eq( Expr::qcol(LEFT_INPUT_NAME, "id"), Expr::qcol(RIGHT_INPUT_NAME, "id"), @@ -1877,8 +1746,7 @@ mod tests { let join: Join = Relation::join() .left(table.clone()) .right(table.clone()) - .left_outer() - .on(Expr::eq( + .left_outer(Expr::eq( Expr::qcol(LEFT_INPUT_NAME, "id"), Expr::qcol(RIGHT_INPUT_NAME, "id"), )) @@ -2064,66 +1932,6 @@ mod tests { println!("MAP: {}", map); } - #[test] - fn test_from_join_constraint() { - let table1 = DataType::structured([ - ("a", DataType::float_interval(-10., 10.)), - ("b", DataType::integer_interval(-8, 34)), - ("c", DataType::float_interval(0., 50.)), - ]); - let table2 = DataType::structured([ - ("a", DataType::float_interval(0., 20.)), - ("b", DataType::integer_interval(-1, 14)), - ("d", DataType::integer_interval(-10, 20)), - ]); - let data_type = - DataType::structured([(Join::left_name(), table1), (Join::right_name(), table2)]); - - // ON - let x = Expr::eq(Expr::qcol(Join::left_name(), "a"), Expr::col("d")); - let jc_x = Expr::from((&JoinConstraint::On(x.clone()), &data_type)); - assert_eq!(jc_x, x); - - // USING - let v = vec![Identifier::from_name("a"), Identifier::from_name("b")]; - let true_x = Expr::and( - Expr::and( - Expr::val(true), - Expr::eq( - Expr::qcol(Join::left_name(), "a"), - Expr::qcol(Join::right_name(), "a"), - ), - ), - Expr::eq( - Expr::qcol(Join::left_name(), "b"), - Expr::qcol(Join::right_name(), "b"), - ), - ); - let jc_x = Expr::from((&JoinConstraint::Using(v), &data_type)); - assert_eq!(jc_x, true_x); - - // NATURAL - let true_x = Expr::and( - Expr::and( - Expr::val(true), - Expr::eq( - Expr::qcol(Join::left_name(), "a"), - Expr::qcol(Join::right_name(), "a"), - ), - ), - Expr::eq( - Expr::qcol(Join::left_name(), "b"), - Expr::qcol(Join::right_name(), "b"), - ), - ); - let jc_x = Expr::from((&JoinConstraint::Natural, &data_type)); - assert_eq!(jc_x, true_x); - - // NONE - let jc_x = Expr::from((&JoinConstraint::None, &data_type)); - assert_eq!(jc_x, Expr::val(true)); - } - #[test] fn test_filter_data_type_inner_join() { let table1 = DataType::structured([ @@ -2144,7 +1952,7 @@ mod tests { Expr::qcol(Join::left_name(), "a"), Expr::qcol(Join::right_name(), "d"), ); - let join_op = JoinOperator::Inner(JoinConstraint::On(x.clone())); + let join_op = JoinOperator::Inner(x.clone()); let filtered_table1 = DataType::structured([ ("a", DataType::integer_interval(-2, 1)), ("b", DataType::integer_interval(-1, 3)), @@ -2163,53 +1971,6 @@ mod tests { data_type.filter_by_join_operator(&join_op), filtered_data_type ); - - // USING - let v = vec![Identifier::from_name("a")]; - let join_op = JoinOperator::Inner(JoinConstraint::Using(v)); - let filtered_table1 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 3)), - ("c", DataType::float_interval(0., 5.)), - ]); - let filtered_table2 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-2, 2)), - ("d", DataType::integer_interval(-2, 1)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), filtered_table1), - (Join::right_name(), filtered_table2), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NATURAL - let join_op = JoinOperator::Inner(JoinConstraint::Natural); - let filtered_table1 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 2)), - ("c", DataType::float_interval(0., 5.)), - ]); - let filtered_table2 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 2)), - ("d", DataType::integer_interval(-2, 1)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), filtered_table1), - (Join::right_name(), filtered_table2), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NONE - let join_op = JoinOperator::Inner(JoinConstraint::None); - assert_eq!(data_type.filter_by_join_operator(&join_op), data_type); } #[test] @@ -2234,7 +1995,7 @@ mod tests { Expr::qcol(Join::left_name(), "a"), Expr::qcol(Join::right_name(), "d"), ); - let join_op = JoinOperator::LeftOuter(JoinConstraint::On(x.clone())); + let join_op = JoinOperator::LeftOuter(x.clone()); let filtered_table2 = DataType::structured([ ("a", DataType::float_interval(0., 20.)), ("b", DataType::integer_interval(-2, 2)), @@ -2248,43 +2009,6 @@ mod tests { data_type.filter_by_join_operator(&join_op), filtered_data_type ); - - // USING - let v = vec![Identifier::from_name("a")]; - let join_op = JoinOperator::LeftOuter(JoinConstraint::Using(v)); - let filtered_table2 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-2, 2)), - ("d", DataType::integer_interval(-2, 1)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), table1.clone()), - (Join::right_name(), filtered_table2), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NATURAL - let join_op = JoinOperator::LeftOuter(JoinConstraint::Natural); - let filtered_table2 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 2)), - ("d", DataType::integer_interval(-2, 1)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), table1.clone()), - (Join::right_name(), filtered_table2), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NONE - let join_op = JoinOperator::LeftOuter(JoinConstraint::None); - assert_eq!(data_type.filter_by_join_operator(&join_op), data_type); } #[test] @@ -2309,7 +2033,7 @@ mod tests { Expr::qcol(Join::left_name(), "a"), Expr::qcol(Join::right_name(), "d"), ); - let join_op = JoinOperator::RightOuter(JoinConstraint::On(x.clone())); + let join_op = JoinOperator::RightOuter(x.clone()); let filtered_table1 = DataType::structured([ ("a", DataType::integer_interval(-2, 1)), ("b", DataType::integer_interval(-1, 3)), @@ -2323,43 +2047,6 @@ mod tests { data_type.filter_by_join_operator(&join_op), filtered_data_type ); - - // USING - let v = vec![Identifier::from_name("a")]; - let join_op = JoinOperator::RightOuter(JoinConstraint::Using(v)); - let filtered_table1 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 3)), - ("c", DataType::float_interval(0., 5.)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), filtered_table1), - (Join::right_name(), table2.clone()), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NATURAL - let join_op = JoinOperator::RightOuter(JoinConstraint::Natural); - let filtered_table1 = DataType::structured([ - ("a", DataType::float_interval(0., 3.)), - ("b", DataType::integer_interval(-1, 2)), - ("c", DataType::float_interval(0., 5.)), - ]); - let filtered_data_type = DataType::structured([ - (Join::left_name(), filtered_table1), - (Join::right_name(), table2.clone()), - ]); - assert_eq!( - data_type.filter_by_join_operator(&join_op), - filtered_data_type - ); - - // NONE - let join_op = JoinOperator::RightOuter(JoinConstraint::None); - assert_eq!(data_type.filter_by_join_operator(&join_op), data_type); } #[test] @@ -2395,7 +2082,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("a", "a") .left(table1.clone()) .right(table2.clone()) @@ -2404,7 +2091,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("a", "a") .left(table2.clone()) .right(table1.clone()) @@ -2413,7 +2100,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .left_outer() + .left_outer(Expr::val(true)) .on_eq("a", "a") .left(table2.clone()) .right(table1.clone()) @@ -2422,7 +2109,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .right_outer() + .right_outer(Expr::val(true)) .on_eq("a", "a") .left(table2.clone()) .right(table1.clone()) @@ -2431,7 +2118,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .full_outer() + .full_outer(Expr::val(true)) .on_eq("a", "a") .left(table2.clone()) .right(table1.clone()) @@ -2440,7 +2127,7 @@ mod tests { let join: Join = Relation::join() .name("join") - .full_outer() + .full_outer(Expr::val(true)) .on_eq("a", "b") .left(table2.clone()) .right(table1.clone()) @@ -2502,7 +2189,7 @@ mod tests { map.display_dot().unwrap(); let join: Relation = Relation::join() - .inner() + .inner(Expr::val(true)) .left(table1) .right(map) .on_eq("b", "my_b") @@ -2547,7 +2234,7 @@ mod tests { // the joining columns are not unique let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("b", "a") .left(table1.clone()) .right(table2.clone()) @@ -2569,7 +2256,7 @@ mod tests { // the left joining column is unique let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("a", "a") .left(table1.clone()) .right(table2.clone()) @@ -2591,7 +2278,7 @@ mod tests { // the right joining column is unique let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("b", "d") .left(table1.clone()) .right(table2.clone()) @@ -2613,7 +2300,7 @@ mod tests { // the joining columns are unique let join: Join = Relation::join() .name("join") - .inner() + .inner(Expr::val(true)) .on_eq("a", "d") .left(table1.clone()) .right(table2.clone()) @@ -2632,94 +2319,4 @@ mod tests { .collect(); assert_eq!(join.schema(), &correct_schema); } - - #[test] - fn test_natural_join_builder() { - let schema1: Schema = vec![ - ("a", DataType::integer_interval(-5, 5)), - ("b", DataType::integer_interval(-2, 2)), - ] - .into_iter() - .collect(); - let table1: Relation = Relation::table().name("table1").schema(schema1).build(); - let schema2: Schema = vec![ - ("a", DataType::integer_interval(0, 10)), - ("c", DataType::integer_interval(0, 20)), - ] - .into_iter() - .collect(); - let table2: Relation = Relation::table().name("table1").schema(schema2).build(); - - // natural inner join - let relation: Relation = Relation::join() - .left(table1.clone()) - .right(table2.clone()) - .left_names(vec!["a", "b"]) - .right_names(vec!["my_a", "c"]) - .inner() - .build(); - relation.display_dot().unwrap(); - assert_eq!( - relation.data_type(), - DataType::structured([ - ("a", DataType::integer_interval(0, 5)), - ("b", DataType::integer_interval(-2, 2)), - ("c", DataType::integer_interval(0, 20)) - ]) - ); - - // natural left join - let relation: Relation = Relation::join() - .left(table1.clone()) - .right(table2.clone()) - .left_names(vec!["a", "b"]) - .right_names(vec!["my_a", "c"]) - .left_outer() - .build(); - relation.display_dot().unwrap(); - assert_eq!( - relation.data_type(), - DataType::structured([ - ("a", DataType::integer_interval(-5, 5)), - ("b", DataType::integer_interval(-2, 2)), - ("c", DataType::optional(DataType::integer_interval(0, 20))) - ]) - ); - - // natural right join - let relation: Relation = Relation::join() - .left(table1.clone()) - .right(table2.clone()) - .left_names(vec!["a", "b"]) - .right_names(vec!["my_a", "c"]) - .right_outer() - .build(); - relation.display_dot().unwrap(); - assert_eq!( - relation.data_type(), - DataType::structured([ - ("a", DataType::integer_interval(0, 10)), - ("b", DataType::optional(DataType::integer_interval(-2, 2))), - ("c", DataType::integer_interval(0, 20)) - ]) - ); - - // natural full join - let relation: Relation = Relation::join() - .left(table1.clone()) - .right(table2.clone()) - .left_names(vec!["a", "b"]) - .right_names(vec!["my_a", "c"]) - .full_outer() - .build(); - relation.display_dot().unwrap(); - assert_eq!( - relation.data_type(), - DataType::structured([ - ("a", DataType::integer_interval(-5, 10)), - ("b", DataType::optional(DataType::integer_interval(-2, 2))), - ("c", DataType::optional(DataType::integer_interval(0, 20))) - ]) - ); - } } diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 22bebd1a..fc0138c1 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -5,8 +5,9 @@ use super::{Join, Map, Reduce, Relation, Set, Table, Values, Variant as _}; use crate::{ builder::{Ready, With, WithIterator}, data_type::{self, function::Function, DataType, DataTyped, Variant as _}, - expr::{self, aggregate, Aggregate, Expr, Value}, - io, namer, relation, + expr::{self, aggregate, Aggregate, Expr, Value, Identifier}, + io, namer, relation::{self, LEFT_INPUT_NAME, RIGHT_INPUT_NAME}, + hierarchy::Hierarchy, }; use std::{ collections::{BTreeMap, HashMap}, @@ -222,6 +223,38 @@ impl Join { self.name = name; self } + + /// Replace the duplicates fields specified in `columns` by their coalesce expression + /// Its mimicks teh behaviour of USING in SQL + pub fn remove_duplicates_and_coalesce(self, vec: Vec, columns: &Hierarchy) -> Relation { + let fields = self.field_inputs() + .filter_map(|(name, id)| { + let col = id.as_ref().last().unwrap(); + if id.as_ref().first().unwrap().as_str() == LEFT_INPUT_NAME && vec.contains(col) { + Some(( + name, + Expr::coalesce( + Expr::col(columns[[LEFT_INPUT_NAME, col]].as_ref().last().unwrap()), + Expr::col(columns[[RIGHT_INPUT_NAME, col]].as_ref().last().unwrap()) + ) + )) + } else { + None + } + }) + .chain( + self.field_inputs() + .filter_map(|(name, id)| { + let col = id.as_ref().last().unwrap(); + (!vec.contains(col)).then_some((name.clone(), Expr::col(name))) + }) + ) + .collect::>(); + Relation::map() + .input(Relation::from(self)) + .with_iter(fields) + .build() + } } /* Set @@ -412,7 +445,7 @@ impl Relation { // TODO fix this // Join the two relations on the entity column let join: Relation = Relation::join() - .inner() + .inner(Expr::val(true)) .on_eq(entities, entities) .left_names( self.fields() @@ -782,10 +815,25 @@ impl Relation { names.push((col.clone(), Expr::col(col))); } } + let x = Expr::and_iter( + self.schema() + .iter() + .filter_map(|f| right.schema() + .field(f.name()) + .is_ok() + .then_some( + Expr::eq( + Expr::qcol(LEFT_INPUT_NAME, f.name()), + Expr::qcol(RIGHT_INPUT_NAME, f.name()), + ) + ) + ) + ); + let join: Relation = Relation::join() .left(self.clone()) .right(right.clone()) - .inner() + .inner(x) .left_names(left_names) .right_names(right_names) .build(); @@ -814,8 +862,6 @@ mod tests { relation::schema::Schema, sql::parse, }; - use colored::Colorize; - use itertools::Itertools; #[test] fn test_with_computed_field() { diff --git a/src/relation/sql.rs b/src/relation/sql.rs index e668080d..925e36d3 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -1,6 +1,6 @@ //! Methods to convert Relations to ast::Query use super::{ - Error, Join, JoinConstraint, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Set, + Error, Join, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Set, SetOperator, SetQuantifier, Table, Values, Variant as _, Visitor, }; use crate::{ @@ -33,36 +33,20 @@ impl From for ast::ObjectName { } } -impl From for ast::JoinConstraint { - fn from(value: JoinConstraint) -> Self { - match value { - JoinConstraint::On(expr) => ast::JoinConstraint::On(ast::Expr::from(&expr)), - JoinConstraint::Using(idents) => ast::JoinConstraint::Using( - idents - .into_iter() - .map(|ident| ident.try_into().unwrap()) - .collect(), - ), - JoinConstraint::Natural => ast::JoinConstraint::Natural, - JoinConstraint::None => ast::JoinConstraint::None, - } - } -} - impl From for ast::JoinOperator { fn from(value: JoinOperator) -> Self { match value { - JoinOperator::Inner(join_constraint) => { - ast::JoinOperator::Inner(join_constraint.into()) + JoinOperator::Inner(expr) => { + ast::JoinOperator::Inner(ast::JoinConstraint::On(ast::Expr::from(&expr))) } - JoinOperator::LeftOuter(join_constraint) => { - ast::JoinOperator::LeftOuter(join_constraint.into()) + JoinOperator::LeftOuter(expr) => { + ast::JoinOperator::LeftOuter(ast::JoinConstraint::On(ast::Expr::from(&expr))) } - JoinOperator::RightOuter(join_constraint) => { - ast::JoinOperator::RightOuter(join_constraint.into()) + JoinOperator::RightOuter(expr) => { + ast::JoinOperator::RightOuter(ast::JoinConstraint::On(ast::Expr::from(&expr))) } - JoinOperator::FullOuter(join_constraint) => { - ast::JoinOperator::FullOuter(join_constraint.into()) + JoinOperator::FullOuter(expr) => { + ast::JoinOperator::FullOuter(ast::JoinConstraint::On(ast::Expr::from(&expr))) } JoinOperator::Cross => ast::JoinOperator::CrossJoin, } @@ -690,7 +674,7 @@ mod tests { let join: Relation = Relation::join() .name("join") - .left_outer() + .left_outer(Expr::val(true)) //.using("a") .on_eq("b", "b") .left(left) diff --git a/src/rewriting/mod.rs b/src/rewriting/mod.rs index caebf59d..992b19f9 100644 --- a/src/rewriting/mod.rs +++ b/src/rewriting/mod.rs @@ -184,7 +184,7 @@ mod tests { let queries = [ "SELECT order_id, sum(price) FROM item_table GROUP BY order_id", - "SELECT sum(distinct price) FROM item_table GROUP BY order_id HAVING count(*) > 2", + "SELECT order_id, sum(price), sum(distinct price) FROM item_table GROUP BY order_id HAVING count(*) > 2", ]; for q in queries { diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 023ed678..1f665487 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -16,12 +16,13 @@ use crate::{ namer::{self, FIELD}, parser::Parser, relation::{ - Join, JoinConstraint, JoinOperator, MapBuilder, Relation, SetOperator, SetQuantifier, + Join, JoinOperator, MapBuilder, Relation, SetOperator, SetQuantifier, Variant as _, WithInput, + LEFT_INPUT_NAME, RIGHT_INPUT_NAME }, tokenizer::Tokenizer, visitor::{Acceptor, Dependencies, Visited}, - types::And + types::And, display::Dot }; use itertools::Itertools; use std::{ @@ -29,7 +30,8 @@ use std::{ iter::{once, Iterator}, result, str::FromStr, - sync::Arc, collections::HashMap, + sync::Arc, + ops::Deref }; /* @@ -173,18 +175,47 @@ impl<'a> VisitedQueryRelations<'a> { &self, join_constraint: &ast::JoinConstraint, columns: &'a Hierarchy, - ) -> Result { - match join_constraint { - ast::JoinConstraint::On(expr) => Ok(JoinConstraint::On(expr.with(columns).try_into()?)), - ast::JoinConstraint::Using(idents) => Ok(JoinConstraint::Using( - idents - .into_iter() - .map(|ident| Identifier::from(ident.value.clone())) - .collect(), - )), - ast::JoinConstraint::Natural => Ok(JoinConstraint::Natural), - ast::JoinConstraint::None => Ok(JoinConstraint::None), - } + ) -> Result { + Ok(match join_constraint { + ast::JoinConstraint::On(expr) => expr.with(columns).try_into()?, + ast::JoinConstraint::Using(idents) => { // the "Using (id)" condition is equivalent to "ON _LEFT_.id = _RIGHT_.id" + Expr::and_iter( + idents.into_iter() + .map(|id| Expr::eq( + Expr::Column(Identifier::from(vec![LEFT_INPUT_NAME.to_string(), id.value.to_string()])), + Expr::Column(Identifier::from(vec![RIGHT_INPUT_NAME.to_string(), id.value.to_string()])), + )) + ) + }, + ast::JoinConstraint::Natural => { // When joining table_1 with table_2, the NATURAL condition is equivalent to a "ON table_1.col1 = table_2.col1 AND table_1.col2 = table_2.col2" where col1, col2... are the columns present in both table_1 and table_2 + let tables = columns.iter() + .map(|(k, _)| k.iter().take(k.len() - 1).map(|s| s.to_string()).collect::>()) + .dedup() + .collect::>(); + assert_eq!(tables.len(), 2); + let columns_1 = columns.filter(tables[0].as_slice()); + let columns_2 = columns.filter(tables[1].as_slice()); + let columns_1 = columns_1 + .iter() + .map(|(k, _)| k.last().unwrap()) + .collect::>(); + let columns_2 = columns_2 + .iter() + .map(|(k, _)| k.last().unwrap()) + .collect::>(); + + Expr::and_iter( + columns_1 + .iter() + .filter_map(|col| columns_2.contains(&col).then_some(col)) + .map(|id| Expr::eq( + Expr::Column(Identifier::from(vec![LEFT_INPUT_NAME.to_string(), id.to_string()])), + Expr::Column(Identifier::from(vec![RIGHT_INPUT_NAME.to_string(), id.to_string()])) + )) + ) + }, + ast::JoinConstraint::None => todo!(), + }) } fn try_from_join_operator_with_columns( @@ -194,7 +225,7 @@ impl<'a> VisitedQueryRelations<'a> { ) -> Result { match join_operator { ast::JoinOperator::Inner(join_constraint) => Ok(JoinOperator::Inner( - self.try_from_join_constraint_with_columns(join_constraint, columns)?, + self.try_from_join_constraint_with_columns(join_constraint, columns)? )), ast::JoinOperator::LeftOuter(join_constraint) => Ok(JoinOperator::LeftOuter( self.try_from_join_constraint_with_columns(join_constraint, columns)?, @@ -219,10 +250,10 @@ impl<'a> VisitedQueryRelations<'a> { // Then the JOIN if needed let result = table_with_joins.joins.iter().fold( self.try_from_table_factor(&table_with_joins.relation), - |left, join| { + |left, ast_join| { let RelationWithColumns(left_relation, left_columns) = left?; let RelationWithColumns(right_relation, right_columns) = - self.try_from_table_factor(&join.relation)?; + self.try_from_table_factor(&ast_join.relation)?; let left_columns = left_columns.map(|i| { let mut v = vec![Join::left_name().to_string()]; v.extend(i.to_vec()); @@ -235,7 +266,7 @@ impl<'a> VisitedQueryRelations<'a> { }); let all_columns = left_columns.with(right_columns); let operator = self.try_from_join_operator_with_columns( - &join.join_operator, + &ast_join.join_operator, // &all_columns.filter_map(|i| Some(i.split_last().ok()?.0)),//TODO remove this &all_columns, )?; @@ -245,13 +276,42 @@ impl<'a> VisitedQueryRelations<'a> { .left(left_relation) .right(right_relation) .build(); + // We collect column mapping inputs should map to new names (hence the inversion) let new_columns: Hierarchy = join.field_inputs().map(|(f, i)| (i, f.into())).collect(); - let composed_columns = all_columns.and_then(new_columns); - let relation = Arc::new(Relation::from(join)); + let composed_columns = all_columns.and_then(new_columns.clone()); + + // If the join contraint is of type "USING" or "NATURAL", add a map for removing du duplicate columns + let relation = match &ast_join.join_operator { + ast::JoinOperator::Inner(ast::JoinConstraint::Using(v)) + | ast::JoinOperator::LeftOuter(ast::JoinConstraint::Using(v)) + | ast::JoinOperator::RightOuter(ast::JoinConstraint::Using(v)) + | ast::JoinOperator::FullOuter(ast::JoinConstraint::Using(v)) => { + join.remove_duplicates_and_coalesce( + v.into_iter().map(|id| id.value.to_string()).collect(), + &new_columns + ) + }, + ast::JoinOperator::Inner(ast::JoinConstraint::Natural) + | ast::JoinOperator::LeftOuter(ast::JoinConstraint::Natural) + | ast::JoinOperator::RightOuter(ast::JoinConstraint::Natural) + | ast::JoinOperator::FullOuter(ast::JoinConstraint::Natural) => { + let v = join.left().fields() + .into_iter() + .filter_map(|f| join.right().schema().field(f.name()).is_ok().then_some(f.name().to_string())) + .collect(); + join.remove_duplicates_and_coalesce(v,&new_columns) + }, + ast::JoinOperator::LeftSemi(_) => todo!(), + ast::JoinOperator::RightSemi(_) => todo!(), + ast::JoinOperator::LeftAnti(_) => todo!(), + ast::JoinOperator::RightAnti(_) => todo!(), + _ => Relation::from(join), + }; + // We should compose hierarchies - Ok(RelationWithColumns::new(relation, composed_columns)) + Ok(RelationWithColumns::new(Arc::new(relation), composed_columns)) }, ); result @@ -1276,4 +1336,193 @@ mod tests { ]) ); } + + #[test] + fn test_join_with_using() { + namer::reset(); + let table_1: Relation = Relation::table() + .name("table_1") + .schema( + vec![ + ("a", DataType::integer_interval(0, 10)), + ("b", DataType::float_interval(20., 50.)), + ].into_iter() + .collect::() + ) + .size(100) + .build(); + let table_2: Relation = Relation::table() + .name("table_2") + .schema( + vec![ + ("a", DataType::integer_interval(-5, 5)), + ("c", DataType::float()), + ].into_iter() + .collect::() + ) + .size(100) + .build(); + let relations = Hierarchy::from([ + (["schema", "table_1"], Arc::new(table_1)), + (["schema", "table_2"], Arc::new(table_2)), + ]); + + // INNER JOIN + let query = parse("SELECT * FROM table_1 INNER JOIN table_2 USING (a)").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(0, 5))); + assert_eq!(s[1], Arc::new(DataType::float_interval(20., 50.))); + assert_eq!(s[2], Arc::new(DataType::float())); + } + + // LEFT JOIN + let query = parse("SELECT * FROM table_1 LEFT JOIN table_2 USING (a)").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(0, 10))); + assert_eq!(s[1], Arc::new(DataType::float_interval(20., 50.))); + assert_eq!(s[2], Arc::new(DataType::optional(DataType::float()))); + } + + // RIGHT JOIN + let query = parse("SELECT * FROM table_1 RIGHT JOIN table_2 USING (a)").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(-5, 5))); + assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!(s[2], Arc::new(DataType::float())); + } + + // FULL JOIN + let query = parse("SELECT * FROM table_1 FULL JOIN table_2 USING (a)").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::optional(DataType::integer_interval(-5, 10)))); + assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!(s[2], Arc::new(DataType::optional(DataType::float()))); + } + } + + #[test] + fn test_join_with_natural() { + namer::reset(); + let table_1: Relation = Relation::table() + .name("table_1") + .schema( + vec![ + ("a", DataType::integer_interval(0, 10)), + ("b", DataType::float_interval(20., 50.)), + ("d", DataType::float_interval(-10., 50.)), + ].into_iter() + .collect::() + ) + .size(100) + .build(); + let table_2: Relation = Relation::table() + .name("table_2") + .schema( + vec![ + ("a", DataType::integer_interval(-5, 5)), + ("c", DataType::float()), + ("d", DataType::float_interval(10., 100.)), + ].into_iter() + .collect::() + ) + .size(100) + .build(); + let relations = Hierarchy::from([ + (["schema", "table_1"], Arc::new(table_1)), + (["schema", "table_2"], Arc::new(table_2)), + ]); + + // INNER JOIN + let query = parse("SELECT * FROM table_1 NATURAL INNER JOIN table_2").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(0, 5))); + assert_eq!(s[1], Arc::new(DataType::float_interval(10., 50.))); + assert_eq!(s[2], Arc::new(DataType::float_interval(20., 50.))); + assert_eq!(s[3], Arc::new(DataType::float())); + } + + // LEFT JOIN + let query = parse("SELECT * FROM table_1 NATURAL LEFT JOIN table_2").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(0, 10))); + assert_eq!(s[1], Arc::new(DataType::float_interval(-10., 50.))); + assert_eq!(s[2], Arc::new(DataType::float_interval(20., 50.))); + assert_eq!(s[3], Arc::new(DataType::optional(DataType::float()))); + } + + // RIGHT JOIN + let query = parse("SELECT * FROM table_1 NATURAL RIGHT JOIN table_2").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::integer_interval(-5, 5))); + assert_eq!(s[1], Arc::new(DataType::float_interval(10., 100.))); + assert_eq!(s[2], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!(s[3], Arc::new(DataType::float())); + + } + + // FULL JOIN + let query = parse("SELECT * FROM table_1 NATURAL FULL JOIN table_2").unwrap(); + let relation = Relation::try_from(QueryWithRelations::new( + &query, + &relations, + )) + .unwrap(); + relation.display_dot().unwrap(); + assert!(matches!(relation.data_type(), DataType::Struct(_))); + if let DataType::Struct(s) = relation.data_type() { + assert_eq!(s[0], Arc::new(DataType::optional(DataType::integer_interval(-5, 10)))); + assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(-10., 100.)))); + assert_eq!(s[2], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!(s[3], Arc::new(DataType::optional(DataType::float()))); + } + } } diff --git a/tests/integration.rs b/tests/integration.rs index 3f0da10e..4f117dd5 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -104,13 +104,16 @@ const QUERIES: &[&str] = &[ "SELECT DISTINCT c, d FROM table_1", "SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c", "SELECT SUM(DISTINCT a) AS s1 FROM table_1 GROUP BY c HAVING COUNT(*) > 5;", + // using joins + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 INNER JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 LEFT JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 RIGHT JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", // natural joins - "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL INNER JOIN t2", - "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL LEFT JOIN t2", - "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL RIGHT JOIN t2", - "WITH t1 AS (SELECT a, b FROM table_1 WHERE a > 5), t2 AS (SELECT a, d FROM table_1 WHERE a < 7) SELECT * FROM t1 NATURAL FULL JOIN t2", - - + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL INNER JOIN t2", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL LEFT JOIN t2", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL RIGHT JOIN t2", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", ]; #[cfg(feature = "sqlite")] @@ -147,7 +150,7 @@ fn test_on_postgresql() { for tab in database.tables() { println!("schema {} = {}", tab, tab.schema()); } - for &query in POSTGRESQL_QUERIES.iter().chain(QUERIES) { + for &query in QUERIES.iter().chain(POSTGRESQL_QUERIES) { assert!(test_rewritten_eq(&mut database, query)); } } From 4f4f0360c9048152ac4859ac13ab7d7f57e80de2 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 14 Dec 2023 16:51:40 +0100 Subject: [PATCH 26/27] typos --- src/relation/dot.rs | 2 -- src/relation/sql.rs | 1 - src/sql/relation.rs | 4 ++-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/relation/dot.rs b/src/relation/dot.rs index 01267e2e..2dea6bd3 100644 --- a/src/relation/dot.rs +++ b/src/relation/dot.rs @@ -440,8 +440,6 @@ mod tests { let join: Relation = Relation::join() .name("join") .cross() - //.using("a") - //.on(Expr::eq(Expr::qcol("left", "b"), Expr::qcol("right", "b"))) .left(left) .right(right) .build(); diff --git a/src/relation/sql.rs b/src/relation/sql.rs index 925e36d3..c98c3ca2 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -675,7 +675,6 @@ mod tests { let join: Relation = Relation::join() .name("join") .left_outer(Expr::val(true)) - //.using("a") .on_eq("b", "b") .left(left) .right(right) diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 1f665487..9761f141 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -187,7 +187,7 @@ impl<'a> VisitedQueryRelations<'a> { )) ) }, - ast::JoinConstraint::Natural => { // When joining table_1 with table_2, the NATURAL condition is equivalent to a "ON table_1.col1 = table_2.col1 AND table_1.col2 = table_2.col2" where col1, col2... are the columns present in both table_1 and table_2 + ast::JoinConstraint::Natural => { // the NATURAL condition is equivalent to a "ON _LEFT_.col1 = _RIGHT_.col1 AND _LEFT_.col2 = _RIGHT_.col2" where col1, col2... are the columns present in both tables let tables = columns.iter() .map(|(k, _)| k.iter().take(k.len() - 1).map(|s| s.to_string()).collect::>()) .dedup() @@ -282,7 +282,7 @@ impl<'a> VisitedQueryRelations<'a> { join.field_inputs().map(|(f, i)| (i, f.into())).collect(); let composed_columns = all_columns.and_then(new_columns.clone()); - // If the join contraint is of type "USING" or "NATURAL", add a map for removing du duplicate columns + // If the join contraint is of type "USING" or "NATURAL", add a map to coalesce the duplicate columns let relation = match &ast_join.join_operator { ast::JoinOperator::Inner(ast::JoinConstraint::Using(v)) | ast::JoinOperator::LeftOuter(ast::JoinConstraint::Using(v)) From 8a4bb3403b0c66d1b3d63c84fca56f96362cfdd8 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 15 Dec 2023 11:05:52 +0100 Subject: [PATCH 27/27] Update ci.yml --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1afdcb06..2def36cd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,7 @@ name: CI on: + workflow_dispatch: push: branches: ["main"] pull_request: