From 6e4011bd1c9d539903d61f6608eb2ec2f6322d78 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 19 Dec 2023 10:50:17 +0100 Subject: [PATCH 1/2] ok --- CHANGELOG.md | 1 + src/differential_privacy/aggregates.rs | 14 ++++-- src/differential_privacy/mod.rs | 68 +++++++++++++++++++++++++- src/rewriting/mod.rs | 1 + src/rewriting/rewriting_rule.rs | 2 +- 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fa98c9f..22606f62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.6.0] - 2023-12-18 ### Added - Unsupported DP aggregations are handled gracefully +- Rewrite into DP Reduce with MAX, MIN, QUANTILE(S), FIRST, LAST, MEDIAN if they are applied on a grouping column [#227](https://github.com/Qrlew/qrlew/issues/227) ## [0.5.6] - 2023-12-18 ### Fixed diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 790a80f8..f8081c09 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -3,10 +3,10 @@ use crate::{ data_type::DataTyped, differential_privacy::private_query::PrivateQuery, differential_privacy::{private_query, DPRelation, Error, Result}, - expr::{aggregate, AggregateColumn, Expr, Column, Identifier}, + expr::{aggregate::{self, Aggregate}, AggregateColumn, Expr, Column, Identifier}, privacy_unit_tracking::PUPRelation, relation::{field::Field, Map, Reduce, Relation, Variant}, - DataType, Ready, display::Dot, + DataType, Ready, }; use std::{cmp, collections::HashMap, ops::Deref}; @@ -161,10 +161,16 @@ impl PUPRelation { let square_col = format!("_SQUARE_{}", col_name); let sum_square_col = format!("_SUM_{}", square_col); match aggregate.aggregate() { - aggregate::Aggregate::First => { + Aggregate::Min | + Aggregate::Max | + Aggregate::Median | + Aggregate::First | + Aggregate::Last | + Aggregate::Quantile(_) | + Aggregate::Quantiles(_) => { assert!(group_by_names.contains(&col_name.as_str())); output_b = output_b.with((name, Expr::col(col_name.as_str()))) - } + }, aggregate::Aggregate::Mean => { input_b = input_b .with((col_name.as_str(), Expr::col(col_name.as_str()))) diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 7a1533ca..af691227 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -13,7 +13,8 @@ use crate::{ differential_privacy::private_query::PrivateQuery, expr, privacy_unit_tracking, relation::{rewriting, Reduce, Relation}, - Ready, data_type::function::Aggregate, + Ready, + display::Dot }; use std::{error, fmt, ops::Deref, result}; @@ -177,7 +178,7 @@ mod tests { io::{postgresql, Database}, privacy_unit_tracking::PrivacyUnit, privacy_unit_tracking::{PrivacyUnitTracking, Strategy}, - relation::{Field, Map, Relation, Schema, Variant}, + relation::{Field, Map, Relation, Schema, Variant, Constraint}, }; use std::collections::HashSet; @@ -620,4 +621,67 @@ mod tests { .collect(); assert_eq!(city_keys, correct_keys); } + + #[test] + fn test_dp_rewrite_reduce() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let table = relations + .get(&["table_1".to_string()]) + .unwrap() + .deref() + .clone(); + let (epsilon, delta) = (1., 1e-3); + let (epsilon_tau_thresholding, delta_tau_thresholding) = (0.5, 2e-3); + + // privacy track the inputs + let privacy_unit_tracking = PrivacyUnitTracking::from(( + &relations, + vec![("table_1", vec![], PrivacyUnit::privacy_unit_row())], + Strategy::Hard, + )); + let pup_table = privacy_unit_tracking + .table(&table.try_into().unwrap()) + .unwrap(); + let reduce = Reduce::new( + "my_reduce".to_string(), + vec![ + ("sum_a".to_string(), AggregateColumn::sum("a")), + ("d".to_string(), AggregateColumn::first("d")), + ("max_d".to_string(), AggregateColumn::max("d")), + ], + vec!["d".into()], + pup_table.deref().clone().into(), + ); + let relation = Relation::from(reduce.clone()); + relation.display_dot().unwrap(); + + let (dp_relation, private_query) = reduce + .differentially_private( + epsilon, + delta, + epsilon_tau_thresholding, + delta_tau_thresholding, + ) + .unwrap() + .into(); + dp_relation.display_dot().unwrap(); + assert_eq!( + private_query, + PrivateQuery::EpsilonDelta(epsilon_tau_thresholding, delta_tau_thresholding) + .compose(PrivateQuery::gaussian_from_epsilon_delta_sensitivity(epsilon, delta, 10.)) + ); + let correct_schema: Schema = vec![ + ("sum_a", DataType::float_interval(0., 100.), None), + ("d", DataType::integer_interval(0, 10), Some(Constraint::Unique)), + ("max_d", DataType::integer_interval(0, 10), Some(Constraint::Unique)), + ].into_iter() + .collect(); + assert_eq!(dp_relation.schema(), &correct_schema); + + let query: &str = &ast::Query::from(&dp_relation).to_string(); + println!("{query}"); + _ = database.query(query).unwrap(); + } } diff --git a/src/rewriting/mod.rs b/src/rewriting/mod.rs index daf5425e..538e8fdd 100644 --- a/src/rewriting/mod.rs +++ b/src/rewriting/mod.rs @@ -187,6 +187,7 @@ mod tests { "SELECT order_id, sum(price), sum(distinct price) FROM item_table GROUP BY order_id HAVING count(*) > 2", "SELECT order_id, sum(order_id) FROM item_table GROUP BY order_id", "SELECT order_id As my_order, sum(price) FROM item_table GROUP BY my_order", + "SELECT order_id, MAX(order_id), sum(price) FROM item_table GROUP BY order_id" ]; for q in queries { diff --git a/src/rewriting/rewriting_rule.rs b/src/rewriting/rewriting_rule.rs index 81fdc779..3b4e688a 100644 --- a/src/rewriting/rewriting_rule.rs +++ b/src/rewriting/rewriting_rule.rs @@ -706,7 +706,7 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { Aggregate::Quantile(_) | Aggregate::Quantiles(_) => reduce.group_by().contains(f.column()), _ => false, - + } }) { rewriting_rules.push( From 975a7210e61cdb5902565387eba7dc5066383bd5 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Tue, 19 Dec 2023 10:51:43 +0100 Subject: [PATCH 2/2] rm useless line --- src/rewriting/rewriting_rule.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rewriting/rewriting_rule.rs b/src/rewriting/rewriting_rule.rs index 3b4e688a..f71640f5 100644 --- a/src/rewriting/rewriting_rule.rs +++ b/src/rewriting/rewriting_rule.rs @@ -706,7 +706,6 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { Aggregate::Quantile(_) | Aggregate::Quantiles(_) => reduce.group_by().contains(f.column()), _ => false, - } }) { rewriting_rules.push(