Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
victoria de sainte agathe committed Dec 19, 2023
1 parent 3bf0501 commit 6e4011b
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.6.0] - 2023-12-18
### Added
- Unsupported DP aggregations are handled gracefully
- Rewrite into DP Reduce with MAX, MIN, QUANTILE(S), FIRST, LAST, MEDIAN if they are applied on a grouping column [#227](https://github.com/Qrlew/qrlew/issues/227)

## [0.5.6] - 2023-12-18
### Fixed
Expand Down
14 changes: 10 additions & 4 deletions src/differential_privacy/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::{
data_type::DataTyped,
differential_privacy::private_query::PrivateQuery,
differential_privacy::{private_query, DPRelation, Error, Result},
expr::{aggregate, AggregateColumn, Expr, Column, Identifier},
expr::{aggregate::{self, Aggregate}, AggregateColumn, Expr, Column, Identifier},
privacy_unit_tracking::PUPRelation,
relation::{field::Field, Map, Reduce, Relation, Variant},
DataType, Ready, display::Dot,
DataType, Ready,

};
use std::{cmp, collections::HashMap, ops::Deref};
Expand Down Expand Up @@ -161,10 +161,16 @@ impl PUPRelation {
let square_col = format!("_SQUARE_{}", col_name);
let sum_square_col = format!("_SUM_{}", square_col);
match aggregate.aggregate() {
aggregate::Aggregate::First => {
Aggregate::Min |
Aggregate::Max |
Aggregate::Median |
Aggregate::First |
Aggregate::Last |
Aggregate::Quantile(_) |
Aggregate::Quantiles(_) => {
assert!(group_by_names.contains(&col_name.as_str()));
output_b = output_b.with((name, Expr::col(col_name.as_str())))
}
},
aggregate::Aggregate::Mean => {
input_b = input_b
.with((col_name.as_str(), Expr::col(col_name.as_str())))
Expand Down
68 changes: 66 additions & 2 deletions src/differential_privacy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use crate::{
differential_privacy::private_query::PrivateQuery,
expr, privacy_unit_tracking,
relation::{rewriting, Reduce, Relation},
Ready, data_type::function::Aggregate,
Ready,
display::Dot
};
use std::{error, fmt, ops::Deref, result};

Expand Down Expand Up @@ -177,7 +178,7 @@ mod tests {
io::{postgresql, Database},
privacy_unit_tracking::PrivacyUnit,
privacy_unit_tracking::{PrivacyUnitTracking, Strategy},
relation::{Field, Map, Relation, Schema, Variant},
relation::{Field, Map, Relation, Schema, Variant, Constraint},
};
use std::collections::HashSet;

Expand Down Expand Up @@ -620,4 +621,67 @@ mod tests {
.collect();
assert_eq!(city_keys, correct_keys);
}

#[test]
fn test_dp_rewrite_reduce() {
let mut database = postgresql::test_database();
let relations = database.relations();

let table = relations
.get(&["table_1".to_string()])
.unwrap()
.deref()
.clone();
let (epsilon, delta) = (1., 1e-3);
let (epsilon_tau_thresholding, delta_tau_thresholding) = (0.5, 2e-3);

// privacy track the inputs
let privacy_unit_tracking = PrivacyUnitTracking::from((
&relations,
vec![("table_1", vec![], PrivacyUnit::privacy_unit_row())],
Strategy::Hard,
));
let pup_table = privacy_unit_tracking
.table(&table.try_into().unwrap())
.unwrap();
let reduce = Reduce::new(
"my_reduce".to_string(),
vec![
("sum_a".to_string(), AggregateColumn::sum("a")),
("d".to_string(), AggregateColumn::first("d")),
("max_d".to_string(), AggregateColumn::max("d")),
],
vec!["d".into()],
pup_table.deref().clone().into(),
);
let relation = Relation::from(reduce.clone());
relation.display_dot().unwrap();

let (dp_relation, private_query) = reduce
.differentially_private(
epsilon,
delta,
epsilon_tau_thresholding,
delta_tau_thresholding,
)
.unwrap()
.into();
dp_relation.display_dot().unwrap();
assert_eq!(
private_query,
PrivateQuery::EpsilonDelta(epsilon_tau_thresholding, delta_tau_thresholding)
.compose(PrivateQuery::gaussian_from_epsilon_delta_sensitivity(epsilon, delta, 10.))
);
let correct_schema: Schema = vec![
("sum_a", DataType::float_interval(0., 100.), None),
("d", DataType::integer_interval(0, 10), Some(Constraint::Unique)),
("max_d", DataType::integer_interval(0, 10), Some(Constraint::Unique)),
].into_iter()
.collect();
assert_eq!(dp_relation.schema(), &correct_schema);

let query: &str = &ast::Query::from(&dp_relation).to_string();
println!("{query}");
_ = database.query(query).unwrap();
}
}
1 change: 1 addition & 0 deletions src/rewriting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ mod tests {
"SELECT order_id, sum(price), sum(distinct price) FROM item_table GROUP BY order_id HAVING count(*) > 2",
"SELECT order_id, sum(order_id) FROM item_table GROUP BY order_id",
"SELECT order_id As my_order, sum(price) FROM item_table GROUP BY my_order",
"SELECT order_id, MAX(order_id), sum(price) FROM item_table GROUP BY order_id"
];

for q in queries {
Expand Down
2 changes: 1 addition & 1 deletion src/rewriting/rewriting_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> {
Aggregate::Quantile(_) |
Aggregate::Quantiles(_) => reduce.group_by().contains(f.column()),
_ => false,

}
}) {
rewriting_rules.push(
Expand Down

0 comments on commit 6e4011b

Please sign in to comment.