Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fmt #114

Merged
merged 1 commit into from
Sep 15, 2023
Merged

fmt #114

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 81 additions & 35 deletions src/differential_privacy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ pub mod mechanisms;
pub mod protect_grouping_keys;

use crate::{
Ready,
builder::With,
data_type::DataTyped,
expr::{self, aggregate, Expr, AggregateColumn},
hierarchy::Hierarchy,
relation::{field::Field, transforms, Map, Reduce, Relation, Variant as _},
DataType,
display::Dot,
expr::{self, aggregate, AggregateColumn, Expr},
hierarchy::Hierarchy,
protection::PEPRelation,
builder::With,
relation::{field::Field, transforms, Map, Reduce, Relation, Variant as _},
DataType, Ready,
};
use std::collections::{HashMap, HashSet};
use std::ops::Deref;
Expand Down Expand Up @@ -75,7 +74,6 @@ impl Deref for DPRelation {
}
}


impl Field {
pub fn clipping_value(self, multiplicity: i64) -> f64 {
match self.data_type() {
Expand Down Expand Up @@ -107,16 +105,27 @@ impl PEPRelation {
// }

/// Compile a protected Relation into DP
pub fn dp_compile(self, epsilon: f64, delta: f64) -> Result<DPRelation> {// Return a DP relation
pub fn dp_compile(self, epsilon: f64, delta: f64) -> Result<DPRelation> {
// Return a DP relation
let protected_entity_id = self.protected_entity_id().to_string();
let protected_entity_weight = self.protected_entity_weight().to_string();
match Relation::from(self) {
Relation::Map(map) => {
let dp_input = PEPRelation(map.input().clone()).dp_compile(epsilon, delta)?;
Ok(DPRelation(Map::builder().with(map).input(Relation::from(dp_input)).build()))
},
Relation::Reduce(reduce) => reduce.dp_compile_sums(&protected_entity_id, &protected_entity_weight, epsilon, delta),
relation => Err(Error::invalid_relation(relation))
Ok(DPRelation(
Map::builder()
.with(map)
.input(Relation::from(dp_input))
.build(),
))
}
Relation::Reduce(reduce) => reduce.dp_compile_sums(
&protected_entity_id,
&protected_entity_weight,
epsilon,
delta,
),
relation => Err(Error::invalid_relation(relation)),
}
}
}
Expand All @@ -125,7 +134,13 @@ impl PEPRelation {
*/
impl Reduce {
/// DP compile the sums
fn dp_compile_sums(self, protected_entity_id: &str, protected_entity_weight: &str, epsilon: f64, delta: f64) -> Result<DPRelation> {
fn dp_compile_sums(
self,
protected_entity_id: &str,
protected_entity_weight: &str,
epsilon: f64,
delta: f64,
) -> Result<DPRelation> {
// Collect groups
let mut input_entities: Option<&str> = None;
let mut input_groups: HashSet<&str> = self.group_by_names().into_iter().collect();
Expand All @@ -140,21 +155,28 @@ impl Reduce {
// remove pe group
input_groups.remove(&input_name);
input_entities = Some(input_name);
} else if aggregate.aggregate() == &aggregate::Aggregate::Sum && name != protected_entity_weight {// add aggregate
} else if aggregate.aggregate() == &aggregate::Aggregate::Sum
&& name != protected_entity_weight
{
// add aggregate
let input_data_type = self.input().schema()[input_name].data_type();
let absolute_bound = input_data_type.absolute_upper_bound().unwrap_or(1.0);
input_values_bound.push((input_name, absolute_bound));
}
};
}
// Check that groups are public
if !input_groups.iter().all(|e| match self.input().schema()[*e].data_type() {// TODO improve this
DataType::Boolean(b) if b.all_values() => true,
DataType::Integer(i) if i.all_values() => true,
DataType::Enum(e) => true,
DataType::Float(f) if f.all_values() => true,
DataType::Text(t) if t.all_values() => true,
_ => false,
}) {
if !input_groups
.iter()
.all(|e| match self.input().schema()[*e].data_type() {
// TODO improve this
DataType::Boolean(b) if b.all_values() => true,
DataType::Integer(i) if i.all_values() => true,
DataType::Enum(e) => true,
DataType::Float(f) if f.all_values() => true,
DataType::Text(t) if t.all_values() => true,
_ => false,
})
{
//return Err(Error::invalid_relation(self));
println!("GROUPS SHOULD BE PUBLIC")
};
Expand All @@ -165,48 +187,72 @@ impl Reduce {
input_values_bound.iter().cloned().collect(),
);
let noise_multiplier = 1.; // TODO set this properly
let dp_clipped_relation = clipped_relation.add_gaussian_noise(input_values_bound.into_iter().map(|(name, bound)| (name,noise_multiplier*bound)).collect());
let renamed_dp_clipped_relation = dp_clipped_relation.rename_fields(|n, e| names.get(n).unwrap_or(&n).to_string());
let dp_clipped_relation = clipped_relation.add_gaussian_noise(
input_values_bound
.into_iter()
.map(|(name, bound)| (name, noise_multiplier * bound))
.collect(),
);
let renamed_dp_clipped_relation =
dp_clipped_relation.rename_fields(|n, e| names.get(n).unwrap_or(&n).to_string());
Ok(DPRelation(renamed_dp_clipped_relation))
}

/// Rewrite aggregations as sums and compile sums
pub fn dp_compile(self, protected_entity_id: &str, protected_entity_weight: &str, epsilon: f64, delta: f64) -> Result<DPRelation> {
pub fn dp_compile(
self,
protected_entity_id: &str,
protected_entity_weight: &str,
epsilon: f64,
delta: f64,
) -> Result<DPRelation> {
let mut output = Map::builder();
let mut sums = Reduce::builder();
// Add aggregate colums
for (name, aggregate) in self.named_aggregates().into_iter() {
match aggregate.aggregate() {
aggregate::Aggregate::First => {
sums = sums.with((aggregate.column_name()?, AggregateColumn::col(aggregate.column_name()?)));
},
sums = sums.with((
aggregate.column_name()?,
AggregateColumn::col(aggregate.column_name()?),
));
}
aggregate::Aggregate::Mean => {
let sum_col = &format!("_SUM_{}", aggregate.column_name()?);
let count_col = &format!("_COUNT_{}", aggregate.column_name()?);
sums = sums
.with((count_col, Expr::sum(Expr::val(1.))))
.with((sum_col, Expr::sum(Expr::col(aggregate.column_name()?))));
output = output
.with((name, Expr::divide(Expr::col(sum_col), Expr::greatest(Expr::val(1.), Expr::col(count_col)))))
},
output = output.with((
name,
Expr::divide(
Expr::col(sum_col),
Expr::greatest(Expr::val(1.), Expr::col(count_col)),
),
))
}
aggregate::Aggregate::Count => {
let count_col = &format!("_COUNT_{}", aggregate.column_name()?);
sums = sums.with((count_col, Expr::sum(Expr::val(1.))));
output = output.with((name, Expr::col(count_col)));
},
aggregate::Aggregate::Sum if aggregate.column_name()? != protected_entity_weight => {
}
aggregate::Aggregate::Sum
if aggregate.column_name()? != protected_entity_weight =>
{
let sum_col = &format!("_SUM_{}", aggregate.column_name()?);
sums = sums.with((sum_col, Expr::sum(Expr::col(aggregate.column_name()?))));
output = output.with((name, Expr::col(sum_col)));
},
}
aggregate::Aggregate::Std => todo!(),
aggregate::Aggregate::Var => todo!(),
_ => (),
}
}
sums = sums.group_by_iter(self.group_by().iter().cloned());
let sums: Reduce = sums.input(self.input().clone()).build();
let dp_sums: Relation = sums.dp_compile_sums(protected_entity_id, protected_entity_weight, epsilon, delta)?.into();
let dp_sums: Relation = sums
.dp_compile_sums(protected_entity_id, protected_entity_weight, epsilon, delta)?
.into();
Ok(DPRelation(output.input(dp_sums).build()))
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/differential_privacy/protect_grouping_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ impl Relation {
#[cfg(test)]
mod tests {
use super::*;
use crate::{display::Dot, relation::Schema, expr::AggregateColumn};
use crate::{display::Dot, expr::AggregateColumn, relation::Schema};
use std::rc::Rc;

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/expr/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Identifier {
pub fn head(&self) -> Result<&str> {
self.0.get(0).map_or_else(
|| Err(Error::invalid_expression("Identifier too short")),
|h| Ok(h.as_str())
|h| Ok(h.as_str()),
)
}

Expand All @@ -36,7 +36,7 @@ impl Identifier {
pub fn last(&self) -> Result<&str> {
self.0.last().map_or_else(
|| Err(Error::invalid_expression("Identifier too short")),
|h| Ok(h.as_str())
|h| Ok(h.as_str()),
)
}

Expand Down
6 changes: 3 additions & 3 deletions src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{
collections::BTreeMap,
convert::identity,
error, fmt, hash,
ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Sub, Deref},
ops::{Add, BitAnd, BitOr, BitXor, Deref, Div, Mul, Neg, Not, Rem, Sub},
rc::Rc,
result,
};
Expand Down Expand Up @@ -732,7 +732,7 @@ impl AggregateColumn {
AggregateColumn {
aggregate,
column: column.clone(),
expr: Expr::Aggregate(Aggregate::new(aggregate, Rc::new(Expr::Column(column))))
expr: Expr::Aggregate(Aggregate::new(aggregate, Rc::new(Expr::Column(column)))),
}
}
/// Access aggregate
Expand Down Expand Up @@ -782,7 +782,7 @@ impl TryFrom<Expr> for AggregateColumn {
} else {
Err(Error::invalid_conversion(argument, "Column"))
}
},
}
_ => Err(Error::invalid_conversion(value, "AggregateColumn")),
}
}
Expand Down
Loading