From 5180f433398499891d585db986d8fe238d697167 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 5 Sep 2023 07:00:19 +0200 Subject: [PATCH 01/36] refacto --- src/data_type/function.rs | 8 ++--- src/differential_privacy/mod.rs | 20 ++++++------ src/relation/transforms.rs | 54 ++++++++++++++++----------------- 3 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 5e6127f9..8ca42b2c 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -365,10 +365,10 @@ impl Function for Pointwise { } } -/// Partitionned monotonic function (plus some complex periodic cases) -/// The domain is a (cartesian) product of Intervals types -/// P and T are convenient representations of the product and elements of the product -/// The partition function maps a product into a vector of products where the value function is supposed to be monotonic +/// Partitionned monotonic function (plus some complex periodic cases). +/// The domain is a (cartesian) product of `Intervals` types. +/// `P` and `T` are convenient representations of the product and elements of the product. +/// The partition function maps a product into a vector of products where the value function is supposed to be monotonic. #[derive(Clone)] pub struct PartitionnedMonotonic where diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 88cdce6c..335ae63b 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -8,7 +8,6 @@ pub mod protect_grouping_keys; use crate::data_type::DataTyped; use crate::{ - data_type::intervals::Bound, expr::{aggregate, Expr}, hierarchy::Hierarchy, protected::PE_ID, @@ -147,16 +146,15 @@ mod tests { fn test_table_with_noise() { let mut database = postgresql::test_database(); let relations = database.relations(); - // // CReate a relation to add noise to - // let relation = Relation::try_from( - // parse("SELECT sum(price) FROM item_table GROUP BY order_id") - // .unwrap() - // .with(&relations), - // ) - // .unwrap(); - // println!("Schema = {}", relation.schema()); - // relation.display_dot().unwrap(); - + // CReate a relation to add noise to + let relation = Relation::try_from( + parse("SELECT sum(price) FROM item_table GROUP BY order_id") + .unwrap() + .with(&relations), + ) + .unwrap(); + println!("Schema = {}", relation.schema()); + relation.display_dot().unwrap(); // Add noise directly for row in database .query("SELECT random(), sum(price) FROM item_table GROUP BY order_id") diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index c84ee8e3..bf537789 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -92,7 +92,7 @@ pub type Result = result::Result; */ impl Table { - /// Rename a Table + /// Create a new Table with a new name pub fn with_name(mut self, name: String) -> Table { self.name = name; self @@ -103,12 +103,12 @@ impl Table { */ impl Map { - /// Rename a Map + /// Create a new Map with a new name pub fn with_name(mut self, name: String) -> Map { self.name = name; self } - /// Prepend a field to a Map + /// Create a new Map with a prepended field pub fn with_field(self, name: &str, expr: Expr) -> Map { Relation::map().with((name, expr)).with(self).build() } @@ -507,7 +507,6 @@ impl Relation { .input(join) .build() } - /// Add a field designated with a "field path" pub fn with_field_path<'a>( self, @@ -546,7 +545,7 @@ impl Relation { ) } } - + /// Filter fields pub fn filter_fields bool>(self, predicate: P) -> Relation { match self { Relation::Map(map) => map.filter_fields(predicate).into(), @@ -560,7 +559,7 @@ impl Relation { } } } - + /// Map fields pub fn map_fields Expr>(self, f: F) -> Relation { match self { Relation::Map(map) => map.map_fields(f).into(), @@ -575,7 +574,7 @@ impl Relation { .build(), } } - + /// Rename fields pub fn rename_fields String>(self, f: F) -> Relation { match self { Relation::Map(map) => map.rename_fields(f).into(), @@ -591,37 +590,38 @@ impl Relation { .build(), } } - - pub fn sum_by(self, base: Vec<&str>, coordinates: Vec<&str>) -> Self { + /// Sum values for each group. + /// Groups form the basis of a vector space, the sums are the coordinates. + pub fn sum_by_group(self, groups: Vec<&str>, values: Vec<&str>) -> Self { let mut reduce = Relation::reduce().input(self.clone()); - reduce = base - .iter() - .fold(reduce, |acc, s| acc.with_group_by_column(s.to_string())); + reduce = groups + .into_iter() + .fold(reduce, |acc, s| acc.with_group_by_column(s)); reduce = reduce.with_iter( - coordinates - .iter() - .map(|c| (*c, Expr::sum(Expr::col(c.to_string())))), + values + .into_iter() + .map(|c| (c, Expr::sum(Expr::col(c.to_string())))), ); reduce.build() } - - pub fn l1_norm(self, vectors: &str, base: Vec<&str>, coordinates: Vec<&str>) -> Self { - let mut vectors_base = vec![vectors]; - vectors_base.extend(base.clone()); - let first = self.sum_by(vectors_base, coordinates.clone()); + /// Compute L1 norms of the vectors formed by the group values for each entities + pub fn l1_norm(self, entities: &str, groups: Vec<&str>, values: Vec<&str>) -> Self { + let mut vectors_base = vec![entities]; + vectors_base.extend(groups.clone()); + let first = self.sum_by_group(vectors_base, values.clone()); let map_rel = first.map_fields(|n, e| { - if coordinates.contains(&n) { + if values.contains(&n) { Expr::abs(e) } else { e } }); - if base.is_empty() { + if groups.is_empty() { map_rel } else { - map_rel.sum_by(vec![vectors], coordinates) + map_rel.sum_by_group(vec![entities], values) } } @@ -631,7 +631,7 @@ impl Relation { } else { let mut vectors_base = vec![vectors]; vectors_base.extend(base.clone()); - let first = self.sum_by(vectors_base, coordinates.clone()); + let first = self.sum_by_group(vectors_base, coordinates.clone()); let map_rel = first.map_fields(|n, e| { if coordinates.contains(&n) { @@ -640,7 +640,7 @@ impl Relation { e } }); - let reduce_rel = map_rel.sum_by(vec![vectors], coordinates.clone()); + let reduce_rel = map_rel.sum_by_group(vec![vectors], coordinates.clone()); reduce_rel.map_fields(|n, e| { if coordinates.contains(&n) { Expr::sqrt(e) @@ -771,13 +771,13 @@ impl Relation { } else { let mut vectors_base = vec![vectors]; vectors_base.extend(base.clone()); - self.sum_by(vectors_base, coordinates.clone()) + self.sum_by_group(vectors_base, coordinates.clone()) }; let weighted_relation = aggregated_relation.renormalize(weights, vectors, base.clone(), coordinates.clone()); - weighted_relation.sum_by(base, coordinates) + weighted_relation.sum_by_group(base, coordinates) } pub fn clip_aggregates(self, vectors: &str, clipping_values: Vec<(&str, f64)>) -> Result { From a634630f7799a8d3db2a7a0c6b3fe197fb1da90c Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 5 Sep 2023 18:23:59 +0200 Subject: [PATCH 02/36] Fixed l1 and l2 norms --- src/protected/mod.rs | 2 +- src/relation/transforms.rs | 170 +++++++++++++++++++++++++------------ 2 files changed, 118 insertions(+), 54 deletions(-) diff --git a/src/protected/mod.rs b/src/protected/mod.rs index 9790afd6..e9375a0b 100644 --- a/src/protected/mod.rs +++ b/src/protected/mod.rs @@ -440,7 +440,7 @@ mod tests { let vector = PE_ID.clone(); let base = vec!["item"]; let coordinates = vec!["price"]; - let norm = relation.l2_norm(vector, base, coordinates); + let norm = relation.l2_norms(vector, base, coordinates); norm.display_dot().unwrap(); // Print query let query: &str = &ast::Query::from(&norm).to_string(); diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index bf537789..dee7db92 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -592,7 +592,7 @@ impl Relation { } /// Sum values for each group. /// Groups form the basis of a vector space, the sums are the coordinates. - pub fn sum_by_group(self, groups: Vec<&str>, values: Vec<&str>) -> Self { + pub fn sums_by_group(self, groups: Vec<&str>, values: Vec<&str>) -> Self { let mut reduce = Relation::reduce().input(self.clone()); reduce = groups .into_iter() @@ -605,50 +605,41 @@ impl Relation { reduce.build() } /// Compute L1 norms of the vectors formed by the group values for each entities - pub fn l1_norm(self, entities: &str, groups: Vec<&str>, values: Vec<&str>) -> Self { - let mut vectors_base = vec![entities]; - vectors_base.extend(groups.clone()); - let first = self.sum_by_group(vectors_base, values.clone()); - - let map_rel = first.map_fields(|n, e| { - if values.contains(&n) { - Expr::abs(e) - } else { - e - } - }); - - if groups.is_empty() { - map_rel - } else { - map_rel.sum_by_group(vec![entities], values) - } - } - - pub fn l2_norm(self, vectors: &str, base: Vec<&str>, coordinates: Vec<&str>) -> Self { - if base.is_empty() { - self.l1_norm(vectors, base, coordinates) - } else { - let mut vectors_base = vec![vectors]; - vectors_base.extend(base.clone()); - let first = self.sum_by_group(vectors_base, coordinates.clone()); - - let map_rel = first.map_fields(|n, e| { - if coordinates.contains(&n) { - Expr::pow(e, Expr::val(2)) - } else { - e - } - }); - let reduce_rel = map_rel.sum_by_group(vec![vectors], coordinates.clone()); - reduce_rel.map_fields(|n, e| { - if coordinates.contains(&n) { - Expr::sqrt(e) + pub fn l1_norms(self, entities: &str, groups: Vec<&str>, values: Vec<&str>) -> Self { + let mut entities_groups = vec![entities]; + entities_groups.extend(groups.clone()); + self + .sums_by_group(entities_groups, values.clone()) + .map_fields(|field, expr| { + if values.contains(&field) { + Expr::abs(expr) } else { - e + expr } }) - } + .sums_by_group(vec![entities], values) + } + + pub fn l2_norms(self, entities: &str, groups: Vec<&str>, values: Vec<&str>) -> Self { + let mut entities_groups = vec![entities]; + entities_groups.extend(groups.clone()); + self + .sums_by_group(entities_groups, values.clone()) + .map_fields(|field, expr| { + if values.contains(&field) { + Expr::pow(expr, Expr::val(2)) + } else { + expr + } + }) + .sums_by_group(vec![entities], values.clone()) + .map_fields(|field, expr| { + if values.contains(&field) { + Expr::sqrt(expr) + } else { + expr + } + }) } /// This transform multiplies the coordinates in `self` relation by their corresponding weights in `weight_relation`. @@ -734,7 +725,7 @@ impl Relation { ) -> Self { let norm = self .clone() - .l2_norm(vectors.clone(), base.clone(), coordinates.clone()); + .l2_norms(vectors.clone(), base.clone(), coordinates.clone()); let map_clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); @@ -771,13 +762,13 @@ impl Relation { } else { let mut vectors_base = vec![vectors]; vectors_base.extend(base.clone()); - self.sum_by_group(vectors_base, coordinates.clone()) + self.sums_by_group(vectors_base, coordinates.clone()) }; let weighted_relation = aggregated_relation.renormalize(weights, vectors, base.clone(), coordinates.clone()); - weighted_relation.sum_by_group(base, coordinates) + weighted_relation.sums_by_group(base, coordinates) } pub fn clip_aggregates(self, vectors: &str, clipping_values: Vec<(&str, f64)>) -> Result { @@ -1175,6 +1166,79 @@ mod tests { sorted_results } + #[test] + fn test_sums_by_group() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + let mut relation = relations + .get(&["item_table".into()]) + .unwrap() + .as_ref() + .clone(); + // Print query before + println!("Before: {}", &ast::Query::from(&relation)); + relation.display_dot().unwrap(); + // Sum by group + relation = relation.sums_by_group(vec!["order_id"], vec!["price"]); + // Print query after + println!("After: {}", &ast::Query::from(&relation)); + relation.display_dot().unwrap(); + } + + #[test] + fn test_l1_norms() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + let mut relation = relations + .get(&["user_table".into()]) + .unwrap() + .as_ref() + .clone(); + // Compute l1 norm + relation = relation.l1_norms("id", vec!["city"], vec!["age"]); + // Print query + let query = &ast::Query::from(&relation); + println!("After: {}", query); + relation.display_dot().unwrap(); + let expected_query = "SELECT id, SUM(ABS(age)) FROM (SELECT id, city, SUM(age) AS age FROM user_table GROUP BY id, city) AS sums GROUP BY id"; + assert_eq!( + database.query(&query.to_string()).unwrap(), + database.query(expected_query).unwrap() + ); + // To double check + for row in database.query("SELECT id, SUM(ABS(age)) FROM (SELECT id, city, SUM(age) AS age FROM user_table GROUP BY id, city) AS sums GROUP BY id ORDER BY id").unwrap() { + println!("{row}"); + } + for row in database.query("SELECT id, count(id) FROM user_table GROUP BY id ORDER BY id").unwrap() { + println!("{row}"); + } + for row in database.query("SELECT id, age FROM user_table ORDER BY id").unwrap() { + println!("{row}"); + } + } + + #[test] + fn test_l2_norms() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + let mut relation = relations + .get(&["user_table".into()]) + .unwrap() + .as_ref() + .clone(); + // Compute l1 norm + relation = relation.l2_norms("id", vec!["city"], vec!["age"]); + // Print query + let query = &ast::Query::from(&relation); + println!("After: {}", query); + relation.display_dot().unwrap(); + let expected_query = "SELECT id, SQRT(SUM(age*age)) FROM (SELECT id, city, SUM(age) AS age FROM user_table GROUP BY id, city) AS sums GROUP BY id"; + assert_eq!( + database.query(&query.to_string()).unwrap(), + database.query(expected_query).unwrap() + ); + } + #[test] fn test_compute_norm_for_table() { let mut database = postgresql::test_database(); @@ -1188,7 +1252,7 @@ mod tests { // L1 Norm let amount_norm = table .clone() - .l1_norm("order_id", vec!["item"], vec!["price"]); + .l1_norms("order_id", vec!["item"], vec!["price"]); // amount_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&amount_norm).to_string(); println!("Query = {}", query); @@ -1198,7 +1262,7 @@ mod tests { database.query(valid_query).unwrap() ); // L2 Norm - let amount_norm = table.l2_norm("order_id", vec!["item"], vec!["price"]); + let amount_norm = table.l2_norms("order_id", vec!["item"], vec!["price"]); amount_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&amount_norm).to_string(); let valid_query = "SELECT order_id, SQRT(SUM(sum_by_group)) FROM (SELECT order_id, item, POWER(SUM(price), 2) AS sum_by_group FROM item_table GROUP BY order_id, item) AS subquery GROUP BY order_id"; @@ -1219,7 +1283,7 @@ mod tests { .as_ref() .clone(); // L1 Norm - let amount_norm = table.clone().l1_norm("order_id", vec![], vec!["price"]); + let amount_norm = table.clone().l1_norms("order_id", vec![], vec!["price"]); amount_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&amount_norm).to_string(); println!("Query = {}", query); @@ -1231,7 +1295,7 @@ mod tests { ); // L2 Norm - let amount_norm = table.l2_norm("order_id", vec![], vec!["price"]); + let amount_norm = table.l2_norms("order_id", vec![], vec!["price"]); amount_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&amount_norm).to_string(); let valid_query = @@ -1259,7 +1323,7 @@ mod tests { let relation_norm = relation .clone() - .l1_norm("order_id", vec!["item"], vec!["price", "std_price"]); + .l1_norms("order_id", vec!["item"], vec!["price", "std_price"]); relation_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation_norm).to_string(); //println!("Query = {}", query); @@ -1269,7 +1333,7 @@ mod tests { database.query(valid_query).unwrap() ); // L2 Norm - let relation_norm = relation.l2_norm("order_id", vec!["item"], vec!["price", "std_price"]); + let relation_norm = relation.l2_norms("order_id", vec!["item"], vec!["price", "std_price"]); relation_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation_norm).to_string(); let valid_query = "SELECT order_id, SQRT(SUM(sum_1)), SQRT(SUM(sum_2)) FROM (SELECT order_id, item, POWER(SUM(price), 2) AS sum_1, POWER(SUM(std_price), 2) AS sum_2 FROM ( SELECT price - 25 AS std_price, * FROM item_table ) AS intermediate_table GROUP BY order_id, item) AS subquery GROUP BY order_id"; @@ -1311,7 +1375,7 @@ mod tests { // L1 Norm let relation_norm = relation .clone() - .l1_norm(user_id, vec![item, date], vec![price]); + .l1_norms(user_id, vec![item, date], vec![price]); relation_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation_norm).to_string(); println!("Query = {}", query); @@ -1322,7 +1386,7 @@ mod tests { database.query(valid_query).unwrap() ); // L2 Norm - let relation_norm = relation.l2_norm(user_id, vec![item, date], vec![price]); + let relation_norm = relation.l2_norms(user_id, vec![item, date], vec![price]); relation_norm.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation_norm).to_string(); let valid_query = "SELECT user_id, SQRT(SUM(sum_1)) FROM (SELECT user_id, item, date, POWER(SUM(price), 2) AS sum_1 FROM item_table JOIN order_table ON item_table.order_id = order_table.id GROUP BY user_id, item, date) AS subquery GROUP BY user_id"; From 8c4f8be6f6443329679bf6aab926d57f0098cc1d Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 5 Sep 2023 18:29:39 +0200 Subject: [PATCH 03/36] ok --- src/relation/transforms.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index dee7db92..2b879044 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -647,7 +647,7 @@ impl Relation { /// `self` contains the coordinates, the base and vectors columns pub fn renormalize( self, - weight_relation: Self, + weight_relation: Relation, vectors: &str, base: Vec<&str>, coordinates: Vec<&str>, From 42337b65d071d4e815393303dac19acccc760806 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 5 Sep 2023 18:39:51 +0200 Subject: [PATCH 04/36] L1 and L2 work --- src/relation/transforms.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 2b879044..2887eefb 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -1273,7 +1273,7 @@ mod tests { } #[test] - fn test_compute_norm_for_empty_base() { + fn test_compute_norm_for_empty_groups() { let mut database = postgresql::test_database(); let relations = database.relations(); @@ -1285,10 +1285,9 @@ mod tests { // L1 Norm let amount_norm = table.clone().l1_norms("order_id", vec![], vec!["price"]); amount_norm.display_dot().unwrap(); - let query: &str = &ast::Query::from(&amount_norm).to_string(); + let query: &str = &format!("{} ORDER BY order_id", ast::Query::from(&amount_norm)); println!("Query = {}", query); - let valid_query = "SELECT order_id, ABS(SUM(price)) FROM item_table GROUP BY order_id"; - database.query(query).unwrap(); + let valid_query = "SELECT order_id, ABS(SUM(price)) FROM item_table GROUP BY order_id ORDER BY order_id"; assert_eq!( database.query(query).unwrap(), database.query(valid_query).unwrap() @@ -1297,10 +1296,9 @@ mod tests { // L2 Norm let amount_norm = table.l2_norms("order_id", vec![], vec!["price"]); amount_norm.display_dot().unwrap(); - let query: &str = &ast::Query::from(&amount_norm).to_string(); + let query: &str = &format!("{} ORDER BY order_id", ast::Query::from(&amount_norm)); let valid_query = - "SELECT order_id, SQRT(POWER(SUM(price), 2)) FROM item_table GROUP BY order_id"; - database.query(query).unwrap(); + "SELECT order_id, SQRT(POWER(SUM(price), 2)) FROM item_table GROUP BY order_id ORDER BY order_id"; assert_eq!( database.query(query).unwrap(), database.query(valid_query).unwrap() From f0cdad405bc24398f8fa341e46bd5c56fb26d304 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Wed, 6 Sep 2023 23:23:17 +0200 Subject: [PATCH 05/36] Clipped values --- src/relation/transforms.rs | 183 +++++++++++++++---------------------- 1 file changed, 73 insertions(+), 110 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 2887eefb..491ed528 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -619,22 +619,22 @@ impl Relation { }) .sums_by_group(vec![entities], values) } - + /// Compute L2 norms of the vectors formed by the group values for each entities pub fn l2_norms(self, entities: &str, groups: Vec<&str>, values: Vec<&str>) -> Self { let mut entities_groups = vec![entities]; entities_groups.extend(groups.clone()); self .sums_by_group(entities_groups, values.clone()) - .map_fields(|field, expr| { - if values.contains(&field) { + .map_fields(|field_name, expr| { + if values.contains(&field_name) { Expr::pow(expr, Expr::val(2)) } else { expr } }) .sums_by_group(vec![entities], values.clone()) - .map_fields(|field, expr| { - if values.contains(&field) { + .map_fields(|field_name, expr| { + if values.contains(&field_name) { Expr::sqrt(expr) } else { expr @@ -642,133 +642,73 @@ impl Relation { }) } - /// This transform multiplies the coordinates in `self` relation by their corresponding weights in `weight_relation`. - /// `weight_relation` contains the coordinates weights and the vectors columns + /// This transform multiplies the values in `self` relation by their corresponding `scale_factors`. + /// `scale_factors` contains the entities scaling factors and the vectors columns /// `self` contains the coordinates, the base and vectors columns - pub fn renormalize( + pub fn scale( self, - weight_relation: Relation, - vectors: &str, - base: Vec<&str>, - coordinates: Vec<&str>, - ) -> Self { - // Join the two relations on the peid column + entities: &str, + values: Vec<&str>, + scale_factors: Relation, + ) -> Self {// TODO fix this + // Join the two relations on the entity column let join: Relation = Relation::join() - .left(self.clone()) - .right(weight_relation.clone()) - .inner() .on(Expr::eq( - Expr::qcol(self.name(), vectors), - Expr::qcol(weight_relation.name(), vectors), + Expr::qcol(self.name(), entities), + Expr::qcol(scale_factors.name(), entities), )) + .left_names(self.fields().into_iter().map(|field| field.name()).collect()) + .right_names(scale_factors.fields().into_iter().map(|field| format!("_SCALE_FACTOR_{}", field.name())).collect()) + .left(self) + .right(scale_factors) + .inner() .build(); - - // Multiply by weights - let mut grouping_cols: Vec = vec![]; - let mut weighted_agg: Vec = vec![]; - let left_len = if base.is_empty() { - self.schema().len() + 1 - } else { - self.schema().len() - }; - let join_len = join.schema().len(); - let out_fields = join.schema().fields(); - let in_fields = join.input_fields(); - for i in 0..left_len { - // length + 1 - if coordinates.contains(&in_fields[i].name()) { - let mut pos = i + 1; - while &in_fields[i].name() != &in_fields[pos].name() { - pos += 1; - if pos > join_len { - panic!() - } - } - - weighted_agg.push(Expr::multiply( - Expr::col(out_fields[i].name()), - Expr::col(out_fields[pos].name()), - )); + // Multiply the values by the factors + join.map_fields(|field_name, expr| { + if values.contains(&field_name) { + Expr::multiply(expr, Expr::col(format!("_SCALE_FACTOR_{}", field_name))) } else { - grouping_cols.push(Expr::col(out_fields[i].name())); + expr } - } - - let mut vectors_base = vec![vectors]; - vectors_base.extend(base.clone()); - Relation::map() - .input(join) - .with_iter( - vectors_base - .iter() - .zip(grouping_cols.iter()) - .map(|(s, e)| (s.to_string(), e.clone())), - ) - .with_iter( - coordinates - .iter() - .zip(weighted_agg.iter()) - .map(|(s, e)| (s.to_string(), e.clone())), - ) - .build() + }) } - /// For each coordinate, rescale the columns by 1 / max(c, norm_l2(coordinate)) + /// For each coordinate, rescale the columns by 1 / greatest(1, norm_l2/C) /// where the l2 norm is computed for each elecment of `vectors` /// The `self` relation must contain the vectors, base and coordinates columns pub fn clipped_sum( self, - vectors: &str, - base: Vec<&str>, - coordinates: Vec<&str>, + entities: &str, + groups: Vec<&str>, + values: Vec<&str>, clipping_values: Vec<(&str, f64)>, ) -> Self { - let norm = self + // Compute the norm + let norms = self .clone() - .l2_norms(vectors.clone(), base.clone(), coordinates.clone()); - - let map_clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); - - let weights = norm.map_fields(|n, e| { - if coordinates.contains(&n) { + .l2_norms(entities.clone(), groups.clone(), values.clone()); + // Put the `clipping_values`in the right shape + let clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); + // Compute the scaling factors + let scaling_factors = norms.map_fields(|field_name, expr| { + if values.contains(&field_name) { Expr::divide( - Expr::val(2), - Expr::plus( - Expr::abs(Expr::minus( - Expr::divide(e.clone(), Expr::val(map_clipping_values[&n])), - Expr::val(1), - )), - Expr::plus( - Expr::divide(e, Expr::val(map_clipping_values[&n])), - Expr::val(1), - ), + Expr::val(1), + Expr::greatest( + Expr::val(1), + Expr::divide(expr.clone(), Expr::val(clipping_values[&field_name])), ), ) } else { - Expr::col(n) + Expr::val(1) } }); - - let aggregated_relation: Relation = if base.is_empty() { - Relation::map() - .input(self) - .with((vectors, Expr::col(vectors))) - .with_iter( - coordinates - .iter() - .map(|s| (s.to_string(), Expr::col(s.to_string()))), - ) - .build() - } else { - let mut vectors_base = vec![vectors]; - vectors_base.extend(base.clone()); - self.sums_by_group(vectors_base, coordinates.clone()) - }; - - let weighted_relation = - aggregated_relation.renormalize(weights, vectors, base.clone(), coordinates.clone()); - - weighted_relation.sums_by_group(base, coordinates) + let clipped_relation = self.scale( + entities, + values.clone(), + scaling_factors, + ); + clipped_relation.sums_by_group(groups, values) } pub fn clip_aggregates(self, vectors: &str, clipping_values: Vec<(&str, f64)>) -> Result { @@ -1226,7 +1166,7 @@ mod tests { .unwrap() .as_ref() .clone(); - // Compute l1 norm + // Compute l2 norm relation = relation.l2_norms("id", vec!["city"], vec!["age"]); // Print query let query = &ast::Query::from(&relation); @@ -1398,6 +1338,29 @@ mod tests { } } + #[test] + fn test_l2_scaling() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + let mut relation = relations + .get(&["user_table".into()]) + .unwrap() + .as_ref() + .clone(); + // Compute l1 norm + let norms = relation.clone().l2_norms("id", vec!["city"], vec!["age"]); + relation = relation.scale("id", vec!["age"], norms); + // Print query + let query = &ast::Query::from(&relation); + println!("After: {}", query); + relation.display_dot().unwrap(); + // let expected_query = "SELECT id, SQRT(SUM(age*age)) FROM (SELECT id, city, SUM(age) AS age FROM user_table GROUP BY id, city) AS sums GROUP BY id"; + // assert_eq!( + // database.query(&query.to_string()).unwrap(), + // database.query(expected_query).unwrap() + // ); + } + #[test] fn test_clipped_sum_for_table() { let mut database = postgresql::test_database(); From ad8e7cd2676ba76639cbf9535f0adde8af0dc352 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Wed, 6 Sep 2023 23:30:11 +0200 Subject: [PATCH 06/36] Fix clipping --- src/relation/transforms.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 491ed528..f780be8c 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -243,7 +243,7 @@ impl Reduce { Got {len_clipping_values} clipping values for {len_coordinates} output fields" ))); } - let clipped_relation = self.input.as_ref().clone().clipped_sum( + let clipped_relation = self.input.as_ref().clone().l2_clipped_sums( vectors?.as_str(), base.iter().map(|s| s.as_str()).collect(), coordinates.iter().map(|s| s.as_str()).collect(), @@ -676,7 +676,7 @@ impl Relation { /// For each coordinate, rescale the columns by 1 / greatest(1, norm_l2/C) /// where the l2 norm is computed for each elecment of `vectors` /// The `self` relation must contain the vectors, base and coordinates columns - pub fn clipped_sum( + pub fn l2_clipped_sums( self, entities: &str, groups: Vec<&str>, @@ -1339,7 +1339,7 @@ mod tests { } #[test] - fn test_l2_scaling() { + fn test_l2_clipped_sums() { let mut database = postgresql::test_database(); let relations = database.relations(); let mut relation = relations @@ -1348,8 +1348,7 @@ mod tests { .as_ref() .clone(); // Compute l1 norm - let norms = relation.clone().l2_norms("id", vec!["city"], vec!["age"]); - relation = relation.scale("id", vec!["age"], norms); + relation = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", 20.)]); // Print query let query = &ast::Query::from(&relation); println!("After: {}", query); @@ -1371,7 +1370,7 @@ mod tests { .unwrap() .as_ref() .clone(); - let clipped_relation = table.clone().clipped_sum( + let clipped_relation = table.clone().l2_clipped_sums( "order_id", vec!["item"], vec!["price"], @@ -1405,7 +1404,7 @@ mod tests { let clipped_relation = table .clone() - .clipped_sum("order_id", vec![], vec!["price"], vec![("price", 45.)]); + .l2_clipped_sums("order_id", vec![], vec!["price"], vec![("price", 45.)]); clipped_relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&clipped_relation).to_string(); println!("Query: {}", query); @@ -1436,7 +1435,7 @@ mod tests { relation.display_dot().unwrap(); // L2 Norm - let clipped_relation = relation.clone().clipped_sum( + let clipped_relation = relation.clone().l2_clipped_sums( "order_id", vec!["item"], vec!["price", "std_price"], @@ -1492,7 +1491,7 @@ mod tests { let date = schema.field_from_index(6).unwrap().name(); let clipped_relation = - relation.clipped_sum(user_id, vec![item, date], vec![price], vec![(price, 50.)]); + relation.l2_clipped_sums(user_id, vec![item, date], vec![price], vec![(price, 50.)]); clipped_relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&clipped_relation).to_string(); let valid_query = r#" From de1c0f34c51b7be12b308b237da3c1b0795f46f3 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Wed, 6 Sep 2023 23:35:20 +0200 Subject: [PATCH 07/36] To fix --- src/relation/transforms.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index f780be8c..7eec2243 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -687,6 +687,8 @@ impl Relation { let norms = self .clone() .l2_norms(entities.clone(), groups.clone(), values.clone()); + // TODO REMOVE DEBUG + norms.display_dot().unwrap(); // Put the `clipping_values`in the right shape let clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); // Compute the scaling factors @@ -703,6 +705,8 @@ impl Relation { Expr::val(1) } }); + // TODO REMOVE DEBUG + scaling_factors.display_dot().unwrap(); let clipped_relation = self.scale( entities, values.clone(), From 6b9fe6bd55818a4f1402d6a314d310cdfd74619b Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Thu, 7 Sep 2023 18:03:38 +0200 Subject: [PATCH 08/36] ok --- src/relation/transforms.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 9381bb51..4e861221 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -694,11 +694,11 @@ impl Relation { // Compute the scaling factors let scaling_factors = norms.map_fields(|field_name, expr| { if values.contains(&field_name) { - Expr::divide( + Expr::multiply( Expr::val(1), Expr::greatest( Expr::val(1), - Expr::divide(expr.clone(), Expr::val(clipping_values[&field_name])), + Expr::multiply(expr.clone(), Expr::val(clipping_values[&field_name])), ), ) } else { From 738833172d12d94bbcde0e78fe3e73c2afb2c8d8 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 12:21:01 +0200 Subject: [PATCH 09/36] fixed And op for struct of structs` --- src/data_type/mod.rs | 143 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 6 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 4c1ea9e6..fffab413 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -912,13 +912,27 @@ impl, T: Into>> And<(S, T)> for Struct { fn and(self, other: (S, T)) -> Self::Product { let field: String = other.0.into(); let data_type: Rc = other.1.into(); + let mut push_other = true; // Remove existing elements with the same name let mut fields: Vec<(String, Rc)> = self .fields .iter() - .filter_map(|(f, t)| (&field != f).then_some((f.clone(), t.clone()))) + .filter_map( + |(f, t)| {//(&field != f).then_some((f.clone(), t.clone())) + if &field != f { + Some((f.clone(), t.clone())) + } else if let (&DataType::Struct(_), &DataType::Struct(_)) = (data_type.as_ref(), t.as_ref()){ + push_other = false; + Some((f.clone(), Rc::new(data_type.as_ref().clone().and(t.as_ref().clone())))) + }else { + None + } + } + ) .collect(); - fields.push((field, data_type)); + if push_other { + fields.push((field, data_type)) + } Struct::new(fields.into()) } } @@ -3334,13 +3348,130 @@ mod tests { .and(("c", DataType::float())) .and(("d", DataType::float())) .and(("d", DataType::float())); + println!("a = {a}"); + println!("b = {b}"); + + // a and b let c = a.clone().and(b.clone()); + println!("\na and b = {c}"); + assert_eq!( + c, + Struct::default() + .and(("0", DataType::float())) + .and(("a", DataType::integer())) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("b", DataType::integer())) + .and(("c", DataType::float())) + .and(("d", DataType::float())) + ); + + // a and unit let d = a.clone().and(DataType::unit()); - let e = a.and(DataType::Struct(b)); - println!("{c}"); - println!("{d}"); - println!("{e}"); + println!("\na and unit = {d}"); + assert_eq!( + d, + Struct::default() + .and(("0", DataType::float())) + .and(("a", DataType::integer())) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("4", DataType::unit())) + ); + + // a and DataType(b) + let e = a.clone().and(DataType::Struct(b.clone())); + println!("\na and b = {e}"); assert_eq!(e.fields().len(), 8); + assert_eq!( + e, + Struct::default() + .and(("0", DataType::float())) + .and(("a", DataType::integer())) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("b", DataType::integer())) + .and(("c", DataType::float())) + .and(("d", DataType::float())) + ); + + //struct(table1: a) and b + let f = DataType::structured([("table1", DataType::Struct(a.clone()))]).and(DataType::Struct(b.clone())); + println!("\na and struct(table1: b) = {f}"); + assert_eq!( + f, + DataType::structured([ + ( + "table1", + DataType::structured([ + ("0", DataType::float()), + ("a", DataType::integer()), + ("1", DataType::float()), + ("2", DataType::float()), + ("3", DataType::float()) + ]) + ), + ("b", DataType::integer()), + ("c", DataType::float()), + ("d", DataType::float()), + ]) + ); + + //struct(table1: a) and struct(table1: b) + let g = DataType::structured([("table1", DataType::Struct(a.clone()))]).and( + DataType::structured([("table1", DataType::Struct(b.clone()))]) + ); + println!("\nstruct(table1: a) and struct(table1: b) = {g}"); + assert_eq!( + g, + DataType::structured([ + ( + "table1", + DataType::structured([ + ("0", DataType::float()), + ("a", DataType::integer()), + ("1", DataType::float()), + ("2", DataType::float()), + ("3", DataType::float()), + ("b", DataType::integer()), + ("c", DataType::float()), + ("d", DataType::float()), + ]) + ) + ]) + ); + + // struct(table1: a) and struct(table2: b) + let h = DataType::structured([("table1", DataType::Struct(a))]).and( + DataType::structured([("table2", DataType::Struct(b))]) + ); + println!("\nstruct(table1: a) and struct(table2: b) = {h}"); + assert_eq!( + h, + DataType::structured([ + ( + "table1", + DataType::structured([ + ("0", DataType::float()), + ("a", DataType::integer()), + ("1", DataType::float()), + ("2", DataType::float()), + ("3", DataType::float()) + ]) + ), + ( + "table2", + DataType::structured([ + ("b", DataType::integer()), + ("c", DataType::float()), + ("d", DataType::float()), + ]) + ) + ]) + ); } #[test] From df6bfaac49d56c5b751f32283e745f5b6ede84d0 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 14:39:57 +0200 Subject: [PATCH 10/36] fmt --- CHANGELOG.md | 4 +- src/data_type/mod.rs | 139 +++++++++++++++++++++++++------------------ 2 files changed, 85 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 881a7f2e..676111ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed bivariate_min and bivariate_max to least and greatest - Cast to string before MD5 for protection - Implemented `least` and `greatest` (untested) +### Fixed +- `And` for struct of structs [MR100](https://github.com/Qrlew/qrlew/pull/100) ## [0.2.2] - 2023-08-29 ### Changed @@ -21,7 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated sqlparser version - Deactivate graphviz display by default - Deactivate multiplicity testing by default -- +- ### Added - In `sampling_adjustments` added differenciated sampling and adjustments [MR77](https://github.com/Qrlew/qrlew/pull/77) - Updated sqlparser version diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index fffab413..a706ef3e 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -917,18 +917,26 @@ impl, T: Into>> And<(S, T)> for Struct { let mut fields: Vec<(String, Rc)> = self .fields .iter() - .filter_map( - |(f, t)| {//(&field != f).then_some((f.clone(), t.clone())) - if &field != f { - Some((f.clone(), t.clone())) - } else if let (&DataType::Struct(_), &DataType::Struct(_)) = (data_type.as_ref(), t.as_ref()){ - push_other = false; - Some((f.clone(), Rc::new(data_type.as_ref().clone().and(t.as_ref().clone())))) - }else { - None - } + .map(|(f, t)| { + //(&field != f).then_some((f.clone(), t.clone())) + if &field != f { + (f.clone(), t.clone()) + } else if let (&DataType::Struct(_), &DataType::Struct(_)) = + (data_type.as_ref(), t.as_ref()) + { + push_other = false; + ( + f.clone(), + Rc::new(data_type.as_ref().clone().and(t.as_ref().clone())), + ) + } else { + push_other = false; + ( + f.clone(), + Rc::new(data_type.as_ref().super_intersection(t.as_ref()).unwrap()), + ) } - ) + }) .collect(); if push_other { fields.push((field, data_type)) @@ -3337,9 +3345,18 @@ mod tests { #[test] fn test_struct_and() { + let a = Struct::default() + .and(("a", DataType::float_interval(1., 3.))) + .and(("a", DataType::integer_interval(-10, 10))); + println!("a = {a}"); + assert_eq!( + a, + Struct::from_field("a", DataType::float_values([1., 2., 3.])) + ); + let a = Struct::default() .and(DataType::float()) - .and(("a", DataType::integer())) + .and(("a", DataType::integer_interval(-10, 10))) .and(DataType::float()) .and(DataType::float()) .and(DataType::float()); @@ -3347,7 +3364,8 @@ mod tests { .and(("b", DataType::integer())) .and(("c", DataType::float())) .and(("d", DataType::float())) - .and(("d", DataType::float())); + .and(("d", DataType::float())) + .and(("a", DataType::float_interval(1., 3.))); println!("a = {a}"); println!("b = {b}"); @@ -3357,14 +3375,14 @@ mod tests { assert_eq!( c, Struct::default() - .and(("0", DataType::float())) - .and(("a", DataType::integer())) - .and(("1", DataType::float())) - .and(("2", DataType::float())) - .and(("3", DataType::float())) - .and(("b", DataType::integer())) - .and(("c", DataType::float())) - .and(("d", DataType::float())) + .and(("0", DataType::float())) + .and(("a", DataType::float_values([1., 2., 3.]))) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("b", DataType::integer())) + .and(("c", DataType::float())) + .and(("d", DataType::float())) ); // a and unit @@ -3373,12 +3391,12 @@ mod tests { assert_eq!( d, Struct::default() - .and(("0", DataType::float())) - .and(("a", DataType::integer())) - .and(("1", DataType::float())) - .and(("2", DataType::float())) - .and(("3", DataType::float())) - .and(("4", DataType::unit())) + .and(("0", DataType::float())) + .and(("a", DataType::integer_interval(-10, 10))) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("4", DataType::unit())) ); // a and DataType(b) @@ -3388,18 +3406,19 @@ mod tests { assert_eq!( e, Struct::default() - .and(("0", DataType::float())) - .and(("a", DataType::integer())) - .and(("1", DataType::float())) - .and(("2", DataType::float())) - .and(("3", DataType::float())) - .and(("b", DataType::integer())) - .and(("c", DataType::float())) - .and(("d", DataType::float())) + .and(("0", DataType::float())) + .and(("a", DataType::integer_interval(1, 3))) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("b", DataType::integer())) + .and(("c", DataType::float())) + .and(("d", DataType::float())) ); //struct(table1: a) and b - let f = DataType::structured([("table1", DataType::Struct(a.clone()))]).and(DataType::Struct(b.clone())); + let f = DataType::structured([("table1", DataType::Struct(a.clone()))]) + .and(DataType::Struct(b.clone())); println!("\na and struct(table1: b) = {f}"); assert_eq!( f, @@ -3408,7 +3427,7 @@ mod tests { "table1", DataType::structured([ ("0", DataType::float()), - ("a", DataType::integer()), + ("a", DataType::integer_interval(-10, 10)), ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()) @@ -3417,37 +3436,35 @@ mod tests { ("b", DataType::integer()), ("c", DataType::float()), ("d", DataType::float()), + ("a", DataType::float_values([1., 2., 3.])) ]) ); //struct(table1: a) and struct(table1: b) let g = DataType::structured([("table1", DataType::Struct(a.clone()))]).and( - DataType::structured([("table1", DataType::Struct(b.clone()))]) + DataType::structured([("table1", DataType::Struct(b.clone()))]), ); println!("\nstruct(table1: a) and struct(table1: b) = {g}"); assert_eq!( g, - DataType::structured([ - ( - "table1", - DataType::structured([ - ("0", DataType::float()), - ("a", DataType::integer()), - ("1", DataType::float()), - ("2", DataType::float()), - ("3", DataType::float()), - ("b", DataType::integer()), - ("c", DataType::float()), - ("d", DataType::float()), - ]) - ) - ]) + DataType::structured([( + "table1", + DataType::structured([ + ("0", DataType::float()), + ("a", DataType::float_values([1., 2., 3.])), + ("1", DataType::float()), + ("2", DataType::float()), + ("3", DataType::float()), + ("b", DataType::integer()), + ("c", DataType::float()), + ("d", DataType::float()), + ]) + )]) ); // struct(table1: a) and struct(table2: b) - let h = DataType::structured([("table1", DataType::Struct(a))]).and( - DataType::structured([("table2", DataType::Struct(b))]) - ); + let h = DataType::structured([("table1", DataType::Struct(a))]) + .and(DataType::structured([("table2", DataType::Struct(b))])); println!("\nstruct(table1: a) and struct(table2: b) = {h}"); assert_eq!( h, @@ -3456,7 +3473,7 @@ mod tests { "table1", DataType::structured([ ("0", DataType::float()), - ("a", DataType::integer()), + ("a", DataType::integer_interval(-10, 10)), ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()) @@ -3468,6 +3485,7 @@ mod tests { ("b", DataType::integer()), ("c", DataType::float()), ("d", DataType::float()), + ("a", DataType::float_values([1., 2., 3.])) ]) ) ]) @@ -3629,6 +3647,13 @@ mod tests { #[test] fn test_intersection() { + let left = DataType::float_interval(1., 3.); + let right = DataType::integer_interval(-10, 10); + let inter = left.super_intersection(&right).unwrap(); + println!("{left} ∩ {right} = {inter}"); + assert_eq!(inter, DataType::integer_interval(1, 3)); + assert_eq!(inter, right.super_intersection(&left).unwrap()); + let left = DataType::integer_interval(0, 10); let right = DataType::float_interval(5., 12.); From a5e3743a199338928799b99e20876730068ca802 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 14:46:33 +0200 Subject: [PATCH 11/36] fixed tests --- src/data_type/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index a706ef3e..19179bc5 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -3436,7 +3436,7 @@ mod tests { ("b", DataType::integer()), ("c", DataType::float()), ("d", DataType::float()), - ("a", DataType::float_values([1., 2., 3.])) + ("a", DataType::float_interval(1., 3.)) ]) ); @@ -3485,7 +3485,7 @@ mod tests { ("b", DataType::integer()), ("c", DataType::float()), ("d", DataType::float()), - ("a", DataType::float_values([1., 2., 3.])) + ("a", DataType::float_interval(1., 3.)) ]) ) ]) From dec2384fa7e6924977afa61efcc0564706f5bae5 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 16:21:31 +0200 Subject: [PATCH 12/36] fixed tests --- src/data_type/mod.rs | 57 +++++++++++++++++++------------------------- src/expr/mod.rs | 7 +++--- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 19179bc5..b08d993f 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -918,22 +918,13 @@ impl, T: Into>> And<(S, T)> for Struct { .fields .iter() .map(|(f, t)| { - //(&field != f).then_some((f.clone(), t.clone())) if &field != f { (f.clone(), t.clone()) - } else if let (&DataType::Struct(_), &DataType::Struct(_)) = - (data_type.as_ref(), t.as_ref()) - { - push_other = false; - ( - f.clone(), - Rc::new(data_type.as_ref().clone().and(t.as_ref().clone())), - ) } else { push_other = false; ( f.clone(), - Rc::new(data_type.as_ref().super_intersection(t.as_ref()).unwrap()), + Rc::new(t.as_ref().clone().and(data_type.as_ref().clone())), ) } }) @@ -2930,7 +2921,8 @@ impl DataType { pub fn product>(data_types: I) -> DataType { data_types .into_iter() - .fold(DataType::unit(), |s, t| s.and(t)) + .reduce(|s, t| s.and(t)) + .unwrap_or(DataType::unit()) } } @@ -3346,12 +3338,12 @@ mod tests { #[test] fn test_struct_and() { let a = Struct::default() - .and(("a", DataType::float_interval(1., 3.))) - .and(("a", DataType::integer_interval(-10, 10))); - println!("a = {a}"); + .and(("a", DataType::integer_interval(-10, 10))) + .and(("a", DataType::float_interval(1., 3.))); + println!("{a}"); assert_eq!( a, - Struct::from_field("a", DataType::float_values([1., 2., 3.])) + Struct::from_field("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.)) ); let a = Struct::default() @@ -3371,19 +3363,18 @@ mod tests { // a and b let c = a.clone().and(b.clone()); + let true_c = Struct::default() + .and(("0", DataType::float())) + .and(("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.))) + .and(("1", DataType::float())) + .and(("2", DataType::float())) + .and(("3", DataType::float())) + .and(("b", DataType::integer())) + .and(("c", DataType::float())) + .and(("d", DataType::float() & DataType::float())); println!("\na and b = {c}"); - assert_eq!( - c, - Struct::default() - .and(("0", DataType::float())) - .and(("a", DataType::float_values([1., 2., 3.]))) - .and(("1", DataType::float())) - .and(("2", DataType::float())) - .and(("3", DataType::float())) - .and(("b", DataType::integer())) - .and(("c", DataType::float())) - .and(("d", DataType::float())) - ); + println!("\na and b = {true_c}"); + assert_eq!(c,true_c); // a and unit let d = a.clone().and(DataType::unit()); @@ -3407,13 +3398,13 @@ mod tests { e, Struct::default() .and(("0", DataType::float())) - .and(("a", DataType::integer_interval(1, 3))) + .and(("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.))) .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) .and(("b", DataType::integer())) .and(("c", DataType::float())) - .and(("d", DataType::float())) + .and(("d", DataType::float() & DataType::float())) ); //struct(table1: a) and b @@ -3435,7 +3426,7 @@ mod tests { ), ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float()), + ("d", DataType::float() & DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ); @@ -3451,13 +3442,13 @@ mod tests { "table1", DataType::structured([ ("0", DataType::float()), - ("a", DataType::float_values([1., 2., 3.])), + ("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.)), ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()), ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float()), + ("d", DataType::float() & DataType::float()), ]) )]) ); @@ -3484,7 +3475,7 @@ mod tests { DataType::structured([ ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float()), + ("d", DataType::float() & DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index f336817d..571218dd 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -1260,9 +1260,10 @@ mod tests { #[test] fn test_bin_op() { - let dict = DataType::unit() - & ("0", DataType::float_interval(-5., 2.)) - & ("1", DataType::float_interval(-1., 2.)); + let dict = DataType::structured([ + ("0", DataType::float_interval(-5., 2.)), + ("1", DataType::float_interval(-1., 2.)) + ]); let left = Expr::col("0"); let right = Expr::col("1"); // Sum From 0e48542402bb089b51d2007570864399880a0256 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 16:21:48 +0200 Subject: [PATCH 13/36] format --- src/data_type/mod.rs | 30 +++++++++++++++++++++--------- src/expr/mod.rs | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index b08d993f..668c7972 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -3338,12 +3338,15 @@ mod tests { #[test] fn test_struct_and() { let a = Struct::default() - .and(("a", DataType::integer_interval(-10, 10))) - .and(("a", DataType::float_interval(1., 3.))); + .and(("a", DataType::integer_interval(-10, 10))) + .and(("a", DataType::float_interval(1., 3.))); println!("{a}"); assert_eq!( a, - Struct::from_field("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.)) + Struct::from_field( + "a", + DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) + ) ); let a = Struct::default() @@ -3365,7 +3368,10 @@ mod tests { let c = a.clone().and(b.clone()); let true_c = Struct::default() .and(("0", DataType::float())) - .and(("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.))) + .and(( + "a", + DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.), + )) .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) @@ -3374,7 +3380,7 @@ mod tests { .and(("d", DataType::float() & DataType::float())); println!("\na and b = {c}"); println!("\na and b = {true_c}"); - assert_eq!(c,true_c); + assert_eq!(c, true_c); // a and unit let d = a.clone().and(DataType::unit()); @@ -3398,7 +3404,10 @@ mod tests { e, Struct::default() .and(("0", DataType::float())) - .and(("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.))) + .and(( + "a", + DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) + )) .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) @@ -3426,7 +3435,7 @@ mod tests { ), ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float() & DataType::float()), + ("d", DataType::float() & DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ); @@ -3442,7 +3451,10 @@ mod tests { "table1", DataType::structured([ ("0", DataType::float()), - ("a", DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.)), + ( + "a", + DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) + ), ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()), @@ -3475,7 +3487,7 @@ mod tests { DataType::structured([ ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float() & DataType::float()), + ("d", DataType::float() & DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 571218dd..1848ebd6 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -1262,7 +1262,7 @@ mod tests { fn test_bin_op() { let dict = DataType::structured([ ("0", DataType::float_interval(-5., 2.)), - ("1", DataType::float_interval(-1., 2.)) + ("1", DataType::float_interval(-1., 2.)), ]); let left = Expr::col("0"); let right = Expr::col("1"); From 830cb7bc8c6d6918723187d212dbdf5a6a24ebd4 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 17:04:52 +0200 Subject: [PATCH 14/36] fixed test --- src/data_type/mod.rs | 56 ++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 668c7972..c666f9c1 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -924,7 +924,7 @@ impl, T: Into>> And<(S, T)> for Struct { push_other = false; ( f.clone(), - Rc::new(t.as_ref().clone().and(data_type.as_ref().clone())), + Rc::new(t.as_ref().clone().super_intersection(data_type.as_ref()).unwrap()), ) } }) @@ -948,11 +948,7 @@ impl>> And<(T,)> for Struct { impl And for Struct { type Product = Struct; fn and(self, other: Struct) -> Self::Product { - let mut result = self; - for field in other.fields() { - result = result.and(field.clone()) - } - result + self.super_intersection(&other).unwrap() } } @@ -962,7 +958,7 @@ impl And for Struct { // Simplify in the case of struct and Unit match other { //DataType::Unit(_u) => self, // TODO remove that ? - DataType::Struct(s) => self.and(s), + DataType::Struct(s) => self.super_intersection(&s).unwrap(),//self.and(s), other => self.and((other,)), } } @@ -3338,17 +3334,15 @@ mod tests { #[test] fn test_struct_and() { let a = Struct::default() - .and(("a", DataType::integer_interval(-10, 10))) - .and(("a", DataType::float_interval(1., 3.))); - println!("{a}"); + .and(("a", DataType::integer_interval(-10, 10))) + .and(("a", DataType::float_interval(1., 3.))); + println!("a = {a}"); assert_eq!( a, - Struct::from_field( - "a", - DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) - ) + Struct::default().and(("a", DataType::float_values([1., 2., 3.]))) ); + let a = Struct::default() .and(DataType::float()) .and(("a", DataType::integer_interval(-10, 10))) @@ -3368,16 +3362,16 @@ mod tests { let c = a.clone().and(b.clone()); let true_c = Struct::default() .and(("0", DataType::float())) - .and(( - "a", - DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.), - )) .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) + .and(( + "a", + DataType::float_values([1., 2., 3.]), + )) .and(("b", DataType::integer())) .and(("c", DataType::float())) - .and(("d", DataType::float() & DataType::float())); + .and(("d", DataType::float())); println!("\na and b = {c}"); println!("\na and b = {true_c}"); assert_eq!(c, true_c); @@ -3404,16 +3398,16 @@ mod tests { e, Struct::default() .and(("0", DataType::float())) - .and(( - "a", - DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) - )) .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) + .and(( + "a", + DataType::float_values([1., 2., 3.]) + )) .and(("b", DataType::integer())) .and(("c", DataType::float())) - .and(("d", DataType::float() & DataType::float())) + .and(("d", DataType::float())) ); //struct(table1: a) and b @@ -3435,7 +3429,7 @@ mod tests { ), ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float() & DataType::float()), + ("d", DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ); @@ -3451,16 +3445,16 @@ mod tests { "table1", DataType::structured([ ("0", DataType::float()), - ( - "a", - DataType::integer_interval(-10, 10) & DataType::float_interval(1., 3.) - ), ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()), + ( + "a", + DataType::float_values([1., 2., 3.]) + ), ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float() & DataType::float()), + ("d", DataType::float()), ]) )]) ); @@ -3487,7 +3481,7 @@ mod tests { DataType::structured([ ("b", DataType::integer()), ("c", DataType::float()), - ("d", DataType::float() & DataType::float()), + ("d", DataType::float()), ("a", DataType::float_interval(1., 3.)) ]) ) From 8bacdfd7bec2639ca13db9ea1a4aa93b7b11bb1b Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 17:05:38 +0200 Subject: [PATCH 15/36] fmt --- src/data_type/mod.rs | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index c666f9c1..00b89b7d 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -924,7 +924,12 @@ impl, T: Into>> And<(S, T)> for Struct { push_other = false; ( f.clone(), - Rc::new(t.as_ref().clone().super_intersection(data_type.as_ref()).unwrap()), + Rc::new( + t.as_ref() + .clone() + .super_intersection(data_type.as_ref()) + .unwrap(), + ), ) } }) @@ -958,7 +963,7 @@ impl And for Struct { // Simplify in the case of struct and Unit match other { //DataType::Unit(_u) => self, // TODO remove that ? - DataType::Struct(s) => self.super_intersection(&s).unwrap(),//self.and(s), + DataType::Struct(s) => self.super_intersection(&s).unwrap(), //self.and(s), other => self.and((other,)), } } @@ -3334,15 +3339,14 @@ mod tests { #[test] fn test_struct_and() { let a = Struct::default() - .and(("a", DataType::integer_interval(-10, 10))) - .and(("a", DataType::float_interval(1., 3.))); + .and(("a", DataType::integer_interval(-10, 10))) + .and(("a", DataType::float_interval(1., 3.))); println!("a = {a}"); assert_eq!( a, Struct::default().and(("a", DataType::float_values([1., 2., 3.]))) ); - let a = Struct::default() .and(DataType::float()) .and(("a", DataType::integer_interval(-10, 10))) @@ -3365,10 +3369,7 @@ mod tests { .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) - .and(( - "a", - DataType::float_values([1., 2., 3.]), - )) + .and(("a", DataType::float_values([1., 2., 3.]))) .and(("b", DataType::integer())) .and(("c", DataType::float())) .and(("d", DataType::float())); @@ -3401,10 +3402,7 @@ mod tests { .and(("1", DataType::float())) .and(("2", DataType::float())) .and(("3", DataType::float())) - .and(( - "a", - DataType::float_values([1., 2., 3.]) - )) + .and(("a", DataType::float_values([1., 2., 3.]))) .and(("b", DataType::integer())) .and(("c", DataType::float())) .and(("d", DataType::float())) @@ -3448,10 +3446,7 @@ mod tests { ("1", DataType::float()), ("2", DataType::float()), ("3", DataType::float()), - ( - "a", - DataType::float_values([1., 2., 3.]) - ), + ("a", DataType::float_values([1., 2., 3.])), ("b", DataType::integer()), ("c", DataType::float()), ("d", DataType::float()), From bc6ba0467fa1a2e6cfd252ce720f925416997ceb Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 17:54:30 +0200 Subject: [PATCH 16/36] fixed test --- src/data_type/mod.rs | 29 +++++++++++++++-------------- src/expr/mod.rs | 7 +++---- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 00b89b7d..1ade9753 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -912,17 +912,20 @@ impl, T: Into>> And<(S, T)> for Struct { fn and(self, other: (S, T)) -> Self::Product { let field: String = other.0.into(); let data_type: Rc = other.1.into(); - let mut push_other = true; // Remove existing elements with the same name - let mut fields: Vec<(String, Rc)> = self + let (mut fields, push_other): (Vec<_>, _) = self .fields .iter() - .map(|(f, t)| { + .fold( + (vec![], true), + |(v, b), (f, t)| { + let mut v = v; + let mut b = b; if &field != f { - (f.clone(), t.clone()) + v.push((f.clone(), t.clone())); } else { - push_other = false; - ( + b = false; + v.push(( f.clone(), Rc::new( t.as_ref() @@ -930,10 +933,10 @@ impl, T: Into>> And<(S, T)> for Struct { .super_intersection(data_type.as_ref()) .unwrap(), ), - ) + )); } - }) - .collect(); + (v, b) + }); if push_other { fields.push((field, data_type)) } @@ -962,7 +965,6 @@ impl And for Struct { fn and(self, other: DataType) -> Self::Product { // Simplify in the case of struct and Unit match other { - //DataType::Unit(_u) => self, // TODO remove that ? DataType::Struct(s) => self.super_intersection(&s).unwrap(), //self.and(s), other => self.and((other,)), } @@ -2855,7 +2857,7 @@ impl And for DataType { // Simplify in the case of struct and Unit match self { DataType::Null => DataType::Null, - //DataType::Unit(_u) => other, // TODO: reactivate ? + DataType::Unit(_) => other, DataType::Struct(s) => s.and(other).into(), s => Struct::from_data_type(s).and(other).into(), } @@ -2922,8 +2924,7 @@ impl DataType { pub fn product>(data_types: I) -> DataType { data_types .into_iter() - .reduce(|s, t| s.and(t)) - .unwrap_or(DataType::unit()) + .fold(DataType::unit(), |s, t| s.and(t)) } } @@ -3498,7 +3499,7 @@ mod tests { & ("c", DataType::boolean()) & ("d", DataType::float()); println!("b = {b}"); - assert_eq!(Struct::try_from(b).unwrap().fields.len(), 7); + assert_eq!(Struct::try_from(b).unwrap().fields.len(), 6); } #[test] diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 1848ebd6..f336817d 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -1260,10 +1260,9 @@ mod tests { #[test] fn test_bin_op() { - let dict = DataType::structured([ - ("0", DataType::float_interval(-5., 2.)), - ("1", DataType::float_interval(-1., 2.)), - ]); + let dict = DataType::unit() + & ("0", DataType::float_interval(-5., 2.)) + & ("1", DataType::float_interval(-1., 2.)); let left = Expr::col("0"); let right = Expr::col("1"); // Sum From 27e3f1ec3b61e2807648e7d9260abefde7b50d0d Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Thu, 7 Sep 2023 17:54:42 +0200 Subject: [PATCH 17/36] fmt --- src/data_type/mod.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 1ade9753..9caf8722 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -913,12 +913,8 @@ impl, T: Into>> And<(S, T)> for Struct { let field: String = other.0.into(); let data_type: Rc = other.1.into(); // Remove existing elements with the same name - let (mut fields, push_other): (Vec<_>, _) = self - .fields - .iter() - .fold( - (vec![], true), - |(v, b), (f, t)| { + let (mut fields, push_other): (Vec<_>, _) = + self.fields.iter().fold((vec![], true), |(v, b), (f, t)| { let mut v = v; let mut b = b; if &field != f { From d34e13401e0c89db9a5933ab961eb62e36f97518 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Thu, 7 Sep 2023 18:23:57 +0200 Subject: [PATCH 18/36] fix extend --- src/data_type/function.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index dc1562d6..eb11aea5 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -681,10 +681,12 @@ impl Function for Extended { } /// A function is extensible if it can be extended to a different domain -pub trait Extensible { +pub trait Extensible: Function { type Extended; - fn extend(self, domain: DataType) -> Self::Extended; + fn extend(self, domain: DataType) -> Self::Extended { + + } } // Implement extensible for all borrowed function From f63cc5e7045e436edc0b7324906f4abb585a5067 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 14:00:55 +0200 Subject: [PATCH 19/36] ok --- src/data_type/function.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index eb11aea5..dc1562d6 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -681,12 +681,10 @@ impl Function for Extended { } /// A function is extensible if it can be extended to a different domain -pub trait Extensible: Function { +pub trait Extensible { type Extended; - fn extend(self, domain: DataType) -> Self::Extended { - - } + fn extend(self, domain: DataType) -> Self::Extended; } // Implement extensible for all borrowed function From 47df6cf67e045e615d3f397b39ee48dc6094d30f Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 15:36:20 +0200 Subject: [PATCH 20/36] ok --- src/data_type/function.rs | 232 +++++++++++++++++++------------------ src/data_type/mod.rs | 40 +++++++ src/expr/implementation.rs | 10 +- 3 files changed, 163 insertions(+), 119 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index b996b8bd..07334fe9 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -620,6 +620,47 @@ where } } +/// # Extended function +/// Functions can be wrapped with this `Extended` object. +/// The co_domain is usually Option unless the domain is included in the original domain. +#[derive(Debug)] +pub struct Optional(F); + +impl Optional { + pub fn new(function: F) -> Optional { + Optional(function) + } +} + +impl fmt::Display for Optional { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "optional{{{} -> {}}}", self.domain(), self.co_domain()) + } +} + +impl Function for Optional { + fn domain(&self) -> DataType { + DataType::optional(self.0.domain()) + } + + fn super_image(&self, set: &DataType) -> Result { + match set { + DataType::Optional(optional_set) => self.0.super_image(optional_set.data_type()).map(|dt| DataType::optional(dt)), + set => self.0.super_image(set), + } + } + + fn value(&self, arg: &Value) -> Result { + match arg { + Value::Optional(optional_arg) => match optional_arg.as_deref() { + Some(arg) => self.0.value(arg).map(Value::some), + None => Ok(Value::none()), + }, + arg => self.0.value(arg), + } + } +} + /// # Extended function /// Functions can be wrapped with this `Extended` object. /// The co_domain is usually Option unless the domain is included in the original domain. @@ -641,12 +682,6 @@ impl fmt::Display for Extended { } } -impl Clone for Extended { - fn clone(&self) -> Self { - Extended::new(self.function.clone(), self.domain.clone()) - } -} - impl Function for Extended { fn domain(&self) -> DataType { self.domain.clone() @@ -680,30 +715,6 @@ impl Function for Extended { } } -/// A function is extensible if it can be extended to a different domain -pub trait Extensible { - type Extended; - - fn extend(self, domain: DataType) -> Self::Extended; -} - -// Implement extensible for all borrowed function -impl<'a, F: Function + Clone> Extensible for &'a F { - type Extended = Extended; - - fn extend(self: &'a F, domain: DataType) -> Self::Extended { - Extended::new(self.clone(), domain) - } -} - -impl Extensible for Extended { - type Extended = Extended; - - fn extend(self: Extended, domain: DataType) -> Extended { - Extended::new(self.function, domain) - } -} - /// A function defined pointwise without any other particular properties #[derive(Clone)] pub struct Aggregate @@ -817,7 +828,7 @@ where // domain: Domain, // value: ValueFunction, // } -#[derive(Clone, Debug, Default)] +#[derive(Debug, Default)] pub struct Polymorphic(Vec>); impl Polymorphic { @@ -1023,7 +1034,7 @@ We list here all the functions to expose */ // Invalid function -pub fn null() -> impl Function + Clone { +pub fn null() -> impl Function { PartitionnedMonotonic::univariate(data_type::Text::default(), |_x| "null".to_string()) } @@ -1032,7 +1043,7 @@ Conversion function */ /// Builds the cast operator -pub fn cast(into: DataType) -> impl Function + Clone { +pub fn cast(into: DataType) -> impl Function { // TODO Only cast as text is working for now match into { DataType::Text(t) if t == data_type::Text::full() => { @@ -1045,18 +1056,18 @@ pub fn cast(into: DataType) -> impl Function + Clone { // Unary operators /// Builds the minus `Function` -pub fn opposite() -> impl Function + Clone { +pub fn opposite() -> impl Function { PartitionnedMonotonic::univariate(data_type::Float::default(), |x| -x) } /// Builds the minus `Function` -pub fn not() -> impl Function + Clone { +pub fn not() -> impl Function { PartitionnedMonotonic::univariate(data_type::Boolean::default(), |x| !x) } // Arithmetic binary operators /// The sum (polymorphic) -pub fn plus() -> impl Function + Clone { +pub fn plus() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1070,7 +1081,7 @@ pub fn plus() -> impl Function + Clone { } /// The difference -pub fn minus() -> impl Function + Clone { +pub fn minus() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1084,7 +1095,7 @@ pub fn minus() -> impl Function + Clone { } /// The product (the domain is partitionned) -pub fn multiply() -> impl Function + Clone { +pub fn multiply() -> impl Function { Polymorphic::from(( // Integer implementation PartitionnedMonotonic::piecewise_bivariate( @@ -1134,7 +1145,7 @@ pub fn multiply() -> impl Function + Clone { } /// The division (the domain is partitionned) -pub fn divide() -> impl Function + Clone { +pub fn divide() -> impl Function { Polymorphic::from(( // Integer implementation PartitionnedMonotonic::piecewise_bivariate( @@ -1184,7 +1195,7 @@ pub fn divide() -> impl Function + Clone { } /// The modulo -pub fn modulo() -> impl Function + Clone { +pub fn modulo() -> impl Function { Pointwise::bivariate( (data_type::Integer::default(), data_type::Integer::default()), data_type::Integer::default(), @@ -1192,7 +1203,7 @@ pub fn modulo() -> impl Function + Clone { ) } -pub fn string_concat() -> impl Function + Clone { +pub fn string_concat() -> impl Function { Pointwise::bivariate( (data_type::Text::default(), data_type::Text::default()), data_type::Text::default(), @@ -1200,13 +1211,13 @@ pub fn string_concat() -> impl Function + Clone { ) } -pub fn concat(n: usize) -> impl Function + Clone { +pub fn concat(n: usize) -> impl Function { Pointwise::variadic(vec![DataType::Any; n], data_type::Text::default(), |v| { v.into_iter().map(|v| v.to_string()).join("") }) } -pub fn md5() -> impl Function + Clone { +pub fn md5() -> impl Function { Stateful::new( DataType::text(), DataType::text(), @@ -1218,7 +1229,7 @@ pub fn md5() -> impl Function + Clone { ) } -pub fn random(mut rng: R) -> impl Function + Clone { +pub fn random(mut rng: R) -> impl Function { Stateful::new( DataType::unit(), DataType::float_interval(0., 1.), @@ -1226,7 +1237,7 @@ pub fn random(mut rng: R) -> impl Function + Clone { ) } -pub fn gt() -> impl Function + Clone { +pub fn gt() -> impl Function { Polymorphic::default() .with(Pointwise::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1263,7 +1274,7 @@ pub fn gt() -> impl Function + Clone { )) } -pub fn lt() -> impl Function + Clone { +pub fn lt() -> impl Function { Polymorphic::default() .with(Pointwise::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1300,7 +1311,7 @@ pub fn lt() -> impl Function + Clone { )) } -pub fn gt_eq() -> impl Function + Clone { +pub fn gt_eq() -> impl Function { Polymorphic::default() .with(Pointwise::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1337,7 +1348,7 @@ pub fn gt_eq() -> impl Function + Clone { )) } -pub fn lt_eq() -> impl Function + Clone { +pub fn lt_eq() -> impl Function { Polymorphic::default() .with(Pointwise::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1374,7 +1385,7 @@ pub fn lt_eq() -> impl Function + Clone { )) } -pub fn eq() -> impl Function + Clone { +pub fn eq() -> impl Function { Pointwise::bivariate( (DataType::Any, DataType::Any), data_type::Boolean::default(), @@ -1382,7 +1393,7 @@ pub fn eq() -> impl Function + Clone { ) } -pub fn not_eq() -> impl Function + Clone { +pub fn not_eq() -> impl Function { Pointwise::bivariate( (DataType::Any, DataType::Any), data_type::Boolean::default(), @@ -1393,21 +1404,21 @@ pub fn not_eq() -> impl Function + Clone { // Boolean binary operators /// The conjunction -pub fn and() -> impl Function + Clone { +pub fn and() -> impl Function { PartitionnedMonotonic::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), |x, y| x && y, ) } /// The disjunction -pub fn or() -> impl Function + Clone { +pub fn or() -> impl Function { PartitionnedMonotonic::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), |x, y| x || y, ) } /// The exclusive or -pub fn xor() -> impl Function + Clone { +pub fn xor() -> impl Function { PartitionnedMonotonic::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), |x, y| x ^ y, @@ -1416,7 +1427,7 @@ pub fn xor() -> impl Function + Clone { // Bitwise binary operators -pub fn bitwise_or() -> impl Function + Clone { +pub fn bitwise_or() -> impl Function { Pointwise::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), data_type::Boolean::default(), @@ -1424,7 +1435,7 @@ pub fn bitwise_or() -> impl Function + Clone { ) } -pub fn bitwise_and() -> impl Function + Clone { +pub fn bitwise_and() -> impl Function { Pointwise::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), data_type::Boolean::default(), @@ -1432,7 +1443,7 @@ pub fn bitwise_and() -> impl Function + Clone { ) } -pub fn bitwise_xor() -> impl Function + Clone { +pub fn bitwise_xor() -> impl Function { Pointwise::bivariate( (data_type::Boolean::default(), data_type::Boolean::default()), data_type::Boolean::default(), @@ -1443,21 +1454,21 @@ pub fn bitwise_xor() -> impl Function + Clone { // Real functions /// Builds the exponential `Function` -pub fn exp() -> impl Function + Clone { +pub fn exp() -> impl Function { PartitionnedMonotonic::univariate(data_type::Float::default(), |x| { x.exp().clamp(0.0, ::max()) }) } /// Builds the logarithm `Function` -pub fn ln() -> impl Function + Clone { +pub fn ln() -> impl Function { PartitionnedMonotonic::univariate(data_type::Float::from(0.0..), |x| { x.ln().clamp(::min(), ::max()) }) } /// Builds the decimal logarithm `Function` -pub fn log() -> impl Function + Clone { +pub fn log() -> impl Function { PartitionnedMonotonic::univariate(data_type::Float::from(0.0..), |x| { x.log(10.) .clamp(::min(), ::max()) @@ -1465,14 +1476,14 @@ pub fn log() -> impl Function + Clone { } /// Builds the sqrt `Function` -pub fn sqrt() -> impl Function + Clone { +pub fn sqrt() -> impl Function { PartitionnedMonotonic::univariate(data_type::Float::from(0.0..), |x| { x.sqrt().clamp(::min(), ::max()) }) } /// The pow function -pub fn pow() -> impl Function + Clone { +pub fn pow() -> impl Function { PartitionnedMonotonic::piecewise_bivariate( [ ( @@ -1492,7 +1503,7 @@ pub fn pow() -> impl Function + Clone { } /// Builds the abs `Function`, a piecewise monotonic function -pub fn abs() -> impl Function + Clone { +pub fn abs() -> impl Function { PartitionnedMonotonic::piecewise_univariate( [ data_type::Float::from(..=0.0), @@ -1503,7 +1514,7 @@ pub fn abs() -> impl Function + Clone { } /// sine -pub fn sin() -> impl Function + Clone { +pub fn sin() -> impl Function { PartitionnedMonotonic::periodic_univariate( [ data_type::Float::from(-0.5 * std::f64::consts::PI..=0.5 * std::f64::consts::PI), @@ -1514,7 +1525,7 @@ pub fn sin() -> impl Function + Clone { } /// cosine -pub fn cos() -> impl Function + Clone { +pub fn cos() -> impl Function { PartitionnedMonotonic::periodic_univariate( [ data_type::Float::from(0.0..=std::f64::consts::PI), @@ -1524,7 +1535,7 @@ pub fn cos() -> impl Function + Clone { ) } -pub fn least() -> impl Function + Clone { +pub fn least() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1537,7 +1548,7 @@ pub fn least() -> impl Function + Clone { )) } -pub fn greatest() -> impl Function + Clone { +pub fn greatest() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( (data_type::Integer::default(), data_type::Integer::default()), @@ -1553,17 +1564,17 @@ pub fn greatest() -> impl Function + Clone { // String functions /// Builds the lower `Function` -pub fn lower() -> impl Function + Clone { +pub fn lower() -> impl Function { PartitionnedMonotonic::univariate(data_type::Text::default(), |x| x.to_lowercase()) } /// Builds the upper `Function` -pub fn upper() -> impl Function + Clone { +pub fn upper() -> impl Function { PartitionnedMonotonic::univariate(data_type::Text::default(), |x| x.to_uppercase()) } /// Builds the char_length `Function` -pub fn char_length() -> impl Function + Clone { +pub fn char_length() -> impl Function { Pointwise::univariate( data_type::Text::default(), data_type::Integer::default(), @@ -1572,7 +1583,7 @@ pub fn char_length() -> impl Function + Clone { } /// Builds the position `Function` -pub fn position() -> impl Function + Clone { +pub fn position() -> impl Function { Pointwise::bivariate( (data_type::Text::default(), data_type::Text::default()), DataType::optional(DataType::integer()), @@ -1586,12 +1597,12 @@ pub fn position() -> impl Function + Clone { } // Case function -pub fn case() -> impl Function + Clone { +pub fn case() -> impl Function { Case } // In operator -pub fn in_list() -> impl Function + Clone { +pub fn in_list() -> impl Function { Polymorphic::from(( InList(data_type::Integer::default().into()), InList(data_type::Float::default().into()), @@ -1604,16 +1615,16 @@ Aggregation functions */ /// Median aggregation -pub fn median() -> impl Function + Clone { +pub fn median() -> impl Function { null() } -pub fn n_unique() -> impl Function + Clone { +pub fn n_unique() -> impl Function { null() } /// First element in group -pub fn first() -> impl Function + Clone { +pub fn first() -> impl Function { Aggregate::from( DataType::Any, |values| values.first().unwrap().clone(), @@ -1625,7 +1636,7 @@ pub fn first() -> impl Function + Clone { } /// Last element in group -pub fn last() -> impl Function + Clone { +pub fn last() -> impl Function { Aggregate::from( DataType::Any, |values| values.last().unwrap().clone(), @@ -1637,7 +1648,7 @@ pub fn last() -> impl Function + Clone { } /// Mean aggregation -pub fn mean() -> impl Function + Clone { +pub fn mean() -> impl Function { // Only works on types that can be converted to floats Aggregate::from( data_type::Float::full(), @@ -1652,12 +1663,12 @@ pub fn mean() -> impl Function + Clone { } /// Aggregate as a list -pub fn list() -> impl Function + Clone { +pub fn list() -> impl Function { null() } /// Count aggregation -pub fn count() -> impl Function + Clone { +pub fn count() -> impl Function { Polymorphic::from(( // Any implementation Aggregate::from( @@ -1681,7 +1692,7 @@ pub fn count() -> impl Function + Clone { } /// Min aggregation -pub fn min() -> impl Function + Clone { +pub fn min() -> impl Function { Polymorphic::from(( // Integer implementation Aggregate::from( @@ -1713,7 +1724,7 @@ pub fn min() -> impl Function + Clone { } /// Max aggregation -pub fn max() -> impl Function + Clone { +pub fn max() -> impl Function { Polymorphic::from(( // Integer implementation Aggregate::from( @@ -1745,17 +1756,17 @@ pub fn max() -> impl Function + Clone { } /// Quantile aggregation -pub fn quantile(_p: f64) -> impl Function + Clone { +pub fn quantile(_p: f64) -> impl Function { null() } /// Multi-quantileq aggregation -pub fn quantiles(_p: Vec) -> impl Function + Clone { +pub fn quantiles(_p: Vec) -> impl Function { null() } /// Sum aggregation -pub fn sum() -> impl Function + Clone { +pub fn sum() -> impl Function { Polymorphic::from(( // Integer implementation Aggregate::from( @@ -1781,12 +1792,12 @@ pub fn sum() -> impl Function + Clone { } /// Agg groups aggregation -pub fn agg_groups() -> impl Function + Clone { +pub fn agg_groups() -> impl Function { null() } /// Standard deviation aggregation -pub fn std() -> impl Function + Clone { +pub fn std() -> impl Function { // Only works on types that can be converted to floats Aggregate::from( data_type::Float::full(), @@ -1812,7 +1823,7 @@ pub fn std() -> impl Function + Clone { } /// Variance aggregation -pub fn var() -> impl Function + Clone { +pub fn var() -> impl Function { // Only works on types that can be converted to floats Aggregate::from( data_type::Float::full(), @@ -2156,26 +2167,29 @@ mod tests { ); } + #[test] + fn test_optional() { + println!("Test optional"); + let optional_greatest = Optional::new(greatest()); + println!("greatest = {}", greatest()); + println!("optional greatest = {}", optional_greatest); + println!("super_image([0,1] & [-5,2]) = {}", optional_greatest.super_image(&(DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.))).unwrap()); + println!("super_image(optional([0,1] & [-5,2])) = {}", optional_greatest.super_image(&DataType::optional((DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.)))).unwrap()); + println!("super_image(optional([0,1]) & [-5,2]) = {}", optional_greatest.super_image(&(DataType::optional(DataType::float_interval(0.,1.)) & DataType::float_interval(-5.,2.))).unwrap()); + // assert_eq!( + // optional_greatest.co_domain(), + // DataType::optional(DataType::float_range(-1.0..=1.0)) + // ); + } + #[test] fn test_extended() { println!("Test extended"); - let extended_cos = cos().extend(DataType::Any); + let extended_cos = Extended::new(cos(), DataType::Any); println!("cos = {}", cos()); println!("extended cos = {}", extended_cos); - println!( - "extended extended cos = {}", - extended_cos.clone().extend(DataType::integer()) - ); - println!( - "extended extended cos = {}", - extended_cos.clone().extend(DataType::Any) - ); assert_eq!( - extended_cos.clone().extend(DataType::integer()).co_domain(), - DataType::float_range(-1.0..=1.0) - ); - assert_eq!( - extended_cos.extend(DataType::Any).co_domain(), + extended_cos.co_domain(), DataType::optional(DataType::float_range(-1.0..=1.0)) ); } @@ -2184,26 +2198,16 @@ mod tests { fn test_extended_binary() { println!("Test extended"); // Test a bivariate monotonic function - let extended_add = plus().extend(DataType::Any & DataType::Any); + let extended_add = Extended::new(plus(), DataType::Any & DataType::Any); println!("add = {}", plus()); println!("extended add = {}", extended_add); - println!( - "extended extended add = {}", - extended_add - .clone() - .extend(DataType::integer() & DataType::integer()) - ); - println!( - "extended extended add = {}", - extended_add.extend(DataType::Any & DataType::Any) - ); } #[test] fn test_extended_plus() { println!("Test extended"); // Test a bivariate monotonic function - let extended_plus = plus().extend(DataType::Any & DataType::Any); + let extended_plus = Extended::new(plus(), DataType::Any & DataType::Any); println!("plus = {}", plus()); println!("extended plus = {}", extended_plus); println!( @@ -2286,7 +2290,7 @@ mod tests { fn test_extended_aggregate_sum() { println!("Test extended"); // Test a bivariate monotonic function - let extended_sum = sum().extend(DataType::Any); + let extended_sum = Extended::new(sum(), DataType::Any); println!("sum = {}", sum()); println!("sum domain = {}", sum().domain()); println!("extended sum = {}", extended_sum); diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 6468ae29..083e8255 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -2955,6 +2955,46 @@ impl<'a> Acceptor<'a> for DataType { } } +// Visitors + +/// A Visitor for the type Expr +pub trait Visitor<'a, T: Clone> { + // Composed types + fn structured(&self, fields: Vec<(String, T)>) -> T; + fn union(&self, fields: Vec<(String, T)>) -> T; + fn optional(&self, data_type: T) -> T; + fn list(&self, data_type: T, size: &'a Integer) -> T; + fn set(&self, data_type: T, size: &'a Integer) -> T; + fn array(&self, data_type: T, shape: &'a [usize]) -> T; + fn function(&self, domain: T, co_domain: T) -> T; + fn primitive(&self, acceptor: &'a DataType) -> T; +} + +/// Implement a specific visitor to dispatch the dependencies more easily +impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, DataType, T> for V { + fn visit(&self, acceptor: &'a DataType, dependencies: visitor::Visited<'a, DataType, T>) -> T { + match acceptor { + DataType::Struct(s) => self.structured(s.fields.iter().map(|(s, t)| (s.clone(), dependencies.get(t.as_ref()).clone())).collect()), + DataType::Union(u) => self.union(s.fields.iter().map(|(s, t)| (s.clone(), dependencies.get(t.as_ref()).clone())).collect()), + DataType::Optional(o) => self.optional(dependencies.get(o.data_type()).clone()), + DataType::List(l) => self.list(dependencies.get(l.data_type()).clone(), l.size()), + DataType::Set(s) => self.set(dependencies.get(s.data_type()).clone(), l.size()), + DataType::Array(a) => self.array(dependencies.get(a.data_type()).clone(), a.shape()), + DataType::Function(f) => self.function(dependencies.get(f.domain()).clone(), dependencies.get(f.co_domain()).clone()), + primitive => self.primitive(primitive), + } + } +} + +/// Implement a LiftOptionalVisitor +struct LiftOptionalVisitor; + +impl<'a> visitor::Visitor<'a, DataType, DataType> for LiftOptionalVisitor { + fn visit(&self, acceptor: &'a DataType, dependencies: visitor::Visited<'a, DataType, DataType>) -> DataType { + todo!() + } +} + // TODO Write tests for all types #[cfg(test)] mod tests { diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 08e4ef76..95198306 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -1,6 +1,6 @@ use super::{aggregate::Aggregate, function::Function}; use crate::data_type::{ - function::{self, Extensible}, + function::{self, Extended}, DataType, }; use paste::paste; @@ -13,9 +13,9 @@ macro_rules! function_implementations { // A (thread local) global map thread_local! { static FUNCTION_IMPLEMENTATIONS: FunctionImplementations = FunctionImplementations { - $([< $unary:snake >]: Rc::new(function::[< $unary:snake >]().extend(DataType::Any)),)* - $([< $binary:snake >]: Rc::new(function::[< $binary:snake >]().extend(DataType::Any & DataType::Any)),)* - $([< $ternary:snake >]: Rc::new(function::[< $ternary:snake >]().extend(DataType::Any & DataType::Any & DataType::Any)),)* + $([< $unary:snake >]: Rc::new(Extended::new(function::[< $unary:snake >](), DataType::Any)),)* + $([< $binary:snake >]: Rc::new(Extended::new(function::[< $binary:snake >](), DataType::Any & DataType::Any)),)* + $([< $ternary:snake >]: Rc::new(Extended::new(function::[< $ternary:snake >](), DataType::Any & DataType::Any & DataType::Any)),)* }; } @@ -94,7 +94,7 @@ macro_rules! aggregate_implementations { // A (thread local) global map thread_local! { static AGGREGATE_IMPLEMENTATIONS: AggregateImplementations = AggregateImplementations { - $([< $implementation:snake >]: Rc::new(function::[< $implementation:snake >]().extend(DataType::Any)),)* + $([< $implementation:snake >]: Rc::new(Extended::new(function::[< $implementation:snake >](), DataType::Any)),)* }; } From 1d7a077ba823556478608cadfd087b65d842b2ba Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 17:49:26 +0200 Subject: [PATCH 21/36] Optionnal --- src/data_type/function.rs | 13 +++---- src/data_type/mod.rs | 74 ++++++++++++++++++++++++++++++++++----- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 07334fe9..5f3cc1e9 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -640,13 +640,14 @@ impl fmt::Display for Optional { impl Function for Optional { fn domain(&self) -> DataType { - DataType::optional(self.0.domain()) + DataType::optional(self.0.domain()).flatten_optional() } fn super_image(&self, set: &DataType) -> Result { + let set = set.flatten_optional(); match set { DataType::Optional(optional_set) => self.0.super_image(optional_set.data_type()).map(|dt| DataType::optional(dt)), - set => self.0.super_image(set), + set => self.0.super_image(&set), } } @@ -2176,10 +2177,10 @@ mod tests { println!("super_image([0,1] & [-5,2]) = {}", optional_greatest.super_image(&(DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.))).unwrap()); println!("super_image(optional([0,1] & [-5,2])) = {}", optional_greatest.super_image(&DataType::optional((DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.)))).unwrap()); println!("super_image(optional([0,1]) & [-5,2]) = {}", optional_greatest.super_image(&(DataType::optional(DataType::float_interval(0.,1.)) & DataType::float_interval(-5.,2.))).unwrap()); - // assert_eq!( - // optional_greatest.co_domain(), - // DataType::optional(DataType::float_range(-1.0..=1.0)) - // ); + assert_eq!( + optional_greatest.super_image(&DataType::optional((DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.)))).unwrap(), + optional_greatest.super_image(&(DataType::optional(DataType::float_interval(0.,1.)) & DataType::float_interval(-5.,2.))).unwrap(), + ); } #[test] diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 083e8255..fe9e281f 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -2736,8 +2736,8 @@ impl DataType { ))) } - pub fn array>(data_type: DataType, shape: &[usize]) -> DataType { - DataType::from(Array::from((data_type, shape))) + pub fn array>(data_type: DataType, shape: S) -> DataType { + DataType::from(Array::from((data_type, shape.as_ref()))) } pub fn function(domain: DataType, co_domain: DataType) -> DataType { @@ -2975,10 +2975,10 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, DataType, T> for V { fn visit(&self, acceptor: &'a DataType, dependencies: visitor::Visited<'a, DataType, T>) -> T { match acceptor { DataType::Struct(s) => self.structured(s.fields.iter().map(|(s, t)| (s.clone(), dependencies.get(t.as_ref()).clone())).collect()), - DataType::Union(u) => self.union(s.fields.iter().map(|(s, t)| (s.clone(), dependencies.get(t.as_ref()).clone())).collect()), + DataType::Union(u) => self.union(u.fields.iter().map(|(s, t)| (s.clone(), dependencies.get(t.as_ref()).clone())).collect()), DataType::Optional(o) => self.optional(dependencies.get(o.data_type()).clone()), DataType::List(l) => self.list(dependencies.get(l.data_type()).clone(), l.size()), - DataType::Set(s) => self.set(dependencies.get(s.data_type()).clone(), l.size()), + DataType::Set(s) => self.set(dependencies.get(s.data_type()).clone(), s.size()), DataType::Array(a) => self.array(dependencies.get(a.data_type()).clone(), a.shape()), DataType::Function(f) => self.function(dependencies.get(f.domain()).clone(), dependencies.get(f.co_domain()).clone()), primitive => self.primitive(primitive), @@ -2987,11 +2987,57 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, DataType, T> for V { } /// Implement a LiftOptionalVisitor -struct LiftOptionalVisitor; +struct FlattenOptionalVisitor; -impl<'a> visitor::Visitor<'a, DataType, DataType> for LiftOptionalVisitor { - fn visit(&self, acceptor: &'a DataType, dependencies: visitor::Visited<'a, DataType, DataType>) -> DataType { - todo!() +impl<'a> Visitor<'a, (bool, DataType)> for FlattenOptionalVisitor { + fn structured(&self, fields: Vec<(String, (bool, DataType))>) -> (bool, DataType) { + fields.into_iter().fold( + (false, DataType::unit()), + |a, (s, (o, d))| (a.0 || o, a.1 & (s, d)) + ) + } + + fn union(&self, fields: Vec<(String, (bool, DataType))>) -> (bool, DataType) { + fields.into_iter().fold( + (false, DataType::Null), + |a, (s, (o, d))| (a.0 || o, a.1 | (s, d)) + ) + } + + fn optional(&self, data_type: (bool, DataType)) -> (bool, DataType) { + (true, data_type.1) + } + + fn list(&self, data_type: (bool, DataType), size: &'a Integer) -> (bool, DataType) { + (data_type.0, List::new(Rc::new(data_type.1), size.clone()).into()) + } + + fn set(&self, data_type: (bool, DataType), size: &'a Integer) -> (bool, DataType) { + (data_type.0, Set::new(Rc::new(data_type.1), size.clone()).into()) + } + + fn array(&self, data_type: (bool, DataType), shape: &'a [usize]) -> (bool, DataType) { + (data_type.0, DataType::array(data_type.1, shape)) + } + + fn function(&self, domain: (bool, DataType), co_domain: (bool, DataType)) -> (bool, DataType) { + (domain.0 || co_domain.0, DataType::function(domain.1, co_domain.1)) + } + + fn primitive(&self, acceptor: &'a DataType) -> (bool, DataType) { + (false, acceptor.clone()) + } +} + +impl DataType { + /// Return a type with non-optional subtypes, it may be optional if one of the + pub fn flatten_optional(&self) -> DataType { + let (is_optional, flat) = self.accept(FlattenOptionalVisitor); + if is_optional { + DataType::optional(flat) + } else { + flat + } } } @@ -4072,4 +4118,16 @@ mod tests { println!("{}", h); assert_eq!(h, correct_hierarchy); } + + #[test] + fn test_flatten_optional() { + let a = DataType::unit() & DataType::float() & DataType::optional(DataType::integer_interval(0, 10)); + println!("a = {a}"); + println!("flat opt a = {}", a.flatten_optional()); + assert_eq!(a.flatten_optional(), DataType::optional(DataType::unit() & DataType::float() & DataType::integer_interval(0, 10))); + let b = DataType::unit() & DataType::float() & DataType::integer_interval(0, 10); + println!("b = {b}"); + println!("flat opt b = {}", b.flatten_optional()); + assert_eq!(b.flatten_optional(), DataType::unit() & DataType::float() & DataType::integer_interval(0, 10)); + } } From cd6ec967652dd99e20b644a880f2463fd06e9e55 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 17:55:05 +0200 Subject: [PATCH 22/36] impl --- src/expr/implementation.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 95198306..fbc24ba6 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -1,6 +1,6 @@ use super::{aggregate::Aggregate, function::Function}; use crate::data_type::{ - function::{self, Extended}, + function::{self, Optional, Extended}, DataType, }; use paste::paste; @@ -13,9 +13,9 @@ macro_rules! function_implementations { // A (thread local) global map thread_local! { static FUNCTION_IMPLEMENTATIONS: FunctionImplementations = FunctionImplementations { - $([< $unary:snake >]: Rc::new(Extended::new(function::[< $unary:snake >](), DataType::Any)),)* - $([< $binary:snake >]: Rc::new(Extended::new(function::[< $binary:snake >](), DataType::Any & DataType::Any)),)* - $([< $ternary:snake >]: Rc::new(Extended::new(function::[< $ternary:snake >](), DataType::Any & DataType::Any & DataType::Any)),)* + $([< $unary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $unary:snake >]()), DataType::Any)),)* + $([< $binary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $binary:snake >]()), DataType::Any & DataType::Any)),)* + $([< $ternary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $ternary:snake >]()), DataType::Any & DataType::Any & DataType::Any)),)* }; } From 35ad8cede34bc02dc9105dc44f05859930a72528 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 19:36:43 +0200 Subject: [PATCH 23/36] L2 clipping working --- src/data_type/function.rs | 37 ++++++- src/expr/implementation.rs | 8 +- src/relation/transforms.rs | 205 +++++++------------------------------ 3 files changed, 76 insertions(+), 174 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 5f3cc1e9..f3eaf508 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -640,7 +640,11 @@ impl fmt::Display for Optional { impl Function for Optional { fn domain(&self) -> DataType { - DataType::optional(self.0.domain()).flatten_optional() + DataType::Any + } + + fn co_domain(&self) -> DataType { + DataType::optional(self.0.co_domain()).flatten_optional() } fn super_image(&self, set: &DataType) -> Result { @@ -648,7 +652,7 @@ impl Function for Optional { match set { DataType::Optional(optional_set) => self.0.super_image(optional_set.data_type()).map(|dt| DataType::optional(dt)), set => self.0.super_image(&set), - } + }.or_else(|err| Ok(self.co_domain())) } fn value(&self, arg: &Value) -> Result { @@ -658,7 +662,7 @@ impl Function for Optional { None => Ok(Value::none()), }, arg => self.0.value(arg), - } + }.or_else(|err| Ok(Value::none())) } } @@ -2181,6 +2185,7 @@ mod tests { optional_greatest.super_image(&DataType::optional((DataType::float_interval(0.,1.) & DataType::float_interval(-5.,2.)))).unwrap(), optional_greatest.super_image(&(DataType::optional(DataType::float_interval(0.,1.)) & DataType::float_interval(-5.,2.))).unwrap(), ); + println!("super_image(text) = {}", optional_greatest.super_image(&DataType::text()).unwrap()); } #[test] @@ -2195,6 +2200,32 @@ mod tests { ); } + #[test] + fn test_optional_aggregate_sum() { + println!("Test sum aggregate"); + // Test an aggregate function + let sum = sum(); + println!("sum = {}", sum); + let list = DataType::list(DataType::float_interval(-1., 2.), 2, 20); + println!("sum({}) = {}", list, sum.super_image(&list).unwrap()); + assert_eq!( + sum.super_image(&list).unwrap(), + DataType::float_interval(-20., 40.) + ); + let opt_sum = Optional::new(sum); + println!("opt_sum = {}", opt_sum); + let list = DataType::optional(DataType::list(DataType::float_interval(-1., 2.), 2, 20)); + println!("\n{} is_subset_of {} = {}", list, opt_sum.domain(), list.is_subset_of(&opt_sum.domain())); + println!("\nopt_sum({}) = {}", list, opt_sum.super_image(&list).unwrap()); + let list = DataType::list(DataType::optional(DataType::float_interval(-1., 2.)), 2, 20); + println!("\n{} is_subset_of {} = {}", list, opt_sum.domain(), list.is_subset_of(&opt_sum.domain())); + println!("\nopt_sum({}) = {}", list, opt_sum.super_image(&list).unwrap()); + let list = DataType::list(DataType::float_interval(-1., 2.), 2, 20); + println!("\n{} is_subset_of {} = {}", list, opt_sum.domain(), list.is_subset_of(&opt_sum.domain())); + println!("\nopt_sum({}) = {}", list, opt_sum.super_image(&list).unwrap()); + + } + #[test] fn test_extended_binary() { println!("Test extended"); diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index fbc24ba6..bbd3e65c 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -13,9 +13,9 @@ macro_rules! function_implementations { // A (thread local) global map thread_local! { static FUNCTION_IMPLEMENTATIONS: FunctionImplementations = FunctionImplementations { - $([< $unary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $unary:snake >]()), DataType::Any)),)* - $([< $binary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $binary:snake >]()), DataType::Any & DataType::Any)),)* - $([< $ternary:snake >]: Rc::new(Extended::new(Optional::new(function::[< $ternary:snake >]()), DataType::Any & DataType::Any & DataType::Any)),)* + $([< $unary:snake >]: Rc::new(Optional::new(function::[< $unary:snake >]())),)* + $([< $binary:snake >]: Rc::new(Optional::new(function::[< $binary:snake >]())),)* + $([< $ternary:snake >]: Rc::new(Optional::new(function::[< $ternary:snake >]())),)* }; } @@ -94,7 +94,7 @@ macro_rules! aggregate_implementations { // A (thread local) global map thread_local! { static AGGREGATE_IMPLEMENTATIONS: AggregateImplementations = AggregateImplementations { - $([< $implementation:snake >]: Rc::new(Extended::new(function::[< $implementation:snake >](), DataType::Any)),)* + $([< $implementation:snake >]: Rc::new(Optional::new(function::[< $implementation:snake >]())),)* }; } diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 4e861221..baa1638d 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -653,6 +653,7 @@ impl Relation { ) -> Self {// TODO fix this // Join the two relations on the entity column let join: Relation = Relation::join() + .inner() .on(Expr::eq( Expr::qcol(self.name(), entities), Expr::qcol(scale_factors.name(), entities), @@ -661,7 +662,6 @@ impl Relation { .right_names(scale_factors.fields().into_iter().map(|field| format!("_SCALE_FACTOR_{}", field.name())).collect()) .left(self) .right(scale_factors) - .inner() .build(); // Multiply the values by the factors join.map_fields(|field_name, expr| { @@ -687,26 +687,22 @@ impl Relation { let norms = self .clone() .l2_norms(entities.clone(), groups.clone(), values.clone()); - // TODO REMOVE DEBUG - norms.display_dot().unwrap(); // Put the `clipping_values`in the right shape let clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); // Compute the scaling factors let scaling_factors = norms.map_fields(|field_name, expr| { if values.contains(&field_name) { - Expr::multiply( + Expr::divide( Expr::val(1), Expr::greatest( Expr::val(1), - Expr::multiply(expr.clone(), Expr::val(clipping_values[&field_name])), + Expr::divide(expr.clone(), Expr::val(clipping_values[&field_name])), ), ) } else { - Expr::val(1) + expr } }); - // TODO REMOVE DEBUG - scaling_factors.display_dot().unwrap(); let clipped_relation = self.scale( entities, values.clone(), @@ -1343,166 +1339,39 @@ mod tests { .unwrap() .as_ref() .clone(); - // Compute l1 norm - relation = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", 20.)]); + // Compute l2 norm + let clipped_relation = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", 20.)]); + clipped_relation.display_dot().unwrap(); // Print query - let query = &ast::Query::from(&relation); + let query = &ast::Query::from(&clipped_relation).to_string(); println!("After: {}", query); - relation.display_dot().unwrap(); - // let expected_query = "SELECT id, SQRT(SUM(age*age)) FROM (SELECT id, city, SUM(age) AS age FROM user_table GROUP BY id, city) AS sums GROUP BY id"; - // assert_eq!( - // database.query(&query.to_string()).unwrap(), - // database.query(expected_query).unwrap() - // ); - } - - #[test] - fn test_clipped_sum_for_table() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - - let table = relations - .get(&["item_table".into()]) - .unwrap() - .as_ref() - .clone(); - let clipped_relation = table.clone().l2_clipped_sums( - "order_id", - vec!["item"], - vec!["price"], - vec![("price", 45.)], - ); - clipped_relation.display_dot().unwrap(); - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - let valid_query = r#" - WITH norms AS ( - SELECT order_id, SQRT(SUM(sum_by_group)) AS norm FROM ( - SELECT order_id, item, POWER(SUM(price), 2) AS sum_by_group FROM item_table GROUP BY order_id, item - ) AS subquery GROUP BY order_id - ), weights AS (SELECT order_id, CASE WHEN 45 / norm < 1 THEN 45 / norm ELSE 1 END AS weight FROM norms) - SELECT item, SUM(price*weight) FROM item_table LEFT JOIN weights USING (order_id) GROUP BY item; - "#; - let my_res = database.query(query).unwrap(); - let true_res = database.query(valid_query).unwrap(); - assert_eq!(refacto_results(my_res, 2), refacto_results(true_res, 2)); - } - - #[test] - fn test_clipped_sum_with_empty_base() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - - let table = relations - .get(&["item_table".into()]) - .unwrap() - .as_ref() - .clone(); - let clipped_relation = - table - .clone() - .l2_clipped_sums("order_id", vec![], vec!["price"], vec![("price", 45.)]); - clipped_relation.display_dot().unwrap(); - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - println!("Query: {}", query); - let valid_query = r#" - WITH norms AS ( - SELECT order_id, ABS(SUM(price)) AS norm FROM item_table GROUP BY order_id - ), weights AS ( - SELECT order_id, CASE WHEN 45 / norm < 1 THEN 45 / norm ELSE 1 END AS weight FROM norms - ) - SELECT SUM(price*weight) FROM item_table LEFT JOIN weights USING (order_id); - "#; - let my_res = refacto_results(database.query(query).unwrap(), 1); - let true_res = refacto_results(database.query(valid_query).unwrap(), 1); - assert_eq!(my_res, true_res); - } - - #[test] - fn test_clipped_sum_for_map() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - - let relation = Relation::try_from( - parse("SELECT price * 25 AS std_price, * FROM item_table") - .unwrap() - .with(&relations), - ) - .unwrap(); - relation.display_dot().unwrap(); - - // L2 Norm - let clipped_relation = relation.clone().l2_clipped_sums( - "order_id", - vec!["item"], - vec!["price", "std_price"], - vec![("std_price", 45.), ("price", 50.)], - ); - clipped_relation.display_dot().unwrap(); - - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - let valid_query = r#" - WITH my_table AS ( - SELECT price * 25 AS std_price, * FROM item_table - ), norms AS ( - SELECT order_id, SQRT(SUM(sum_by_group)) AS norm1, SQRT(SUM(sum_by_group2)) AS norm2 FROM ( - SELECT order_id, item, POWER(SUM(price), 2) AS sum_by_group, POWER(SUM(std_price), 2) AS sum_by_group2 FROM my_table GROUP BY order_id, item - ) AS subquery GROUP BY order_id - ), weights AS (SELECT order_id, CASE WHEN 50 / norm1 < 1 THEN 50 / norm1 ELSE 1 END AS weight1, CASE WHEN 45 / norm2 < 1 THEN 45 / norm2 ELSE 1 END AS weight2 FROM norms) - SELECT item, SUM(price*weight1), SUM(std_price*weight2) FROM my_table LEFT JOIN weights USING (order_id) GROUP BY item; - "#; - - let my_res = refacto_results(database.query(query).unwrap(), 3); - let true_res = refacto_results(database.query(valid_query).unwrap(), 3); - assert_eq!(my_res, true_res); - } - - #[test] - fn test_clipped_sum_for_join() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - - let left: Relation = relations - .get(&["item_table".into()]) - .unwrap() - .as_ref() - .clone(); - let right: Relation = relations - .get(&["order_table".into()]) - .unwrap() - .as_ref() - .clone(); - let relation: Relation = Relation::join() - .left(left) - .right(right) - .on(Expr::eq( - Expr::qcol("items", "order_id"), - Expr::qcol("orders", "id"), - )) - .build(); - relation.display_dot().unwrap(); - let schema = relation.schema().clone(); - let item = schema.field_from_index(1).unwrap().name(); - let price = schema.field_from_index(2).unwrap().name(); - let user_id = schema.field_from_index(4).unwrap().name(); - let date = schema.field_from_index(6).unwrap().name(); - - let clipped_relation = - relation.l2_clipped_sums(user_id, vec![item, date], vec![price], vec![(price, 50.)]); - clipped_relation.display_dot().unwrap(); - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - let valid_query = r#" - WITH join_table AS ( - SELECT * FROM item_table JOIN order_table ON item_table.order_id = order_table.id - ), norms AS ( - SELECT user_id, SQRT(SUM(sum_1)) AS norm FROM (SELECT user_id, item, date, POWER(SUM(price), 2) AS sum_1 FROM join_table GROUP BY user_id, item, date) As subq GROUP BY user_id - ), weights AS ( - SELECT user_id, CASE WHEN 50 / norm < 1 THEN 50 / norm ELSE 1 END AS weight FROM norms - ) SELECT item, date, SUM(price*weight) FROM join_table LEFT JOIN weights USING (user_id) GROUP BY item, date; - "#; - - let my_res = refacto_results(database.query(query).unwrap(), 3); - let true_res = refacto_results(database.query(valid_query).unwrap(), 3); - assert_eq!(my_res, true_res); + for row in database.query(query).unwrap() { + println!("{row}"); + } + // 100 + let norm = 100.; + let clipped_relation_100 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + for row in database.query(&ast::Query::from(&clipped_relation_100).to_string()).unwrap() { + println!("{row}"); + } + // 1000 + let norm = 1000.; + let clipped_relation_1000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + for row in database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap() { + println!("{row}"); + } + assert!(database.query(&ast::Query::from(&clipped_relation_100).to_string()).unwrap()!=database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap()); + // 10000 + let norm = 10000.; + let clipped_relation_10000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + for row in database.query(&ast::Query::from(&clipped_relation_10000).to_string()).unwrap() { + println!("{row}"); + } + assert!(database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap()==database.query(&ast::Query::from(&clipped_relation_10000).to_string()).unwrap()); + for row in database.query("SELECT city, sum(age) FROM user_table GROUP BY city").unwrap() { + println!("{row}"); + } + assert!(database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap()==database.query("SELECT city, sum(age) FROM user_table GROUP BY city").unwrap()); } #[test] @@ -1801,7 +1670,9 @@ mod tests { ); } - fn test_possion_sampling() { + #[ignore] + #[test] + fn test_poisson_sampling() { let mut database = postgresql::test_database(); let relations = database.relations(); From 661dc8d8e431bf6506cf31e2ddd24fac18a2e1b5 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Fri, 8 Sep 2023 23:35:37 +0200 Subject: [PATCH 24/36] ok --- src/relation/transforms.rs | 65 -------------------------------------- 1 file changed, 65 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index baa1638d..2d109a5a 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -939,42 +939,6 @@ impl Relation { .right_names(right_names) .build()) } - - /// Returns the left join between `self` and `right` where - /// the output names of the fields are conserved. - /// This fails if one column name is contained in both relations - pub fn left_join(self, right: Self, on: Vec<(&str, &str)>) -> Result { - if on.is_empty() { - return Err(Error::InvalidArguments( - "Vector `on` cannot be empty.".into(), - )); - } - let left_names: Vec = self.schema().iter().map(|f| f.name().to_string()).collect(); - let right_names: Vec = right - .schema() - .iter() - .map(|f| f.name().to_string()) - .collect(); - let on: Vec = on - .into_iter() - .map(|(l, r)| Expr::eq(Expr::qcol(self.name(), l), Expr::qcol(right.name(), r))) - .collect(); - if left_names.iter().any(|item| right_names.contains(item)) { - return Err( - Error::InvalidArguments( - "Cannot use `left_join` method for joining two relations containing fields with the same names.".to_string() - ) - ); - } - Ok(Relation::join() - .left(self.clone()) - .right(right.clone()) - .left_outer() - .on_iter(on) - .left_names(left_names) - .right_names(right_names) - .build()) - } } impl With<(&str, Expr)> for Relation { @@ -2064,35 +2028,6 @@ mod tests { rel.display_dot(); } - #[test] - fn test_left_join() { - let table1: Relation = Relation::table() - .name("table") - .schema( - Schema::builder() - .with(("a", DataType::integer_range(1..=10))) - .with(("b", DataType::integer_values([1, 2, 5, 6, 7, 8]))) - .build(), - ) - .build(); - - let table2: Relation = Relation::table() - .name("table") - .schema( - Schema::builder() - .with(("c", DataType::integer_range(5..=20))) - .with(("d", DataType::integer_range(1..=100))) - .build(), - ) - .build(); - - let joined_rel = table1 - .clone() - .left_join(table2.clone(), vec![("a", "c")]) - .unwrap(); - _ = joined_rel.display_dot(); - } - #[test] fn test_cross_join() { let table_1: Relation = Relation::table() From 037c2a3fde4c36ceb92c69bb3f8aac18cc431894 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Sat, 9 Sep 2023 11:03:52 +0200 Subject: [PATCH 25/36] In progress --- src/expr/mod.rs | 6 +++ src/relation/mod.rs | 102 ++++++++++++++++++++++++------------- src/relation/transforms.rs | 15 ++++++ 3 files changed, 88 insertions(+), 35 deletions(-) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 2751dbad..5390c3fb 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -490,10 +490,16 @@ impl Aggregate { } } + /// Get aggregate pub fn aggregate(&self) -> aggregate::Aggregate { self.aggregate } + /// Get argument + pub fn argument(&self) -> &Expr { + self.argument.as_ref() + } + pub fn argument_name(&self) -> Result<&String> { match self.argument.as_ref() { Expr::Column(col) => Ok(col.last().unwrap()), diff --git a/src/relation/mod.rs b/src/relation/mod.rs index a6f72b85..6e8ecf93 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -152,13 +152,13 @@ pub trait Variant: #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Table { /// The name of the table - pub name: String, + pub(self) name: String, /// The path to the actual table - pub path: Identifier, + pub(self) path: Identifier, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the table - pub size: Integer, + pub(self) size: Integer, } impl Table { @@ -244,21 +244,21 @@ impl OrderBy { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Map { /// The name of the output - pub name: String, + pub(self) name: String, /// The list of expressions (SELECT items) - pub projection: Vec, + pub(self) projection: Vec, /// The predicate expression, which must have Boolean type (WHERE clause). It is applied on the input columns. - pub filter: Option, + pub(self) filter: Option, /// The sort expressions (SORT) - pub order_by: Vec, + pub(self) order_by: Vec, /// The limit (LIMIT value) - pub limit: Option, + pub(self) limit: Option, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the Map - pub size: Integer, + pub(self) size: Integer, /// The incoming logical plan - pub input: Rc, + pub(self) input: Rc, } impl Map { @@ -331,9 +331,25 @@ impl Map { ) } + /// Return a new builder pub fn builder() -> MapBuilder { MapBuilder::new() } + + /// Get projections + pub fn projection(&self) -> &[Expr] { + &self.projection + } + + /// Get names and expressions + pub fn field_exprs(&self) -> Vec<(&Field, &Expr)> { + self.schema.iter().zip(self.projection.iter()).collect() + } + + /// Get names and expressions + pub fn named_exprs(&self) -> Vec<(&str, &Expr)> { + self.schema.iter().map(|f| f.name()).zip(self.projection.iter()).collect() + } } impl fmt::Display for Map { @@ -409,17 +425,17 @@ impl Variant for Map { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Reduce { /// The name of the output - pub name: String, + pub(self) name: String, /// Aggregate expressions - pub aggregate: Vec, + pub(self) aggregate: Vec, /// Grouping expressions - pub group_by: Vec, + pub(self) group_by: Vec, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the Reduce - pub size: Integer, + pub(self) size: Integer, /// The incoming relation - pub input: Rc, + pub(self) input: Rc, } impl Reduce { @@ -477,9 +493,25 @@ impl Reduce { ) } + /// Return a new builder pub fn builder() -> ReduceBuilder { ReduceBuilder::new() } + + /// Get aggregates + pub fn aggregate(&self) -> &[Expr] { + &self.aggregate + } + + /// Get names and expressions + pub fn field_exprs(&self) -> Vec<(&Field, &Expr)> { + self.schema.iter().zip(self.aggregate.iter()).collect() + } + + /// Get names and expressions + pub fn named_exprs(&self) -> Vec<(&str, &Expr)> { + self.schema.iter().map(|f| f.name()).zip(self.aggregate.iter()).collect() + } } impl fmt::Display for Reduce { @@ -613,17 +645,17 @@ impl JoinConstraint { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Join { /// The name of the output - pub name: String, + pub(self) name: String, /// Join constraint - pub operator: JoinOperator, + pub(self) operator: JoinOperator, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the Join - pub size: Integer, + pub(self) size: Integer, /// Left input - pub left: Rc, + pub(self) left: Rc, /// Right input - pub right: Rc, + pub(self) right: Rc, } impl Join { @@ -855,19 +887,19 @@ impl fmt::Display for SetQuantifier { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Set { /// The name of the output - pub name: String, + pub(self) name: String, /// Set operator - pub operator: SetOperator, + pub(self) operator: SetOperator, /// Set quantifier - pub quantifier: SetQuantifier, + pub(self) quantifier: SetQuantifier, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the Set - pub size: Integer, + pub(self) size: Integer, /// Left input - pub left: Rc, + pub(self) left: Rc, /// Right input - pub right: Rc, + pub(self) right: Rc, } impl Set { @@ -996,13 +1028,13 @@ impl Variant for Set { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Values { /// The name of the output - pub name: String, + pub(self) name: String, /// The values - pub values: Vec, + pub(self) values: Vec, /// The schema description of the output - pub schema: Schema, + pub(self) schema: Schema, /// The size of the Set - pub size: Integer, + pub(self) size: Integer, } impl Values { diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 2d109a5a..32907980 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -185,6 +185,21 @@ impl Reduce { self } + pub fn l2_clipped_all_sums(&self, entities: &str) -> Relation { + // TODO + let mut entities: Option<&str> = None; + let mut groups: Vec<&str> = vec![]; + let mut values: Vec<&str> = vec![]; + let mut clipping_values: Vec<(&str, f64)> = vec![]; + for name_expr in self.named_exprs() { + match name_expr { + (name, Expr::Aggregate(agg)) => values.push(agg.argument_name().unwrap()), + _ => (), + } + } + todo!() + } + pub fn clip_aggregates( self, vectors: &str, From f96d80d9dde1c9bce4cc3a241ae5e8339f6e82d3 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Sat, 9 Sep 2023 11:08:15 +0200 Subject: [PATCH 26/36] refactor in progress --- src/relation/mod.rs | 70 ++++++++++++++++++++++----------------------- src/sql/mod.rs | 16 +++++------ 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 6e8ecf93..6c84a282 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -152,13 +152,13 @@ pub trait Variant: #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Table { /// The name of the table - pub(self) name: String, + name: String, /// The path to the actual table - pub(self) path: Identifier, + path: Identifier, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the table - pub(self) size: Integer, + size: Integer, } impl Table { @@ -244,21 +244,21 @@ impl OrderBy { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Map { /// The name of the output - pub(self) name: String, + name: String, /// The list of expressions (SELECT items) - pub(self) projection: Vec, + projection: Vec, /// The predicate expression, which must have Boolean type (WHERE clause). It is applied on the input columns. - pub(self) filter: Option, + filter: Option, /// The sort expressions (SORT) - pub(self) order_by: Vec, + order_by: Vec, /// The limit (LIMIT value) - pub(self) limit: Option, + limit: Option, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the Map - pub(self) size: Integer, + size: Integer, /// The incoming logical plan - pub(self) input: Rc, + input: Rc, } impl Map { @@ -425,17 +425,17 @@ impl Variant for Map { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Reduce { /// The name of the output - pub(self) name: String, + name: String, /// Aggregate expressions - pub(self) aggregate: Vec, + aggregate: Vec, /// Grouping expressions - pub(self) group_by: Vec, + group_by: Vec, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the Reduce - pub(self) size: Integer, + size: Integer, /// The incoming relation - pub(self) input: Rc, + input: Rc, } impl Reduce { @@ -645,17 +645,17 @@ impl JoinConstraint { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Join { /// The name of the output - pub(self) name: String, + name: String, /// Join constraint - pub(self) operator: JoinOperator, + operator: JoinOperator, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the Join - pub(self) size: Integer, + size: Integer, /// Left input - pub(self) left: Rc, + left: Rc, /// Right input - pub(self) right: Rc, + right: Rc, } impl Join { @@ -887,19 +887,19 @@ impl fmt::Display for SetQuantifier { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Set { /// The name of the output - pub(self) name: String, + name: String, /// Set operator - pub(self) operator: SetOperator, + operator: SetOperator, /// Set quantifier - pub(self) quantifier: SetQuantifier, + quantifier: SetQuantifier, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the Set - pub(self) size: Integer, + size: Integer, /// Left input - pub(self) left: Rc, + left: Rc, /// Right input - pub(self) right: Rc, + right: Rc, } impl Set { @@ -1028,13 +1028,13 @@ impl Variant for Set { #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Values { /// The name of the output - pub(self) name: String, + name: String, /// The values - pub(self) values: Vec, + values: Vec, /// The schema description of the output - pub(self) schema: Schema, + schema: Schema, /// The size of the Set - pub(self) size: Integer, + size: Integer, } impl Values { diff --git a/src/sql/mod.rs b/src/sql/mod.rs index c2d1d2ca..c5869080 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -10,20 +10,20 @@ pub mod relation; pub mod visitor; pub mod writer; -use crate::ast as sql; +use crate::{ast, relation::Variant as _}; // I would put here the abstact AST Visitor. // Then in expr.rs module we write an implementation of the abstract visitor for Qrlew expr pub trait Visitor<'a, T> { - fn identifier(&self, identifier: &'a sql::Ident) -> T; - fn compound_identifier(&self, qident: &'a Vec) -> T; - fn unary_op(&self, op: &'a sql::UnaryOperator, expr: &'a Box) -> T; + fn identifier(&self, identifier: &'a ast::Ident) -> T; + fn compound_identifier(&self, qident: &'a Vec) -> T; + fn unary_op(&self, op: &'a ast::UnaryOperator, expr: &'a Box) -> T; fn binary_op( &self, - left: &'a Box, - op: &'a sql::BinaryOperator, - right: &'a Box, + left: &'a Box, + op: &'a ast::BinaryOperator, + right: &'a Box, ) -> T; } @@ -122,7 +122,7 @@ mod tests { let database = postgresql::test_database(); println!("database {} = {}", database.name(), database.relations()); for tab in database.tables() { - println!("schema {} = {}", tab, tab.schema); + println!("schema {} = {}", tab, tab.schema()); } for query in [ "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2", From a04a0c3b7f2a8c412e8a14daf3637e205ad09a61 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Sat, 9 Sep 2023 16:07:59 +0200 Subject: [PATCH 27/36] changed --- src/differential_privacy/mod.rs | 4 +- .../protect_grouping_keys.rs | 6 +- src/protected/mod.rs | 29 +++---- src/relation/mod.rs | 81 ++++++++++++++----- src/sampling_adjustment/mod.rs | 30 +++---- tests/integration.rs | 5 +- 6 files changed, 95 insertions(+), 60 deletions(-) diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 335ae63b..7a439d78 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -71,14 +71,14 @@ impl Reduce { .schema() .clone() .iter() - .zip(self.aggregate.clone().into_iter()) + .zip(self.aggregate().clone().into_iter()) .fold((vec![], vec![]), |(c, s), (f, x)| { if let (name, Expr::Aggregate(agg)) = (f.name(), x) { match agg.aggregate() { aggregate::Aggregate::Sum => { let mut c = c; let cvalue = self - .input + .input() .schema() .field(agg.argument_name().unwrap()) .unwrap() diff --git a/src/differential_privacy/protect_grouping_keys.rs b/src/differential_privacy/protect_grouping_keys.rs index 8c83e625..a9f91f9c 100644 --- a/src/differential_privacy/protect_grouping_keys.rs +++ b/src/differential_privacy/protect_grouping_keys.rs @@ -49,7 +49,7 @@ pub const PE_DISTINCT_COUNT: &str = "_PROTECTED_DISTINCT_COUNT_"; impl Reduce { pub fn grouping_columns(&self) -> Result> { - self.group_by + self.group_by() .iter() .cloned() .map(|x| { @@ -73,7 +73,7 @@ impl Reduce { fn join_with_grouping_values(self, grouping_values: Relation) -> Result { let on: Vec = self - .group_by + .group_by() .clone() .into_iter() .map(|c| { @@ -109,7 +109,7 @@ impl Reduce { delta: f64, sensitivity: f64, ) -> Result { - if self.group_by.is_empty() { + if self.group_by().is_empty() { // TODO: vec![PE_ID] ? return Ok(Relation::from(self)); } diff --git a/src/protected/mod.rs b/src/protected/mod.rs index c5af52eb..ca0a7bd4 100644 --- a/src/protected/mod.rs +++ b/src/protected/mod.rs @@ -160,36 +160,35 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis // Preserve names let names: Vec = join.schema().iter().map(|f| f.name().to_string()).collect(); let mut left_names = vec![format!("_LEFT{PE_ID}"), format!("_LEFT{PE_WEIGHT}")]; - left_names.extend(names.iter().take(join.left.schema().len()).cloned()); + left_names.extend(names.iter().take(join.left().schema().len()).cloned()); let mut right_names = vec![format!("_RIGHT{PE_ID}"), format!("_RIGHT{PE_WEIGHT}")]; - right_names.extend(names.iter().skip(join.left.schema().len()).cloned()); + right_names.extend(names.iter().skip(join.left().schema().len()).cloned()); // Create the protected join match self.strategy { Strategy::Soft => Err(Error::not_protected_entity_preserving(join)), Strategy::Hard => { - let Join { name, operator, .. } = join; let left = left?; let right = right?; // Compute the mapping between current and new columns //TODO clean this code a bit let columns: Hierarchy = join - .left + .left() .schema() .iter() .zip(left.schema().iter().skip(PROTECTION_COLUMNS)) .map(|(o, n)| { ( - vec![join.left.name().to_string(), o.name().to_string()], + vec![join.left().name().to_string(), o.name().to_string()], Identifier::from(vec![left_name.clone(), n.name().to_string()]), ) }) .chain( - join.right + join.right() .schema() .iter() .zip(right.schema().iter().skip(PROTECTION_COLUMNS)) .map(|(o, n)| { ( - vec![join.right.name().to_string(), o.name().to_string()], + vec![join.right().name().to_string(), o.name().to_string()], Identifier::from(vec![ right_name.clone(), n.name().to_string(), @@ -202,7 +201,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis let builder = Relation::join() .left_names(left_names) .right_names(right_names) - .operator(operator.rename(&columns)) + .operator(join.operator().rename(&columns)) .and(Expr::eq( Expr::qcol(left_name.as_str(), PE_ID), Expr::qcol(right_name.as_str(), PE_ID), @@ -210,7 +209,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis .left(left) .right(right); let join: Join = builder.build(); - let mut builder = Relation::map().name(name); + let mut builder = Relation::map().name(join.name()); builder = builder.with((PE_ID, Expr::col(format!("_LEFT{PE_ID}")))); builder = builder.with(( PE_WEIGHT, @@ -239,16 +238,10 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis left: Result, right: Result, ) -> Result { - let Set { - name, - operator, - quantifier, - .. - } = set; let builder = Relation::set() - .name(name) - .operator(operator.clone()) - .quantifier(quantifier.clone()) + .name(set.name()) + .operator(set.operator().clone()) + .quantifier(set.quantifier().clone()) .left(left?) .right(right?); Ok(builder.build()) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 6c84a282..97aa80e3 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -330,26 +330,38 @@ impl Map { }, ) } - - /// Return a new builder - pub fn builder() -> MapBuilder { - MapBuilder::new() - } - /// Get projections pub fn projection(&self) -> &[Expr] { &self.projection } - + /// Get filter + pub fn filter(&self) -> &Option { + &self.filter + } + /// Get order_by + pub fn order_by(&self) -> &[OrderBy] { + &self.order_by + } + /// Get limit + pub fn limit(&self) -> &Option { + &self.limit + } + /// Get the input + pub fn input(&self) -> &Relation { + &self.input + } /// Get names and expressions pub fn field_exprs(&self) -> Vec<(&Field, &Expr)> { self.schema.iter().zip(self.projection.iter()).collect() } - /// Get names and expressions pub fn named_exprs(&self) -> Vec<(&str, &Expr)> { self.schema.iter().map(|f| f.name()).zip(self.projection.iter()).collect() } + /// Return a new builder + pub fn builder() -> MapBuilder { + MapBuilder::new() + } } impl fmt::Display for Map { @@ -483,7 +495,6 @@ impl Reduce { .unzip(); (Schema::new(fields), exprs) } - /// Compute the size of the reduce /// The size of the reduce can be the same as its input and will be at least 0 fn size(input: &Relation) -> Integer { @@ -492,26 +503,30 @@ impl Reduce { |&max| Integer::from_interval(0, max), ) } - - /// Return a new builder - pub fn builder() -> ReduceBuilder { - ReduceBuilder::new() - } - /// Get aggregates pub fn aggregate(&self) -> &[Expr] { &self.aggregate } - + /// Get group_by + pub fn group_by(&self) -> &[Expr] { + &self.group_by + } + /// Get the input + pub fn input(&self) -> &Relation { + &self.input + } /// Get names and expressions pub fn field_exprs(&self) -> Vec<(&Field, &Expr)> { self.schema.iter().zip(self.aggregate.iter()).collect() } - /// Get names and expressions pub fn named_exprs(&self) -> Vec<(&str, &Expr)> { self.schema.iter().map(|f| f.name()).zip(self.aggregate.iter()).collect() } + /// Return a new builder + pub fn builder() -> ReduceBuilder { + ReduceBuilder::new() + } } impl fmt::Display for Reduce { @@ -740,11 +755,22 @@ impl Join { .zip(left_identifiers.chain(right_identifiers)) .map(|(f, i)| (f, i)) } - + /// Get the hyerarchy of names pub fn names(&self) -> Hierarchy { Hierarchy::from_iter(self.field_inputs().map(|(n, i)| (i, n))) } - + /// Get join operator + pub fn operator(&self) -> &JoinOperator { + &self.operator + } + /// Get left input + pub fn left(&self) -> &Relation { + &self.left + } + /// Get right input + pub fn right(&self) -> &Relation { + &self.right + } pub fn builder() -> JoinBuilder { JoinBuilder::new() } @@ -974,7 +1000,22 @@ impl Set { SetOperator::Intersect => Integer::from_interval(0, left_size_max.min(right_size_max)), } } - + /// Get set operator + pub fn operator(&self) -> &SetOperator { + &self.operator + } + /// Get set quantifier + pub fn quantifier(&self) -> &SetQuantifier { + &self.quantifier + } + /// Get left input + pub fn left(&self) -> &Relation { + &self.left + } + /// Get right input + pub fn right(&self) -> &Relation { + &self.right + } pub fn builder() -> SetBuilder { SetBuilder::new() } diff --git a/src/sampling_adjustment/mod.rs b/src/sampling_adjustment/mod.rs index 38b8d816..479d7ad9 100644 --- a/src/sampling_adjustment/mod.rs +++ b/src/sampling_adjustment/mod.rs @@ -197,7 +197,7 @@ impl<'a, F: Fn(&Table) -> RelationWithWeight> Visitor<'a, RelationWithWeight> .fields() .iter() .map(|field| field.name()) - .zip((&reduce.aggregate).iter()) + .zip(reduce.aggregate()) .collect(); // Apply corrections to aggregate function @@ -255,14 +255,14 @@ impl<'a, F: Fn(&Table) -> RelationWithWeight> Visitor<'a, RelationWithWeight> join.schema().iter().map(|f| f.name().to_string()).collect(); let mut left_names = vec![format!("_LEFT{ROW_WEIGHT}")]; - left_names.extend(schema_names.iter().take(join.left.schema().len()).cloned()); + left_names.extend(schema_names.iter().take(join.left().schema().len()).cloned()); let mut right_names = vec![format!("_RIGHT{ROW_WEIGHT}")]; - right_names.extend(schema_names.iter().skip(join.left.schema().len()).cloned()); + right_names.extend(schema_names.iter().skip(join.left().schema().len()).cloned()); // map old columns names (from the join) into new column names from the left and right let columns_mapping: Hierarchy = join - .left + .left() .schema() .iter() // skip 1 because the left (coming from the RelationWithWeight) @@ -270,18 +270,18 @@ impl<'a, F: Fn(&Table) -> RelationWithWeight> Visitor<'a, RelationWithWeight> .zip(left.schema().iter().skip(PROPAGATED_COLUMNS)) .map(|(o, n)| { ( - vec![join.left.name().to_string(), o.name().to_string()], + vec![join.left().name().to_string(), o.name().to_string()], Identifier::from(vec![left_new_name.clone(), n.name().to_string()]), ) }) .chain( - join.right + join.right() .schema() .iter() .zip(right.schema().iter().skip(PROPAGATED_COLUMNS)) .map(|(o, n)| { ( - vec![join.right.name().to_string(), o.name().to_string()], + vec![join.right().name().to_string(), o.name().to_string()], Identifier::from(vec![right_new_name.clone(), n.name().to_string()]), ) }), @@ -291,7 +291,7 @@ impl<'a, F: Fn(&Table) -> RelationWithWeight> Visitor<'a, RelationWithWeight> let builder = Relation::join() .left_names(left_names) .right_names(right_names) - .operator(join.operator.clone().rename(&columns_mapping)) + .operator(join.operator().clone().rename(&columns_mapping)) .left(left.clone()) .right(right.clone()); @@ -366,34 +366,34 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Relation> for TableSamplerVisito join.schema().iter().map(|f| f.name().to_string()).collect(); let left_names: Vec = schema_names .iter() - .take(join.left.schema().len()) + .take(join.left().schema().len()) .cloned() .collect(); let right_names: Vec = schema_names .iter() - .skip(join.left.schema().len()) + .skip(join.left().schema().len()) .cloned() .collect(); let columns_mapping: Hierarchy = join - .left + .left() .schema() .iter() .zip(left.schema().iter()) .map(|(o, n)| { ( - vec![join.left.name().to_string(), o.name().to_string()], + vec![join.left().name().to_string(), o.name().to_string()], Identifier::from(vec![left_new_name.clone(), n.name().to_string()]), ) }) .chain( - join.right + join.right() .schema() .iter() .zip(right.schema().iter()) .map(|(o, n)| { ( - vec![join.right.name().to_string(), o.name().to_string()], + vec![join.right().name().to_string(), o.name().to_string()], Identifier::from(vec![right_new_name.clone(), n.name().to_string()]), ) }), @@ -404,7 +404,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Relation> for TableSamplerVisito Relation::join() .left_names(left_names) .right_names(right_names) - .operator(join.operator.clone().rename(&columns_mapping)) + .operator(join.operator().clone().rename(&columns_mapping)) .left(left) .right(right) .build() diff --git a/tests/integration.rs b/tests/integration.rs index 1b4ba7fc..f5fcb3cf 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -10,6 +10,7 @@ use qrlew::{ ast, display::Dot, expr, + relation::Variant as _, io::{postgresql, Database}, protected::PE_ID, sql::parse, @@ -102,7 +103,7 @@ fn test_on_sqlite() { let mut database = sqlite::test_database(); println!("database {} = {}", database.name(), database.relations()); for tab in database.tables() { - println!("schema {} = {}", tab, tab.schema); + println!("schema {} = {}", tab, tab.schema()); } for &query in SQLITE_QUERIES.iter().chain(QUERIES) { assert!(test_rewritten_eq(&mut database, query)); @@ -125,7 +126,7 @@ fn test_on_postgresql() { let mut database = postgresql::test_database(); println!("database {} = {}", database.name(), database.relations()); for tab in database.tables() { - println!("schema {} = {}", tab, tab.schema); + println!("schema {} = {}", tab, tab.schema()); } for &query in POSTGRESQL_QUERIES.iter().chain(QUERIES) { assert!(test_rewritten_eq(&mut database, query)); From 0c25c7c383d729fb6629468a3e03da51b2e76939 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Sat, 9 Sep 2023 16:10:37 +0200 Subject: [PATCH 28/36] Fixed --- src/protected/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/protected/mod.rs b/src/protected/mod.rs index ca0a7bd4..e09d1c0f 100644 --- a/src/protected/mod.rs +++ b/src/protected/mod.rs @@ -167,6 +167,8 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis match self.strategy { Strategy::Soft => Err(Error::not_protected_entity_preserving(join)), Strategy::Hard => { + let name = join.name(); + let operator = join.operator(); let left = left?; let right = right?; // Compute the mapping between current and new columns //TODO clean this code a bit @@ -201,7 +203,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis let builder = Relation::join() .left_names(left_names) .right_names(right_names) - .operator(join.operator().rename(&columns)) + .operator(operator.rename(&columns)) .and(Expr::eq( Expr::qcol(left_name.as_str(), PE_ID), Expr::qcol(right_name.as_str(), PE_ID), @@ -209,7 +211,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis .left(left) .right(right); let join: Join = builder.build(); - let mut builder = Relation::map().name(join.name()); + let mut builder = Relation::map().name(name); builder = builder.with((PE_ID, Expr::col(format!("_LEFT{PE_ID}")))); builder = builder.with(( PE_WEIGHT, From de0f5eea8d5edf7c3cccc88a71cb5463f38e0151 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Sun, 10 Sep 2023 23:09:32 +0200 Subject: [PATCH 29/36] ok --- src/expr/mod.rs | 4 +--- src/relation/mod.rs | 7 +++++++ src/relation/transforms.rs | 14 ++++++++------ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 5390c3fb..25de42db 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -489,17 +489,15 @@ impl Aggregate { argument, } } - /// Get aggregate pub fn aggregate(&self) -> aggregate::Aggregate { self.aggregate } - /// Get argument pub fn argument(&self) -> &Expr { self.argument.as_ref() } - + /// Get the argument name pub fn argument_name(&self) -> Result<&String> { match self.argument.as_ref() { Expr::Column(col) => Ok(col.last().unwrap()), diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 97aa80e3..aaa2c83c 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -527,6 +527,13 @@ impl Reduce { pub fn builder() -> ReduceBuilder { ReduceBuilder::new() } + /// Get group_by_names + pub fn group_by_names(&self) -> Vec<&str> { + self.group_by.iter().filter_map(|e| match e { + Expr::Column(col) => col.last(), + _ => None, + }).map(|s| s.as_str()).collect() + } } impl fmt::Display for Reduce { diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 32907980..dab0e146 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -186,14 +186,16 @@ impl Reduce { } pub fn l2_clipped_all_sums(&self, entities: &str) -> Relation { - // TODO - let mut entities: Option<&str> = None; - let mut groups: Vec<&str> = vec![]; + let groups: Vec<&str> = self.group_by_names(); let mut values: Vec<&str> = vec![]; let mut clipping_values: Vec<(&str, f64)> = vec![]; - for name_expr in self.named_exprs() { - match name_expr { - (name, Expr::Aggregate(agg)) => values.push(agg.argument_name().unwrap()), + for aggregate in self.aggregate() { + match aggregate { + Expr::Aggregate(agg) => { + let value_name = agg.argument_name().unwrap(); + values.push(value_name); + clipping_values.push((value_name, 1.))// TODO Fix this + }, _ => (), } } From 7f83e94b1b9a629182f0165ff53a5a0ac0d27a88 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 09:56:35 +0200 Subject: [PATCH 30/36] l2_clip reduce --- src/data_type/mod.rs | 13 ++++++++++++ src/relation/transforms.rs | 43 +++++++++++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index fe9e281f..51763f1c 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -3041,6 +3041,19 @@ impl DataType { } } +// Return the bounds of a DataType if possible +impl DataType { + pub fn absolute_upper_bound(&self) -> Option { + match self { + DataType::Boolean(b) => Some(if *b.max()? {1.} else {0.}), + DataType::Integer(i) => Some(f64::max(i.min()?.abs() as f64, i.max()?.abs() as f64)), + DataType::Float(f) => Some(f64::max(f.min()?.abs(), f.max()?.abs())), + DataType::Optional(o) => o.data_type().absolute_upper_bound(), + _ => None + } + } +} + // TODO Write tests for all types #[cfg(test)] mod tests { diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index dab0e146..89159be7 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -193,13 +193,15 @@ impl Reduce { match aggregate { Expr::Aggregate(agg) => { let value_name = agg.argument_name().unwrap(); + let value_data_type = self.input().schema()[value_name.as_str()].data_type(); + let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); values.push(value_name); - clipping_values.push((value_name, 1.))// TODO Fix this + clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value }, _ => (), } - } - todo!() + }; + self.input().clone().l2_clipped_sums(entities, groups, values, clipping_values) } pub fn clip_aggregates( @@ -728,6 +730,14 @@ impl Relation { clipped_relation.sums_by_group(groups, values) } + /// Clip sums in the first `Reduce`s found + pub fn l2_clipped_all_sums(&self, vectors: &str) -> Self { + match self { + Relation::Reduce(reduce) => reduce.l2_clipped_all_sums(vectors), + _ => todo!(), + } + } + pub fn clip_aggregates(self, vectors: &str, clipping_values: Vec<(&str, f64)>) -> Result { match self { Relation::Reduce(reduce) => reduce.clip_aggregates(vectors, clipping_values), @@ -1355,6 +1365,33 @@ mod tests { assert!(database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap()==database.query("SELECT city, sum(age) FROM user_table GROUP BY city").unwrap()); } + #[test] + fn test_l2_clipped_all_sums_reduce() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let table = relations + .get(&["item_table".into()]) + .unwrap() + .as_ref() + .clone(); + + // with GROUP BY + let my_relation: Relation = Relation::reduce() + .input(table.clone()) + .with(("sum_price", Expr::sum(Expr::col("price")))) + .with_group_by_column("item") + .with_group_by_column("order_id") + .build(); + + let schema = my_relation.inputs()[0].schema().clone(); + let price = schema.field_from_index(0).unwrap().name(); + let clipped_relation = my_relation + .l2_clipped_all_sums("order_id"); + let name_fields: Vec<&str> = clipped_relation.schema().iter().map(|f| f.name()).collect(); + clipped_relation.display_dot(); + } + #[test] fn test_clip_aggregates_reduce() { let mut database = postgresql::test_database(); From 0685663cfa98a94ec0720e52b26900153cc1294f Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 15:12:33 +0200 Subject: [PATCH 31/36] New methods for Map and Reduce --- src/relation/mod.rs | 44 ++++++++++++++++++++++++++++++++++++-- src/relation/transforms.rs | 34 +++++++++++++++-------------- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index aaa2c83c..08db5ed9 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -22,7 +22,7 @@ use crate::{ self, function::Function, intervals::Bound, DataType, DataTyped, Integer, Struct, Value, Variant as _, }, - expr::{self, Expr, Identifier, Split}, + expr::{self, Expr, Identifier, Split, Aggregate, aggregate, Column}, hierarchy::Hierarchy, namer, visitor::{self, Acceptor, Dependencies, Visited}, @@ -503,14 +503,34 @@ impl Reduce { |&max| Integer::from_interval(0, max), ) } - /// Get aggregates + /// Get aggregate exprs pub fn aggregate(&self) -> &[Expr] { &self.aggregate } + /// Get aggregate aggregates + pub fn aggregate_aggregates(&self) -> Vec<&Aggregate> { + self.aggregate.iter().filter_map(|e| { + if let Expr::Aggregate(aggregate) = e { + Some(aggregate) + } else { + None + } + }).collect() + } /// Get group_by pub fn group_by(&self) -> &[Expr] { &self.group_by } + /// Get group_by columns + pub fn group_by_columns(&self) -> Vec<&Column> { + self.group_by.iter().filter_map(|e| { + if let Expr::Column(column) = e { + Some(column) + } else { + None + } + }).collect() + } /// Get the input pub fn input(&self) -> &Relation { &self.input @@ -523,6 +543,26 @@ impl Reduce { pub fn named_exprs(&self) -> Vec<(&str, &Expr)> { self.schema.iter().map(|f| f.name()).zip(self.aggregate.iter()).collect() } + /// Get names and expressions + pub fn field_aggregates(&self) -> Vec<(&Field, &Aggregate)> { + self.schema.iter().zip(self.aggregate.iter()).filter_map(|(f,e)| { + if let Expr::Aggregate(aggregate) = e { + Some((f, aggregate)) + } else { + None + } + }).collect() + } + /// Get names and expressions + pub fn named_aggregates(&self) -> Vec<(&str, &Aggregate)> { + self.schema.iter().map(|f| f.name()).zip(self.aggregate.iter()).filter_map(|(f,e)| { + if let Expr::Aggregate(aggregate) = e { + Some((f, aggregate)) + } else { + None + } + }).collect() + } /// Return a new builder pub fn builder() -> ReduceBuilder { ReduceBuilder::new() diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 89159be7..2a64b76c 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -2,13 +2,11 @@ //! use super::{Join, Map, Reduce, Relation, Set, Table, Values, Variant as _}; -use crate::display::Dot; use crate::namer; use crate::{ builder::{Ready, With, WithIterator}, data_type::{ self, - intervals::{Bound, Intervals}, DataTyped, }, expr::{self, aggregate, Aggregate, Expr, Value}, @@ -185,23 +183,29 @@ impl Reduce { self } - pub fn l2_clipped_all_sums(&self, entities: &str) -> Relation { - let groups: Vec<&str> = self.group_by_names(); - let mut values: Vec<&str> = vec![]; + pub fn l2_clipped_all_sums(&self, entities: &str) -> Result { + let mut input_entities: Option<&str> = None; + let input_groups: Vec<&str> = self.group_by_names(); + let mut input_values: Vec<&str> = vec![]; let mut clipping_values: Vec<(&str, f64)> = vec![]; for aggregate in self.aggregate() { match aggregate { Expr::Aggregate(agg) => { - let value_name = agg.argument_name().unwrap(); - let value_data_type = self.input().schema()[value_name.as_str()].data_type(); - let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); - values.push(value_name); - clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value + if agg.aggregate() == aggregate::Aggregate::Sum { + if let Expr::Column(col) = agg.argument() { + let value_name = col.last().unwrap().as_str(); + let value_data_type = self.input().schema()[value_name].data_type(); + let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); + input_values.push(value_name); + clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value + } + } }, _ => (), } }; - self.input().clone().l2_clipped_sums(entities, groups, values, clipping_values) + println!("DEBUG entities = {entities}\ngroups = {input_groups:?}\nvalues = {input_values:?}\nclipping_values = {clipping_values:?}\n"); + Ok(self.input().clone().l2_clipped_sums(entities, input_groups, input_values, clipping_values)) } pub fn clip_aggregates( @@ -731,7 +735,7 @@ impl Relation { } /// Clip sums in the first `Reduce`s found - pub fn l2_clipped_all_sums(&self, vectors: &str) -> Self { + pub fn l2_clipped_all_sums(&self, vectors: &str) -> Result { match self { Relation::Reduce(reduce) => reduce.l2_clipped_all_sums(vectors), _ => todo!(), @@ -1384,11 +1388,9 @@ mod tests { .with_group_by_column("order_id") .build(); - let schema = my_relation.inputs()[0].schema().clone(); - let price = schema.field_from_index(0).unwrap().name(); + my_relation.display_dot(); let clipped_relation = my_relation - .l2_clipped_all_sums("order_id"); - let name_fields: Vec<&str> = clipped_relation.schema().iter().map(|f| f.name()).collect(); + .l2_clipped_all_sums("order_id").unwrap(); clipped_relation.display_dot(); } From e2e57ac4971e20c86f19f202a90fcfe37a3373c7 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 17:27:51 +0200 Subject: [PATCH 32/36] Clip all sums --- src/expr/mod.rs | 12 ++++++--- src/relation/transforms.rs | 54 +++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 25de42db..8d326e80 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -497,13 +497,17 @@ impl Aggregate { pub fn argument(&self) -> &Expr { self.argument.as_ref() } - /// Get the argument name - pub fn argument_name(&self) -> Result<&String> { + /// Get argument + pub fn argument_column(&self) -> Result<&Column> { match self.argument.as_ref() { - Expr::Column(col) => Ok(col.last().unwrap()), - _ => Err(Error::other("Cannot return the argument_name")), + Expr::Column(col) => Ok(col), + _ => Err(Error::other("Cannot return the argument column")), } } + /// Get the argument name + pub fn argument_name(&self) -> Result<&String> { + Ok(self.argument_column()?.last().unwrap()) + } } impl fmt::Display for Aggregate { diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 2a64b76c..67d305c6 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -27,7 +27,7 @@ use std::{ pub enum Error { InvalidRelation(String), InvalidArguments(String), - NoPublicValuesError(String), + NoPublicValues(String), Other(String), } @@ -35,6 +35,12 @@ impl Error { pub fn invalid_relation(relation: impl fmt::Display) -> Error { Error::InvalidRelation(format!("{} is invalid", relation)) } + pub fn invalid_arguments(relation: impl fmt::Display) -> Error { + Error::InvalidArguments(format!("{} is invalid", relation)) + } + pub fn no_public_values(relation: impl fmt::Display) -> Error { + Error::NoPublicValues(format!("{} is invalid", relation)) + } } impl fmt::Display for Error { @@ -42,8 +48,8 @@ impl fmt::Display for Error { match self { Error::InvalidRelation(desc) => writeln!(f, "InvalidRelation: {}", desc), Error::InvalidArguments(desc) => writeln!(f, "InvalidArguments: {}", desc), - Error::NoPublicValuesError(desc) => { - writeln!(f, "NoPublicValuesError: {}", desc) + Error::NoPublicValues(desc) => { + writeln!(f, "NoPublicValues: {}", desc) } Error::Other(err) => writeln!(f, "{}", err), } @@ -183,29 +189,29 @@ impl Reduce { self } + /// Clip all sums in the `Reduce` pub fn l2_clipped_all_sums(&self, entities: &str) -> Result { let mut input_entities: Option<&str> = None; let input_groups: Vec<&str> = self.group_by_names(); let mut input_values: Vec<&str> = vec![]; let mut clipping_values: Vec<(&str, f64)> = vec![]; - for aggregate in self.aggregate() { - match aggregate { - Expr::Aggregate(agg) => { - if agg.aggregate() == aggregate::Aggregate::Sum { - if let Expr::Column(col) = agg.argument() { - let value_name = col.last().unwrap().as_str(); - let value_data_type = self.input().schema()[value_name].data_type(); - let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); - input_values.push(value_name); - clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value - } - } - }, - _ => (), + let mut names: HashMap<&str, &str> = HashMap::new(); + for (name, aggregate) in self.named_aggregates() { + if name == entities { + input_entities = Some(aggregate.argument_name()?); + } else if aggregate.aggregate() == aggregate::Aggregate::Sum { + let value_name = aggregate.argument_name()?.as_str(); + let value_data_type = self.input().schema()[value_name].data_type(); + let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); + input_values.push(value_name); + names.insert(value_name, name); + clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value } }; - println!("DEBUG entities = {entities}\ngroups = {input_groups:?}\nvalues = {input_values:?}\nclipping_values = {clipping_values:?}\n"); - Ok(self.input().clone().l2_clipped_sums(entities, input_groups, input_values, clipping_values)) + let input_entities = input_entities.ok_or(Error::invalid_arguments(entities))?; + println!("DEBUG {:#?}", names); + Ok(self.input().clone().l2_clipped_sums(input_entities, input_groups, input_values, clipping_values) + .rename_fields(|s, _| names.get(s).unwrap_or(&s).to_string())) } pub fn clip_aggregates( @@ -735,9 +741,9 @@ impl Relation { } /// Clip sums in the first `Reduce`s found - pub fn l2_clipped_all_sums(&self, vectors: &str) -> Result { + pub fn l2_clipped_all_sums(&self, entities: &str) -> Result { match self { - Relation::Reduce(reduce) => reduce.l2_clipped_all_sums(vectors), + Relation::Reduce(reduce) => reduce.l2_clipped_all_sums(entities), _ => todo!(), } } @@ -1384,14 +1390,14 @@ mod tests { let my_relation: Relation = Relation::reduce() .input(table.clone()) .with(("sum_price", Expr::sum(Expr::col("price")))) - .with_group_by_column("item") + .group_by(Expr::col("item")) .with_group_by_column("order_id") .build(); - my_relation.display_dot(); + my_relation.display_dot().unwrap(); let clipped_relation = my_relation .l2_clipped_all_sums("order_id").unwrap(); - clipped_relation.display_dot(); + clipped_relation.display_dot().unwrap(); } #[test] From dbe568f2bcf4078b1f096b463f6cc6ad463fc33c Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 17:31:01 +0200 Subject: [PATCH 33/36] ok --- src/relation/transforms.rs | 200 ------------------------------------- 1 file changed, 200 deletions(-) diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 67d305c6..38e13814 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -209,79 +209,10 @@ impl Reduce { } }; let input_entities = input_entities.ok_or(Error::invalid_arguments(entities))?; - println!("DEBUG {:#?}", names); Ok(self.input().clone().l2_clipped_sums(input_entities, input_groups, input_values, clipping_values) .rename_fields(|s, _| names.get(s).unwrap_or(&s).to_string())) } - pub fn clip_aggregates( - self, - vectors: &str, - clipping_values: Vec<(&str, f64)>, - ) -> Result { - let (map_names, out_vectors, base, coordinates): ( - Vec<(String, String)>, - Option, - Vec, - Vec, - ) = self - .schema() - .clone() - .iter() - .zip(self.aggregate.into_iter()) - .fold((vec![], None, vec![], vec![]), |(mn, v, b, c), (f, x)| { - if let (name, Expr::Aggregate(agg)) = (f.name(), x) { - let argname = agg.argument_name().unwrap().clone(); - let mut mn = mn; - mn.push((argname.clone(), name.to_string())); - match agg.aggregate() { - aggregate::Aggregate::Sum => { - let mut c = c; - c.push(argname); - (mn, v, b, c) - } - aggregate::Aggregate::First => { - if name == vectors { - let v = Some(argname); - (mn, v, b, c) - } else { - let mut b = b; - b.push(argname); - (mn, v, b, c) - } - } - _ => (mn, v, b, c), - } - } else { - (mn, v, b, c) - } - }); - - let vectors = if let Some(v) = out_vectors { - Ok(v) - } else { - Err(Error::InvalidArguments(format!( - "{vectors} should be in the input `Relation`" - ))) - }; - let len_clipping_values = clipping_values.len(); - let len_coordinates = coordinates.len(); - if len_clipping_values != len_coordinates { - return Err(Error::InvalidArguments(format!( - "You must provide one clipping_value for each output field. \n \ - Got {len_clipping_values} clipping values for {len_coordinates} output fields" - ))); - } - let clipped_relation = self.input.as_ref().clone().l2_clipped_sums( - vectors?.as_str(), - base.iter().map(|s| s.as_str()).collect(), - coordinates.iter().map(|s| s.as_str()).collect(), - clipping_values, - ); - let map_names: HashMap = map_names.into_iter().collect(); - Ok(clipped_relation.rename_fields(|n, _| map_names[n].to_string())) - } - /// Rename fields pub fn rename_fields String>(self, f: F) -> Reduce { Relation::reduce().rename_with(self, f).build() @@ -748,13 +679,6 @@ impl Relation { } } - pub fn clip_aggregates(self, vectors: &str, clipping_values: Vec<(&str, f64)>) -> Result { - match self { - Relation::Reduce(reduce) => reduce.clip_aggregates(vectors, clipping_values), - _ => todo!(), - } - } - /// Add gaussian noise of a given standard deviation to the given columns pub fn add_gaussian_noise(self, name_sigmas: Vec<(&str, f64)>) -> Relation { let name_sigmas: HashMap<&str, f64> = name_sigmas.into_iter().collect(); @@ -1400,130 +1324,6 @@ mod tests { clipped_relation.display_dot().unwrap(); } - #[test] - fn test_clip_aggregates_reduce() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - - let table = relations - .get(&["item_table".into()]) - .unwrap() - .as_ref() - .clone(); - - // with GROUP BY - let my_relation: Relation = Relation::reduce() - .input(table.clone()) - .with(("sum_price", Expr::sum(Expr::col("price")))) - .with_group_by_column("item") - .with_group_by_column("order_id") - .build(); - - let schema = my_relation.inputs()[0].schema().clone(); - let price = schema.field_from_index(0).unwrap().name(); - let clipped_relation = my_relation - .clip_aggregates("order_id", vec![(price, 45.)]) - .unwrap(); - let name_fields: Vec<&str> = clipped_relation.schema().iter().map(|f| f.name()).collect(); - assert_eq!(name_fields, vec!["item", "sum_price"]); - clipped_relation.display_dot(); - - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - println!("Query: {}", query); - let valid_query = r#" - WITH norms AS ( - SELECT order_id, SQRT(SUM(sum_by_group)) AS norm FROM ( - SELECT order_id, item, POWER(SUM(price), 2) AS sum_by_group FROM item_table GROUP BY order_id, item - ) AS subquery GROUP BY order_id - ), weights AS (SELECT order_id, CASE WHEN 45 / norm < 1 THEN 45 / norm ELSE 1 END AS weight FROM norms) - SELECT item, SUM(price*weight) FROM item_table LEFT JOIN weights USING (order_id) GROUP BY item; - "#; - let my_res = refacto_results(database.query(query).unwrap(), 2); - let true_res = refacto_results(database.query(valid_query).unwrap(), 2); - assert_eq!(my_res, true_res); - - // without GROUP BY - let my_relation: Relation = Relation::reduce() - .input(table) - .with(("sum_price", Expr::sum(Expr::col("price")))) - .with_group_by_column("order_id") - .build(); - - let schema = my_relation.inputs()[0].schema().clone(); - let price = schema.field_from_index(0).unwrap().name(); - let clipped_relation = my_relation - .clip_aggregates("order_id", vec![(price, 45.)]) - .unwrap(); - let name_fields: Vec<&str> = clipped_relation.schema().iter().map(|f| f.name()).collect(); - assert_eq!(name_fields, vec!["sum_price"]); - clipped_relation.display_dot(); - - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - println!("Query: {}", query); - let valid_query = r#" - WITH norms AS ( - SELECT order_id, ABS(SUM(price)) AS norm FROM item_table GROUP BY order_id - ), weights AS ( - SELECT order_id, CASE WHEN 45 / norm < 1 THEN 45 / norm ELSE 1 END AS weight FROM norms - ) - SELECT SUM(price*weight) FROM item_table LEFT JOIN weights USING (order_id); - "#; - let my_res = refacto_results(database.query(query).unwrap(), 1); - let true_res = refacto_results(database.query(valid_query).unwrap(), 1); - assert_eq!(my_res, true_res); - } - - #[test] - fn test_clip_aggregates_complex_reduce() { - let mut database = postgresql::test_database(); - let relations = database.relations(); - let initial_query = r#" - SELECT user_id AS user_id, item AS item, 5 * price AS std_price, price AS price, date AS date - FROM item_table LEFT JOIN order_table ON item_table.order_id = order_table.id - "#; - let relation = Relation::try_from(parse(initial_query).unwrap().with(&relations)).unwrap(); - let relation: Relation = Relation::reduce() - .input(relation) - .with_group_by_column("user_id") - .with_group_by_column("item") - .with(("sum1", Expr::sum(Expr::col("price")))) - .with(("sum2", Expr::sum(Expr::col("std_price")))) - .build(); - relation.display_dot(); - - let schema = relation.inputs()[0].schema().clone(); - let price = schema.field_from_index(2).unwrap().name(); - let std_price = schema.field_from_index(3).unwrap().name(); - let clipped_relation = relation - .clip_aggregates("user_id", vec![(price, 45.), (std_price, 50.)]) - .unwrap(); - clipped_relation.display_dot(); - let name_fields: Vec<&str> = clipped_relation.schema().iter().map(|f| f.name()).collect(); - assert_eq!(name_fields, vec!["item", "sum1", "sum2"]); - - let query: &str = &ast::Query::from(&clipped_relation).to_string(); - println!("Query: {}", query); - let valid_query = r#" - WITH my_table AS ( - SELECT user_id AS user_id, item AS item, 5 * price AS std_price, price AS price - FROM item_table LEFT JOIN order_table ON item_table.order_id = order_table.id - ),norms AS ( - SELECT user_id, SQRT(SUM(sum_1)) AS norm, SQRT(SUM(sum_2)) AS norm2 FROM (SELECT user_id, item, POWER(SUM(price), 2) AS sum_1, POWER(SUM(std_price), 2) AS sum_2 FROM my_table GROUP BY user_id, item) As subq GROUP BY user_id - ), weights AS ( - SELECT user_id, CASE WHEN 45 / norm < 1 THEN 45 / norm ELSE 1 END AS weight, CASE WHEN 50 / norm2 < 1 THEN 50 / norm2 ELSE 1 END AS weight2 FROM norms - ) - SELECT my_table.item, SUM(price*weight) AS sum1, SUM(std_price*weight2) As sum2 FROM my_table LEFT JOIN weights USING (user_id) GROUP BY item; - "#; - let my_res: Vec> = refacto_results(database.query(query).unwrap(), 3); - let true_res = refacto_results(database.query(valid_query).unwrap(), 3); - // for (r1, r2) in my_res.iter().zip(true_res.iter()) { - // if r1!=r2 { - // println!("{:?} != {:?}", r1, r2); - // } - // } - // assert_eq!(my_res, true_res); // todo: fix that - } - #[test] fn test_add_noise() { let mut database = postgresql::test_database(); From fa65772f228e2b2c0dda1ef3460ad1d0c68ef0a7 Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 18:23:22 +0200 Subject: [PATCH 34/36] Updated protection --- examples/website.rs | 4 +- src/differential_privacy/mod.rs | 88 ++++++------ .../protect_grouping_keys.rs | 2 +- src/lib.rs | 2 +- src/{protected => protection}/mod.rs | 126 +++++++++++------- tests/integration.rs | 2 +- 6 files changed, 130 insertions(+), 94 deletions(-) rename src/{protected => protection}/mod.rs (80%) diff --git a/examples/website.rs b/examples/website.rs index 5ead3677..6b339f4e 100644 --- a/examples/website.rs +++ b/examples/website.rs @@ -62,7 +62,7 @@ fn protect() { ) .unwrap(); println!("relation = {relation}"); - let relation = relation.force_protect_from_field_paths( + let relation: Relation = relation.force_protect_from_field_paths( &relations, &[ ( @@ -76,7 +76,7 @@ fn protect() { ("order_table", &[("user_id", "user_table", "id")], "name"), ("user_table", &[], "name"), ], - ); + ).into(); println!("relation = {relation}"); relation.display_dot().unwrap(); let query = Query::from(&relation); diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 7a439d78..b6bb0e73 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -10,7 +10,7 @@ use crate::data_type::DataTyped; use crate::{ expr::{aggregate, Expr}, hierarchy::Hierarchy, - protected::PE_ID, + protection::PE_ID, relation::{field::Field, transforms, Reduce, Relation, Variant as _}, DataType, }; @@ -66,47 +66,48 @@ impl Reduce { epsilon: f64, delta: f64, ) -> Result { - let multiplicity = 1; // TODO - let (clipping_values, name_sigmas): (Vec<(String, f64)>, Vec<(String, f64)>) = self - .schema() - .clone() - .iter() - .zip(self.aggregate().clone().into_iter()) - .fold((vec![], vec![]), |(c, s), (f, x)| { - if let (name, Expr::Aggregate(agg)) = (f.name(), x) { - match agg.aggregate() { - aggregate::Aggregate::Sum => { - let mut c = c; - let cvalue = self - .input() - .schema() - .field(agg.argument_name().unwrap()) - .unwrap() - .clone() - .clipping_value(multiplicity); - c.push((agg.argument_name().unwrap().to_string(), cvalue)); - let mut s = s; - s.push(( - name.to_string(), - mechanisms::gaussian_noise(epsilon, delta, cvalue), - )); - (c, s) - } - _ => (c, s), - } - } else { - (c, s) - } - }); - - let clipping_values = clipping_values - .iter() - .map(|(n, v)| (n.as_str(), *v)) - .collect(); - let clipped_relation = self.clip_aggregates(PE_ID, clipping_values)?; - - let name_sigmas = name_sigmas.iter().map(|(n, v)| (n.as_str(), *v)).collect(); - Ok(clipped_relation.add_gaussian_noise(name_sigmas)) + // let multiplicity = 1; // TODO + // let (clipping_values, name_sigmas): (Vec<(String, f64)>, Vec<(String, f64)>) = self + // .schema() + // .clone() + // .iter() + // .zip(self.aggregate().clone().into_iter()) + // .fold((vec![], vec![]), |(c, s), (f, x)| { + // if let (name, Expr::Aggregate(agg)) = (f.name(), x) { + // match agg.aggregate() { + // aggregate::Aggregate::Sum => { + // let mut c = c; + // let cvalue = self + // .input() + // .schema() + // .field(agg.argument_name().unwrap()) + // .unwrap() + // .clone() + // .clipping_value(multiplicity); + // c.push((agg.argument_name().unwrap().to_string(), cvalue)); + // let mut s = s; + // s.push(( + // name.to_string(), + // mechanisms::gaussian_noise(epsilon, delta, cvalue), + // )); + // (c, s) + // } + // _ => (c, s), + // } + // } else { + // (c, s) + // } + // }); + + // let clipping_values = clipping_values + // .iter() + // .map(|(n, v)| (n.as_str(), *v)) + // .collect(); + // let clipped_relation = self.clip_aggregates(PE_ID, clipping_values)?; + + // let name_sigmas = name_sigmas.iter().map(|(n, v)| (n.as_str(), *v)).collect(); + // Ok(clipped_relation.add_gaussian_noise(name_sigmas)) + todo!() } } @@ -118,7 +119,7 @@ impl Relation { epsilon: f64, delta: f64, ) -> Result { - let protected_relation = self.force_protect_from_field_paths(relations, protected_entity); + let protected_relation: Relation = self.force_protect_from_field_paths(relations, protected_entity).into(); match protected_relation { Relation::Reduce(reduce) => { reduce.dp_compilation(relations, protected_entity, epsilon, delta) @@ -164,6 +165,7 @@ mod tests { } } + #[ignore]// TODO reactivate this #[test] fn test_dp_compilation() { let mut database = postgresql::test_database(); diff --git a/src/differential_privacy/protect_grouping_keys.rs b/src/differential_privacy/protect_grouping_keys.rs index a9f91f9c..c4025882 100644 --- a/src/differential_privacy/protect_grouping_keys.rs +++ b/src/differential_privacy/protect_grouping_keys.rs @@ -7,7 +7,7 @@ use crate::{ }, expr::{aggregate, Aggregate, Expr, Value}, hierarchy::Hierarchy, - protected::PE_ID, + protection::PE_ID, relation::{transforms, Field, Join, Map, Reduce, Relation, Set, Table, Variant as _, Visitor}, visitor::Acceptor, DataType, diff --git a/src/lib.rs b/src/lib.rs index 7fb74ca3..91787595 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub mod encoder; pub mod hierarchy; pub mod io; pub mod namer; -pub mod protected; +pub mod protection; pub mod relation; pub mod sampling_adjustment; pub mod sql; diff --git a/src/protected/mod.rs b/src/protection/mod.rs similarity index 80% rename from src/protected/mod.rs rename to src/protection/mod.rs index e09d1c0f..798b8441 100644 --- a/src/protected/mod.rs +++ b/src/protection/mod.rs @@ -11,11 +11,12 @@ use crate::{ relation::{Join, Map, Reduce, Relation, Set, Table, Values, Variant as _, Visitor}, visitor::Acceptor, }; -use std::{error, fmt, rc::Rc, result}; +use std::{error, fmt, rc::Rc, result, ops::Deref}; #[derive(Debug, Clone)] pub enum Error { NotProtectedEntityPreserving(String), + UnprotectedTable(String), Other(String), } @@ -23,6 +24,9 @@ impl Error { pub fn not_protected_entity_preserving(relation: impl fmt::Display) -> Error { Error::NotProtectedEntityPreserving(format!("{} is not PEP", relation)) } + pub fn unprotected_table(table: impl fmt::Display) -> Error { + Error::NotProtectedEntityPreserving(format!("{} is not protected", table)) + } } impl fmt::Display for Error { @@ -31,6 +35,9 @@ impl fmt::Display for Error { Error::NotProtectedEntityPreserving(desc) => { writeln!(f, "NotProtectedEntityPreserving: {}", desc) } + Error::UnprotectedTable(desc) => { + writeln!(f, "UnprotectedTable: {}", desc) + } Error::Other(err) => writeln!(f, "{}", err), } } @@ -56,16 +63,43 @@ pub enum Strategy { Hard, } +#[derive(Clone, Debug)] +pub struct ProtectedRelation(pub Relation); + +impl ProtectedRelation { + pub fn protected_entity_id(&self) -> &str { + PE_ID + } + + pub fn protected_entity_weight(&self) -> &str { + PE_WEIGHT + } +} + +impl From for Relation { + fn from(value: ProtectedRelation) -> Self { + value.0 + } +} + +impl Deref for ProtectedRelation { + type Target = Relation; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + /// A visitor to compute Relation protection #[derive(Clone, Debug)] -pub struct ProtectVisitor Relation> { +pub struct ProtectVisitor Result> { /// The protected entity definition protect_tables: F, /// Strategy used strategy: Strategy, } -impl Relation> ProtectVisitor { +impl Result> ProtectVisitor { pub fn new(protect_tables: F, strategy: Strategy) -> Self { ProtectVisitor { protect_tables, @@ -78,14 +112,14 @@ impl Relation> ProtectVisitor { pub fn protect_visitor_from_exprs<'a>( protected_entity: &'a [(&'a Table, Expr)], strategy: Strategy, -) -> ProtectVisitor Relation + 'a> { +) -> ProtectVisitor Result + 'a> { ProtectVisitor::new( move |table: &Table| match protected_entity .iter() .find_map(|(t, e)| (table == *t).then(|| e.clone())) { - Some(expr) => Relation::from(table.clone()).identity_with_field(PE_ID, expr.clone()), - None => table.clone().into(), + Some(expr) => Ok(ProtectedRelation(Relation::from(table.clone()).identity_with_field(PE_ID, expr.clone()))), + None => Err(Error::unprotected_table(table)), }, strategy, ) @@ -96,13 +130,13 @@ pub fn protect_visitor_from_field_paths<'a>( relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], strategy: Strategy, -) -> ProtectVisitor Relation + 'a> { +) -> ProtectVisitor Result + 'a> { ProtectVisitor::new( move |table: &Table| match protected_entity .iter() .find(|(tab, _path, _field)| table.name() == relations[*tab].name()) { - Some((_tab, path, field)) => Relation::from(table.clone()) + Some((_tab, path, field)) => Ok(ProtectedRelation(Relation::from(table.clone()) .with_field_path(relations, path, field, PE_ID) .map_fields(|n, e| { if n == PE_ID { @@ -110,31 +144,31 @@ pub fn protect_visitor_from_field_paths<'a>( } else { e } - }), - None => table.clone().into(), + }))), + None => Err(Error::unprotected_table(table)), }, //TODO fix MD5 here strategy, ) } -impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVisitor { - fn table(&self, table: &'a Table) -> Result { - Ok((self.protect_tables)(table) +impl<'a, F: Fn(&Table) -> Result> Visitor<'a, Result> for ProtectVisitor { + fn table(&self, table: &'a Table) -> Result { + Ok(ProtectedRelation(Relation::from((self.protect_tables)(table)?) .insert_field(1, PE_WEIGHT, Expr::val(1)) // We preserve the name - .with_name(format!("{}{}", PROTECTION_PREFIX, table.name()))) + .with_name(format!("{}{}", PROTECTION_PREFIX, table.name())))) } - fn map(&self, map: &'a Map, input: Result) -> Result { + fn map(&self, map: &'a Map, input: Result) -> Result { let builder = Relation::map() .with((PE_ID, Expr::col(PE_ID))) .with((PE_WEIGHT, Expr::col(PE_WEIGHT))) .with(map.clone()) - .input(input?); - Ok(builder.build()) + .input(Relation::from(input?)); + Ok(ProtectedRelation(builder.build())) } - fn reduce(&self, reduce: &'a Reduce, input: Result) -> Result { + fn reduce(&self, reduce: &'a Reduce, input: Result) -> Result { match self.strategy { Strategy::Soft => Err(Error::not_protected_entity_preserving(reduce)), Strategy::Hard => { @@ -142,8 +176,8 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis .with_group_by_column(PE_ID) .with((PE_WEIGHT, Expr::sum(Expr::col(PE_WEIGHT)))) .with(reduce.clone()) - .input(input?); - Ok(builder.build()) + .input(Relation::from(input?)); + Ok(ProtectedRelation(builder.build())) } } } @@ -152,9 +186,9 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis //TODO this need to be cleaned (really) &self, join: &'a crate::relation::Join, - left: Result, - right: Result, - ) -> Result { + left: Result, + right: Result, + ) -> Result { let left_name = left.as_ref().unwrap().name().to_string(); let right_name: String = right.as_ref().unwrap().name().to_string(); // Preserve names @@ -208,8 +242,8 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis Expr::qcol(left_name.as_str(), PE_ID), Expr::qcol(right_name.as_str(), PE_ID), )) - .left(left) - .right(right); + .left(Relation::from(left)) + .right(Relation::from(right)); let join: Join = builder.build(); let mut builder = Relation::map().name(name); builder = builder.with((PE_ID, Expr::col(format!("_LEFT{PE_ID}")))); @@ -229,7 +263,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis }); let builder = builder.input(Rc::new(join.into())); - Ok(builder.build()) + Ok(ProtectedRelation(builder.build())) } } } @@ -237,34 +271,34 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Result> for ProtectVis fn set( &self, set: &'a crate::relation::Set, - left: Result, - right: Result, - ) -> Result { + left: Result, + right: Result, + ) -> Result { let builder = Relation::set() .name(set.name()) .operator(set.operator().clone()) .quantifier(set.quantifier().clone()) - .left(left?) - .right(right?); - Ok(builder.build()) + .left(Relation::from(left?)) + .right(Relation::from(right?)); + Ok(ProtectedRelation(builder.build())) } - fn values(&self, values: &'a Values) -> Result { - Ok(Relation::Values(values.clone())) + fn values(&self, values: &'a Values) -> Result { + Ok(ProtectedRelation(Relation::Values(values.clone()))) } } impl Relation { /// Add protection - pub fn protect_from_visitor Relation>( + pub fn protect_from_visitor Result>( self, protect_visitor: ProtectVisitor, - ) -> Result { + ) -> Result { self.accept(protect_visitor) } /// Add protection - pub fn protect Relation>(self, protect_tables: F) -> Result { + pub fn protect Result>(self, protect_tables: F) -> Result { self.accept(ProtectVisitor::new(protect_tables, Strategy::Soft)) } @@ -272,7 +306,7 @@ impl Relation { pub fn protect_from_exprs<'a>( self, protected_entity: &'a [(&'a Table, Expr)], - ) -> Result { + ) -> Result { self.accept(protect_visitor_from_exprs(protected_entity, Strategy::Soft)) } @@ -281,7 +315,7 @@ impl Relation { self, relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], - ) -> Result { + ) -> Result { self.accept(protect_visitor_from_field_paths( relations, protected_entity, @@ -290,7 +324,7 @@ impl Relation { } /// Force protection - pub fn force_protect Relation>(self, protect_tables: F) -> Relation { + pub fn force_protect Result>(self, protect_tables: F) -> ProtectedRelation { self.accept(ProtectVisitor::new(protect_tables, Strategy::Hard)) .unwrap() } @@ -299,7 +333,7 @@ impl Relation { pub fn force_protect_from_exprs<'a>( self, protected_entity: &'a [(&'a Table, Expr)], - ) -> Relation { + ) -> ProtectedRelation { self.accept(protect_visitor_from_exprs(protected_entity, Strategy::Hard)) .unwrap() } @@ -309,7 +343,7 @@ impl Relation { self, relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], - ) -> Relation { + ) -> ProtectedRelation { self.accept(protect_visitor_from_field_paths( relations, protected_entity, @@ -364,7 +398,7 @@ mod tests { .unwrap(); table.display_dot().unwrap(); println!("Schema protected = {}", table.schema()); - println!("Query protected = {}", ast::Query::from(&table)); + println!("Query protected = {}", ast::Query::from(&*table)); assert_eq!(table.schema()[0].name(), PE_ID) } @@ -399,7 +433,7 @@ mod tests { println!("Schema protected = {}", relation.schema()); assert_eq!(relation.schema()[0].name(), PE_ID); // Print query - let query: &str = &ast::Query::from(&relation).to_string(); + let query: &str = &ast::Query::from(&*relation).to_string(); println!( "{}\n{}", format!("{query}").yellow(), @@ -438,7 +472,7 @@ mod tests { let vector = PE_ID.clone(); let base = vec!["item"]; let coordinates = vec!["price"]; - let norm = relation.l2_norms(vector, base, coordinates); + let norm = Relation::from(relation).l2_norms(vector, base, coordinates); norm.display_dot().unwrap(); // Print query let query: &str = &ast::Query::from(&norm).to_string(); @@ -484,7 +518,7 @@ mod tests { println!("Schema protected = {}", relation.schema()); assert_eq!(relation.schema()[0].name(), PE_ID); // Print query - let query: &str = &ast::Query::from(&relation).to_string(); + let query: &str = &ast::Query::from(&*relation).to_string(); println!("{}", format!("{query}").yellow()); println!( "{}\n{}", diff --git a/tests/integration.rs b/tests/integration.rs index f5fcb3cf..d5a9e94a 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -12,7 +12,7 @@ use qrlew::{ expr, relation::Variant as _, io::{postgresql, Database}, - protected::PE_ID, + protection::PE_ID, sql::parse, Relation, With, }; From 2da4d575ba87530d04dc1ab11612e6b80e04630b Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Mon, 11 Sep 2023 23:37:19 +0200 Subject: [PATCH 35/36] compilation --- src/differential_privacy/mod.rs | 51 ++++++++++++++++++++++- src/protection/mod.rs | 72 ++++++++++++++++----------------- src/relation/transforms.rs | 31 +++++++------- 3 files changed, 99 insertions(+), 55 deletions(-) diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index b6bb0e73..036e9c2d 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -7,28 +7,42 @@ pub mod mechanisms; pub mod protect_grouping_keys; use crate::data_type::DataTyped; +use crate::protection::PEPRelation; use crate::{ - expr::{aggregate, Expr}, + expr::{self, aggregate, Expr}, hierarchy::Hierarchy, - protection::PE_ID, relation::{field::Field, transforms, Reduce, Relation, Variant as _}, DataType, }; +use std::collections::{HashMap, HashSet}; use std::{cmp, error, fmt, rc::Rc, result}; #[derive(Debug, PartialEq)] pub enum Error { + InvalidRelation(String), Other(String), } +impl Error { + pub fn invalid_relation(relation: impl fmt::Display) -> Error { + Error::InvalidRelation(format!("{} is invalid", relation)) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Error::InvalidRelation(relation) => writeln!(f, "{} invalid.", relation), Error::Other(err) => writeln!(f, "{}", err), } } } +impl From for Error { + fn from(err: expr::Error) -> Self { + Error::Other(err.to_string()) + } +} impl From for Error { fn from(err: transforms::Error) -> Self { Error::Other(err.to_string()) @@ -56,9 +70,42 @@ impl Field { } } +impl PEPRelation { + /// Compile a protected Relation into DP + pub fn dp_compile_sums(self, epsilon: f64, delta: f64) -> Result {// Return a DP relation + let protected_entity_id = self.protected_entity_id().to_string(); + if let PEPRelation(Relation::Reduce(reduce)) = self { + reduce.dp_compile_sums(&protected_entity_id, epsilon, delta) + } else { + Err(Error::invalid_relation(self.0)) + } + } +} + /* Reduce */ impl Reduce { + fn dp_compile_sums(self, protected_entity_id: &str, epsilon: f64, delta: f64) -> Result { + let input_groups: Vec<&str> = self.group_by_names(); + let mut input_values_bound: Vec<(&str, f64)> = vec![]; + let mut names: HashMap<&str, &str> = HashMap::new(); + for (name, aggregate) in self.named_aggregates() { + if aggregate.aggregate() == aggregate::Aggregate::Sum { + let value_name = aggregate.argument_name()?.as_str(); + let value_data_type = self.input().schema()[value_name].data_type(); + let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); + input_values_bound.push((value_name, absolute_bound)); + names.insert(value_name, name); + } + }; + self.input().clone().l2_clipped_sums( + protected_entity_id, + input_groups.into_iter().collect(), + input_values_bound.into_iter().collect() + ); + todo!() + } + pub fn dp_compilation<'a>( self, relations: &'a Hierarchy>, diff --git a/src/protection/mod.rs b/src/protection/mod.rs index 798b8441..6c3777ed 100644 --- a/src/protection/mod.rs +++ b/src/protection/mod.rs @@ -64,9 +64,9 @@ pub enum Strategy { } #[derive(Clone, Debug)] -pub struct ProtectedRelation(pub Relation); +pub struct PEPRelation(pub Relation); -impl ProtectedRelation { +impl PEPRelation { pub fn protected_entity_id(&self) -> &str { PE_ID } @@ -76,13 +76,13 @@ impl ProtectedRelation { } } -impl From for Relation { - fn from(value: ProtectedRelation) -> Self { +impl From for Relation { + fn from(value: PEPRelation) -> Self { value.0 } } -impl Deref for ProtectedRelation { +impl Deref for PEPRelation { type Target = Relation; fn deref(&self) -> &Self::Target { @@ -92,14 +92,14 @@ impl Deref for ProtectedRelation { /// A visitor to compute Relation protection #[derive(Clone, Debug)] -pub struct ProtectVisitor Result> { +pub struct ProtectVisitor Result> { /// The protected entity definition protect_tables: F, /// Strategy used strategy: Strategy, } -impl Result> ProtectVisitor { +impl Result> ProtectVisitor { pub fn new(protect_tables: F, strategy: Strategy) -> Self { ProtectVisitor { protect_tables, @@ -112,13 +112,13 @@ impl Result> ProtectVisitor { pub fn protect_visitor_from_exprs<'a>( protected_entity: &'a [(&'a Table, Expr)], strategy: Strategy, -) -> ProtectVisitor Result + 'a> { +) -> ProtectVisitor Result + 'a> { ProtectVisitor::new( move |table: &Table| match protected_entity .iter() .find_map(|(t, e)| (table == *t).then(|| e.clone())) { - Some(expr) => Ok(ProtectedRelation(Relation::from(table.clone()).identity_with_field(PE_ID, expr.clone()))), + Some(expr) => Ok(PEPRelation(Relation::from(table.clone()).identity_with_field(PE_ID, expr.clone()))), None => Err(Error::unprotected_table(table)), }, strategy, @@ -130,13 +130,13 @@ pub fn protect_visitor_from_field_paths<'a>( relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], strategy: Strategy, -) -> ProtectVisitor Result + 'a> { +) -> ProtectVisitor Result + 'a> { ProtectVisitor::new( move |table: &Table| match protected_entity .iter() .find(|(tab, _path, _field)| table.name() == relations[*tab].name()) { - Some((_tab, path, field)) => Ok(ProtectedRelation(Relation::from(table.clone()) + Some((_tab, path, field)) => Ok(PEPRelation(Relation::from(table.clone()) .with_field_path(relations, path, field, PE_ID) .map_fields(|n, e| { if n == PE_ID { @@ -151,24 +151,24 @@ pub fn protect_visitor_from_field_paths<'a>( ) } -impl<'a, F: Fn(&Table) -> Result> Visitor<'a, Result> for ProtectVisitor { - fn table(&self, table: &'a Table) -> Result { - Ok(ProtectedRelation(Relation::from((self.protect_tables)(table)?) +impl<'a, F: Fn(&Table) -> Result> Visitor<'a, Result> for ProtectVisitor { + fn table(&self, table: &'a Table) -> Result { + Ok(PEPRelation(Relation::from((self.protect_tables)(table)?) .insert_field(1, PE_WEIGHT, Expr::val(1)) // We preserve the name .with_name(format!("{}{}", PROTECTION_PREFIX, table.name())))) } - fn map(&self, map: &'a Map, input: Result) -> Result { + fn map(&self, map: &'a Map, input: Result) -> Result { let builder = Relation::map() .with((PE_ID, Expr::col(PE_ID))) .with((PE_WEIGHT, Expr::col(PE_WEIGHT))) .with(map.clone()) .input(Relation::from(input?)); - Ok(ProtectedRelation(builder.build())) + Ok(PEPRelation(builder.build())) } - fn reduce(&self, reduce: &'a Reduce, input: Result) -> Result { + fn reduce(&self, reduce: &'a Reduce, input: Result) -> Result { match self.strategy { Strategy::Soft => Err(Error::not_protected_entity_preserving(reduce)), Strategy::Hard => { @@ -177,7 +177,7 @@ impl<'a, F: Fn(&Table) -> Result> Visitor<'a, Result Result> Visitor<'a, Result, - right: Result, - ) -> Result { + left: Result, + right: Result, + ) -> Result { let left_name = left.as_ref().unwrap().name().to_string(); let right_name: String = right.as_ref().unwrap().name().to_string(); // Preserve names @@ -263,7 +263,7 @@ impl<'a, F: Fn(&Table) -> Result> Visitor<'a, Result Result> Visitor<'a, Result, - right: Result, - ) -> Result { + left: Result, + right: Result, + ) -> Result { let builder = Relation::set() .name(set.name()) .operator(set.operator().clone()) .quantifier(set.quantifier().clone()) .left(Relation::from(left?)) .right(Relation::from(right?)); - Ok(ProtectedRelation(builder.build())) + Ok(PEPRelation(builder.build())) } - fn values(&self, values: &'a Values) -> Result { - Ok(ProtectedRelation(Relation::Values(values.clone()))) + fn values(&self, values: &'a Values) -> Result { + Ok(PEPRelation(Relation::Values(values.clone()))) } } impl Relation { /// Add protection - pub fn protect_from_visitor Result>( + pub fn protect_from_visitor Result>( self, protect_visitor: ProtectVisitor, - ) -> Result { + ) -> Result { self.accept(protect_visitor) } /// Add protection - pub fn protect Result>(self, protect_tables: F) -> Result { + pub fn protect Result>(self, protect_tables: F) -> Result { self.accept(ProtectVisitor::new(protect_tables, Strategy::Soft)) } @@ -306,7 +306,7 @@ impl Relation { pub fn protect_from_exprs<'a>( self, protected_entity: &'a [(&'a Table, Expr)], - ) -> Result { + ) -> Result { self.accept(protect_visitor_from_exprs(protected_entity, Strategy::Soft)) } @@ -315,7 +315,7 @@ impl Relation { self, relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], - ) -> Result { + ) -> Result { self.accept(protect_visitor_from_field_paths( relations, protected_entity, @@ -324,7 +324,7 @@ impl Relation { } /// Force protection - pub fn force_protect Result>(self, protect_tables: F) -> ProtectedRelation { + pub fn force_protect Result>(self, protect_tables: F) -> PEPRelation { self.accept(ProtectVisitor::new(protect_tables, Strategy::Hard)) .unwrap() } @@ -333,7 +333,7 @@ impl Relation { pub fn force_protect_from_exprs<'a>( self, protected_entity: &'a [(&'a Table, Expr)], - ) -> ProtectedRelation { + ) -> PEPRelation { self.accept(protect_visitor_from_exprs(protected_entity, Strategy::Hard)) .unwrap() } @@ -343,7 +343,7 @@ impl Relation { self, relations: &'a Hierarchy>, protected_entity: &'a [(&'a str, &'a [(&'a str, &'a str, &'a str)], &'a str)], - ) -> ProtectedRelation { + ) -> PEPRelation { self.accept(protect_visitor_from_field_paths( relations, protected_entity, diff --git a/src/relation/transforms.rs b/src/relation/transforms.rs index 38e13814..5eb85002 100644 --- a/src/relation/transforms.rs +++ b/src/relation/transforms.rs @@ -13,7 +13,7 @@ use crate::{ hierarchy::Hierarchy, io, relation, DataType, }; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::{ convert::Infallible, error, fmt, @@ -193,7 +193,6 @@ impl Reduce { pub fn l2_clipped_all_sums(&self, entities: &str) -> Result { let mut input_entities: Option<&str> = None; let input_groups: Vec<&str> = self.group_by_names(); - let mut input_values: Vec<&str> = vec![]; let mut clipping_values: Vec<(&str, f64)> = vec![]; let mut names: HashMap<&str, &str> = HashMap::new(); for (name, aggregate) in self.named_aggregates() { @@ -203,13 +202,12 @@ impl Reduce { let value_name = aggregate.argument_name()?.as_str(); let value_data_type = self.input().schema()[value_name].data_type(); let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); - input_values.push(value_name); names.insert(value_name, name); clipping_values.push((value_name, absolute_bound))// TODO Set a better clipping value } }; let input_entities = input_entities.ok_or(Error::invalid_arguments(entities))?; - Ok(self.input().clone().l2_clipped_sums(input_entities, input_groups, input_values, clipping_values) + Ok(self.input().clone().l2_clipped_sums(input_entities, input_groups, clipping_values) .rename_fields(|s, _| names.get(s).unwrap_or(&s).to_string())) } @@ -640,23 +638,22 @@ impl Relation { self, entities: &str, groups: Vec<&str>, - values: Vec<&str>, - clipping_values: Vec<(&str, f64)>, + value_clippings: Vec<(&str, f64)>, ) -> Self { + // Arrange the values + let value_clippings: HashMap<&str, f64> = value_clippings.into_iter().collect(); // Compute the norm let norms = self .clone() - .l2_norms(entities.clone(), groups.clone(), values.clone()); - // Put the `clipping_values`in the right shape - let clipping_values: HashMap<&str, f64> = clipping_values.into_iter().collect(); + .l2_norms(entities.clone(), groups.clone(), value_clippings.keys().cloned().collect()); // Compute the scaling factors let scaling_factors = norms.map_fields(|field_name, expr| { - if values.contains(&field_name) { + if value_clippings.contains_key(&field_name) { Expr::divide( Expr::val(1), Expr::greatest( Expr::val(1), - Expr::divide(expr.clone(), Expr::val(clipping_values[&field_name])), + Expr::divide(expr.clone(), Expr::val(value_clippings[&field_name])), ), ) } else { @@ -665,10 +662,10 @@ impl Relation { }); let clipped_relation = self.scale( entities, - values.clone(), + value_clippings.keys().cloned().collect(), scaling_factors, ); - clipped_relation.sums_by_group(groups, values) + clipped_relation.sums_by_group(groups, value_clippings.keys().cloned().collect()) } /// Clip sums in the first `Reduce`s found @@ -1265,7 +1262,7 @@ mod tests { .as_ref() .clone(); // Compute l2 norm - let clipped_relation = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", 20.)]); + let clipped_relation = relation.clone().l2_clipped_sums("id", vec!["city"], vec![("age", 20.)]); clipped_relation.display_dot().unwrap(); // Print query let query = &ast::Query::from(&clipped_relation).to_string(); @@ -1275,20 +1272,20 @@ mod tests { } // 100 let norm = 100.; - let clipped_relation_100 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + let clipped_relation_100 = relation.clone().l2_clipped_sums("id", vec!["city"], vec![("age", norm)]); for row in database.query(&ast::Query::from(&clipped_relation_100).to_string()).unwrap() { println!("{row}"); } // 1000 let norm = 1000.; - let clipped_relation_1000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + let clipped_relation_1000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec![("age", norm)]); for row in database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap() { println!("{row}"); } assert!(database.query(&ast::Query::from(&clipped_relation_100).to_string()).unwrap()!=database.query(&ast::Query::from(&clipped_relation_1000).to_string()).unwrap()); // 10000 let norm = 10000.; - let clipped_relation_10000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec!["age"], vec![("age", norm)]); + let clipped_relation_10000 = relation.clone().l2_clipped_sums("id", vec!["city"], vec![("age", norm)]); for row in database.query(&ast::Query::from(&clipped_relation_10000).to_string()).unwrap() { println!("{row}"); } From 84117b144e896b48678bf1a43002cf4c38caac5f Mon Sep 17 00:00:00 2001 From: Nicolas Grislain Date: Tue, 12 Sep 2023 00:37:00 +0200 Subject: [PATCH 36/36] Almost there --- src/differential_privacy/mod.rs | 76 +++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 036e9c2d..3f841a04 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -74,8 +74,9 @@ impl PEPRelation { /// Compile a protected Relation into DP pub fn dp_compile_sums(self, epsilon: f64, delta: f64) -> Result {// Return a DP relation let protected_entity_id = self.protected_entity_id().to_string(); + let protected_entity_weight = self.protected_entity_weight().to_string(); if let PEPRelation(Relation::Reduce(reduce)) = self { - reduce.dp_compile_sums(&protected_entity_id, epsilon, delta) + reduce.dp_compile_sums(&protected_entity_id, &protected_entity_weight, epsilon, delta) } else { Err(Error::invalid_relation(self.0)) } @@ -85,25 +86,35 @@ impl PEPRelation { /* Reduce */ impl Reduce { - fn dp_compile_sums(self, protected_entity_id: &str, epsilon: f64, delta: f64) -> Result { - let input_groups: Vec<&str> = self.group_by_names(); + fn dp_compile_sums(self, protected_entity_id: &str, protected_entity_weight: &str, epsilon: f64, delta: f64) -> Result { + // Collect groups + let mut input_entities: Option<&str> = None; + let mut input_groups: HashSet<&str> = self.group_by_names().into_iter().collect(); let mut input_values_bound: Vec<(&str, f64)> = vec![]; let mut names: HashMap<&str, &str> = HashMap::new(); + // Collect names, sums and bounds for (name, aggregate) in self.named_aggregates() { - if aggregate.aggregate() == aggregate::Aggregate::Sum { - let value_name = aggregate.argument_name()?.as_str(); - let value_data_type = self.input().schema()[value_name].data_type(); - let absolute_bound = value_data_type.absolute_upper_bound().unwrap_or(1.0); - input_values_bound.push((value_name, absolute_bound)); - names.insert(value_name, name); + // Get value name + let input_name = aggregate.argument_name()?.as_str(); + names.insert(input_name, name); + if name == protected_entity_id {// remove pe group + input_groups.remove(&input_name); + input_entities = Some(input_name); + } else if aggregate.aggregate() == aggregate::Aggregate::Sum && name != protected_entity_weight {// add aggregate + let input_data_type = self.input().schema()[input_name].data_type(); + let absolute_bound = input_data_type.absolute_upper_bound().unwrap_or(1.0); + input_values_bound.push((input_name, absolute_bound)); } }; - self.input().clone().l2_clipped_sums( - protected_entity_id, + let clipped_relation = self.input().clone().l2_clipped_sums( + input_entities.unwrap(), input_groups.into_iter().collect(), - input_values_bound.into_iter().collect() + input_values_bound.iter().cloned().collect() ); - todo!() + let noise_multiplier = 1.; // TODO set this properly + let dp_clipped_relation = clipped_relation.add_gaussian_noise(input_values_bound.into_iter().map(|(name, bound)| (name,noise_multiplier*bound)).collect()); + let renamed_dp_clipped_relation = dp_clipped_relation.rename_fields(|n, e| names[n].to_string()); + Ok(renamed_dp_clipped_relation) } pub fn dp_compilation<'a>( @@ -212,6 +223,45 @@ mod tests { } } + #[test] + fn test_dp_compile_sums() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let table = relations + .get(&["item_table".into()]) + .unwrap() + .as_ref() + .clone(); + + // with GROUP BY + let relation: Relation = Relation::reduce() + .input(table.clone()) + .with(("sum_price", Expr::sum(Expr::col("price")))) + .with_group_by_column("order_id") + .build(); + relation.display_dot().unwrap(); + + let pep_relation = relation.force_protect_from_field_paths(&relations, &[ + ( + "item_table", + &[ + ("order_id", "order_table", "id"), + ("user_id", "user_table", "id"), + ], + "name", + ), + ("order_table", &[("user_id", "user_table", "id")], "name"), + ("user_table", &[], "name"), + ]); + pep_relation.display_dot().unwrap(); + + let epsilon = 1.; + let delta = 1e-3; + let dp_relation = pep_relation.dp_compile_sums(epsilon, delta).unwrap(); + dp_relation.display_dot().unwrap(); + } + #[ignore]// TODO reactivate this #[test] fn test_dp_compilation() {