From 8956ca46fc8ce24833fc6a189949511bb3690a7d Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 15:42:28 +0200 Subject: [PATCH 1/8] fmt --- src/data_type/function.rs | 753 ++++++++++++---------- src/dialect_translation/bigquery.rs | 11 +- src/dialect_translation/mod.rs | 42 +- src/dialect_translation/mssql.rs | 24 +- src/dialect_translation/postgresql.rs | 16 +- src/differential_privacy/aggregates.rs | 529 +++++++++------ src/differential_privacy/dp_event.rs | 66 +- src/differential_privacy/dp_parameters.rs | 41 +- src/differential_privacy/group_by.rs | 7 +- src/differential_privacy/mod.rs | 137 ++-- src/expr/aggregate.rs | 4 +- src/expr/bijection.rs | 50 +- src/expr/dot.rs | 10 +- src/expr/function.rs | 14 +- src/expr/implementation.rs | 47 +- src/expr/mod.rs | 193 +++--- src/expr/split.rs | 16 +- src/expr/sql.rs | 402 +++++++----- src/hierarchy.rs | 6 +- src/io/bigquery.rs | 13 +- src/io/mod.rs | 8 +- src/io/mssql.rs | 2 +- src/io/sqlite.rs | 11 +- src/lib.rs | 2 +- src/relation/builder.rs | 2 +- src/relation/dot.rs | 5 +- src/relation/mod.rs | 44 +- src/relation/rewriting.rs | 59 +- src/relation/sql.rs | 27 +- src/sql/expr.rs | 432 +++++++------ src/sql/mod.rs | 19 +- src/sql/relation.rs | 528 +++++++-------- 32 files changed, 2016 insertions(+), 1504 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index ebc8a523..31c5d208 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1,7 +1,18 @@ +use super::{ + super::data_type, + injection, + intervals::{Bound, Intervals}, + product::{self, IntervalProduct, IntervalsProduct, Term, Unit}, + value::{self, Value, Variant as _}, + DataType, DataTyped, Integer, List, Variant, +}; +use chrono::{Datelike, Timelike}; +use itertools::Itertools; use std::{ borrow::BorrowMut, cell::RefCell, - cmp, collections::{self, HashSet}, + cmp, + collections::{self, HashSet}, convert::{Infallible, TryFrom, TryInto}, error, fmt, hash::Hasher, @@ -9,16 +20,6 @@ use std::{ result, sync::{Arc, Mutex}, }; -use itertools::Itertools; -use chrono::{Datelike, Timelike}; -use super::{ - super::data_type, - injection, - intervals::{Bound, Intervals}, - product::{self, IntervalProduct, IntervalsProduct, Term, Unit}, - value::{self, Value, Variant as _}, - DataType, DataTyped, Integer, List, Variant, -}; use crate::{ builder::With, @@ -1155,72 +1156,54 @@ Conversion function /// Builds the cast operator pub fn cast(into: DataType) -> impl Function { match into { - DataType::Text(t) if t == data_type::Text::full() => { - Polymorphic::default() - .with( - Pointwise::univariate( - DataType::Any, - DataType::text(), - |v| v.to_string().into() - ) - ) - } - DataType::Float(f) if f == data_type::Float::full() => { - Polymorphic::from(( - PartitionnedMonotonic::univariate( - data_type::Integer::default(), - |v| v as f64 - ), - Pointwise::univariate( - DataType::text(), - DataType::float(), - |v| v.to_string().parse::().unwrap().into() - ) - )) - } - DataType::Integer(i) if i == data_type::Integer::full() => { - Polymorphic::from(( - PartitionnedMonotonic::univariate( - data_type::Float::default(), - |v| v.round() as i64 - ), - Pointwise::univariate( - DataType::text(), - DataType::integer(), - |v| v.to_string().parse::().unwrap().into() - ) - )) - } - DataType::Boolean(b) if b == data_type::Boolean::full() => { - Polymorphic::default() - .with( - Pointwise::univariate( - DataType::text(), - DataType::boolean(), - |v| { - let true_list = vec![ - "t".to_string(), "tr".to_string(), "tru".to_string(), "true".to_string(), - "y".to_string(), "ye".to_string(), "yes".to_string(), - "on".to_string(), - "1".to_string() - ]; - let false_list = vec![ - "f".to_string(), "fa".to_string(), "fal".to_string(), "fals".to_string(), "false".to_string(), - "n".to_string(), "no".to_string(), - "off".to_string(), - "0".to_string() - ]; - if true_list.contains(&v.to_string().to_lowercase()) { - true.into() - } else if false_list.contains(&v.to_string().to_lowercase()) { - false.into() - } else { - panic!() - } - } - ) - ) - } + DataType::Text(t) if t == data_type::Text::full() => Polymorphic::default().with( + Pointwise::univariate(DataType::Any, DataType::text(), |v| v.to_string().into()), + ), + DataType::Float(f) if f == data_type::Float::full() => Polymorphic::from(( + PartitionnedMonotonic::univariate(data_type::Integer::default(), |v| v as f64), + Pointwise::univariate(DataType::text(), DataType::float(), |v| { + v.to_string().parse::().unwrap().into() + }), + )), + DataType::Integer(i) if i == data_type::Integer::full() => Polymorphic::from(( + PartitionnedMonotonic::univariate(data_type::Float::default(), |v| v.round() as i64), + Pointwise::univariate(DataType::text(), DataType::integer(), |v| { + v.to_string().parse::().unwrap().into() + }), + )), + DataType::Boolean(b) if b == data_type::Boolean::full() => Polymorphic::default().with( + Pointwise::univariate(DataType::text(), DataType::boolean(), |v| { + let true_list = vec![ + "t".to_string(), + "tr".to_string(), + "tru".to_string(), + "true".to_string(), + "y".to_string(), + "ye".to_string(), + "yes".to_string(), + "on".to_string(), + "1".to_string(), + ]; + let false_list = vec![ + "f".to_string(), + "fa".to_string(), + "fal".to_string(), + "fals".to_string(), + "false".to_string(), + "n".to_string(), + "no".to_string(), + "off".to_string(), + "0".to_string(), + ]; + if true_list.contains(&v.to_string().to_lowercase()) { + true.into() + } else if false_list.contains(&v.to_string().to_lowercase()) { + false.into() + } else { + panic!() + } + }), + ), _ => todo!(), } } @@ -1722,70 +1705,56 @@ pub fn cos() -> impl Function { pub fn least() -> impl Function { Polymorphic::default() - .with( - PartitionnedMonotonic::bivariate( - (data_type::Integer::default(), data_type::Integer::default()), - |x, y| x.min(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Float::default(), data_type::Float::default()), - |x, y| x.min(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Time::default(), data_type::Time::default()), - |x, y| x.min(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Date::default(), data_type::Date::default()), - |x, y| x.min(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::DateTime::default(), data_type::DateTime::default()), - |x, y| x.min(y), - ) - ) + .with(PartitionnedMonotonic::bivariate( + (data_type::Integer::default(), data_type::Integer::default()), + |x, y| x.min(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Float::default(), data_type::Float::default()), + |x, y| x.min(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Time::default(), data_type::Time::default()), + |x, y| x.min(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Date::default(), data_type::Date::default()), + |x, y| x.min(y), + )) + .with(PartitionnedMonotonic::bivariate( + ( + data_type::DateTime::default(), + data_type::DateTime::default(), + ), + |x, y| x.min(y), + )) } pub fn greatest() -> impl Function { Polymorphic::default() - .with( - PartitionnedMonotonic::bivariate( - (data_type::Integer::default(), data_type::Integer::default()), - |x, y| x.max(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Float::default(), data_type::Float::default()), - |x, y| x.max(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Time::default(), data_type::Time::default()), - |x, y| x.max(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::Date::default(), data_type::Date::default()), - |x, y| x.max(y), - ) - ) - .with( - PartitionnedMonotonic::bivariate( - (data_type::DateTime::default(), data_type::DateTime::default()), - |x, y| x.max(y), - ) - ) + .with(PartitionnedMonotonic::bivariate( + (data_type::Integer::default(), data_type::Integer::default()), + |x, y| x.max(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Float::default(), data_type::Float::default()), + |x, y| x.max(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Time::default(), data_type::Time::default()), + |x, y| x.max(y), + )) + .with(PartitionnedMonotonic::bivariate( + (data_type::Date::default(), data_type::Date::default()), + |x, y| x.max(y), + )) + .with(PartitionnedMonotonic::bivariate( + ( + data_type::DateTime::default(), + data_type::DateTime::default(), + ), + |x, y| x.max(y), + )) } // String functions @@ -1828,25 +1797,34 @@ pub fn regexp_contains() -> impl Function { Unimplemented::new( DataType::structured_from_data_types([DataType::text(), DataType::text()]), DataType::boolean(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } /// Regexp extract pub fn regexp_extract() -> impl Function { Unimplemented::new( - DataType::structured_from_data_types([DataType::text(), DataType::text(), DataType::integer(), DataType::integer()]), + DataType::structured_from_data_types([ + DataType::text(), + DataType::text(), + DataType::integer(), + DataType::integer(), + ]), DataType::optional(DataType::text()), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } /// Regexp replace pub fn regexp_replace() -> impl Function { Unimplemented::new( - DataType::structured_from_data_types([DataType::text(), DataType::text(), DataType::text()]), + DataType::structured_from_data_types([ + DataType::text(), + DataType::text(), + DataType::text(), + ]), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1855,7 +1833,7 @@ pub fn newid() -> impl Function { Unimplemented::new( DataType::unit(), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1864,7 +1842,7 @@ pub fn encode() -> impl Function { Unimplemented::new( DataType::structured_from_data_types([DataType::text(), DataType::text()]), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1873,7 +1851,7 @@ pub fn decode() -> impl Function { Unimplemented::new( DataType::structured_from_data_types([DataType::text(), DataType::text()]), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1882,7 +1860,7 @@ pub fn unhex() -> impl Function { Unimplemented::new( DataType::text(), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1890,7 +1868,7 @@ pub fn like() -> impl Function { Unimplemented::new( DataType::text(), DataType::boolean(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1898,33 +1876,33 @@ pub fn ilike() -> impl Function { Unimplemented::new( DataType::text(), DataType::boolean(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } pub fn is_null() -> impl Function { - Pointwise::univariate( - DataType::Any, - DataType::boolean(), - |v| if let Value::Optional(o) = v { + Pointwise::univariate(DataType::Any, DataType::boolean(), |v| { + if let Value::Optional(o) = v { o.is_none() } else { false - }.into() - ) + } + .into() + }) } pub fn is_bool() -> impl Function { Pointwise::bivariate( (DataType::optional(DataType::boolean()), DataType::boolean()), data_type::Boolean::default(), - |a, b| if let Value::Optional(o) = a { - o.as_ref() - .map(|x| *x.deref() == b) - .unwrap_or(false) - } else { - a == b - }.into() + |a, b| { + if let Value::Optional(o) = a { + o.as_ref().map(|x| *x.deref() == b).unwrap_or(false) + } else { + a == b + } + .into() + }, ) } @@ -1933,7 +1911,7 @@ pub fn current_date() -> impl Function { Unimplemented::new( DataType::unit(), DataType::date(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1941,7 +1919,7 @@ pub fn current_time() -> impl Function { Unimplemented::new( DataType::unit(), DataType::time(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -1949,17 +1927,15 @@ pub fn current_timestamp() -> impl Function { Unimplemented::new( DataType::unit(), DataType::date_time(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } pub fn extract_year() -> impl Function { Polymorphic::from(( - Pointwise::univariate( - data_type::Date::default(), - DataType::integer_min(0), - |a| (a.year() as i64).into(), - ), + Pointwise::univariate(data_type::Date::default(), DataType::integer_min(0), |a| { + (a.year() as i64).into() + }), Pointwise::univariate( data_type::DateTime::default(), DataType::integer_min(0), @@ -2107,29 +2083,53 @@ pub fn dayname() -> impl Function { Polymorphic::from(( Pointwise::univariate( data_type::Date::default(), - DataType::text_values(["Monday".to_string(), "Tuesday".to_string(), "Wednesday".to_string(), "Thursday".to_string(), "Friday".to_string(), "Saturday".to_string(), "Sunday".to_string()]), - |a| (match a.weekday(){ - chrono::Weekday::Mon => "Monday", - chrono::Weekday::Tue => "Tuesday", - chrono::Weekday::Wed => "Wednesday", - chrono::Weekday::Thu => "Thursday", - chrono::Weekday::Fri => "Friday", - chrono::Weekday::Sat => "Saturday", - chrono::Weekday::Sun => "Sunday", - }).to_string().into(), + DataType::text_values([ + "Monday".to_string(), + "Tuesday".to_string(), + "Wednesday".to_string(), + "Thursday".to_string(), + "Friday".to_string(), + "Saturday".to_string(), + "Sunday".to_string(), + ]), + |a| { + (match a.weekday() { + chrono::Weekday::Mon => "Monday", + chrono::Weekday::Tue => "Tuesday", + chrono::Weekday::Wed => "Wednesday", + chrono::Weekday::Thu => "Thursday", + chrono::Weekday::Fri => "Friday", + chrono::Weekday::Sat => "Saturday", + chrono::Weekday::Sun => "Sunday", + }) + .to_string() + .into() + }, ), Pointwise::univariate( data_type::DateTime::default(), - DataType::text_values(["Monday".to_string(), "Tuesday".to_string(), "Wednesday".to_string(), "Thursday".to_string(), "Friday".to_string(), "Saturday".to_string(), "Sunday".to_string()]), - |a| (match a.weekday(){ - chrono::Weekday::Mon => "Monday", - chrono::Weekday::Tue => "Tuesday", - chrono::Weekday::Wed => "Wednesday", - chrono::Weekday::Thu => "Thursday", - chrono::Weekday::Fri => "Friday", - chrono::Weekday::Sat => "Saturday", - chrono::Weekday::Sun => "Sunday", - }).to_string().into(), + DataType::text_values([ + "Monday".to_string(), + "Tuesday".to_string(), + "Wednesday".to_string(), + "Thursday".to_string(), + "Friday".to_string(), + "Saturday".to_string(), + "Sunday".to_string(), + ]), + |a| { + (match a.weekday() { + chrono::Weekday::Mon => "Monday", + chrono::Weekday::Tue => "Tuesday", + chrono::Weekday::Wed => "Wednesday", + chrono::Weekday::Thu => "Thursday", + chrono::Weekday::Fri => "Friday", + chrono::Weekday::Sat => "Saturday", + chrono::Weekday::Sun => "Sunday", + }) + .to_string() + .into() + }, ), )) } @@ -2139,7 +2139,7 @@ pub fn from_unixtime() -> impl Function { Unimplemented::new( DataType::structured_from_data_types([DataType::integer(), DataType::text()]), DataType::sum([DataType::date(), DataType::date_time()]), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2148,7 +2148,7 @@ pub fn unix_timestamp() -> impl Function { Unimplemented::new( DataType::sum([DataType::date(), DataType::date_time()]), DataType::integer(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2157,10 +2157,10 @@ pub fn date_format() -> impl Function { Unimplemented::new( DataType::structured_from_data_types([ DataType::sum([DataType::date(), DataType::date_time(), DataType::text()]), - DataType::text() + DataType::text(), ]), DataType::text(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2170,22 +2170,28 @@ pub fn quarter() -> impl Function { Pointwise::univariate( data_type::Date::default(), DataType::integer_interval(1, 4), - |a| match a.month() { - 1..=3 => 1, - 4..=6 => 2, - 7..=9 => 3, - _ => 4, - }.into() + |a| { + match a.month() { + 1..=3 => 1, + 4..=6 => 2, + 7..=9 => 3, + _ => 4, + } + .into() + }, ), Pointwise::univariate( data_type::DateTime::default(), DataType::integer_interval(1, 4), - |a| match a.month() { - 1..=3 => 1, - 4..=6 => 2, - 7..=9 => 3, - _ => 4, - }.into() + |a| { + match a.month() { + 1..=3 => 1, + 4..=6 => 2, + 7..=9 => 3, + _ => 4, + } + .into() + }, ), )) } @@ -2197,10 +2203,10 @@ pub fn datetime_diff() -> impl Function { DataType::structured_from_data_types([ DataType::sum([DataType::date(), DataType::date_time(), DataType::text()]), DataType::sum([DataType::date(), DataType::date_time(), DataType::text()]), - DataType::text() + DataType::text(), ]), DataType::integer(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2224,7 +2230,7 @@ pub fn cast_as_date() -> impl Function { Unimplemented::new( DataType::sum([DataType::text(), DataType::date_time()]), DataType::date(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2232,7 +2238,7 @@ pub fn cast_as_datetime() -> impl Function { Unimplemented::new( DataType::sum([DataType::text(), DataType::date()]), DataType::date(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2240,7 +2246,7 @@ pub fn cast_as_time() -> impl Function { Unimplemented::new( DataType::sum([DataType::text(), DataType::date_time()]), DataType::date_time(), - Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))) + Arc::new(Mutex::new(RefCell::new(|_| unimplemented!()))), ) } @@ -2265,18 +2271,12 @@ pub fn coalesce() -> impl Function { // Ceil function pub fn ceil() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::default(), - |a| a.ceil(), - ) + PartitionnedMonotonic::univariate(data_type::Float::default(), |a| a.ceil()) } // Floor function pub fn floor() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::default(), - |a| a.floor(), - ) + PartitionnedMonotonic::univariate(data_type::Float::default(), |a| a.floor()) } // Round function @@ -2288,7 +2288,7 @@ pub fn round() -> impl Function { |a, b| { let multiplier = 10.0_f64.powi(b as i32); (a * multiplier).round() / multiplier - } + }, ) } @@ -2301,33 +2301,44 @@ pub fn trunc() -> impl Function { |a, b| { let multiplier = 10.0_f64.powi(b as i32); (a * multiplier).trunc() / multiplier - } + }, ) } // Sign function pub fn sign() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::default(), - |a| if a == 0. {0} else if a < 0. {-1} else {1} - ) + PartitionnedMonotonic::univariate(data_type::Float::default(), |a| { + if a == 0. { + 0 + } else if a < 0. { + -1 + } else { + 1 + } + }) } pub fn choose() -> impl Function { Pointwise::new( - DataType::structured_from_data_types([DataType::integer(), DataType::list(DataType::Any, 1, i64::MAX as usize)]), + DataType::structured_from_data_types([ + DataType::integer(), + DataType::list(DataType::Any, 1, i64::MAX as usize), + ]), DataType::optional(DataType::Any), Arc::new(|v| { if let Value::Struct(s) = v { - if let (Value::Integer(i),Value::List(l)) = (s[0].as_ref(), s[1].as_ref()) { - Ok(value::Optional::new(l.get(*i.deref() as usize).map(|v| Arc::new(v.clone()))).into()) + if let (Value::Integer(i), Value::List(l)) = (s[0].as_ref(), s[1].as_ref()) { + Ok(value::Optional::new( + l.get(*i.deref() as usize).map(|v| Arc::new(v.clone())), + ) + .into()) } else { - return Err(Error::other("Argument out of range")) + return Err(Error::other("Argument out of range")); } } else { Err(Error::other("Argument out of range")) } - }) + }), ) } @@ -2389,9 +2400,13 @@ pub fn mean_distinct() -> impl Function { Aggregate::from( data_type::Float::full(), |values| { - let (count, sum) = values.into_iter().collect::>().into_iter().fold((0.0, 0.0), |(count, sum), value| { - (count + 1.0, sum + f64::from(value)) - }); + let (count, sum) = values + .into_iter() + .collect::>() + .into_iter() + .fold((0.0, 0.0), |(count, sum), value| { + (count + 1.0, sum + f64::from(value)) + }); (sum / count).into() }, |(intervals, _size)| Ok(intervals.into_interval()), @@ -2433,7 +2448,7 @@ pub fn count_distinct() -> impl Function { // Any implementation Aggregate::from( DataType::Any, - |values| (values.iter().cloned().collect::>().len()as i64).into(), + |values| (values.iter().cloned().collect::>().len() as i64).into(), |(_dt, size)| Ok(data_type::Integer::from_interval(1, *size.max().unwrap())), ), // Optional implementation @@ -2557,7 +2572,16 @@ pub fn sum_distinct() -> impl Function { // Integer implementation Aggregate::from( data_type::Integer::full(), - |values| values.iter().cloned().collect::>().into_iter().map(|f| *f).sum::().into(), + |values| { + values + .iter() + .cloned() + .collect::>() + .into_iter() + .map(|f| *f) + .sum::() + .into() + }, |(intervals, size)| { Ok(data_type::Integer::try_from(multiply().super_image( &DataType::structured_from_data_types([intervals.into(), size.into()]), @@ -2567,7 +2591,16 @@ pub fn sum_distinct() -> impl Function { // Float implementation Aggregate::from( data_type::Float::full(), - |values| values.iter().cloned().collect::>().into_iter().map(|f| *f).sum::().into(), + |values| { + values + .iter() + .cloned() + .collect::>() + .into_iter() + .map(|f| *f) + .sum::() + .into() + }, |(intervals, size)| { Ok(data_type::Float::try_from(multiply().super_image( &DataType::structured_from_data_types([intervals.into(), size.into()]), @@ -2614,19 +2647,17 @@ pub fn std_distinct() -> impl Function { Aggregate::from( data_type::Float::full(), |values| { - let (count, sum, sum_2) = - values - .into_iter() - .collect::>() - .into_iter() - .fold((0.0, 0.0, 0.0), |(count, sum, sum_2), value| { - let value: f64 = value.into(); - ( - count + 1.0, - sum + f64::from(value), - sum_2 + (f64::from(value) * f64::from(value)), - ) - }); + let (count, sum, sum_2) = values.into_iter().collect::>().into_iter().fold( + (0.0, 0.0, 0.0), + |(count, sum, sum_2), value| { + let value: f64 = value.into(); + ( + count + 1.0, + sum + f64::from(value), + sum_2 + (f64::from(value) * f64::from(value)), + ) + }, + ); ((sum_2 - sum * sum / count) / (count - 1.)).sqrt().into() }, |(intervals, _size)| match (intervals.min(), intervals.max()) { @@ -2671,19 +2702,17 @@ pub fn var_distinct() -> impl Function { Aggregate::from( data_type::Float::full(), |values| { - let (count, sum, sum_2) = - values - .into_iter() - .collect::>() - .into_iter() - .fold((0.0, 0.0, 0.0), |(count, sum, sum_2), value| { - let value: f64 = value.into(); - ( - count + 1.0, - sum + f64::from(value), - sum_2 + (f64::from(value) * f64::from(value)), - ) - }); + let (count, sum, sum_2) = values.into_iter().collect::>().into_iter().fold( + (0.0, 0.0, 0.0), + |(count, sum, sum_2), value| { + let value: f64 = value.into(); + ( + count + 1.0, + sum + f64::from(value), + sum_2 + (f64::from(value) * f64::from(value)), + ) + }, + ); ((sum_2 - sum * sum / count) / (count - 1.)).into() }, |(intervals, _size)| match (intervals.min(), intervals.max()) { @@ -3813,11 +3842,11 @@ mod tests { let set: DataType = DataType::structured_from_data_types([ DataType::date_time_interval( NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap(), - NaiveDateTime::from_timestamp_opt(1862921288, 111110).unwrap() + NaiveDateTime::from_timestamp_opt(1862921288, 111110).unwrap(), ), DataType::date_time_interval( NaiveDateTime::from_timestamp_opt(1362921288, 0).unwrap(), - NaiveDateTime::from_timestamp_opt(2062921288, 111110).unwrap() + NaiveDateTime::from_timestamp_opt(2062921288, 111110).unwrap(), ), ]); let im = fun.super_image(&set).unwrap(); @@ -4228,8 +4257,12 @@ mod tests { println!("data_type = {}", fun.data_type()); let set = DataType::from(Struct::from_data_types(&[ - DataType::text_values(["foo@example.com".to_string(), "bar@example.org".to_string(), "www.example.net".to_string()]), - DataType::text_value(r"@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+".to_string()) + DataType::text_values([ + "foo@example.com".to_string(), + "bar@example.org".to_string(), + "www.example.net".to_string(), + ]), + DataType::text_value(r"@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+".to_string()), ])); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4249,7 +4282,7 @@ mod tests { DataType::text_value("Hello Helloo and Hellooo".to_string()), DataType::text_value("H?ello+".to_string()), DataType::integer_value(3), - DataType::integer_value(1) + DataType::integer_value(1), ])); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4287,15 +4320,21 @@ mod tests { let set = DataType::date_values([ NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - NaiveDate::from_ymd_opt(2023, 12, 31).unwrap() + NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_value(2023)); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 8).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 8) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4311,15 +4350,21 @@ mod tests { let set = DataType::date_values([ NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - NaiveDate::from_ymd_opt(2023, 12, 31).unwrap() + NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([1, 12])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 8).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 8) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4335,15 +4380,21 @@ mod tests { let set = DataType::date_values([ NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - NaiveDate::from_ymd_opt(2023, 12, 01).unwrap() + NaiveDate::from_ymd_opt(2023, 12, 01).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_value(1)); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4359,15 +4410,21 @@ mod tests { let set = DataType::date_values([ NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - NaiveDate::from_ymd_opt(2023, 12, 01).unwrap() + NaiveDate::from_ymd_opt(2023, 12, 01).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([0, 5])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4383,15 +4440,21 @@ mod tests { let set = DataType::date_values([ NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - NaiveDate::from_ymd_opt(2023, 12, 01).unwrap() + NaiveDate::from_ymd_opt(2023, 12, 01).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([48, 52])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4407,15 +4470,21 @@ mod tests { let set = DataType::time_values([ NaiveTime::from_hms_milli_opt(10, 12, 13, 14).unwrap(), - NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap() + NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([10, 11])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(6, 15, 1).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(6, 15, 1) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4431,15 +4500,21 @@ mod tests { let set = DataType::time_values([ NaiveTime::from_hms_milli_opt(10, 12, 13, 14).unwrap(), - NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap() + NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([12, 57])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(6, 15, 1).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(6, 15, 1) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4455,15 +4530,21 @@ mod tests { let set = DataType::time_values([ NaiveTime::from_hms_milli_opt(10, 12, 13, 14).unwrap(), - NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap() + NaiveTime::from_hms_milli_opt(11, 57, 58, 59).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([13, 58])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(6, 15, 1).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(6, 15, 1) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4479,15 +4560,21 @@ mod tests { let set = DataType::time_values([ NaiveTime::from_hms_milli_opt(10, 12, 13, 14).unwrap(), - NaiveTime::from_hms_milli_opt(11, 57, 59, 999).unwrap() + NaiveTime::from_hms_milli_opt(11, 57, 59, 999).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_values([13014000, 59999000])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(6, 15, 1).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(6, 15, 1) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4503,15 +4590,21 @@ mod tests { let set = DataType::time_values([ NaiveTime::from_hms_milli_opt(10, 12, 13, 14).unwrap(), - NaiveTime::from_hms_milli_opt(11, 57, 59, 999).unwrap() + NaiveTime::from_hms_milli_opt(11, 57, 59, 999).unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::float_values([13014.000, 59999.000])); let set = DataType::date_time_values([ - NaiveDate::from_ymd_opt(2016, 7, 18).unwrap().and_hms_opt(9, 10, 11).unwrap(), - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(6, 15, 1).unwrap(), + NaiveDate::from_ymd_opt(2016, 7, 18) + .unwrap() + .and_hms_opt(9, 10, 11) + .unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(6, 15, 1) + .unwrap(), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4527,15 +4620,16 @@ mod tests { println!("co_domain = {}", fun.co_domain()); println!("data_type = {}", fun.data_type()); - let set = DataType::date_value( - NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - ); + let set = DataType::date_value(NaiveDate::from_ymd_opt(2023, 01, 01).unwrap()); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::text_value("Sunday".to_string())); let set = DataType::date_time_value( - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4551,15 +4645,16 @@ mod tests { println!("co_domain = {}", fun.co_domain()); println!("data_type = {}", fun.data_type()); - let set = DataType::date_value( - NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - ); + let set = DataType::date_value(NaiveDate::from_ymd_opt(2023, 01, 01).unwrap()); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::integer_value(1)); let set = DataType::date_time_value( - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4575,15 +4670,16 @@ mod tests { println!("co_domain = {}", fun.co_domain()); println!("data_type = {}", fun.data_type()); - let set = DataType::date_value( - NaiveDate::from_ymd_opt(2023, 01, 01).unwrap(), - ); + let set = DataType::date_value(NaiveDate::from_ymd_opt(2023, 01, 01).unwrap()); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::date_value(NaiveDate::from_ymd_opt(2023, 01, 01).unwrap())); let set = DataType::date_time_value( - NaiveDate::from_ymd_opt(2026, 7, 8).unwrap().and_hms_opt(9, 15, 11).unwrap(), + NaiveDate::from_ymd_opt(2026, 7, 8) + .unwrap() + .and_hms_opt(9, 15, 11) + .unwrap(), ); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); @@ -4601,20 +4697,29 @@ mod tests { let set = DataType::structured_from_data_types([ DataType::integer_value(2), - DataType::list(DataType::text_values( - ["a".to_string(), "b".to_string(), "c".to_string()] - ), 1, 3) + DataType::list( + DataType::text_values(["a".to_string(), "b".to_string(), "c".to_string()]), + 1, + 3, + ), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::optional(DataType::Any)); let arg = Value::structured_from_values([ 2.into(), - Value::list(["a".to_string(), "b".to_string(), "c".to_string()].into_iter().map(|s| Value::from(s))) + Value::list( + ["a".to_string(), "b".to_string(), "c".to_string()] + .into_iter() + .map(|s| Value::from(s)), + ), ]); let val = fun.value(&arg).unwrap(); println!("val({}) = {}", arg, val); - assert_eq!(val, Value::Optional(value::Optional::new(Some(Arc::new("c".to_string().into()))))); + assert_eq!( + val, + Value::Optional(value::Optional::new(Some(Arc::new("c".to_string().into())))) + ); } #[test] @@ -4654,22 +4759,18 @@ mod tests { // True let set = DataType::structured_from_data_types([ DataType::boolean(), - DataType::boolean_value(true) + DataType::boolean_value(true), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::boolean()); - let arg = Value::structured_from_values([ - Value::from(true), Value::from(true) - ]); + let arg = Value::structured_from_values([Value::from(true), Value::from(true)]); let val = fun.value(&arg).unwrap(); println!("val({}) = {}", arg, val); assert_eq!(val, true.into()); - let arg = Value::structured_from_values([ - Value::from(true), Value::from(false) - ]); + let arg = Value::structured_from_values([Value::from(true), Value::from(false)]); let val = fun.value(&arg).unwrap(); println!("val({}) = {}", arg, val); assert_eq!(val, false.into()); @@ -4677,26 +4778,20 @@ mod tests { // False let set = DataType::structured_from_data_types([ DataType::boolean(), - DataType::boolean_value(false) + DataType::boolean_value(false), ]); let im = fun.super_image(&set).unwrap(); println!("im({}) = {}", set, im); assert!(im == DataType::boolean()); - let arg = Value::structured_from_values([ - Value::from(false), Value::from(false) - ]); + let arg = Value::structured_from_values([Value::from(false), Value::from(false)]); let val = fun.value(&arg).unwrap(); println!("val({}) = {}", arg, val); assert_eq!(val, true.into()); - let arg = Value::structured_from_values([ - Value::from(false), Value::from(true) - ]); + let arg = Value::structured_from_values([Value::from(false), Value::from(true)]); let val = fun.value(&arg).unwrap(); println!("val({}) = {}", arg, val); assert_eq!(val, false.into()); } - - } diff --git a/src/dialect_translation/bigquery.rs b/src/dialect_translation/bigquery.rs index 73140d32..bfe23005 100644 --- a/src/dialect_translation/bigquery.rs +++ b/src/dialect_translation/bigquery.rs @@ -6,14 +6,15 @@ use crate::{ use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; use sqlparser::{ast, dialect::BigQueryDialect}; - - #[derive(Clone, Copy)] pub struct BigQueryTranslator; impl RelationToQueryTranslator for BigQueryTranslator { fn identifier(&self, value: &expr::Identifier) -> Vec { - value.iter().map(|r| ast::Ident::with_quote('`', r)).collect() + value + .iter() + .map(|r| ast::Ident::with_quote('`', r)) + .collect() } fn cte(&self, name: ast::Ident, _columns: Vec, query: ast::Query) -> ast::Cte { @@ -24,7 +25,7 @@ impl RelationToQueryTranslator for BigQueryTranslator { }, query: Box::new(query), from: None, - materialized: None + materialized: None, } } fn first(&self, expr: &expr::Expr) -> ast::Expr { @@ -46,7 +47,7 @@ impl RelationToQueryTranslator for BigQueryTranslator { function_builder("STDDEV", vec![arg], false) } /// Converting LOG to LOG10 - fn log(&self,expr: &expr::Expr) -> ast::Expr { + fn log(&self, expr: &expr::Expr) -> ast::Expr { let arg = self.expr(expr); function_builder("LOG10", vec![arg], false) } diff --git a/src/dialect_translation/mod.rs b/src/dialect_translation/mod.rs index b0bbaaa8..7f5dac20 100644 --- a/src/dialect_translation/mod.rs +++ b/src/dialect_translation/mod.rs @@ -154,7 +154,7 @@ macro_rules! relation_to_query_tranlator_trait_constructor { named_window: vec![], window_before_qualify: false, value_table_mode: None, - connect_by: None + connect_by: None, }))), order_by, limit, @@ -174,12 +174,12 @@ macro_rules! relation_to_query_tranlator_trait_constructor { global: None, if_not_exists: true, transient: false, - name: ast::ObjectName(self.identifier( &(table.path().clone().into()) )), + name: ast::ObjectName(self.identifier(&(table.path().clone().into()))), columns: table .schema() .iter() .map(|f| ast::ColumnDef { - name: self.identifier( &(f.name().into()) )[0].clone(), + name: self.identifier(&(f.name().into()))[0].clone(), data_type: f.data_type().into(), collation: None, options: if let DataType::Optional(_) = f.data_type() { @@ -222,9 +222,13 @@ macro_rules! relation_to_query_tranlator_trait_constructor { ast::Statement::Insert(ast::Insert { or: None, into: true, - table_name: ast::ObjectName(self.identifier( &(table.path().clone().into()) )), + table_name: ast::ObjectName(self.identifier(&(table.path().clone().into()))), table_alias: None, - columns: table.schema().iter().map(|f| self.identifier( &(f.name().into()) )[0].clone()).collect(), + columns: table + .schema() + .iter() + .map(|f| self.identifier(&(f.name().into()))[0].clone()) + .collect(), overwrite: false, source: Some(Box::new(ast::Query { with: None, @@ -265,10 +269,10 @@ macro_rules! relation_to_query_tranlator_trait_constructor { query: ast::Query, ) -> ast::Cte { ast::Cte { - alias: ast::TableAlias {name, columns}, + alias: ast::TableAlias { name, columns }, query: Box::new(query), from: None, - materialized: None + materialized: None, } } fn join_projection(&self, join: &Join) -> Vec { @@ -278,7 +282,10 @@ macro_rules! relation_to_query_tranlator_trait_constructor { } fn identifier(&self, value: &expr::Identifier) -> Vec { - value.iter().map(|r| ast::Ident::with_quote('"', r)).collect() + value + .iter() + .map(|r| ast::Ident::with_quote('"', r)) + .collect() } fn table_factor(&self, relation: &Relation, alias: Option<&str>) -> ast::TableFactor { @@ -784,12 +791,15 @@ pub trait QueryToRelationTranslator { match args { ast::FunctionArguments::None => Ok(vec![]), ast::FunctionArguments::Subquery(_) => Ok(vec![]), - ast::FunctionArguments::List(arg_list) => arg_list.args + ast::FunctionArguments::List(arg_list) => arg_list + .args .iter() .map(|func_arg| match func_arg { - ast::FunctionArg::Named { name: _, arg, operator: _ } => { - self.try_function_arg_expr(arg, context) - } + ast::FunctionArg::Named { + name: _, + arg, + operator: _, + } => self.try_function_arg_expr(arg, context), ast::FunctionArg::Unnamed(arg) => self.try_function_arg_expr(arg, context), }) .collect(), @@ -826,7 +836,11 @@ fn function_builder(name: &str, exprs: Vec, distinct: bool) -> ast::E .collect(); let function_name = name.to_uppercase(); let name = ast::ObjectName(vec![ast::Ident::from(&function_name[..])]); - let ast_distinct = if distinct {Some(ast::DuplicateTreatment::Distinct)} else {None}; + let ast_distinct = if distinct { + Some(ast::DuplicateTreatment::Distinct) + } else { + None + }; let func_args_list = ast::FunctionArgumentList { duplicate_treatment: ast_distinct, args: function_args, @@ -838,7 +852,7 @@ fn function_builder(name: &str, exprs: Vec, distinct: bool) -> ast::E over: None, filter: None, null_treatment: None, - within_group: vec![] + within_group: vec![], }; ast::Expr::Function(function) } diff --git a/src/dialect_translation/mssql.rs b/src/dialect_translation/mssql.rs index d7bfd9c2..28bb23cb 100644 --- a/src/dialect_translation/mssql.rs +++ b/src/dialect_translation/mssql.rs @@ -60,11 +60,11 @@ impl RelationToQueryTranslator for MsSqlTranslator { /// Converting MD5(X) to CONVERT(VARCHAR(MAX), HASHBYTES('MD5', X), 2) fn md5(&self, expr: &expr::Expr) -> ast::Expr { let ast_expr = self.expr(expr); - + // Construct HASHBYTES('MD5', X) let md5_literal = ast::Expr::Value(ast::Value::SingleQuotedString("MD5".to_string())); let md5_literal_as_function_arg = - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(md5_literal)); + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(md5_literal)); let ast_expr_as_function_arg = ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(ast_expr)); @@ -88,7 +88,7 @@ impl RelationToQueryTranslator for MsSqlTranslator { data_type: Some(ast::DataType::Varchar(Some(ast::CharacterLength::Max))), charset: None, target_before_value: true, - styles: vec![ast::Expr::Value(ast::Value::Number("2".to_string(), false))] + styles: vec![ast::Expr::Value(ast::Value::Number("2".to_string(), false))], } } @@ -103,7 +103,10 @@ impl RelationToQueryTranslator for MsSqlTranslator { let ast_expr = self.expr(expr); ast::Expr::Cast { expr: Box::new(ast_expr), - data_type: ast::DataType::Nvarchar(Some(ast::CharacterLength::IntegerLength {length: 255, unit:None })), + data_type: ast::DataType::Nvarchar(Some(ast::CharacterLength::IntegerLength { + length: 255, + unit: None, + })), format: None, kind: ast::CastKind::Cast, } @@ -190,7 +193,7 @@ impl RelationToQueryTranslator for MsSqlTranslator { named_window: vec![], window_before_qualify: false, value_table_mode: None, - connect_by: None + connect_by: None, }))), order_by, limit: None, @@ -210,12 +213,12 @@ impl RelationToQueryTranslator for MsSqlTranslator { global: None, if_not_exists: false, transient: false, - name: ast::ObjectName(self.identifier( &(table.path().clone().into()) )), + name: ast::ObjectName(self.identifier(&(table.path().clone().into()))), columns: table .schema() .iter() .map(|f| ast::ColumnDef { - name: self.identifier( &(f.name().into()) )[0].clone(), + name: self.identifier(&(f.name().into()))[0].clone(), // Need to override some convertions data_type: { translate_data_type(f.data_type()) }, collation: None, @@ -301,7 +304,7 @@ impl QueryToRelationTranslator for MsSqlTranslator { } else { let is_first_arg_valid = is_varchar_valid(&args[0]); let is_last_arg_valid = is_literal_two_arg(&args[2]); - let extract_x_arg = extract_hashbyte_expression_if_valid(&args[1]); + let extract_x_arg = extract_hashbyte_expression_if_valid(&args[1]); if is_first_arg_valid && is_last_arg_valid && extract_x_arg.is_some() { let function_args = ast::FunctionArgumentList { duplicate_treatment: None, @@ -379,7 +382,10 @@ fn extract_hashbyte_expression_if_valid(func_arg: &ast::FunctionArg) -> Option ast::DataType fn translate_data_type(dtype: DataType) -> ast::DataType { match dtype { - DataType::Text(_) => ast::DataType::Nvarchar(Some(ast::CharacterLength::IntegerLength { length: 255, unit: None})), + DataType::Text(_) => ast::DataType::Nvarchar(Some(ast::CharacterLength::IntegerLength { + length: 255, + unit: None, + })), //DataType::Boolean(_) => Boolean should be displayed as BIT for MSSQL, // SQLParser doesn't support the BIT DataType (mssql equivalent of bool) DataType::Optional(o) => translate_data_type(o.data_type().clone()), diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 3a72dbcf..c7e7310a 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -9,7 +9,9 @@ use crate::{ Relation, }; -use super::{function_builder, RelationWithTranslator, QueryToRelationTranslator, RelationToQueryTranslator}; +use super::{ + function_builder, QueryToRelationTranslator, RelationToQueryTranslator, RelationWithTranslator, +}; use sqlparser::{ast, dialect::PostgreSqlDialect}; use crate::sql::{Error, Result}; @@ -105,7 +107,14 @@ mod tests { use super::*; use crate::{ - builder::{Ready, With}, data_type::{DataType, Value as _}, display::Dot, expr::Expr, io::{postgresql, Database as _}, namer, relation::{schema::Schema, Relation, TableBuilder}, sql::{parse, relation::QueryWithRelations} + builder::{Ready, With}, + data_type::{DataType, Value as _}, + display::Dot, + expr::Expr, + io::{postgresql, Database as _}, + namer, + relation::{schema::Schema, Relation, TableBuilder}, + sql::{parse, relation::QueryWithRelations}, }; use std::sync::Arc; @@ -156,7 +165,8 @@ mod tests { fn test_table_special() -> Result<()> { let mut database = postgresql::test_database(); let relations = database.relations(); - let query_str = r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM "MY SPECIAL TABLE" ORDER BY "Id" "#; + let query_str = + r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM "MY SPECIAL TABLE" ORDER BY "Id" "#; let translator = PostgreSqlTranslator; let query = parse_with_dialect(query_str, translator.dialect())?; let query_with_relation = QueryWithRelations::new(&query, &relations); diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index c562e998..f81fc589 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -3,11 +3,13 @@ use crate::{ data_type::DataTyped, differential_privacy::dp_event::DpEvent, differential_privacy::{dp_event, DpRelation, Error, Result}, - expr::{aggregate::{self, Aggregate}, AggregateColumn, Expr, Column, Identifier}, + expr::{ + aggregate::{self, Aggregate}, + AggregateColumn, Column, Expr, Identifier, + }, privacy_unit_tracking::PupRelation, relation::{field::Field, Map, Reduce, Relation, Variant}, DataType, Ready, - }; use std::{cmp, collections::HashMap, ops::Deref}; @@ -28,7 +30,14 @@ pub struct DpAggregatesParameters { } impl DpAggregatesParameters { - pub fn new(epsilon: f64, delta: f64, size: usize, privacy_unit_unique: bool, privacy_unit_max_multiplicity: f64, privacy_unit_max_multiplicity_share: f64) -> DpAggregatesParameters { + pub fn new( + epsilon: f64, + delta: f64, + size: usize, + privacy_unit_unique: bool, + privacy_unit_max_multiplicity: f64, + privacy_unit_max_multiplicity_share: f64, + ) -> DpAggregatesParameters { DpAggregatesParameters { epsilon, delta, @@ -40,7 +49,14 @@ impl DpAggregatesParameters { } pub fn from_dp_parameters(dp_parameters: DpParameters, share: f64) -> DpAggregatesParameters { - DpAggregatesParameters::new(dp_parameters.epsilon*share, dp_parameters.delta*share, 1, false, dp_parameters.privacy_unit_max_multiplicity, dp_parameters.privacy_unit_max_multiplicity_share) + DpAggregatesParameters::new( + dp_parameters.epsilon * share, + dp_parameters.delta * share, + 1, + false, + dp_parameters.privacy_unit_max_multiplicity, + dp_parameters.privacy_unit_max_multiplicity_share, + ) } pub fn split(self, n: usize) -> DpAggregatesParameters { @@ -59,7 +75,10 @@ impl DpAggregatesParameters { } pub fn with_privacy_unit_unique(self, unique_privacy_unit: bool) -> DpAggregatesParameters { - DpAggregatesParameters { privacy_unit_unique: unique_privacy_unit, ..self } + DpAggregatesParameters { + privacy_unit_unique: unique_privacy_unit, + ..self + } } /// Compute the multiplicity estimate to use for the computations @@ -67,14 +86,16 @@ impl DpAggregatesParameters { if self.privacy_unit_unique { 1. } else { - self.privacy_unit_max_multiplicity.min((self.size as f64)*self.privacy_unit_max_multiplicity_share).ceil() + self.privacy_unit_max_multiplicity + .min((self.size as f64) * self.privacy_unit_max_multiplicity_share) + .ceil() } } } impl Relation { fn gaussian_mechanisms(self, epsilon: f64, delta: f64, bounds: Vec<(&str, f64)>) -> DpRelation { - if epsilon>1. { + if epsilon > 1. { // Cf. Theorem A.1. in (Dwork, Roth et al. 2014) log::warn!("Warning, epsilon>1 the gaussian mechanism applied will not be exactly epsilon,delta-DP!") } @@ -100,15 +121,12 @@ impl Relation { .into(); ( self.add_clipped_gaussian_noise(&noise_multipliers), - dp_event + dp_event, ) } else { (self, DpEvent::no_op()) }; - DpRelation::new( - dp_relation, - dp_event, - ) + DpRelation::new(dp_relation, dp_event) } } @@ -126,8 +144,7 @@ impl PupRelation { if (parameters.epsilon == 0. || parameters.delta == 0.) && !named_sums.is_empty() { return Err(Error::BudgetError(format!( "Not enough budget for the aggregations. Got: (espilon, delta) = ({}, {})", - parameters.epsilon, - parameters.delta, + parameters.epsilon, parameters.delta, ))); } // let multiplicity_bound = parameters.clipping_quantile // TODO @@ -142,7 +159,8 @@ impl PupRelation { .absolute_upper_bound() .unwrap_or(1.0) // This may add a lot of noise depending on the parameters - * parameters.privacy_unit_multiplicity()).clamp(f64::MIN, f64::MAX), + * parameters.privacy_unit_multiplicity()) + .clamp(f64::MIN, f64::MAX), ) }) .collect::>(); @@ -173,13 +191,10 @@ impl PupRelation { let mut output_builder = Map::builder(); let mut named_sums = vec![]; let mut input_builder = Map::builder() - .with(( - self.privacy_unit(), - Expr::col(self.privacy_unit()) - )) + .with((self.privacy_unit(), Expr::col(self.privacy_unit()))) .with(( self.privacy_unit_weight(), - Expr::col(self.privacy_unit_weight()) + Expr::col(self.privacy_unit_weight()), )); let mut group_by_names = vec![]; @@ -202,20 +217,27 @@ impl PupRelation { let square_col = format!("_SQUARE_{}", col_name); let sum_square_col = format!("_SUM{}", square_col); match aggregate.aggregate() { - Aggregate::Min | - Aggregate::Max | - Aggregate::Median | - Aggregate::First | - Aggregate::Last | - Aggregate::Quantile(_) | - Aggregate::Quantiles(_) => { + Aggregate::Min + | Aggregate::Max + | Aggregate::Median + | Aggregate::First + | Aggregate::Last + | Aggregate::Quantile(_) + | Aggregate::Quantiles(_) => { assert!(group_by_names.contains(&col_name.as_str())); output_b = output_b.with((name, Expr::col(col_name.as_str()))) - }, + } aggregate::Aggregate::Mean => { input_b = input_b .with((col_name.as_str(), Expr::col(col_name.as_str()))) - .with((one_col.as_str(), Expr::case(Expr::is_null(Expr::col(col_name.as_str())), Expr::val(0.), Expr::val(1.)))); + .with(( + one_col.as_str(), + Expr::case( + Expr::is_null(Expr::col(col_name.as_str())), + Expr::val(0.), + Expr::val(1.), + ), + )); sums.push((count_col.clone(), one_col)); sums.push((sum_col.clone(), col_name)); output_b = output_b.with(( @@ -227,9 +249,17 @@ impl PupRelation { )) } aggregate::Aggregate::Count => { - input_b = input_b.with((one_col.as_str(), Expr::case(Expr::is_null(Expr::col(col_name.as_str())), Expr::val(0.), Expr::val(1.)))); + input_b = input_b.with(( + one_col.as_str(), + Expr::case( + Expr::is_null(Expr::col(col_name.as_str())), + Expr::val(0.), + Expr::val(1.), + ), + )); sums.push((count_col.clone(), one_col)); - output_b = output_b.with((name, Expr::cast_as_integer(Expr::col(count_col)))); + output_b = + output_b.with((name, Expr::cast_as_integer(Expr::col(count_col)))); } aggregate::Aggregate::Sum => { input_b = input_b.with((col_name.as_str(), Expr::col(col_name.as_str()))); @@ -239,8 +269,18 @@ impl PupRelation { aggregate::Aggregate::Std => { input_b = input_b .with((col_name.as_str(), Expr::col(col_name.as_str()))) - .with((square_col.as_str(), Expr::pow(Expr::col(col_name.as_str()), Expr::val(2)))) - .with((one_col.as_str(), Expr::case(Expr::is_null(Expr::col(col_name.as_str())), Expr::val(0.), Expr::val(1.)))); + .with(( + square_col.as_str(), + Expr::pow(Expr::col(col_name.as_str()), Expr::val(2)), + )) + .with(( + one_col.as_str(), + Expr::case( + Expr::is_null(Expr::col(col_name.as_str())), + Expr::val(0.), + Expr::val(1.), + ), + )); sums.push((count_col.clone(), one_col)); sums.push((sum_col.clone(), col_name)); sums.push((sum_square_col.clone(), square_col)); @@ -257,15 +297,25 @@ impl PupRelation { Expr::col(sum_col), Expr::greatest(Expr::val(1.), Expr::col(count_col)), ), - ) - )) + ), + )), )) } aggregate::Aggregate::Var => { input_b = input_b .with((col_name.as_str(), Expr::col(col_name.as_str()))) - .with((square_col.as_str(), Expr::pow(Expr::col(col_name.as_str()), Expr::val(2)))) - .with((one_col.as_str(), Expr::case(Expr::is_null(Expr::col(col_name.as_str())), Expr::val(0.), Expr::val(1.)))); + .with(( + square_col.as_str(), + Expr::pow(Expr::col(col_name.as_str()), Expr::val(2)), + )) + .with(( + one_col.as_str(), + Expr::case( + Expr::is_null(Expr::col(col_name.as_str())), + Expr::val(0.), + Expr::val(1.), + ), + )); sums.push((count_col.clone(), one_col)); sums.push((sum_col.clone(), col_name)); sums.push((sum_square_col.clone(), square_col)); @@ -282,8 +332,8 @@ impl PupRelation { Expr::col(sum_col), Expr::greatest(Expr::val(1.), Expr::col(count_col)), ), - ) - ) + ), + ), )) } _ => (), @@ -304,9 +354,7 @@ impl PupRelation { parameters, )? .into(); - let dp_relation = output_builder - .input(dp_relation) - .build(); + let dp_relation = output_builder.input(dp_relation).build(); Ok(DpRelation::new(dp_relation, dp_event)) } } @@ -322,21 +370,26 @@ impl Reduce { let reduces = self.split_distinct_aggregates(); let split_parameters = parameters.clone().split(reduces.len()); // Rewrite into differential privacy each `Reduce` then join them. - let (relation, dp_event) = reduces.iter() - .map(|r| pup_input.clone().differentially_private_aggregates( - r.named_aggregates() - .into_iter() - .map(|(n, agg)| (n, agg.clone())) - .collect(), - self.group_by(), - split_parameters.clone(), - )) + let (relation, dp_event) = reduces + .iter() + .map(|r| { + pup_input.clone().differentially_private_aggregates( + r.named_aggregates() + .into_iter() + .map(|(n, agg)| (n, agg.clone())) + .collect(), + self.group_by(), + split_parameters.clone(), + ) + }) .reduce(|acc, dp_rel| { let acc = acc?; let dp_rel = dp_rel?; Ok(DpRelation::new( - acc.relation().clone().natural_inner_join(dp_rel.relation().clone()), - acc.dp_event().clone().compose(dp_rel.dp_event().clone()) + acc.relation() + .clone() + .natural_inner_join(dp_rel.relation().clone()), + acc.dp_event().clone().compose(dp_rel.dp_event().clone()), )) }) .unwrap()? @@ -344,17 +397,21 @@ impl Reduce { let relation: Relation = Relation::map() .input(relation) - .with_iter(self.fields().into_iter().map(|f| (f.name(), Expr::col(f.name())))) + .with_iter( + self.fields() + .into_iter() + .map(|f| (f.name(), Expr::col(f.name()))), + ) .build(); Ok((relation, dp_event).into()) } - /// Returns a Vec of rewritten `Reduce` whose each item corresponds to a specific `DISTINCT` clause /// (e.g.: SUM(DISTINCT a) or COUNT(DISTINCT a) have the same `DISTINCT` clause). The original `Reduce`` /// has been rewritten with `GROUP BY`s for each `DISTINCT` clause. fn split_distinct_aggregates(&self) -> Vec { - let mut distinct_map: HashMap, Vec<(String, AggregateColumn)>> = HashMap::new(); + let mut distinct_map: HashMap, Vec<(String, AggregateColumn)>> = + HashMap::new(); let mut first_aggs: Vec<(String, AggregateColumn)> = vec![]; for (agg, f) in self.aggregate().iter().zip(self.fields()) { match agg.aggregate() { @@ -362,9 +419,15 @@ impl Reduce { | aggregate::Aggregate::SumDistinct | aggregate::Aggregate::MeanDistinct | aggregate::Aggregate::VarDistinct - | aggregate::Aggregate::StdDistinct => distinct_map.entry(Some(agg.column().clone())).or_insert(Vec::new()).push((f.name().to_string(), agg.clone())), + | aggregate::Aggregate::StdDistinct => distinct_map + .entry(Some(agg.column().clone())) + .or_insert(Vec::new()) + .push((f.name().to_string(), agg.clone())), aggregate::Aggregate::First => first_aggs.push((f.name().to_string(), agg.clone())), - _ => distinct_map.entry(None).or_insert(Vec::new()).push((f.name().to_string(), agg.clone())), + _ => distinct_map + .entry(None) + .or_insert(Vec::new()) + .push((f.name().to_string(), agg.clone())), } } @@ -373,11 +436,17 @@ impl Reduce { } else { first_aggs.extend( self.group_by() - .into_iter() - .map(|x| (x.to_string(), AggregateColumn::new(aggregate::Aggregate::First, x.clone()))) - .collect::>() + .into_iter() + .map(|x| { + ( + x.to_string(), + AggregateColumn::new(aggregate::Aggregate::First, x.clone()), + ) + }) + .collect::>(), ); - distinct_map.into_iter() + distinct_map + .into_iter() .map(|(identifier, mut aggs)| { aggs.extend(first_aggs.clone()); self.rewrite_distinct(identifier, aggs) @@ -401,21 +470,29 @@ impl Reduce { /// Example 2 : /// (SELECT sum(DISTINCT col1), count(*) FROM table GROUP BY a, None, ("my_count", count(*))) /// --> SELECT a AS a, count(*) AS my_count FROM table GROUP BY a - fn rewrite_distinct(&self, identifier: Option, aggs: Vec<(String, AggregateColumn)>) -> Reduce { - let builder = Relation::reduce() - .input(self.input().clone()); + fn rewrite_distinct( + &self, + identifier: Option, + aggs: Vec<(String, AggregateColumn)>, + ) -> Reduce { + let builder = Relation::reduce().input(self.input().clone()); if let Some(identifier) = identifier { - let mut group_by = self.group_by() + let mut group_by = self + .group_by() .into_iter() .map(|c| c.clone()) .collect::>(); group_by.push(identifier); - let first_aggs = group_by.clone() - .into_iter() - .map(|c| (c.to_string(), AggregateColumn::new(aggregate::Aggregate::First, c))); + let first_aggs = group_by.clone().into_iter().map(|c| { + ( + c.to_string(), + AggregateColumn::new(aggregate::Aggregate::First, c), + ) + }); - let group_by = group_by.into_iter() + let group_by = group_by + .into_iter() .map(|c| Expr::from(c.clone())) .collect::>(); @@ -424,19 +501,18 @@ impl Reduce { .with_iter(first_aggs) .build(); - let aggs = aggs.into_iter() - .map(|(s, agg)| { - let new_agg = match agg.aggregate() { - aggregate::Aggregate::MeanDistinct => aggregate::Aggregate::Mean, - aggregate::Aggregate::CountDistinct => aggregate::Aggregate::Count, - aggregate::Aggregate::SumDistinct => aggregate::Aggregate::Sum, - aggregate::Aggregate::StdDistinct => aggregate::Aggregate::Std, - aggregate::Aggregate::VarDistinct => aggregate::Aggregate::Var, - aggregate::Aggregate::First => aggregate::Aggregate::First, - _ => todo!(), - }; - (s, AggregateColumn::new(new_agg, agg.column().clone())) - }); + let aggs = aggs.into_iter().map(|(s, agg)| { + let new_agg = match agg.aggregate() { + aggregate::Aggregate::MeanDistinct => aggregate::Aggregate::Mean, + aggregate::Aggregate::CountDistinct => aggregate::Aggregate::Count, + aggregate::Aggregate::SumDistinct => aggregate::Aggregate::Sum, + aggregate::Aggregate::StdDistinct => aggregate::Aggregate::Std, + aggregate::Aggregate::VarDistinct => aggregate::Aggregate::Var, + aggregate::Aggregate::First => aggregate::Aggregate::First, + _ => todo!(), + }; + (s, AggregateColumn::new(new_agg, agg.column().clone())) + }); Relation::reduce() .input(reduce) .group_by_iter(self.group_by().to_vec()) @@ -444,9 +520,9 @@ impl Reduce { .build() } else { builder - .group_by_iter(self.group_by().clone().to_vec()) - .with_iter(aggs) - .build() + .group_by_iter(self.group_by().clone().to_vec()) + .with_iter(aggs) + .build() } } } @@ -460,13 +536,13 @@ mod tests { data_type::Variant, display::Dot, io::{postgresql, Database}, + privacy_unit_tracking::PrivacyUnit, privacy_unit_tracking::{PrivacyUnitTracking, Strategy}, + relation::{Constraint, Schema, Variant as _}, sql::parse, Relation, - relation::{Constraint, Schema, Variant as _}, - privacy_unit_tracking::PrivacyUnit }; - use std::{sync::Arc, ops::Deref}; + use std::{ops::Deref, sync::Arc}; #[test] fn test_table_with_noise() { @@ -494,7 +570,10 @@ mod tests { fn test_differentially_private_sums_no_group_by() { let mut database = postgresql::test_database(); let relations = database.relations(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // privacy tracking of the inputs let table = relations .get(&["item_table".to_string()]) @@ -554,7 +633,7 @@ mod tests { .unwrap(); let map = Map::new( "my_map".to_string(), - vec![("my_d".to_string(), expr!(d/100))], + vec![("my_d".to_string(), expr!(d / 100))], None, vec![], None, @@ -564,7 +643,7 @@ mod tests { let pup_map = privacy_unit_tracking .map( &map.clone().try_into().unwrap(), - PupRelation(Relation::from(pup_table)) + PupRelation(Relation::from(pup_table)), ) .unwrap(); let reduce = Reduce::new( @@ -598,7 +677,10 @@ mod tests { .unwrap() .deref() .clone(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // privacy tracking of the inputs let privacy_unit_tracking = PrivacyUnitTracking::from(( &relations, @@ -662,7 +744,10 @@ mod tests { ) .size(100) .build(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // GROUP BY and the aggregate input the same column let reduce: Reduce = Relation::reduce() @@ -700,7 +785,10 @@ mod tests { .unwrap() .deref() .clone(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // privacy tracking of the inputs let privacy_unit_tracking = PrivacyUnitTracking::from(( @@ -762,7 +850,10 @@ mod tests { .unwrap() .deref() .clone(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // privacy tracking of the inputs let privacy_unit_tracking = PrivacyUnitTracking::from(( @@ -862,7 +953,10 @@ mod tests { .unwrap() .deref() .clone(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); let privacy_unit_tracking = PrivacyUnitTracking::from(( &relations, vec![ @@ -884,9 +978,15 @@ mod tests { "my_reduce".to_string(), vec![ ("count_price".to_string(), AggregateColumn::count("price")), - ("count_distinct_price".to_string(), AggregateColumn::count_distinct("price")), + ( + "count_distinct_price".to_string(), + AggregateColumn::count_distinct("price"), + ), ("sum_price".to_string(), AggregateColumn::sum("price")), - ("sum_distinct_price".to_string(), AggregateColumn::sum_distinct("price")), + ( + "sum_distinct_price".to_string(), + AggregateColumn::sum_distinct("price"), + ), ("item".to_string(), AggregateColumn::first("item")), ], vec!["item".into()], @@ -921,9 +1021,15 @@ mod tests { "my_reduce".to_string(), vec![ ("count_price".to_string(), AggregateColumn::count("price")), - ("count_distinct_price".to_string(), AggregateColumn::count_distinct("price")), + ( + "count_distinct_price".to_string(), + AggregateColumn::count_distinct("price"), + ), ("sum_price".to_string(), AggregateColumn::sum("price")), - ("sum_distinct_price".to_string(), AggregateColumn::sum_distinct("price")), + ( + "sum_distinct_price".to_string(), + AggregateColumn::sum_distinct("price"), + ), ], vec![], pup_table.deref().clone().into(), @@ -971,24 +1077,22 @@ mod tests { // No distinct + no group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .build(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); assert_eq!( reduces[0].data_type(), - DataType::structured([ - ("sum_a", DataType::float_interval(-2000., 2000.)) - ]) + DataType::structured([("sum_a", DataType::float_interval(-2000., 2000.))]) ); // No distinct + group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .group_by(expr!(b)) - .build(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .group_by(expr!(b)) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); assert_eq!( @@ -1001,25 +1105,23 @@ mod tests { // simple distinct let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .build(); + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); Relation::from(reduces[0].clone()).display_dot().unwrap(); assert_eq!( reduces[0].data_type(), - DataType::structured([ - ("sum_distinct_a", DataType::float_interval(-2000., 2000.)) - ]) + DataType::structured([("sum_distinct_a", DataType::float_interval(-2000., 2000.))]) ); // simple distinct with group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .group_by(expr!(b)) - .build(); + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .group_by(expr!(b)) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); Relation::from(reduces[0].clone()).display_dot().unwrap(); @@ -1033,10 +1135,10 @@ mod tests { // simple distinct with group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with_group_by_column("b") - .build(); + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with_group_by_column("b") + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); Relation::from(reduces[0].clone()).display_dot().unwrap(); @@ -1050,12 +1152,12 @@ mod tests { // multi distinct + no group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with(("count_b", AggregateColumn::count("b"))) - .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) - .build(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 3); Relation::from(reduces[0].clone()).display_dot().unwrap(); @@ -1064,14 +1166,14 @@ mod tests { // multi distinct + group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with(("count_b", AggregateColumn::count("b"))) - .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) - .with(("my_c", AggregateColumn::first("c"))) - .group_by(expr!(c)) - .build(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .with(("my_c", AggregateColumn::first("c"))) + .group_by(expr!(c)) + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 3); Relation::from(reduces[0].clone()).display_dot().unwrap(); @@ -1080,10 +1182,10 @@ mod tests { // reduce without any aggregation let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with_group_by_column("a") - .with_group_by_column("c") - .build(); + .input(table.clone()) + .with_group_by_column("a") + .with_group_by_column("c") + .build(); let reduces = reduce.split_distinct_aggregates(); assert_eq!(reduces.len(), 1); } @@ -1104,14 +1206,19 @@ mod tests { .schema(schema.clone()) .size(1000) .build(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-5), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-5), + 1., + ); // No distinct + no group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); assert_eq!( dp_relation.dp_event(), &DpEvent::gaussian_from_epsilon_delta_sensitivity( @@ -1122,18 +1229,18 @@ mod tests { ); assert_eq!( dp_relation.relation().data_type(), - DataType::structured([ - ("sum_a", DataType::float_interval(-2000., 2000.)) - ]) + DataType::structured([("sum_a", DataType::float_interval(-2000., 2000.))]) ); // No distinct + group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .group_by(expr!(b)) - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .group_by(expr!(b)) + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); assert_eq!( dp_relation.dp_event(), &DpEvent::gaussian_from_epsilon_delta_sensitivity( @@ -1144,17 +1251,17 @@ mod tests { ); assert_eq!( dp_relation.relation().data_type(), - DataType::structured([ - ("sum_a", DataType::float_interval(-2000., 2000.)) - ]) + DataType::structured([("sum_a", DataType::float_interval(-2000., 2000.))]) ); // simple distinct let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); //dp_relation.relation().display_dot().unwrap(); assert_eq!( dp_relation.dp_event(), @@ -1166,18 +1273,18 @@ mod tests { ); assert_eq!( dp_relation.relation().data_type(), - DataType::structured([ - ("sum_distinct_a", DataType::float_interval(-2000., 2000.)) - ]) + DataType::structured([("sum_distinct_a", DataType::float_interval(-2000., 2000.))]) ); // simple distinct with group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with_group_by_column("b") - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with_group_by_column("b") + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); //dp_relation.relation().display_dot().unwrap(); assert_eq!( dp_relation.dp_event(), @@ -1197,16 +1304,18 @@ mod tests { // multi distinct + no group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with(("count_b", AggregateColumn::count("b"))) - .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) - .with(("avg_distinct_b", AggregateColumn::mean_distinct("b"))) - .with(("var_distinct_b", AggregateColumn::var_distinct("b"))) - .with(("std_distinct_b", AggregateColumn::std_distinct("b"))) - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .with(("avg_distinct_b", AggregateColumn::mean_distinct("b"))) + .with(("var_distinct_b", AggregateColumn::var_distinct("b"))) + .with(("std_distinct_b", AggregateColumn::std_distinct("b"))) + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); dp_relation.relation().display_dot().unwrap(); assert_eq!( dp_relation.relation().data_type(), @@ -1217,24 +1326,29 @@ mod tests { ("count_distinct_b", DataType::integer_interval(0, 1000)), ("avg_distinct_b", DataType::float_interval(0., 10000.)), ("var_distinct_b", DataType::float_interval(0., 100000.)), - ("std_distinct_b", DataType::float_interval(0., 316.22776601683796)), + ( + "std_distinct_b", + DataType::float_interval(0., 316.22776601683796) + ), ]) ); // multi distinct + group by let reduce: Reduce = Relation::reduce() - .input(table.clone()) - .with(("sum_a", AggregateColumn::sum("a"))) - .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) - .with(("count_b", AggregateColumn::count("b"))) - .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) - .with(("my_c", AggregateColumn::first("c"))) - .with(("avg_distinct_b", AggregateColumn::mean_distinct("b"))) - .with(("var_distinct_b", AggregateColumn::var_distinct("b"))) - .with(("std_distinct_b", AggregateColumn::std_distinct("b"))) - .group_by(expr!(c)) - .build(); - let dp_relation = reduce.differentially_private_aggregates(parameters.clone()).unwrap(); + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .with(("my_c", AggregateColumn::first("c"))) + .with(("avg_distinct_b", AggregateColumn::mean_distinct("b"))) + .with(("var_distinct_b", AggregateColumn::var_distinct("b"))) + .with(("std_distinct_b", AggregateColumn::std_distinct("b"))) + .group_by(expr!(c)) + .build(); + let dp_relation = reduce + .differentially_private_aggregates(parameters.clone()) + .unwrap(); dp_relation.relation().display_dot().unwrap(); assert_eq!( dp_relation.relation().data_type(), @@ -1246,7 +1360,10 @@ mod tests { ("my_c", DataType::float_interval(10., 20.)), ("avg_distinct_b", DataType::float_interval(0., 10000.)), ("var_distinct_b", DataType::float_interval(0., 100000.)), - ("std_distinct_b", DataType::float_interval(0., 316.22776601683796)), + ( + "std_distinct_b", + DataType::float_interval(0., 316.22776601683796) + ), ]) ); } @@ -1272,7 +1389,10 @@ mod tests { ) .size(100) .build(); - let parameters = DpAggregatesParameters::from_dp_parameters(DpParameters::from_epsilon_delta(1., 1e-3), 1.); + let parameters = DpAggregatesParameters::from_dp_parameters( + DpParameters::from_epsilon_delta(1., 1e-3), + 1., + ); // GROUP BY and the aggregate input the same column let reduce: Reduce = Relation::reduce() .name("reduce_relation") @@ -1281,9 +1401,9 @@ mod tests { .input(table.clone()) .build(); let (dp_relation, dp_event) = reduce - .differentially_private_aggregates(parameters.clone()) - .unwrap() - .into(); + .differentially_private_aggregates(parameters.clone()) + .unwrap() + .into(); dp_relation.display_dot().unwrap(); assert_eq!( dp_event, @@ -1297,11 +1417,10 @@ mod tests { dp_relation.data_type(), DataType::structured([ ("sum_a", DataType::float_interval(0., 1000.)), - ("a", DataType::integer_range(1..=10) - )]) + ("a", DataType::integer_range(1..=10)) + ]) ); - let reduce: Reduce = Relation::reduce() .name("reduce_relation") .with(("sum_a".to_string(), AggregateColumn::sum("a"))) @@ -1309,9 +1428,9 @@ mod tests { .input(table.clone()) .build(); let (dp_relation, dp_event) = reduce - .differentially_private_aggregates(parameters.clone()) - .unwrap() - .into(); + .differentially_private_aggregates(parameters.clone()) + .unwrap() + .into(); dp_relation.display_dot().unwrap(); assert_eq!( dp_event, diff --git a/src/differential_privacy/dp_event.rs b/src/differential_privacy/dp_event.rs index 157580e0..a2e7985a 100644 --- a/src/differential_privacy/dp_event.rs +++ b/src/differential_privacy/dp_event.rs @@ -10,44 +10,35 @@ use std::fmt; #[derive(Clone, Debug, PartialEq)] pub enum DpEvent { /// Represents application of an operation with no privacy impact. - /// + /// /// A `NoOp` is generally never required, but it can be useful as a /// placeholder where a `DpEvent` is expected, such as in tests or some live /// accounting pipelines. NoOp, /// Represents an application of the Gaussian mechanism. - /// + /// /// For values v_i and noise z ~ N(0, s^2I), this mechanism returns sum_i v_i + z. /// If the norms of the values are bounded ||v_i|| <= C, the noise_multiplier is /// defined as s / C. - Gaussian { - noise_multiplier: f64, - }, + Gaussian { noise_multiplier: f64 }, /// Represents an application of the Laplace mechanism. - /// + /// /// For values v_i and noise z sampled coordinate-wise from the Laplace /// distribution L(0, s), this mechanism returns sum_i v_i + z. /// The probability density function of the Laplace distribution L(0, s) with /// parameter s is given as exp(-|x|/s) * (0.5/s) at x for any real value x. /// If the L_1 norm of the values are bounded ||v_i||_1 <= C, the noise_multiplier /// is defined as s / C. - Laplace{ - noise_multiplier: f64, - }, + Laplace { noise_multiplier: f64 }, /// Represents the application of a mechanism which is epsilon-delta approximate DP - EpsilonDelta { - epsilon: f64, - delta: f64, - }, + EpsilonDelta { epsilon: f64, delta: f64 }, /// Represents application of a series of composed mechanisms. - /// + /// /// The composition may be adaptive, where the query producing each event depends /// on the results of prior queries. - Composed { - events: Vec, - }, + Composed { events: Vec }, /// Represents an application of Poisson subsampling. - /// + /// /// Each record in the dataset is included in the sample independently with /// probability `sampling_probability`. Then the `DpEvent` `event` is applied /// to the sample of records. @@ -56,7 +47,7 @@ pub enum DpEvent { event: Box, }, /// Represents sampling a fixed sized batch of records with replacement. - /// + /// /// A sample of `sample_size` (possibly repeated) records is drawn uniformly at /// random from the set of possible samples of a source dataset of size /// `source_dataset_size`. Then the `DpEvent` `event` is applied to the sample of @@ -67,7 +58,7 @@ pub enum DpEvent { event: Box, }, /// Represents sampling a fixed sized batch of records without replacement. - /// + /// /// A sample of `sample_size` unique records is drawn uniformly at random from the /// set of possible samples of a source dataset of size `source_dataset_size`. /// Then the `DpEvent` `event` is applied to the sample of records. @@ -106,21 +97,25 @@ impl DpEvent { other } else { let (v1, v2) = match (self, other) { - (DpEvent::Composed {events: v1}, DpEvent::Composed {events: v2}) => (v1, v2), - (DpEvent::Composed {events: v}, other) => (v, vec![other]), - (current, DpEvent::Composed {events: v}) => (vec![current], v), + (DpEvent::Composed { events: v1 }, DpEvent::Composed { events: v2 }) => (v1, v2), + (DpEvent::Composed { events: v }, other) => (v, vec![other]), + (current, DpEvent::Composed { events: v }) => (vec![current], v), (current, other) => (vec![current], vec![other]), }; - DpEvent::Composed {events: v1.into_iter().chain(v2.into_iter()).collect()} + DpEvent::Composed { + events: v1.into_iter().chain(v2.into_iter()).collect(), + } } } pub fn is_no_op(&self) -> bool { match self { DpEvent::NoOp => true, - DpEvent::Gaussian {noise_multiplier} | DpEvent::Laplace {noise_multiplier} => noise_multiplier == &0.0, - DpEvent::EpsilonDelta {epsilon, delta} => epsilon == &0. && delta == &0., - DpEvent::Composed {events} => events.iter().all(|q| q.is_no_op()), + DpEvent::Gaussian { noise_multiplier } | DpEvent::Laplace { noise_multiplier } => { + noise_multiplier == &0.0 + } + DpEvent::EpsilonDelta { epsilon, delta } => epsilon == &0. && delta == &0., + DpEvent::Composed { events } => events.iter().all(|q| q.is_no_op()), _ => todo!(), } } @@ -130,7 +125,9 @@ impl DpEvent { delta: f64, sensitivity: f64, ) -> Self { - DpEvent::Gaussian {noise_multiplier: gaussian_noise(epsilon, delta, sensitivity)} + DpEvent::Gaussian { + noise_multiplier: gaussian_noise(epsilon, delta, sensitivity), + } } } @@ -138,10 +135,12 @@ impl fmt::Display for DpEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { DpEvent::NoOp => writeln!(f, "NoOp"), - DpEvent::Gaussian {noise_multiplier} => writeln!(f, "Gaussian ({noise_multiplier})"), - DpEvent::Laplace {noise_multiplier} => writeln!(f, "Laplace ({noise_multiplier})"), - DpEvent::EpsilonDelta {epsilon, delta} => writeln!(f, "EpsilonDelta ({epsilon}, {delta})"), - DpEvent::Composed {events} => write!( + DpEvent::Gaussian { noise_multiplier } => writeln!(f, "Gaussian ({noise_multiplier})"), + DpEvent::Laplace { noise_multiplier } => writeln!(f, "Laplace ({noise_multiplier})"), + DpEvent::EpsilonDelta { epsilon, delta } => { + writeln!(f, "EpsilonDelta ({epsilon}, {delta})") + } + DpEvent::Composed { events } => write!( f, "Composed ({})", events.iter().map(|dpe| format!("{}", dpe)).join(", ") @@ -153,7 +152,8 @@ impl fmt::Display for DpEvent { impl FromIterator for DpEvent { fn from_iter>(iter: T) -> Self { - iter.into_iter().fold(DpEvent::NoOp, |composed, event| composed.compose(event)) + iter.into_iter() + .fold(DpEvent::NoOp, |composed, event| composed.compose(event)) } } diff --git a/src/differential_privacy/dp_parameters.rs b/src/differential_privacy/dp_parameters.rs index f2b53e7c..75b1ed20 100644 --- a/src/differential_privacy/dp_parameters.rs +++ b/src/differential_privacy/dp_parameters.rs @@ -16,8 +16,20 @@ pub struct DpParameters { } impl DpParameters { - pub fn new(epsilon: f64, delta: f64, tau_thresholding_share: f64, privacy_unit_max_multiplicity: f64, privacy_unit_max_multiplicity_share: f64) -> DpParameters { - DpParameters { epsilon, delta, tau_thresholding_share, privacy_unit_max_multiplicity, privacy_unit_max_multiplicity_share } + pub fn new( + epsilon: f64, + delta: f64, + tau_thresholding_share: f64, + privacy_unit_max_multiplicity: f64, + privacy_unit_max_multiplicity_share: f64, + ) -> DpParameters { + DpParameters { + epsilon, + delta, + tau_thresholding_share, + privacy_unit_max_multiplicity, + privacy_unit_max_multiplicity_share, + } } pub fn from_epsilon_delta(epsilon: f64, delta: f64) -> DpParameters { @@ -26,15 +38,30 @@ impl DpParameters { } pub fn with_tau_thresholding_share(self, tau_thresholding_share: f64) -> DpParameters { - DpParameters { tau_thresholding_share, ..self } + DpParameters { + tau_thresholding_share, + ..self + } } - pub fn with_privacy_unit_max_multiplicity(self, privacy_unit_max_multiplicity: f64) -> DpParameters { - DpParameters { privacy_unit_max_multiplicity, ..self } + pub fn with_privacy_unit_max_multiplicity( + self, + privacy_unit_max_multiplicity: f64, + ) -> DpParameters { + DpParameters { + privacy_unit_max_multiplicity, + ..self + } } - pub fn with_privacy_unit_max_multiplicity_share(self, privacy_unit_max_multiplicity_share: f64) -> DpParameters { - DpParameters { privacy_unit_max_multiplicity_share, ..self } + pub fn with_privacy_unit_max_multiplicity_share( + self, + privacy_unit_max_multiplicity_share: f64, + ) -> DpParameters { + DpParameters { + privacy_unit_max_multiplicity_share, + ..self + } } } diff --git a/src/differential_privacy/group_by.rs b/src/differential_privacy/group_by.rs index 7221d7e1..f409ac4d 100644 --- a/src/differential_privacy/group_by.rs +++ b/src/differential_privacy/group_by.rs @@ -1,10 +1,10 @@ use super::Error; use crate::{ builder::{Ready, With, WithIterator}, - differential_privacy::{dp_event, DpRelation, DpEvent, Result}, + differential_privacy::{dp_event, DpEvent, DpRelation, Result}, expr::{aggregate, Expr}, namer::{self, name_from_content}, - privacy_unit_tracking::{PupRelation, PrivacyUnit}, + privacy_unit_tracking::{PrivacyUnit, PupRelation}, relation::{Join, Reduce, Relation, Variant as _}, }; @@ -99,7 +99,8 @@ impl PupRelation { /// - Using the propagated public values of the grouping columns when they exist /// - Applying tau-thresholding mechanism with the (epsilon, delta) privacy parameters for t /// he columns that do not have public values - pub fn dp_values(self, epsilon: f64, delta: f64) -> Result {// TODO this code is super-ugly rewrite it + pub fn dp_values(self, epsilon: f64, delta: f64) -> Result { + // TODO this code is super-ugly rewrite it let public_columns: Vec = self .schema() .iter() diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 702d550e..6e9dcb82 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -4,21 +4,22 @@ //! pub mod aggregates; +pub mod dp_event; pub mod dp_parameters; pub mod group_by; -pub mod dp_event; use crate::{ builder::With, - expr, privacy_unit_tracking::{self, privacy_unit, PupRelation}, + expr, + privacy_unit_tracking::{self, privacy_unit, PupRelation}, relation::{rewriting, Constraint, Reduce, Relation, Variant}, Ready, }; use std::{error, fmt, ops::Deref, result}; +pub use dp_event::DpEvent; /// Some exports pub use dp_parameters::DpParameters; -pub use dp_event::DpEvent; use self::aggregates::DpAggregatesParameters; @@ -83,10 +84,7 @@ impl From for Relation { impl DpRelation { pub fn new(relation: Relation, dp_event: DpEvent) -> Self { - DpRelation { - relation, - dp_event, - } + DpRelation { relation, dp_event } } pub fn relation(&self) -> &Relation { @@ -118,27 +116,26 @@ impl From<(Relation, DpEvent)> for DpRelation { } } - - impl Reduce { /// Rewrite a `Reduce` into DP: /// - Protect the grouping keys /// - Add noise on the aggregations - pub fn differentially_private( - self, - parameters: &DpParameters, - ) -> Result { + pub fn differentially_private(self, parameters: &DpParameters) -> Result { let mut dp_event = DpEvent::no_op(); let max_size = self.size().max().unwrap().clone(); let pup_input = PupRelation::try_from(self.input().clone())?; - let privacy_unit_unique = pup_input.schema()[pup_input.privacy_unit()].has_unique_or_primary_key_constraint(); + let privacy_unit_unique = + pup_input.schema()[pup_input.privacy_unit()].has_unique_or_primary_key_constraint(); // DP rewrite group by let reduce_with_dp_group_by = if self.group_by().is_empty() { self } else { let (dp_grouping_values, dp_event_group_by) = self - .differentially_private_group_by(parameters.epsilon*parameters.tau_thresholding_share, parameters.delta*parameters.tau_thresholding_share)? + .differentially_private_group_by( + parameters.epsilon * parameters.tau_thresholding_share, + parameters.delta * parameters.tau_thresholding_share, + )? .into(); let input_relation_with_privacy_tracked_group_by = self .input() @@ -154,11 +151,16 @@ impl Reduce { // if the (epsilon_tau_thresholding, delta_tau_thresholding) budget has // not been spent, allocate it to the aggregations. - let aggregation_share = if dp_event.is_no_op() {1.} else {1.-parameters.tau_thresholding_share}; - let aggregation_parameters = DpAggregatesParameters::from_dp_parameters(parameters.clone(), aggregation_share) - .with_size(usize::try_from(max_size).unwrap()) - .with_privacy_unit_unique(privacy_unit_unique); - + let aggregation_share = if dp_event.is_no_op() { + 1. + } else { + 1. - parameters.tau_thresholding_share + }; + let aggregation_parameters = + DpAggregatesParameters::from_dp_parameters(parameters.clone(), aggregation_share) + .with_size(usize::try_from(max_size).unwrap()) + .with_privacy_unit_unique(privacy_unit_unique); + // DP rewrite aggregates let (dp_relation, dp_event_agg) = reduce_with_dp_group_by .differentially_private_aggregates(aggregation_parameters)? @@ -178,8 +180,8 @@ mod tests { display::Dot, expr::{AggregateColumn, Expr}, io::{postgresql, Database}, - privacy_unit_tracking::{PrivacyUnit,PrivacyUnitTracking, Strategy, PupRelation}, - relation::{Field, Map, Relation, Schema, Variant, Constraint}, + privacy_unit_tracking::{PrivacyUnit, PrivacyUnitTracking, PupRelation, Strategy}, + relation::{Constraint, Field, Map, Relation, Schema, Variant}, }; use std::{collections::HashSet, sync::Arc}; @@ -219,13 +221,17 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); - let mult: f64 = 2000.*DpAggregatesParameters::from_dp_parameters(parameters.clone(), 1.).privacy_unit_multiplicity(); - assert!(matches!(dp_event, DpEvent::Gaussian { noise_multiplier: _ })); + let mult: f64 = 2000. + * DpAggregatesParameters::from_dp_parameters(parameters.clone(), 1.) + .privacy_unit_multiplicity(); + assert!(matches!( + dp_event, + DpEvent::Gaussian { + noise_multiplier: _ + } + )); assert!(dp_relation .data_type() .is_subset_of(&DataType::structured([("sum_price", DataType::float())]))); @@ -241,13 +247,7 @@ mod tests { .clone(); let privacy_unit_tracking = PrivacyUnitTracking::from(( &relations, - vec![ - ( - "table_1", - vec![], - PrivacyUnit::privacy_unit_row(), - ), - ], + vec![("table_1", vec![], PrivacyUnit::privacy_unit_row())], Strategy::Hard, )); let pup_table = privacy_unit_tracking @@ -255,7 +255,7 @@ mod tests { .unwrap(); let map = Map::new( "my_map".to_string(), - vec![("my_d".to_string(), expr!(d/100))], + vec![("my_d".to_string(), expr!(d / 100))], None, vec![], None, @@ -265,7 +265,7 @@ mod tests { let pup_map = privacy_unit_tracking .map( &map.clone().try_into().unwrap(), - PupRelation(Relation::from(pup_table)) + PupRelation(Relation::from(pup_table)), ) .unwrap(); let reduce = Reduce::new( @@ -277,13 +277,13 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); assert!(dp_event.is_no_op()); // private query is null beacause we have computed the sum of zeros - assert_eq!(dp_relation.data_type(), DataType::structured([("sum_d", DataType::float_value(0.))])); + assert_eq!( + dp_relation.data_type(), + DataType::structured([("sum_d", DataType::float_value(0.))]) + ); let query: &str = &ast::Query::from(&dp_relation).to_string(); _ = database.query(query).unwrap(); @@ -338,12 +338,14 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); - assert!(matches!(dp_event, DpEvent::Gaussian { noise_multiplier: _ })); + assert!(matches!( + dp_event, + DpEvent::Gaussian { + noise_multiplier: _ + } + )); assert!(dp_relation .data_type() .is_subset_of(&DataType::structured([("sum_price", DataType::float())]))); @@ -398,10 +400,7 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); assert!(matches!(dp_event, DpEvent::Composed { events: _ })); assert!(dp_relation @@ -467,10 +466,7 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); assert!(matches!(dp_event, DpEvent::Composed { events: _ })); assert!(dp_relation.schema()[0] @@ -655,21 +651,34 @@ mod tests { let relation = Relation::from(reduce.clone()); relation.display_dot().unwrap(); - let (dp_relation, dp_event) = reduce - .differentially_private(¶meters) - .unwrap() - .into(); + let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); assert_eq!( dp_event, - DpEvent::epsilon_delta(parameters.epsilon*parameters.tau_thresholding_share, parameters.delta*parameters.tau_thresholding_share) - .compose(DpEvent::gaussian_from_epsilon_delta_sensitivity(parameters.epsilon*(1.-parameters.tau_thresholding_share), parameters.delta*(1.-parameters.tau_thresholding_share), 10.)) + DpEvent::epsilon_delta( + parameters.epsilon * parameters.tau_thresholding_share, + parameters.delta * parameters.tau_thresholding_share + ) + .compose(DpEvent::gaussian_from_epsilon_delta_sensitivity( + parameters.epsilon * (1. - parameters.tau_thresholding_share), + parameters.delta * (1. - parameters.tau_thresholding_share), + 10. + )) ); let correct_schema: Schema = vec![ ("sum_a", DataType::float_interval(0., 100.), None), - ("d", DataType::integer_interval(0, 10), Some(Constraint::Unique)), - ("max_d", DataType::integer_interval(0, 10), Some(Constraint::Unique)), - ].into_iter() + ( + "d", + DataType::integer_interval(0, 10), + Some(Constraint::Unique), + ), + ( + "max_d", + DataType::integer_interval(0, 10), + Some(Constraint::Unique), + ), + ] + .into_iter() .collect(); assert_eq!(dp_relation.schema(), &correct_schema); diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index 8350ad26..a4d7e49c 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -5,8 +5,6 @@ use itertools::Itertools; use super::{implementation, Result}; use crate::data_type::{value::Value, DataType}; - - /// The list of operators /// inspired by: https://docs.rs/sqlparser/latest/sqlparser/ast/enum.BinaryOperator.html /// and mostly: https://docs.rs/polars/latest/polars/prelude/enum.AggExpr.html @@ -32,7 +30,7 @@ pub enum Aggregate { Std, StdDistinct, Var, - VarDistinct + VarDistinct, } // TODO make sure f64::nan do not happen diff --git a/src/expr/bijection.rs b/src/expr/bijection.rs index aeda2b1d..0dbd4750 100644 --- a/src/expr/bijection.rs +++ b/src/expr/bijection.rs @@ -4,9 +4,14 @@ impl Expr { /// Reduce the expression modulo a bijection pub fn reduce_modulo_bijection(&self) -> &Expr { match self { - Expr::Function(Function { function, arguments }) => { + Expr::Function(Function { + function, + arguments, + }) => { if function.is_bijection() { - arguments.get(0).map(|arg|arg.reduce_modulo_bijection() ) + arguments + .get(0) + .map(|arg| arg.reduce_modulo_bijection()) .unwrap_or_else(|| self) } else { self @@ -29,9 +34,14 @@ impl Expr { pub fn is_unique(&self) -> bool { let expr = self.reduce_modulo_bijection(); match expr { - Expr::Function(Function { function, arguments }) => { + Expr::Function(Function { + function, + arguments, + }) => { if function.is_bijection() { - arguments.get(0).map(|arg|arg.is_unique() ) + arguments + .get(0) + .map(|arg| arg.is_unique()) .unwrap_or_else(|| false) } else { function.is_unique() @@ -43,7 +53,7 @@ impl Expr { /// True if 2 expressions are equal modulo a bijection pub fn eq_modulo_bijection(&self, expr: &Expr) -> bool { - self.reduce_modulo_bijection()==expr.reduce_modulo_bijection() + self.reduce_modulo_bijection() == expr.reduce_modulo_bijection() } } @@ -57,31 +67,31 @@ mod tests { fn test_into_column_modulo_bijection() { let a = expr!(md5(cast_as_text(exp(a)))); let b = expr!(md5(cast_as_text(sin(a)))); - println!("a.into_column_modulo_bijection() {:?}", a.into_column_modulo_bijection()); - println!("b.into_column_modulo_bijection() {:?}", b.into_column_modulo_bijection()); - assert!(a.into_column_modulo_bijection()==Some(Identifier::from_name("a"))); - assert!(b.into_column_modulo_bijection()==None); + println!( + "a.into_column_modulo_bijection() {:?}", + a.into_column_modulo_bijection() + ); + println!( + "b.into_column_modulo_bijection() {:?}", + b.into_column_modulo_bijection() + ); + assert!(a.into_column_modulo_bijection() == Some(Identifier::from_name("a"))); + assert!(b.into_column_modulo_bijection() == None); } #[test] fn test_eq_modulo_bijection() { let a = expr!(a + b); - let b = expr!(exp(a+b)); + let b = expr!(exp(a + b)); assert!(a.eq_modulo_bijection(&b)); let a = expr!(a + b); - let b = expr!(exp(sin(a+b))); + let b = expr!(exp(sin(a + b))); assert!(!a.eq_modulo_bijection(&b)); } #[test] fn test_is_unique() { - assert!( - Expr::md5(Expr::cast_as_text(Expr::exp(Expr::newid()))) - .is_unique() - ); - assert!( - !Expr::md5(Expr::cast_as_text(Expr::exp(Expr::col("a")))) - .is_unique() - ); + assert!(Expr::md5(Expr::cast_as_text(Expr::exp(Expr::newid()))).is_unique()); + assert!(!Expr::md5(Expr::cast_as_text(Expr::exp(Expr::col("a")))).is_unique()); } -} \ No newline at end of file +} diff --git a/src/expr/dot.rs b/src/expr/dot.rs index 3046f6e9..6a992aec 100644 --- a/src/expr/dot.rs +++ b/src/expr/dot.rs @@ -103,11 +103,11 @@ impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T>> dot::Labeller<'a, Node<'a, Expr::Value(val) => { println!("{}", &val.to_string()); format!( - "{}
{}", - dot::escape_html(&val.to_string()), - &node.1 - ) - }, + "{}
{}", + dot::escape_html(&val.to_string()), + &node.1 + ) + } Expr::Function(fun) => { format!( "{}
{}", diff --git a/src/expr/function.rs b/src/expr/function.rs index b696d9c9..2986c96d 100644 --- a/src/expr/function.rs +++ b/src/expr/function.rs @@ -100,7 +100,7 @@ pub enum Function { Ilike, Choose, IsNull, - IsBool + IsBool, } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -308,11 +308,12 @@ impl Function { | Function::Choose | Function::Like | Function::Ilike - | Function::IsBool => { - Arity::Nary(2) - } + | Function::IsBool => Arity::Nary(2), // Ternary Function - Function::Case | Function::SubstrWithSize | Function::RegexpReplace | Function::DatetimeDiff => Arity::Nary(3), + Function::Case + | Function::SubstrWithSize + | Function::RegexpReplace + | Function::DatetimeDiff => Arity::Nary(3), // Quaternary Function Function::RegexpExtract => Arity::Nary(4), // Nary Function @@ -348,8 +349,7 @@ impl Function { pub fn is_unique(self) -> bool { match self { // Unary Operators - Function::Random(_) - | Function::Newid => true, + Function::Random(_) | Function::Newid => true, _ => false, } } diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 94a33322..2020b170 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -54,7 +54,27 @@ macro_rules! function_implementations { // Nary: Concat function_implementations!( [Pi, Newid, CurrentDate, CurrentTime, CurrentTimestamp], - [Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos, Sqrt, Md5, Ceil, Floor, Sign, Unhex, Dayname, Quarter, Date, UnixTimestamp, IsNull], + [ + Opposite, + Not, + Exp, + Ln, + Log, + Abs, + Sin, + Cos, + Sqrt, + Md5, + Ceil, + Floor, + Sign, + Unhex, + Dayname, + Quarter, + Date, + UnixTimestamp, + IsNull + ], [ Plus, Minus, @@ -114,7 +134,9 @@ function_implementations!( Function::CastAsInteger => Arc::new(Optional::new(function::cast(DataType::integer()))), Function::CastAsFloat => Arc::new(Optional::new(function::cast(DataType::float()))), Function::CastAsBoolean => Arc::new(Optional::new(function::cast(DataType::boolean()))), - Function::CastAsDateTime => Arc::new(Optional::new(function::cast(DataType::date_time()))), + Function::CastAsDateTime => { + Arc::new(Optional::new(function::cast(DataType::date_time()))) + } Function::CastAsDate => Arc::new(Optional::new(function::cast(DataType::date()))), Function::CastAsTime => Arc::new(Optional::new(function::cast(DataType::time()))), Function::Concat(n) => Arc::new(function::concat(n)), @@ -152,7 +174,26 @@ macro_rules! aggregate_implementations { } aggregate_implementations!( - [Min, Max, Median, NUnique, First, Last, Mean, List, Count, Sum, AggGroups, Std, Var, MeanDistinct, CountDistinct, SumDistinct, StdDistinct, VarDistinct], + [ + Min, + Max, + Median, + NUnique, + First, + Last, + Mean, + List, + Count, + Sum, + AggGroups, + Std, + Var, + MeanDistinct, + CountDistinct, + SumDistinct, + StdDistinct, + VarDistinct + ], x, { match x { diff --git a/src/expr/mod.rs b/src/expr/mod.rs index af9478f4..e22cb88b 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -7,6 +7,7 @@ #[macro_use] pub mod dsl; pub mod aggregate; +pub mod bijection; pub mod dot; pub mod function; pub mod identifier; @@ -14,7 +15,6 @@ pub mod implementation; pub mod rewriting; pub mod split; pub mod sql; -pub mod bijection; use itertools::Itertools; use paste::paste; @@ -43,7 +43,7 @@ TODO - Remove */ -static EPSILON: f64 = 1.0/f64::MAX; +static EPSILON: f64 = 1.0 / f64::MAX; // Error management @@ -267,13 +267,7 @@ macro_rules! impl_nullary_function_constructors { }; } -impl_nullary_function_constructors!( - Pi, - Newid, - CurrentDate, - CurrentTime, - CurrentTimestamp -); +impl_nullary_function_constructors!(Pi, Newid, CurrentDate, CurrentTime, CurrentTimestamp); /// Implement unary function constructors macro_rules! impl_unary_function_constructors { @@ -409,9 +403,10 @@ impl Function { pub fn divide, R: Into>(left: L, right: R) -> Function { Function::new( function::Function::Divide, - <[_]>::into_vec( - Box::new([(Arc::new(left.into())), (Arc::new(right.into()))]), - ), + <[_]>::into_vec(Box::new([ + (Arc::new(left.into())), + (Arc::new(right.into())), + ])), ) } } @@ -422,10 +417,11 @@ impl Expr { let division = Expr::from(Function::divide(left, right.clone())); Expr::case( Expr::or( - Expr::gt_eq(right.clone(), Expr::val(EPSILON)), - Expr::lt_eq(right, - Expr::val(EPSILON)) - ), division, - Expr::val(0.0) + Expr::gt_eq(right.clone(), Expr::val(EPSILON)), + Expr::lt_eq(right, -Expr::val(EPSILON)), + ), + division, + Expr::val(0.0), ) } } @@ -589,7 +585,22 @@ macro_rules! impl_aggregation_constructors { }; } -impl_aggregation_constructors!(First, Last, Min, Max, Count, Mean, Sum, Var, Std, CountDistinct, MeanDistinct, SumDistinct, VarDistinct, StdDistinct); +impl_aggregation_constructors!( + First, + Last, + Min, + Max, + Count, + Mean, + Sum, + Var, + Std, + CountDistinct, + MeanDistinct, + SumDistinct, + VarDistinct, + StdDistinct +); /// An aggregate function expression #[derive(Clone, Debug, Hash, PartialEq, Eq)] @@ -1444,17 +1455,24 @@ impl DataType { } else { right_dt }; - let set = DataType::structured_from_data_types([left_dt.clone(), right_dt.clone()]); - if let (Expr::Column(col), Ok(dt)) = (left, data_type::function::greatest().super_image(&set)) { + let set = + DataType::structured_from_data_types([left_dt.clone(), right_dt.clone()]); + if let (Expr::Column(col), Ok(dt)) = + (left, data_type::function::greatest().super_image(&set)) + { datatype = datatype.replace(col, dt.super_intersection(&left_dt).unwrap()) } - if let (Expr::Column(col), Ok(dt)) = (right, data_type::function::least().super_image(&set)) { + if let (Expr::Column(col), Ok(dt)) = + (right, data_type::function::least().super_image(&set)) + { datatype = datatype.replace(col, dt.super_intersection(&right_dt).unwrap()) } } - }, + } (function::Function::Eq, [left, right]) => { - if let (Ok(left_dt), Ok(right_dt)) = (left.super_image(&datatype), right.super_image(&datatype)) { + if let (Ok(left_dt), Ok(right_dt)) = + (left.super_image(&datatype), right.super_image(&datatype)) + { let dt = left_dt.super_intersection(&right_dt).unwrap(); if let Expr::Column(col) = left { datatype = datatype.replace(&col, dt.clone()) @@ -2976,16 +2994,12 @@ mod tests { #[test] fn test_cast_integer_text() { println!("integer => text"); - let expression = Expr::cast_as_text( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_text(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::integer_values([1, 2, 3])), - ]); + let set = DataType::structured([("col1", DataType::integer_values([1, 2, 3]))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -2996,16 +3010,15 @@ mod tests { ); println!("\ntext => integer"); - let expression = Expr::cast_as_integer( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_integer(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::text_values(["1".to_string(), "2".to_string(), "3".to_string()])), - ]); + let set = DataType::structured([( + "col1", + DataType::text_values(["1".to_string(), "2".to_string(), "3".to_string()]), + )]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3019,16 +3032,12 @@ mod tests { #[test] fn test_cast_float_integer() { println!("float => integer"); - let expression = Expr::cast_as_integer( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_integer(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::float_values([1.1, 1.9, 5.49])), - ]); + let set = DataType::structured([("col1", DataType::float_values([1.1, 1.9, 5.49]))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3037,9 +3046,7 @@ mod tests { expression.super_image(&set).unwrap(), DataType::integer_values([1, 2, 5]) ); - let set = DataType::structured([ - ("col1", DataType::float_interval(1.1, 5.49)), - ]); + let set = DataType::structured([("col1", DataType::float_interval(1.1, 5.49))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3048,9 +3055,7 @@ mod tests { expression.super_image(&set).unwrap(), DataType::integer_interval(1, 5) ); - let set = DataType::structured([ - ("col1", DataType::float_interval(1.1, 1.49)), - ]); + let set = DataType::structured([("col1", DataType::float_interval(1.1, 1.49))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3061,16 +3066,12 @@ mod tests { ); println!("integer => float"); - let expression = Expr::cast_as_float( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_float(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::integer_values([1, 4, 7])), - ]); + let set = DataType::structured([("col1", DataType::integer_values([1, 4, 7]))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3079,9 +3080,7 @@ mod tests { expression.super_image(&set).unwrap(), DataType::float_values([1., 4., 7.]) ); - let set = DataType::structured([ - ("col1", DataType::integer_interval(1, 7)), - ]); + let set = DataType::structured([("col1", DataType::integer_interval(1, 7))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3095,16 +3094,12 @@ mod tests { #[test] fn test_cast_float_text() { println!("float => text"); - let expression = Expr::cast_as_text( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_text(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::float_values([1.1, 2., 3.5])), - ]); + let set = DataType::structured([("col1", DataType::float_values([1.1, 2., 3.5]))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3115,16 +3110,15 @@ mod tests { ); println!("\ntext => float"); - let expression = Expr::cast_as_float( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_float(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::text_values(["1.1".to_string(), "2".to_string(), "3.5".to_string()])), - ]); + let set = DataType::structured([( + "col1", + DataType::text_values(["1.1".to_string(), "2".to_string(), "3.5".to_string()]), + )]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3138,16 +3132,12 @@ mod tests { #[test] fn test_cast_boolean_text() { println!("boolean => text"); - let expression = Expr::cast_as_text( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_text(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::boolean_values([true, false])), - ]); + let set = DataType::structured([("col1", DataType::boolean_values([true, false]))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3158,16 +3148,15 @@ mod tests { ); println!("\ntext => boolean"); - let expression = Expr::cast_as_boolean( - Expr::col("col1".to_string()) - ); + let expression = Expr::cast_as_boolean(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::text_values(["n".to_string(), "fa".to_string(), "off".to_string()])), - ]); + let set = DataType::structured([( + "col1", + DataType::text_values(["n".to_string(), "fa".to_string(), "off".to_string()]), + )]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3181,17 +3170,13 @@ mod tests { #[test] fn test_sign() { println!("sign"); - let expression = Expr::sign( - Expr::col("col1".to_string()) - ); + let expression = Expr::sign(Expr::col("col1".to_string())); println!("expression = {}", expression); println!("expression domain = {}", expression.domain()); println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::float_interval(-10., 1.)), - ]); + let set = DataType::structured([("col1", DataType::float_interval(-10., 1.))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3201,9 +3186,7 @@ mod tests { DataType::integer_interval(-1, 1) ); - let set = DataType::structured([ - ("col1", DataType::integer_min(-0)), - ]); + let set = DataType::structured([("col1", DataType::integer_min(-0))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3213,9 +3196,7 @@ mod tests { DataType::integer_interval(0, 1) ); - let set = DataType::structured([ - ("col1", DataType::float_min(1.)), - ]); + let set = DataType::structured([("col1", DataType::float_min(1.))]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3235,9 +3216,7 @@ mod tests { println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::float()), - ]); + let set = DataType::structured([("col1", DataType::float())]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3257,9 +3236,7 @@ mod tests { println!("expression co domain = {}", expression.co_domain()); println!("expression data type = {}", expression.data_type()); - let set = DataType::structured([ - ("col1", DataType::float()), - ]); + let set = DataType::structured([("col1", DataType::float())]); println!( "expression super image = {}", expression.super_image(&set).unwrap() @@ -3280,17 +3257,20 @@ mod tests { println!("expression data type = {}", expression.data_type()); let set = DataType::structured([ - ("value", DataType::text_values(["a@foo.com".to_string(), "b@bar.org".to_string()])), - ("regexp", DataType::text_value(r"^[\w.+-]+@foo\.com|[\w.+-]+@bar\.org$".to_string())), + ( + "value", + DataType::text_values(["a@foo.com".to_string(), "b@bar.org".to_string()]), + ), + ( + "regexp", + DataType::text_value(r"^[\w.+-]+@foo\.com|[\w.+-]+@bar\.org$".to_string()), + ), ]); println!( "expression super image = {}", expression.super_image(&set).unwrap() ); - assert_eq!( - expression.super_image(&set).unwrap(), - DataType::boolean() - ); + assert_eq!(expression.super_image(&set).unwrap(), DataType::boolean()); } #[test] @@ -3330,15 +3310,12 @@ mod tests { let set = DataType::structured([ ("value", DataType::text_value("ab".to_string())), ("regexp", DataType::text_value(r"*b".to_string())), - ("replacement", DataType::text()) + ("replacement", DataType::text()), ]); println!( "expression super image = {}", expression.super_image(&set).unwrap() ); - assert_eq!( - expression.super_image(&set).unwrap(), - DataType::text() - ); + assert_eq!(expression.super_image(&set).unwrap(), DataType::text()); } } diff --git a/src/expr/split.rs b/src/expr/split.rs index d073e876..69b2e5e9 100644 --- a/src/expr/split.rs +++ b/src/expr/split.rs @@ -1,8 +1,8 @@ //! The splits with some improvements //! Each split has named Expr and anonymous exprs use super::{ - aggregate, function, visitor::Acceptor, AggregateColumn, Column, Expr, Function, - Identifier, Value, Visitor, + aggregate, function, visitor::Acceptor, AggregateColumn, Column, Expr, Function, Identifier, + Value, Visitor, }; use crate::{ namer::{self, FIELD}, @@ -36,9 +36,11 @@ impl Split { } pub fn group_by(expr: Expr) -> Reduce { - if let Expr::Column(col) = expr {// If the expression is a column + if let Expr::Column(col) = expr { + // If the expression is a column Reduce::new(vec![], vec![col], None) - } else {// If not + } else { + // If not let name = namer::name_from_content(FIELD, &expr); let map = Map::new(vec![(name.clone(), expr)], None, vec![], None); Reduce::new(vec![], vec![name.into()], Some(map)) @@ -933,9 +935,9 @@ mod tests { #[test] fn test_split_map_reduce_map_group_by_expr() { - let split = Split::from(("b", expr!(2*count(1 + y)))); - let split = split.and(Split::group_by(expr!(x-y)).into()); - let split = split.and(Split::from(("a", expr!(x-y)))); + let split = Split::from(("b", expr!(2 * count(1 + y)))); + let split = split.and(Split::group_by(expr!(x - y)).into()); + let split = split.and(Split::from(("a", expr!(x - y)))); println!("split = {split}"); } } diff --git a/src/expr/sql.rs b/src/expr/sql.rs index 7c89d8da..17e1b40b 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -1,7 +1,7 @@ //! Convert Expr into ast::Expr use crate::{ ast, - data_type::{DataType, Boolean}, + data_type::{Boolean, DataType}, expr::{self, Expr}, visitor::Acceptor, }; @@ -211,48 +211,58 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::function::Function::RegexpExtract => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: None, - args: vec![arguments[0].clone(), arguments[1].clone(), arguments[2].clone(), arguments[3].clone()] - .into_iter() - .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) - .collect(), + args: vec![ + arguments[0].clone(), + arguments[1].clone(), + arguments[2].clone(), + arguments[3].clone(), + ] + .into_iter() + .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + .collect(), clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, - expr::function::Function::Round - | expr::function::Function::Trunc => { + name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } + expr::function::Function::Round | expr::function::Function::Trunc => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: None, args: arguments .into_iter() - .filter_map(|e| (e!=ast::Expr::Value(ast::Value::Number("0".to_string(), false))).then_some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)))) + .filter_map(|e| { + (e != ast::Expr::Value(ast::Value::Number("0".to_string(), false))) + .then_some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + }) .collect(), clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::function::Function::Case => ast::Expr::Case { operand: None, conditions: vec![arguments[0].clone()], @@ -327,16 +337,46 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { null_treatment: None, within_group: vec![], }), - expr::function::Function::ExtractYear => ast::Expr::Extract {field: ast::DateTimeField::Year, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractMonth => ast::Expr::Extract {field: ast::DateTimeField::Month, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractDay => ast::Expr::Extract {field: ast::DateTimeField::Day, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractHour => ast::Expr::Extract {field: ast::DateTimeField::Hour, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractMinute => ast::Expr::Extract {field: ast::DateTimeField::Minute, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractSecond => ast::Expr::Extract {field: ast::DateTimeField::Second, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractMicrosecond => ast::Expr::Extract {field: ast::DateTimeField::Microsecond, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractMillisecond => ast::Expr::Extract {field: ast::DateTimeField::Millisecond, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractDow => ast::Expr::Extract {field: ast::DateTimeField::Dow, expr: arguments[0].clone().into()}, - expr::function::Function::ExtractWeek => ast::Expr::Extract {field: ast::DateTimeField::Week(None), expr: arguments[0].clone().into()}, + expr::function::Function::ExtractYear => ast::Expr::Extract { + field: ast::DateTimeField::Year, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractMonth => ast::Expr::Extract { + field: ast::DateTimeField::Month, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractDay => ast::Expr::Extract { + field: ast::DateTimeField::Day, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractHour => ast::Expr::Extract { + field: ast::DateTimeField::Hour, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractMinute => ast::Expr::Extract { + field: ast::DateTimeField::Minute, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractSecond => ast::Expr::Extract { + field: ast::DateTimeField::Second, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractMicrosecond => ast::Expr::Extract { + field: ast::DateTimeField::Microsecond, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractMillisecond => ast::Expr::Extract { + field: ast::DateTimeField::Millisecond, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractDow => ast::Expr::Extract { + field: ast::DateTimeField::Dow, + expr: arguments[0].clone().into(), + }, + expr::function::Function::ExtractWeek => ast::Expr::Extract { + field: ast::DateTimeField::Week(None), + expr: arguments[0].clone().into(), + }, expr::function::Function::Like => ast::Expr::Like { negated: false, expr: arguments[0].clone().into(), @@ -349,27 +389,29 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { pattern: arguments[1].clone().into(), escape_char: None, }, - expr::function::Function::Choose => if let ast::Expr::Tuple(t) = &arguments[1] { - let func_args_list = ast::FunctionArgumentList { - duplicate_treatment: None, - args: vec![arguments[0].clone()] - .into_iter() - .chain(t.clone()) - .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) - .collect(), - clauses: vec![], - }; - ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - }) - } else { - todo!() - }, + expr::function::Function::Choose => { + if let ast::Expr::Tuple(t) = &arguments[1] { + let func_args_list = ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![arguments[0].clone()] + .into_iter() + .chain(t.clone()) + .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + .collect(), + clauses: vec![], + }; + ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } else { + todo!() + } + } expr::function::Function::IsNull => ast::Expr::IsNull(arguments[0].clone().into()), expr::function::Function::IsBool => { if let ast::Expr::Value(ast::Value::Boolean(b)) = arguments[1] { @@ -381,7 +423,7 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { } else { unimplemented!() } - }, + } } } // TODO implement this properly @@ -397,106 +439,126 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { | expr::aggregate::Aggregate::Last | expr::aggregate::Aggregate::Mean | expr::aggregate::Aggregate::Count - | expr::aggregate::Aggregate::Sum + | expr::aggregate::Aggregate::Sum | expr::aggregate::Aggregate::Std => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: None, - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(aggregate.to_string())]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(aggregate.to_string())]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::Var => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: None, - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("variance"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], + name: ast::ObjectName(vec![ast::Ident::new(String::from("variance"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], }) - }, + } expr::aggregate::Aggregate::MeanDistinct => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: Some(ast::DuplicateTreatment::Distinct), - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("avg"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(String::from("avg"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::CountDistinct => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: Some(ast::DuplicateTreatment::Distinct), - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("count"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(String::from("count"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::SumDistinct => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: Some(ast::DuplicateTreatment::Distinct), - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("sum"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(String::from("sum"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::StdDistinct => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: Some(ast::DuplicateTreatment::Distinct), - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("stddev"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(String::from("stddev"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::VarDistinct => { let func_args_list = ast::FunctionArgumentList { duplicate_treatment: Some(ast::DuplicateTreatment::Distinct), - args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(argument))], + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], clauses: vec![], }; ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident::new(String::from("variance"))]), - args: ast::FunctionArguments::List(func_args_list), - over: None, - filter: None, - null_treatment: None, - within_group: vec![], - })}, + name: ast::ObjectName(vec![ast::Ident::new(String::from("variance"))]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } expr::aggregate::Aggregate::Median => todo!(), expr::aggregate::Aggregate::NUnique => todo!(), expr::aggregate::Aggregate::List => todo!(), @@ -780,10 +842,12 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); } - #[test] fn test_floor() { let str_expr = "floor(a)"; @@ -792,7 +856,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); } #[test] @@ -803,7 +870,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); let str_expr = "round(a, 1)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); @@ -811,7 +881,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); let str_expr = "round(a, 1)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); @@ -819,7 +892,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); } #[test] @@ -830,7 +906,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); let str_expr = "trunc(a, 1)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); @@ -838,7 +917,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); let str_expr = "trunc(a, 4)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); @@ -846,7 +928,10 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr.to_string().to_lowercase(), gen_expr.to_string().to_lowercase()); + assert_eq!( + ast_expr.to_string().to_lowercase(), + gen_expr.to_string().to_lowercase() + ); } #[test] @@ -889,7 +974,7 @@ mod tests { println!("ast::expr = {gen_expr}"); assert_eq!(ast_expr, gen_expr); - let epsilon = 1./f64::MAX; + let epsilon = 1. / f64::MAX; let str_expr = "log(b, x)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); @@ -898,21 +983,30 @@ mod tests { let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); let true_expr = parse_expr( - format!("CASE WHEN ((log(b)) >= ({})) OR ((log(b)) <= (-({}))) - THEN (log(x)) / ((log(b))) ELSE 0 END", epsilon, epsilon).as_str() - ).unwrap(); + format!( + "CASE WHEN ((log(b)) >= ({})) OR ((log(b)) <= (-({}))) + THEN (log(x)) / ((log(b))) ELSE 0 END", + epsilon, epsilon + ) + .as_str(), + ) + .unwrap(); assert_eq!(gen_expr, true_expr); - let str_expr = "log10(x)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); let expr = Expr::try_from(&ast_expr).unwrap(); println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); let true_expr = parse_expr( - format!("CASE WHEN ((log(x)) >= ({})) OR ((log(x)) <= (-({}))) - THEN (log(10)) / ((log(x))) ELSE 0 END", epsilon, epsilon).as_str() - ).unwrap(); + format!( + "CASE WHEN ((log(x)) >= ({})) OR ((log(x)) <= (-({}))) + THEN (log(10)) / ((log(x))) ELSE 0 END", + epsilon, epsilon + ) + .as_str(), + ) + .unwrap(); assert_eq!(gen_expr, true_expr); let str_expr = "log2(x)"; @@ -922,9 +1016,14 @@ mod tests { let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); let true_expr = parse_expr( - format!("CASE WHEN ((log(x)) >= ({})) OR ((log(x)) <= (-({}))) - THEN (log(2)) / ((log(x))) ELSE 0 END", epsilon, epsilon).as_str() - ).unwrap(); + format!( + "CASE WHEN ((log(x)) >= ({})) OR ((log(x)) <= (-({}))) + THEN (log(2)) / ((log(x))) ELSE 0 END", + epsilon, epsilon + ) + .as_str(), + ) + .unwrap(); assert_eq!(gen_expr, true_expr); } @@ -946,7 +1045,7 @@ mod tests { println!("ast::expr = {gen_expr}"); assert_eq!(ast_expr, gen_expr); - let epsilon = 1./f64::MAX; + let epsilon = 1. / f64::MAX; let str_expr = "tan(x)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); let expr = Expr::try_from(&ast_expr).unwrap(); @@ -954,9 +1053,14 @@ mod tests { let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); let true_expr = parse_expr( - format!("CASE WHEN ((cos(x)) >= ({})) OR ((cos(x)) <= (-({}))) - THEN (sin(x)) / ((cos(x))) ELSE 0 END", epsilon, epsilon).as_str() - ).unwrap(); + format!( + "CASE WHEN ((cos(x)) >= ({})) OR ((cos(x)) <= (-({}))) + THEN (sin(x)) / ((cos(x))) ELSE 0 END", + epsilon, epsilon + ) + .as_str(), + ) + .unwrap(); assert_eq!(gen_expr, true_expr); } @@ -990,11 +1094,16 @@ mod tests { println!("expr = {}", expr); let gen_expr = ast::Expr::from(&expr); println!("ast::expr = {gen_expr}"); - let epsilon = 1./f64::MAX; + let epsilon = 1. / f64::MAX; let true_expr = parse_expr( - format!("(100) * ((CASE WHEN ((pi()) >= ({})) OR ((pi()) <= (-({}))) - THEN (180) / ((pi())) ELSE 0 END))", epsilon, epsilon).as_str() - ).unwrap(); + format!( + "(100) * ((CASE WHEN ((pi()) >= ({})) OR ((pi()) <= (-({}))) + THEN (180) / ((pi())) ELSE 0 END))", + epsilon, epsilon + ) + .as_str(), + ) + .unwrap(); assert_eq!(gen_expr, true_expr); } @@ -1029,7 +1138,6 @@ mod tests { let true_expr = parse_expr("regexp_extract(value, regexp, 0, 1)").unwrap(); assert_eq!(gen_expr, true_expr); - let str_expr = "regexp_extract(value, regexp, position)"; let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); let expr = Expr::try_from(&ast_expr).unwrap(); diff --git a/src/hierarchy.rs b/src/hierarchy.rs index 03a36056..031293a2 100644 --- a/src/hierarchy.rs +++ b/src/hierarchy.rs @@ -8,9 +8,9 @@ use core::fmt; use itertools::Itertools; use std::{ collections::BTreeMap, + error, iter::Extend, ops::{Deref, DerefMut, Index}, - error, result, }; @@ -284,7 +284,9 @@ where fn index(&self, index: P) -> &Self::Output { let path = index.path(); - self.get(&path).ok_or_else(|| Error::invalid_path(path.join("."))).unwrap() + self.get(&path) + .ok_or_else(|| Error::invalid_path(path.join("."))) + .unwrap() } } diff --git a/src/io/bigquery.rs b/src/io/bigquery.rs index 0ce8a2ca..114bb5b4 100644 --- a/src/io/bigquery.rs +++ b/src/io/bigquery.rs @@ -500,7 +500,7 @@ impl DatabaseTrait for Database { json: map_as_json, }); } - + insert_query.add_rows(rows_for_bq.clone())?; rt.block_on(self.client.tabledata().insert_all( @@ -878,14 +878,13 @@ mod tests { .ok(); if let Some(tabs) = list_tabs { let tables_as_str: Vec = tabs - .tables - .unwrap_or_default() - .into_iter() - .map(|t| t.table_reference.table_id) - .collect(); + .tables + .unwrap_or_default() + .into_iter() + .map(|t| t.table_reference.table_id) + .collect(); println!("{:?}", tables_as_str); } - } // #[tokio::test] diff --git a/src/io/mod.rs b/src/io/mod.rs index dbdb2108..e2a95621 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -7,13 +7,13 @@ //! - BigQuery using the ["bigquery"] feature. //! +#[cfg(feature = "bigquery")] +pub mod bigquery; +#[cfg(feature = "mssql")] +pub mod mssql; pub mod postgresql; #[cfg(feature = "sqlite")] pub mod sqlite; -#[cfg(feature = "mssql")] -pub mod mssql; -#[cfg(feature = "bigquery")] -pub mod bigquery; use crate::{ builder::{Ready, With}, diff --git a/src/io/mssql.rs b/src/io/mssql.rs index 1001b96a..3a15fb15 100644 --- a/src/io/mssql.rs +++ b/src/io/mssql.rs @@ -361,7 +361,7 @@ impl DatabaseTrait for Database { for value in &values { insert_query = insert_query.bind(value); } - + rt.block_on(async_execute(insert_query, &self.pool))?; } Ok(()) diff --git a/src/io/sqlite.rs b/src/io/sqlite.rs index d69beb85..3081892b 100644 --- a/src/io/sqlite.rs +++ b/src/io/sqlite.rs @@ -5,7 +5,8 @@ use crate::{ value::{self, Value}, DataTyped, }, - relation::{Table, Variant as _}, dialect_translation::sqlite::SQLiteTranslator, + dialect_translation::sqlite::SQLiteTranslator, + relation::{Table, Variant as _}, }; use rand::{rngs::StdRng, SeedableRng}; use rusqlite::{ @@ -60,13 +61,17 @@ impl DatabaseTrait for Database { } fn create_table(&mut self, table: &Table) -> Result { - Ok(self.connection.execute(&table.create(SQLiteTranslator).to_string(), ())?) + Ok(self + .connection + .execute(&table.create(SQLiteTranslator).to_string(), ())?) } fn insert_data(&mut self, table: &Table) -> Result<()> { let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED); let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize); - let mut statement = self.connection.prepare(&table.insert("?", SQLiteTranslator).to_string())?; + let mut statement = self + .connection + .prepare(&table.insert("?", SQLiteTranslator).to_string())?; for _ in 0..size { let structured: value::Struct = table.schema().data_type().generate(&mut rng).try_into()?; diff --git a/src/lib.rs b/src/lib.rs index 0329306c..43e5d586 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ pub mod setup; pub mod expr; pub mod builder; pub mod debug; +pub mod dialect_translation; pub mod differential_privacy; pub mod display; pub mod encoder; @@ -35,7 +36,6 @@ pub mod sql; pub mod synthetic_data; pub mod types; pub mod visitor; -pub mod dialect_translation; pub use builder::{Ready, With, WithContext, WithIterator, WithoutContext}; pub use data_type::{value::Value, DataType}; diff --git a/src/relation/builder.rs b/src/relation/builder.rs index 31cc59f6..d4a74b01 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -3,7 +3,7 @@ use std::{hash::Hash, sync::Arc}; use itertools::Itertools; use super::{ - Error, Result, Join, JoinOperator, Map, OrderBy, Reduce, Relation, Schema, Set, SetOperator, + Error, Join, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Schema, Set, SetOperator, SetQuantifier, Table, Values, Variant, }; use crate::{ diff --git a/src/relation/dot.rs b/src/relation/dot.rs index 70e58e71..4f029c29 100644 --- a/src/relation/dot.rs +++ b/src/relation/dot.rs @@ -406,7 +406,10 @@ mod tests { namer::reset(); let schema: Schema = vec![ ("a", DataType::float()), - ("b", DataType::text_values(&["A&B".into(), "C>D".into()]))].into_iter().collect(); + ("b", DataType::text_values(&["A&B".into(), "C>D".into()])), + ] + .into_iter() + .collect(); let table: Relation = Relation::table() .name("table") .schema(schema.clone()) diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 60f5e508..92be3d59 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -12,7 +12,12 @@ pub mod schema; pub mod sql; use std::{ - cmp, collections::HashSet, error, fmt, hash, ops::{Deref, Index}, result, sync::Arc + cmp, + collections::HashSet, + error, fmt, hash, + ops::{Deref, Index}, + result, + sync::Arc, }; use colored::Colorize; @@ -700,7 +705,11 @@ impl JoinOperator { } // A utility function - fn expr_has_unique_constraint(expr: &Expr, left_schema: &Schema, right_schema: &Schema) -> (bool, bool) { + fn expr_has_unique_constraint( + expr: &Expr, + left_schema: &Schema, + right_schema: &Schema, + ) -> (bool, bool) { match expr { Expr::Function(f) => match f.function() { function::Function::Eq => { @@ -749,12 +758,20 @@ impl JoinOperator { (left, right) } function::Function::And => { - let arg_0 = JoinOperator::expr_has_unique_constraint(&f.arguments()[0], left_schema, right_schema); - let arg_1 = JoinOperator::expr_has_unique_constraint(&f.arguments()[1], left_schema, right_schema); + let arg_0 = JoinOperator::expr_has_unique_constraint( + &f.arguments()[0], + left_schema, + right_schema, + ); + let arg_1 = JoinOperator::expr_has_unique_constraint( + &f.arguments()[1], + left_schema, + right_schema, + ); (arg_0.0 || arg_1.0, arg_0.1 || arg_1.1) } _ => (false, false), - } + }, _ => (false, false), } } @@ -764,7 +781,9 @@ impl JoinOperator { JoinOperator::Inner(e) | JoinOperator::LeftOuter(e) | JoinOperator::RightOuter(e) - | JoinOperator::FullOuter(e) => JoinOperator::expr_has_unique_constraint(e, left_schema, right_schema), + | JoinOperator::FullOuter(e) => { + JoinOperator::expr_has_unique_constraint(e, left_schema, right_schema) + } _ => (false, false), } } @@ -1277,8 +1296,17 @@ impl Values { .data_type() .try_into() .unwrap(); - let unique = values.iter().collect::>().len()==values.iter().collect::>().len(); - Schema::from_field(Field::new(name.to_string(), list.data_type().clone(), if unique {Some(Constraint::Unique)} else {None})) + let unique = + values.iter().collect::>().len() == values.iter().collect::>().len(); + Schema::from_field(Field::new( + name.to_string(), + list.data_type().clone(), + if unique { + Some(Constraint::Unique) + } else { + None + }, + )) } pub fn builder() -> ValuesBuilder { diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 292d32d4..4ec7df52 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -205,12 +205,12 @@ impl Join { /// To mimic the behavior of USING(col) and NATURAL JOIN in SQL we create /// a map where join columns identified by `vec` are coalesced. /// vec: vector of string identifying input columns present in both _LEFT_ - /// and _RIGHT_ relation of the join. + /// and _RIGHT_ relation of the join. /// columns: is the Hierarchy mapping input names in the JOIN to name field - /// + /// /// It returns a: /// - Map build on top the Join with coalesced column along with - /// the other fields of the join and + /// the other fields of the join and /// - coalesced columns mapping (name in join -> name in map) pub fn remove_duplicates_and_coalesce( self, @@ -218,36 +218,34 @@ impl Join { columns: &Hierarchy, ) -> (Relation, Hierarchy) { let mut coalesced_cols: Vec<(Identifier, Identifier)> = vec![]; - let coalesced = self - .field_inputs() - .filter_map(|(_, input_id)| { - let col = input_id.as_ref().last().unwrap(); - if input_id.as_ref().first().unwrap().as_str() == LEFT_INPUT_NAME && vec.contains(col) { - let left_col = columns[[LEFT_INPUT_NAME, col]].as_ref().last().unwrap(); - let right_col = columns[[RIGHT_INPUT_NAME, col]].as_ref().last().unwrap(); - coalesced_cols.push((left_col.as_str().into(), col[..].into())); - coalesced_cols.push((right_col.as_str().into(), col[..].into())); - Some(( - col.clone(), - Expr::coalesce( - Expr::col(left_col), - Expr::col(right_col), - ), - )) - } else { - None - } - }); + let coalesced = self.field_inputs().filter_map(|(_, input_id)| { + let col = input_id.as_ref().last().unwrap(); + if input_id.as_ref().first().unwrap().as_str() == LEFT_INPUT_NAME && vec.contains(col) { + let left_col = columns[[LEFT_INPUT_NAME, col]].as_ref().last().unwrap(); + let right_col = columns[[RIGHT_INPUT_NAME, col]].as_ref().last().unwrap(); + coalesced_cols.push((left_col.as_str().into(), col[..].into())); + coalesced_cols.push((right_col.as_str().into(), col[..].into())); + Some(( + col.clone(), + Expr::coalesce(Expr::col(left_col), Expr::col(right_col)), + )) + } else { + None + } + }); let coalesced_with_others = coalesced .chain(self.field_inputs().filter_map(|(name, id)| { let col = id.as_ref().last().unwrap(); (!vec.contains(col)).then_some((name.clone(), Expr::col(name))) })) .collect::>(); - (Relation::map() - .input(Relation::from(self)) - .with_iter(coalesced_with_others) - .build(), coalesced_cols.into_iter().collect()) + ( + Relation::map() + .input(Relation::from(self)) + .with_iter(coalesced_with_others) + .build(), + coalesced_cols.into_iter().collect(), + ) } } @@ -474,7 +472,12 @@ impl Relation { /// - The original fields from the current relation. /// - Rescaled columns, where each rescaled column is a product of the original column (specified by the second element of the corresponding tuple in `values`) /// and its scaling factor output by `scale_factors` Relation - pub fn scale(self, entities: &str, named_values: &[(&str, &str)], scale_factors: Relation) -> Self { + pub fn scale( + self, + entities: &str, + named_values: &[(&str, &str)], + scale_factors: Relation, + ) -> Self { // Join the two relations on the entity column let join: Relation = Relation::join() .left_outer(Expr::val(true)) diff --git a/src/relation/sql.rs b/src/relation/sql.rs index ce60f6f0..dde57baf 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -147,10 +147,10 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV // Add input query to CTEs input_ctes.push( self.translator.cte( - self.translator.identifier( &(map.name().into()) )[0].clone(), + self.translator.identifier(&(map.name().into()))[0].clone(), map.schema() .iter() - .map(|field| self.translator.identifier( &(field.name().into()) )[0].clone()) + .map(|field| self.translator.identifier(&(field.name().into()))[0].clone()) .collect(), self.translator.query( vec![], @@ -160,7 +160,7 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV .zip(map.schema.clone()) .map(|(expr, field)| ast::SelectItem::ExprWithAlias { expr: self.translator.expr(&expr), - alias: self.translator.identifier( &(field.name().into()) )[0].clone(), + alias: self.translator.identifier(&(field.name().into()))[0].clone(), }) .collect(), table_with_joins( @@ -210,11 +210,11 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV // Add input query to CTEs input_ctes.push( self.translator.cte( - self.translator.identifier( &(reduce.name().into()) )[0].clone(), + self.translator.identifier(&(reduce.name().into()))[0].clone(), reduce .schema() .iter() - .map(|field| self.translator.identifier( &(field.name().into()) )[0].clone()) + .map(|field| self.translator.identifier(&(field.name().into()))[0].clone()) .collect(), self.translator.query( vec![], @@ -225,7 +225,7 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV .zip(reduce.schema.clone()) .map(|(aggregate, field)| ast::SelectItem::ExprWithAlias { expr: self.translator.expr(aggregate.deref()), - alias: self.translator.identifier( &(field.name().into()) )[0].clone(), + alias: self.translator.identifier(&(field.name().into()))[0].clone(), }) .collect(), table_with_joins( @@ -280,10 +280,10 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV // Add input query to CTEs input_ctes.push( self.translator.cte( - self.translator.identifier( &(join.name().into()) )[0].clone(), + self.translator.identifier(&(join.name().into()))[0].clone(), join.schema() .iter() - .map(|field| self.translator.identifier( &(field.name().into()) )[0].clone()) + .map(|field| self.translator.identifier(&(field.name().into()))[0].clone()) .collect(), self.translator.query( vec![], @@ -341,7 +341,7 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV set.name().into(), set.schema() .iter() - .map(|field| self.translator.identifier( &(field.name().into()) )[0].clone()) + .map(|field| self.translator.identifier(&(field.name().into()))[0].clone()) .collect(), set_operation( vec![], @@ -397,11 +397,10 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV None, None, ); - let value_name = self.translator.identifier( &(values.name().into()) )[0].clone(); - let input_ctes = - vec![self - .translator - .cte(value_name.clone(), vec![value_name], cte_query)]; + let value_name = self.translator.identifier(&(values.name().into()))[0].clone(); + let input_ctes = vec![self + .translator + .cte(value_name.clone(), vec![value_name], cte_query)]; self.translator.query( input_ctes, all(), diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 74ee1e95..6d1a1d97 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -7,8 +7,8 @@ use crate::{ builder::{WithContext, WithoutContext}, expr::{identifier::Identifier, Expr, Value}, hierarchy::{Hierarchy, Path}, - visitor::{self, Acceptor, Dependencies, Visited}, namer, + visitor::{self, Acceptor, Dependencies, Visited}, }; use itertools::Itertools; use sqlparser::{ @@ -43,10 +43,7 @@ impl<'a> Acceptor<'a> for ast::Expr { match self { ast::Expr::Identifier(_) => Dependencies::empty(), ast::Expr::CompoundIdentifier(_) => Dependencies::empty(), - ast::Expr::JsonAccess { - value, - path - } => Dependencies::from([value.as_ref()]), + ast::Expr::JsonAccess { value, path } => Dependencies::from([value.as_ref()]), ast::Expr::CompositeAccess { expr, key: _ } => Dependencies::from([expr.as_ref()]), ast::Expr::IsFalse(expr) => Dependencies::from([expr.as_ref()]), ast::Expr::IsNotFalse(expr) => Dependencies::from([expr.as_ref()]), @@ -171,14 +168,18 @@ impl<'a> Acceptor<'a> for ast::Expr { value: _, } => Dependencies::empty(), ast::Expr::MapAccess { column, keys } => Dependencies::from([column.as_ref()]), - ast::Expr::Function(function) => { - match &function.args { - ast::FunctionArguments::None => Dependencies::empty(), - ast::FunctionArguments::Subquery(_) => Dependencies::empty(), - ast::FunctionArguments::List(list_args) => list_args.args + ast::Expr::Function(function) => match &function.args { + ast::FunctionArguments::None => Dependencies::empty(), + ast::FunctionArguments::Subquery(_) => Dependencies::empty(), + ast::FunctionArguments::List(list_args) => list_args + .args .iter() .map(|arg| match arg { - ast::FunctionArg::Named { name: _, arg, operator: _} => arg, + ast::FunctionArg::Named { + name: _, + arg, + operator: _, + } => arg, ast::FunctionArg::Unnamed(arg) => arg, }) .filter_map(|arg| match arg { @@ -186,8 +187,7 @@ impl<'a> Acceptor<'a> for ast::Expr { _ => None, }) .collect(), - } - } + }, ast::Expr::Case { operand, conditions, @@ -229,7 +229,13 @@ impl<'a> Acceptor<'a> for ast::Expr { } => todo!(), ast::Expr::Struct { values, fields } => todo!(), ast::Expr::Named { expr, name } => todo!(), - ast::Expr::Convert { expr, data_type, charset, target_before_value, styles } => todo!(), + ast::Expr::Convert { + expr, + data_type, + charset, + target_before_value, + styles, + } => todo!(), ast::Expr::Wildcard => todo!(), ast::Expr::QualifiedWildcard(_) => todo!(), ast::Expr::Dictionary(_) => Dependencies::empty(), @@ -264,7 +270,7 @@ pub trait Visitor<'a, T: Clone> { fn substring(&self, expr: T, substring_from: Option, substring_for: Option) -> T; fn ceil(&self, expr: T, field: &'a ast::DateTimeField) -> T; fn floor(&self, expr: T, field: &'a ast::DateTimeField) -> T; - fn cast(&self, expr:T, data_type: &'a ast::DataType) -> T; + fn cast(&self, expr: T, data_type: &'a ast::DataType) -> T; fn extract(&self, field: &'a ast::DateTimeField, expr: T) -> T; fn like(&self, expr: T, pattern: T) -> T; fn ilike(&self, expr: T, pattern: T) -> T; @@ -285,37 +291,34 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { match acceptor { ast::Expr::Identifier(ident) => self.identifier(ident), ast::Expr::CompoundIdentifier(idents) => self.compound_identifier(idents), - ast::Expr::JsonAccess { - value, - path - } => todo!(), + ast::Expr::JsonAccess { value, path } => todo!(), ast::Expr::CompositeAccess { expr, key } => todo!(), ast::Expr::IsFalse(expr) => self.is( self.cast(dependencies.get(expr).clone(), &ast::DataType::Boolean), - Some(false) + Some(false), ), ast::Expr::IsNotFalse(expr) => self.unary_op( &ast::UnaryOperator::Not, self.is( self.cast(dependencies.get(expr).clone(), &ast::DataType::Boolean), - Some(false) - ) + Some(false), + ), ), ast::Expr::IsTrue(expr) => self.is( self.cast(dependencies.get(expr).clone(), &ast::DataType::Boolean), - Some(true) + Some(true), ), ast::Expr::IsNotTrue(expr) => self.unary_op( &ast::UnaryOperator::Not, self.is( self.cast(dependencies.get(expr).clone(), &ast::DataType::Boolean), - Some(true) - ) + Some(true), + ), ), ast::Expr::IsNull(expr) => self.is(dependencies.get(expr).clone(), None), ast::Expr::IsNotNull(expr) => self.unary_op( &ast::UnaryOperator::Not, - self.is(dependencies.get(expr).clone(), None) + self.is(dependencies.get(expr).clone(), None), ), ast::Expr::IsUnknown(_) => todo!(), ast::Expr::IsNotUnknown(_) => todo!(), @@ -356,13 +359,13 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { self.binary_op( dependencies.get(expr).clone(), &ast::BinaryOperator::GtEq, - dependencies.get(low).clone() + dependencies.get(low).clone(), ), &ast::BinaryOperator::And, self.binary_op( dependencies.get(expr).clone(), &ast::BinaryOperator::LtEq, - dependencies.get(high).clone() + dependencies.get(high).clone(), ), ); if *negated { @@ -387,14 +390,14 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { }; let x = self.like( dependencies.get(expr).clone(), - dependencies.get(pattern).clone() + dependencies.get(pattern).clone(), ); if *negated { self.unary_op(&ast::UnaryOperator::Not, x) } else { x } - }, + } ast::Expr::ILike { negated, expr, @@ -406,14 +409,14 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { }; let x = self.ilike( dependencies.get(expr).clone(), - dependencies.get(pattern).clone() + dependencies.get(pattern).clone(), ); if *negated { self.unary_op(&ast::UnaryOperator::Not, x) } else { x } - }, + } ast::Expr::SimilarTo { negated, expr, @@ -439,13 +442,15 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { expr, data_type, format: _, - kind: _ + kind: _, } => self.cast(dependencies.get(expr).clone(), data_type), ast::Expr::AtTimeZone { timestamp, time_zone, } => todo!(), - ast::Expr::Extract { field, expr } => self.extract(field, dependencies.get(expr).clone()), + ast::Expr::Extract { field, expr } => { + self.extract(field, dependencies.get(expr).clone()) + } ast::Expr::Ceil { expr, field } => self.ceil(dependencies.get(expr).clone(), field), ast::Expr::Floor { expr, field } => self.floor(dependencies.get(expr).clone(), field), ast::Expr::Position { expr, r#in } => self.position( @@ -495,38 +500,40 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { ast::Expr::Value(value) => self.value(value), ast::Expr::TypedString { data_type, value } => todo!(), ast::Expr::MapAccess { column, keys } => todo!(), - ast::Expr::Function(function) => { - self.function(function, { - let mut result = vec![]; - let function_args = match &function.args { - ast::FunctionArguments::None => vec![], - ast::FunctionArguments::Subquery(_) => vec![], - ast::FunctionArguments::List(arg_list) => arg_list.args.iter().collect(), - }; - for function_arg in function_args.iter() { - result.push(match function_arg { - ast::FunctionArg::Named { name, arg , operator} => FunctionArg::Named { - name: name.clone(), - arg: match arg { - ast::FunctionArgExpr::Expr(e) => dependencies.get(e).clone(), - ast::FunctionArgExpr::QualifiedWildcard(idents) => { - self.qualified_wildcard(&idents.0) - } - ast::FunctionArgExpr::Wildcard => self.wildcard(), - }, - }, - ast::FunctionArg::Unnamed(arg) => FunctionArg::Unnamed(match arg { + ast::Expr::Function(function) => self.function(function, { + let mut result = vec![]; + let function_args = match &function.args { + ast::FunctionArguments::None => vec![], + ast::FunctionArguments::Subquery(_) => vec![], + ast::FunctionArguments::List(arg_list) => arg_list.args.iter().collect(), + }; + for function_arg in function_args.iter() { + result.push(match function_arg { + ast::FunctionArg::Named { + name, + arg, + operator, + } => FunctionArg::Named { + name: name.clone(), + arg: match arg { ast::FunctionArgExpr::Expr(e) => dependencies.get(e).clone(), ast::FunctionArgExpr::QualifiedWildcard(idents) => { self.qualified_wildcard(&idents.0) } ast::FunctionArgExpr::Wildcard => self.wildcard(), - }), - }); - } - result - }) - } + }, + }, + ast::FunctionArg::Unnamed(arg) => FunctionArg::Unnamed(match arg { + ast::FunctionArgExpr::Expr(e) => dependencies.get(e).clone(), + ast::FunctionArgExpr::QualifiedWildcard(idents) => { + self.qualified_wildcard(&idents.0) + } + ast::FunctionArgExpr::Wildcard => self.wildcard(), + }), + }); + } + result + }), ast::Expr::Case { operand, conditions, @@ -567,7 +574,13 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { } => todo!(), ast::Expr::Struct { values, fields } => todo!(), ast::Expr::Named { expr, name } => todo!(), - ast::Expr::Convert { expr, data_type, charset, target_before_value, styles } => todo!(), + ast::Expr::Convert { + expr, + data_type, + charset, + target_before_value, + styles, + } => todo!(), ast::Expr::Wildcard => todo!(), ast::Expr::QualifiedWildcard(_) => todo!(), ast::Expr::Dictionary(_) => todo!(), @@ -706,7 +719,11 @@ impl<'a> Visitor<'a, String> for DisplayVisitor { format!( "CEIL ({}{})", expr, - if matches!(field, ast::DateTimeField::NoDateTime) {"".to_string()} else {format!(", {field}")} + if matches!(field, ast::DateTimeField::NoDateTime) { + "".to_string() + } else { + format!(", {field}") + } ) } @@ -714,7 +731,11 @@ impl<'a> Visitor<'a, String> for DisplayVisitor { format!( "FLOOR ({}{})", expr, - if matches!(field, ast::DateTimeField::NoDateTime) {"".to_string()} else {format!(", {field}")} + if matches!(field, ast::DateTimeField::NoDateTime) { + "".to_string() + } else { + format!(", {field}") + } ) } @@ -735,7 +756,13 @@ impl<'a> Visitor<'a, String> for DisplayVisitor { } fn is(&self, expr: String, value: Option) -> String { - format!("{} IS {}", expr, value.map(|b| b.to_string().to_uppercase()).unwrap_or("NULL".to_string())) + format!( + "{} IS {}", + expr, + value + .map(|b| b.to_string().to_uppercase()) + .unwrap_or("NULL".to_string()) + ) } } @@ -766,17 +793,13 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } fn identifier(&self, ident: &'a ast::Ident) -> Result { - let column = self - .0 - .get(&ident.cloned()) - .cloned() - .unwrap_or_else(|| { - if let Some(_) = ident.quote_style { - ident.value.clone().into() - } else { - ident.value.to_lowercase().clone().into() - } - }); + let column = self.0.get(&ident.cloned()).cloned().unwrap_or_else(|| { + if let Some(_) = ident.quote_style { + ident.value.clone().into() + } else { + ident.value.to_lowercase().clone().into() + } + }); Ok(Expr::Column(column)) } @@ -913,7 +936,7 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } else { false } - }, + } }; Ok(match function_name { // Math Functions @@ -925,7 +948,10 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { if flat_args.len() == 1 { Expr::log(flat_args[0].clone()) } else { - Expr::divide(Expr::log(flat_args[1].clone()), Expr::log(flat_args[0].clone())) + Expr::divide( + Expr::log(flat_args[1].clone()), + Expr::log(flat_args[0].clone()), + ) } } "log2" => Expr::divide(Expr::log(Expr::val(2)), Expr::log(flat_args[0].clone())), @@ -933,7 +959,10 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { "abs" => Expr::abs(flat_args[0].clone()), "sin" => Expr::sin(flat_args[0].clone()), "cos" => Expr::cos(flat_args[0].clone()), - "tan" => Expr::divide(Expr::sin(flat_args[0].clone()), Expr::cos(flat_args[0].clone())), + "tan" => Expr::divide( + Expr::sin(flat_args[0].clone()), + Expr::cos(flat_args[0].clone()), + ), "sqrt" => Expr::sqrt(flat_args[0].clone()), "pow" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()), "power" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()), @@ -965,10 +994,7 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } else { Expr::val(0) }; - Expr::round( - flat_args[0].clone(), - precision, - ) + Expr::round(flat_args[0].clone(), precision) } "trunc" | "truncate" => { let precision = if flat_args.len() > 1 { @@ -976,29 +1002,24 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } else { Expr::val(0) }; - Expr::trunc( - flat_args[0].clone(), - precision, - ) + Expr::trunc(flat_args[0].clone(), precision) } "sign" => Expr::sign(flat_args[0].clone()), "random" | "rand" => Expr::random(namer::new_id("UNIFORM_SAMPLING")), "pi" => Expr::pi(), "degrees" => Expr::multiply( flat_args[0].clone(), - Expr::divide(Expr::val(180.), Expr::pi()) + Expr::divide(Expr::val(180.), Expr::pi()), ), "choose" => Expr::choose( flat_args[0].clone(), Expr::val(Value::list( - flat_args.iter() + flat_args + .iter() .skip(1) - .map(|x| - Value::try_from(x.clone()) - .map_err(|e| Error::other(e)) - ) - .collect::>>()? - )) + .map(|x| Value::try_from(x.clone()).map_err(|e| Error::other(e))) + .collect::>>()?, + )), ), // String functions "lower" => Expr::lower(flat_args[0].clone()), @@ -1028,9 +1049,18 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } else { Expr::val(1) }; - Expr::regexp_extract(flat_args[0].clone(), flat_args[1].clone(), position, occurrence) - }, - "regexp_replace" => Expr::regexp_replace(flat_args[0].clone(), flat_args[1].clone(), flat_args[2].clone()), + Expr::regexp_extract( + flat_args[0].clone(), + flat_args[1].clone(), + position, + occurrence, + ) + } + "regexp_replace" => Expr::regexp_replace( + flat_args[0].clone(), + flat_args[1].clone(), + flat_args[2].clone(), + ), "newid" => Expr::newid(), "encode" => Expr::encode(flat_args[0].clone(), flat_args[1].clone()), "decode" => Expr::decode(flat_args[0].clone(), flat_args[1].clone()), @@ -1042,7 +1072,11 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { "dayname" => Expr::dayname(flat_args[0].clone()), "date_format" => Expr::date_format(flat_args[0].clone(), flat_args[1].clone()), "quarter" => Expr::quarter(flat_args[0].clone()), - "datetime_diff" => Expr::datetime_diff(flat_args[0].clone(), flat_args[1].clone(), flat_args[2].clone()), + "datetime_diff" => Expr::datetime_diff( + flat_args[0].clone(), + flat_args[1].clone(), + flat_args[2].clone(), + ), "date" => Expr::date(flat_args[0].clone()), "from_unixtime" => { let format = if flat_args.len() > 1 { @@ -1051,7 +1085,7 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { Expr::val("%Y-%m-%d %H:%i:%S".to_string()) }; Expr::from_unixtime(flat_args[0].clone(), format) - }, + } "unix_timestamp" => { let arg = if flat_args.len() > 0 { flat_args[0].clone() @@ -1059,13 +1093,13 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { Expr::current_timestamp() }; Expr::unix_timestamp(arg) - }, + } // Aggregates "min" => Expr::min(flat_args[0].clone()), "max" => Expr::max(flat_args[0].clone()), "count" if distinct => Expr::count_distinct(flat_args[0].clone()), "count" => Expr::count(flat_args[0].clone()), - "avg" if distinct => Expr::mean_distinct(flat_args[0].clone()), + "avg" if distinct => Expr::mean_distinct(flat_args[0].clone()), "avg" => Expr::mean(flat_args[0].clone()), "sum" if distinct => Expr::sum_distinct(flat_args[0].clone()), "sum" => Expr::sum(flat_args[0].clone()), @@ -1153,110 +1187,110 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } fn ceil(&self, expr: Result, field: &'a ast::DateTimeField) -> Result { - if !matches!(field, ast::DateTimeField::NoDateTime) {todo!()} + if !matches!(field, ast::DateTimeField::NoDateTime) { + todo!() + } Ok(Expr::ceil(expr.clone()?)) } fn floor(&self, expr: Result, field: &'a ast::DateTimeField) -> Result { - if !matches!(field, ast::DateTimeField::NoDateTime) {todo!()} + if !matches!(field, ast::DateTimeField::NoDateTime) { + todo!() + } Ok(Expr::floor(expr.clone()?)) } fn cast(&self, expr: Result, data_type: &'a ast::DataType) -> Result { - Ok( - match data_type { - //Text - ast::DataType::Character(_) - | ast::DataType::Char(_) - | ast::DataType::CharacterVarying(_) - |ast::DataType::CharVarying(_) - | ast::DataType::Varchar(_) - | ast::DataType::Nvarchar(_) - | ast::DataType::Uuid - | ast::DataType::CharacterLargeObject(_) - | ast::DataType::CharLargeObject(_) - | ast::DataType::Clob(_) - | ast::DataType::Text - | ast::DataType::String(_) => Expr::cast_as_text(expr.clone()?), - //Bytes - ast::DataType::Binary(_) - | ast::DataType::Varbinary(_) - | ast::DataType::Blob(_) - | ast::DataType::Bytes(_) - | ast::DataType::Bytea => todo!(), - //Float - ast::DataType::Numeric(_) - | ast::DataType::Decimal(_) - | ast::DataType::BigNumeric(_) - | ast::DataType::BigDecimal(_) - | ast::DataType::Dec(_) - | ast::DataType::Float(_) - | ast::DataType::Float4 - | ast::DataType::Float64 - | ast::DataType::Real - | ast::DataType::Float8 - | ast::DataType::Double - | ast::DataType::DoublePrecision => Expr::cast_as_float(expr.clone()?), - // Integer - ast::DataType::TinyInt(_) - | ast::DataType::UnsignedTinyInt(_) - | ast::DataType::Int2(_) - | ast::DataType::UnsignedInt2(_) - | ast::DataType::SmallInt(_) - | ast::DataType::UnsignedSmallInt(_) - | ast::DataType::MediumInt(_) - | ast::DataType::UnsignedMediumInt(_) - | ast::DataType::Int(_) - | ast::DataType::Int4(_) - | ast::DataType::Int64 - | ast::DataType::Integer(_) - | ast::DataType::UnsignedInt(_) - | ast::DataType::UnsignedInt4(_) - | ast::DataType::UnsignedInteger(_) - | ast::DataType::BigInt(_) - | ast::DataType::UnsignedBigInt(_) - | ast::DataType::Int8(_) - | ast::DataType::UnsignedInt8(_) => Expr::cast_as_integer(expr.clone()?), - // Boolean - ast::DataType::Bool - | ast::DataType::Boolean => Expr::cast_as_boolean(expr.clone()?), - // Date - ast::DataType::Date => Expr::cast_as_date(expr.clone()?), - // Time - ast::DataType::Time(_, _) => Expr::cast_as_time(expr.clone()?), - // DateTime - ast::DataType::Datetime(_) - | ast::DataType::Timestamp(_, _) => Expr::cast_as_date_time(expr.clone()?), - ast::DataType::Interval => todo!(), - ast::DataType::JSON => todo!(), - ast::DataType::Regclass => todo!(), - ast::DataType::Custom(_, _) => todo!(), - ast::DataType::Array(_) => todo!(), - ast::DataType::Enum(_) => todo!(), - ast::DataType::Set(_) => todo!(), - ast::DataType::Struct(_) => todo!(), - ast::DataType::JSONB => todo!(), - ast::DataType::Unspecified => todo!(), + Ok(match data_type { + //Text + ast::DataType::Character(_) + | ast::DataType::Char(_) + | ast::DataType::CharacterVarying(_) + | ast::DataType::CharVarying(_) + | ast::DataType::Varchar(_) + | ast::DataType::Nvarchar(_) + | ast::DataType::Uuid + | ast::DataType::CharacterLargeObject(_) + | ast::DataType::CharLargeObject(_) + | ast::DataType::Clob(_) + | ast::DataType::Text + | ast::DataType::String(_) => Expr::cast_as_text(expr.clone()?), + //Bytes + ast::DataType::Binary(_) + | ast::DataType::Varbinary(_) + | ast::DataType::Blob(_) + | ast::DataType::Bytes(_) + | ast::DataType::Bytea => todo!(), + //Float + ast::DataType::Numeric(_) + | ast::DataType::Decimal(_) + | ast::DataType::BigNumeric(_) + | ast::DataType::BigDecimal(_) + | ast::DataType::Dec(_) + | ast::DataType::Float(_) + | ast::DataType::Float4 + | ast::DataType::Float64 + | ast::DataType::Real + | ast::DataType::Float8 + | ast::DataType::Double + | ast::DataType::DoublePrecision => Expr::cast_as_float(expr.clone()?), + // Integer + ast::DataType::TinyInt(_) + | ast::DataType::UnsignedTinyInt(_) + | ast::DataType::Int2(_) + | ast::DataType::UnsignedInt2(_) + | ast::DataType::SmallInt(_) + | ast::DataType::UnsignedSmallInt(_) + | ast::DataType::MediumInt(_) + | ast::DataType::UnsignedMediumInt(_) + | ast::DataType::Int(_) + | ast::DataType::Int4(_) + | ast::DataType::Int64 + | ast::DataType::Integer(_) + | ast::DataType::UnsignedInt(_) + | ast::DataType::UnsignedInt4(_) + | ast::DataType::UnsignedInteger(_) + | ast::DataType::BigInt(_) + | ast::DataType::UnsignedBigInt(_) + | ast::DataType::Int8(_) + | ast::DataType::UnsignedInt8(_) => Expr::cast_as_integer(expr.clone()?), + // Boolean + ast::DataType::Bool | ast::DataType::Boolean => Expr::cast_as_boolean(expr.clone()?), + // Date + ast::DataType::Date => Expr::cast_as_date(expr.clone()?), + // Time + ast::DataType::Time(_, _) => Expr::cast_as_time(expr.clone()?), + // DateTime + ast::DataType::Datetime(_) | ast::DataType::Timestamp(_, _) => { + Expr::cast_as_date_time(expr.clone()?) } - ) + ast::DataType::Interval => todo!(), + ast::DataType::JSON => todo!(), + ast::DataType::Regclass => todo!(), + ast::DataType::Custom(_, _) => todo!(), + ast::DataType::Array(_) => todo!(), + ast::DataType::Enum(_) => todo!(), + ast::DataType::Set(_) => todo!(), + ast::DataType::Struct(_) => todo!(), + ast::DataType::JSONB => todo!(), + ast::DataType::Unspecified => todo!(), + }) } fn extract(&self, field: &'a ast::DateTimeField, expr: Result) -> Result { - Ok( - match field { - ast::DateTimeField::Year => Expr::extract_year(expr.clone()?), - ast::DateTimeField::Month => Expr::extract_month(expr.clone()?), - ast::DateTimeField::Week(_) => Expr::extract_week(expr.clone()?), - ast::DateTimeField::Day => Expr::extract_day(expr.clone()?), - ast::DateTimeField::Hour => Expr::extract_hour(expr.clone()?), - ast::DateTimeField::Minute => Expr::extract_minute(expr.clone()?), - ast::DateTimeField::Second => Expr::extract_second(expr.clone()?), - ast::DateTimeField::Dow => Expr::extract_dow(expr.clone()?), - ast::DateTimeField::Microsecond => Expr::extract_microsecond(expr.clone()?), - ast::DateTimeField::Millisecond => Expr::extract_millisecond(expr.clone()?), - _ => todo!(), - } - ) + Ok(match field { + ast::DateTimeField::Year => Expr::extract_year(expr.clone()?), + ast::DateTimeField::Month => Expr::extract_month(expr.clone()?), + ast::DateTimeField::Week(_) => Expr::extract_week(expr.clone()?), + ast::DateTimeField::Day => Expr::extract_day(expr.clone()?), + ast::DateTimeField::Hour => Expr::extract_hour(expr.clone()?), + ast::DateTimeField::Minute => Expr::extract_minute(expr.clone()?), + ast::DateTimeField::Second => Expr::extract_second(expr.clone()?), + ast::DateTimeField::Dow => Expr::extract_dow(expr.clone()?), + ast::DateTimeField::Microsecond => Expr::extract_microsecond(expr.clone()?), + ast::DateTimeField::Millisecond => Expr::extract_millisecond(expr.clone()?), + _ => todo!(), + }) } fn like(&self, expr: Result, pattern: Result) -> Result { @@ -1268,12 +1302,10 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { } fn is(&self, expr: Result, value: Option) -> Result { - Ok( - match value { - Some(b) => Expr::is_bool(expr.clone()?, Expr::val(b)), - None => Expr::is_null(expr.clone()?), - } - ) + Ok(match value { + Some(b) => Expr::is_bool(expr.clone()?, Expr::val(b)), + None => Expr::is_null(expr.clone()?), + }) } } diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 8a1da91d..13dfc1bb 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -112,13 +112,13 @@ mod tests { builder::With, display::Dot, io::{postgresql, Database}, - relation::Relation, DataType, + relation::Relation, + DataType, }; use colored::Colorize; use itertools::Itertools; use sqlparser::dialect::BigQueryDialect; - #[test] fn test_display() { let database = postgresql::test_database(); @@ -147,17 +147,18 @@ mod tests { let mut database = postgresql::test_database(); for query in [ - "SELECT CAST(a AS text) FROM table_1", // float => text - "SELECT CAST(b AS text) FROM table_1", // integer => text - "SELECT CAST(c AS text) FROM table_1", // date => text - "SELECT CAST(z AS text) FROM table_2", // text => text - "SELECT CAST(x AS float) FROM table_2", // integer => float + "SELECT CAST(a AS text) FROM table_1", // float => text + "SELECT CAST(b AS text) FROM table_1", // integer => text + "SELECT CAST(c AS text) FROM table_1", // date => text + "SELECT CAST(z AS text) FROM table_2", // text => text + "SELECT CAST(x AS float) FROM table_2", // integer => float "SELECT CAST('true' AS boolean) FROM table_2", // integer => float "SELECT CEIL(3 * b), FLOOR(3 * b), TRUNC(3 * b), ROUND(3 * b) FROM table_1", - "SELECT SUM(DISTINCT a), SUM(a) FROM table_1" + "SELECT SUM(DISTINCT a), SUM(a) FROM table_1", ] { let res1 = database.query(query).unwrap(); - let relation = Relation::try_from(parse(query).unwrap().with(&database.relations())).unwrap(); + let relation = + Relation::try_from(parse(query).unwrap().with(&database.relations())).unwrap(); let relation_query: &str = &ast::Query::from(&relation).to_string(); println!("{query} => {relation_query}"); let res2 = database.query(relation_query).unwrap(); diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 198079d2..43531547 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -8,16 +8,33 @@ use super::{ Error, Result, }; use crate::{ - ast, builder::{Ready, With, WithIterator, WithoutContext}, dialect::{Dialect, GenericDialect}, dialect_translation::{postgresql::PostgreSqlTranslator, QueryToRelationTranslator}, display::Dot, expr::{Expr, Identifier, Reduce, Split}, hierarchy::{Hierarchy, Path}, namer::{self, FIELD}, parser::Parser, relation::{ - Join, JoinOperator, MapBuilder, Relation, SetOperator, SetQuantifier, - Variant as _, WithInput, - LEFT_INPUT_NAME, RIGHT_INPUT_NAME - }, tokenizer::Tokenizer, types::And, visitor::{Acceptor, Dependencies, Visited} + ast, + builder::{Ready, With, WithIterator, WithoutContext}, + dialect::{Dialect, GenericDialect}, + dialect_translation::{postgresql::PostgreSqlTranslator, QueryToRelationTranslator}, + display::Dot, + expr::{Expr, Identifier, Reduce, Split}, + hierarchy::{Hierarchy, Path}, + namer::{self, FIELD}, + parser::Parser, + relation::{ + Join, JoinOperator, MapBuilder, Relation, SetOperator, SetQuantifier, Variant as _, + WithInput, LEFT_INPUT_NAME, RIGHT_INPUT_NAME, + }, + tokenizer::Tokenizer, + types::And, + visitor::{Acceptor, Dependencies, Visited}, }; use dot::Id; use itertools::Itertools; use std::{ - collections::HashMap, convert::TryFrom, iter::{once, Iterator}, ops::Deref, result, str::FromStr, sync::Arc + collections::HashMap, + convert::TryFrom, + iter::{once, Iterator}, + ops::Deref, + result, + str::FromStr, + sync::Arc, }; /* @@ -32,12 +49,20 @@ This is done in the query_names module. struct TryIntoRelationVisitor<'a, T: QueryToRelationTranslator + Copy + Clone> { relations: &'a Hierarchy>, query_names: QueryNames<'a>, - translator: T + translator: T, } impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryIntoRelationVisitor<'a, T> { - fn new(relations: &'a Hierarchy>, query_names: QueryNames<'a>, translator: T) -> Self { - TryIntoRelationVisitor{ relations, query_names, translator } + fn new( + relations: &'a Hierarchy>, + query_names: QueryNames<'a>, + translator: T, + ) -> Self { + TryIntoRelationVisitor { + relations, + query_names, + translator, + } } } @@ -80,10 +105,10 @@ impl RelationWithColumns { } /// A struct to hold the query being visited and its Relations -struct VisitedQueryRelations<'a, T: QueryToRelationTranslator + Copy + Clone>{ +struct VisitedQueryRelations<'a, T: QueryToRelationTranslator + Copy + Clone> { relations: Hierarchy>, visited: Visited<'a, ast::Query, Result>>, - translator: T + translator: T, } impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, T> { @@ -92,14 +117,22 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, query: &'a ast::Query, visited: Visited<'a, ast::Query, Result>>, ) -> Self { - let TryIntoRelationVisitor{relations, query_names, translator} = try_into_relation_visitor; + let TryIntoRelationVisitor { + relations, + query_names, + translator, + } = try_into_relation_visitor; let mut relations: Hierarchy> = (*relations).clone(); relations.extend( query_names .name_referred(query) .map(|(name, referred)| (name.clone(), visited.get(referred).clone().unwrap())), ); - VisitedQueryRelations{relations, visited, translator: *translator} + VisitedQueryRelations { + relations, + visited, + translator: *translator, + } } /// Convert a TableFactor into a RelationWithColumns @@ -107,7 +140,11 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, &self, table_factor: &'a ast::TableFactor, ) -> Result { - let VisitedQueryRelations{relations, visited, translator} = self; + let VisitedQueryRelations { + relations, + visited, + translator, + } = self; // Process the table_factor match &table_factor { @@ -169,43 +206,63 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, ) -> Result { Ok(match join_constraint { ast::JoinConstraint::On(expr) => self.translator.try_expr(expr, columns)?, - ast::JoinConstraint::Using(idents) => { // the "Using (id)" condition is equivalent to "ON _LEFT_.id = _RIGHT_.id" - Expr::and_iter( - idents.into_iter() - .map(|id| - Expr::eq( - Expr::Column(Identifier::from(vec![LEFT_INPUT_NAME.to_string(), id.value.to_string()])), - Expr::Column(Identifier::from(vec![RIGHT_INPUT_NAME.to_string(), id.value.to_string()])), - )) - ) - }, - ast::JoinConstraint::Natural => { // the NATURAL condition is equivalent to a "ON _LEFT_.col1 = _RIGHT_.col1 AND _LEFT_.col2 = _RIGHT_.col2" where col1, col2... are the columns present in both tables - let tables = columns.iter() - .map(|(k, _)| k.iter().take(k.len() - 1).map(|s| s.to_string()).collect::>()) - .dedup() - .collect::>(); + ast::JoinConstraint::Using(idents) => { + // the "Using (id)" condition is equivalent to "ON _LEFT_.id = _RIGHT_.id" + Expr::and_iter(idents.into_iter().map(|id| { + Expr::eq( + Expr::Column(Identifier::from(vec![ + LEFT_INPUT_NAME.to_string(), + id.value.to_string(), + ])), + Expr::Column(Identifier::from(vec![ + RIGHT_INPUT_NAME.to_string(), + id.value.to_string(), + ])), + ) + })) + } + ast::JoinConstraint::Natural => { + // the NATURAL condition is equivalent to a "ON _LEFT_.col1 = _RIGHT_.col1 AND _LEFT_.col2 = _RIGHT_.col2" where col1, col2... are the columns present in both tables + let tables = columns + .iter() + .map(|(k, _)| { + k.iter() + .take(k.len() - 1) + .map(|s| s.to_string()) + .collect::>() + }) + .dedup() + .collect::>(); assert_eq!(tables.len(), 2); let columns_1 = columns.filter(tables[0].as_slice()); let columns_2 = columns.filter(tables[1].as_slice()); let columns_1 = columns_1 - .iter() - .map(|(k, _)| k.last().unwrap()) - .collect::>(); + .iter() + .map(|(k, _)| k.last().unwrap()) + .collect::>(); let columns_2 = columns_2 - .iter() - .map(|(k, _)| k.last().unwrap()) - .collect::>(); + .iter() + .map(|(k, _)| k.last().unwrap()) + .collect::>(); Expr::and_iter( columns_1 - .iter() - .filter_map(|col| columns_2.contains(&col).then_some(col)) - .map(|id| Expr::eq( - Expr::Column(Identifier::from(vec![LEFT_INPUT_NAME.to_string(), id.to_string()])), - Expr::Column(Identifier::from(vec![RIGHT_INPUT_NAME.to_string(), id.to_string()])) - )) + .iter() + .filter_map(|col| columns_2.contains(&col).then_some(col)) + .map(|id| { + Expr::eq( + Expr::Column(Identifier::from(vec![ + LEFT_INPUT_NAME.to_string(), + id.to_string(), + ])), + Expr::Column(Identifier::from(vec![ + RIGHT_INPUT_NAME.to_string(), + id.to_string(), + ])), + ) + }), ) - }, + } ast::JoinConstraint::None => todo!(), }) } @@ -217,7 +274,7 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, ) -> Result { match join_operator { ast::JoinOperator::Inner(join_constraint) => Ok(JoinOperator::Inner( - self.try_from_join_constraint_with_columns(join_constraint, columns)? + self.try_from_join_constraint_with_columns(join_constraint, columns)?, )), ast::JoinOperator::LeftOuter(join_constraint) => Ok(JoinOperator::LeftOuter( self.try_from_join_constraint_with_columns(join_constraint, columns)?, @@ -237,7 +294,7 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, fn try_from_join( &self, left: RelationWithColumns, - ast_join: &'a ast::Join + ast_join: &'a ast::Join, ) -> Result { let RelationWithColumns(left_relation, left_columns) = left; let RelationWithColumns(right_relation, right_columns) = @@ -252,22 +309,18 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, v.extend(i.to_vec()); v.into() }); - // fully qualified input names -> fully qualified JOIN names + // fully qualified input names -> fully qualified JOIN names let all_columns: Hierarchy = left_columns.with(right_columns); - let operator = self.try_from_join_operator_with_columns( - &ast_join.join_operator, - &all_columns, - )?; + let operator = + self.try_from_join_operator_with_columns(&ast_join.join_operator, &all_columns)?; let join: Join = Relation::join() .operator(operator) .left(left_relation) .right(right_relation) .build(); - let join_columns: Hierarchy = join - .field_inputs() - .map(|(f, i)| (i, f.into())) - .collect(); + let join_columns: Hierarchy = + join.field_inputs().map(|(f, i)| (i, f.into())).collect(); // If the join constraint is of type "USING" or "NATURAL", add a map to coalesce the duplicate columns let (relation, coalesced) = match &ast_join.join_operator { @@ -276,19 +329,28 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, | ast::JoinOperator::RightOuter(ast::JoinConstraint::Using(v)) | ast::JoinOperator::FullOuter(ast::JoinConstraint::Using(v)) => { // Do we need to change all_columns? - let to_be_coalesced: Vec = v.into_iter().map(|id| id.value.to_string()).collect(); + let to_be_coalesced: Vec = + v.into_iter().map(|id| id.value.to_string()).collect(); join.remove_duplicates_and_coalesce(to_be_coalesced, &join_columns) - }, + } ast::JoinOperator::Inner(ast::JoinConstraint::Natural) | ast::JoinOperator::LeftOuter(ast::JoinConstraint::Natural) | ast::JoinOperator::RightOuter(ast::JoinConstraint::Natural) | ast::JoinOperator::FullOuter(ast::JoinConstraint::Natural) => { - let v: Vec = join.left().fields() + let v: Vec = join + .left() + .fields() .into_iter() - .filter_map(|f| join.right().schema().field(f.name()).is_ok().then_some(f.name().to_string())) + .filter_map(|f| { + join.right() + .schema() + .field(f.name()) + .is_ok() + .then_some(f.name().to_string()) + }) .collect(); join.remove_duplicates_and_coalesce(v, &join_columns) - }, + } ast::JoinOperator::LeftSemi(_) => todo!(), ast::JoinOperator::RightSemi(_) => todo!(), ast::JoinOperator::LeftAnti(_) => todo!(), @@ -310,11 +372,9 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, ) -> Result { // Process the relation // Then the JOIN if needed - let result = table_with_joins.joins - .iter() - .fold(self.try_from_table_factor(&table_with_joins.relation), - |left, ast_join| - self.try_from_join(left?, &ast_join), + let result = table_with_joins.joins.iter().fold( + self.try_from_table_factor(&table_with_joins.relation), + |left, ast_join| self.try_from_join(left?, &ast_join), ); result } @@ -327,9 +387,7 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, // TODO consider more tables // For now, only consider the first element // It should eventually be cross joined as described in: https://www.postgresql.org/docs/current/queries-table-expressions.html - self.try_from_table_with_joins( - &tables_with_joins[0] - ) + self.try_from_table_with_joins(&tables_with_joins[0]) } /// Extracts named expressions from the from relation and the select items @@ -339,21 +397,18 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, select_items: &'a [ast::SelectItem], from: &'a Arc, ) -> Result<(Vec<(String, Expr)>, Hierarchy)> { - let mut named_exprs: Vec<(String, Expr)> = vec![]; - + // It stores the update for the column mapping: // (old name in columns, new name forced by the select) let mut renamed_columns: Vec<(Identifier, Identifier)> = vec![]; - + for select_item in select_items { match select_item { ast::SelectItem::UnnamedExpr(expr) => { // Pull the original name for implicit aliasing let implicit_alias = match expr { - ast::Expr::Identifier(ident) => { - lower_case_unquoted_ident(ident) - }, + ast::Expr::Identifier(ident) => lower_case_unquoted_ident(ident), ast::Expr::CompoundIdentifier(idents) => { let ident = idents.last().unwrap(); lower_case_unquoted_ident(ident) @@ -364,15 +419,18 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, if let Some(name) = columns.get(&implicit_alias_ident) { renamed_columns.push((name.clone(), implicit_alias_ident)); }; - named_exprs.push((implicit_alias, self.translator.try_expr(expr,columns)?)) - }, + named_exprs.push((implicit_alias, self.translator.try_expr(expr, columns)?)) + } ast::SelectItem::ExprWithAlias { expr, alias } => { let alias_ident = Identifier::from_name(alias.clone().value); if let Some(name) = columns.get(&alias_ident) { renamed_columns.push((name.clone(), alias_ident)); }; - named_exprs.push((alias.clone().value, self.translator.try_expr(expr,columns)?)) - }, + named_exprs.push(( + alias.clone().value, + self.translator.try_expr(expr, columns)?, + )) + } ast::SelectItem::QualifiedWildcard(_, _) => todo!(), ast::SelectItem::Wildcard(_) => { // push all names that are present in the from into named_exprs. @@ -380,18 +438,19 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, // for the ambiguous ones used the name present in the relation. let non_ambiguous_cols = last(columns); // Invert mapping of non_ambiguous_cols - let new_aliases: Hierarchy = non_ambiguous_cols.iter() - .map(|(p, i)|(i.deref(), p.last().unwrap().clone())) + let new_aliases: Hierarchy = non_ambiguous_cols + .iter() + .map(|(p, i)| (i.deref(), p.last().unwrap().clone())) .collect(); - + for field in from.schema().iter() { let field_name = field.name().to_string(); let alias = new_aliases .get_key_value(&[field.name().to_string()]) - .and_then(|(k, v)|{ + .and_then(|(k, v)| { renamed_columns.push((k.to_vec().into(), v.clone().into())); Some(v.clone()) - } ); + }); named_exprs.push((alias.unwrap_or(field_name), Expr::col(field.name()))); } } @@ -416,11 +475,12 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, // Columns from names let columns = &names.map(|s| s.clone().into()); - let (named_expr_from_select, new_columns) = self.try_named_expr_columns_from_select_items(columns, select_items, &from)?; + let (named_expr_from_select, new_columns) = + self.try_named_expr_columns_from_select_items(columns, select_items, &from)?; named_exprs.extend(named_expr_from_select.into_iter()); // Prepare the GROUP BY - let group_by = match group_by { + let group_by = match group_by { ast::GroupByExpr::All => todo!(), ast::GroupByExpr::Expressions(group_by_exprs) => group_by_exprs .iter() @@ -430,7 +490,8 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, // If the GROUP BY contains aliases, then replace them by the corresponding expression in `named_exprs`. // Note that we mimic postgres behavior and support only GROUP BY alias column (no other expressions containing aliases are allowed) // The aliases cannot be used in HAVING - let group_by = group_by.into_iter() + let group_by = group_by + .into_iter() .map(|x| match &x { Expr::Column(c) if columns.get_key_value(&c).is_none() && c.len() == 1 => { named_exprs @@ -438,14 +499,14 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, .find(|&(name, _)| name == &c[0]) .map(|(_, expr)| expr.clone()) .unwrap_or(x) - }, - _ => x + } + _ => x, }) .collect::>(); // Add the having in named_exprs let having = if let Some(expr) = having { let having_name = namer::name_from_content(FIELD, &expr); - let mut expr = self.translator.try_expr(expr,columns)?; + let mut expr = self.translator.try_expr(expr, columns)?; let columns = named_exprs .iter() .map(|(s, x)| (Expr::col(s.to_string()), x.clone())) @@ -468,19 +529,20 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, let split = if group_by.is_empty() { Split::from_iter(named_exprs) } else { - let group_by = group_by.clone().into_iter() - .fold(Split::Reduce(Reduce::default()), - |s, expr| s.and(Split::Reduce(Split::group_by(expr))) - ); - named_exprs.into_iter() - .fold(group_by, - |s, named_expr| s.and(named_expr.into()) - ) + let group_by = group_by + .clone() + .into_iter() + .fold(Split::Reduce(Reduce::default()), |s, expr| { + s.and(Split::Reduce(Split::group_by(expr))) + }); + named_exprs + .into_iter() + .fold(group_by, |s, named_expr| s.and(named_expr.into())) }; // Prepare the WHERE let filter: Option = selection .as_ref() - // todo. Use pass the expression through the translator + // todo. Use pass the expression through the translator .map(|e| self.translator.try_expr(e, columns)) .map_or(Ok(None), |r| r.map(Some))?; @@ -520,7 +582,10 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, } // preserve old columns while composing with new ones let columns = &columns.clone().with(columns.and_then(new_columns)); - Ok(RelationWithColumns::new(Arc::new(relation), columns.clone())) + Ok(RelationWithColumns::new( + Arc::new(relation), + columns.clone(), + )) } /// Convert a Select into a Relation @@ -540,9 +605,9 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, having, named_window, qualify, - window_before_qualify , - value_table_mode , - connect_by + window_before_qualify, + value_table_mode, + connect_by, } = select; if top.is_some() { return Err(Error::other("TOP is not supported")); @@ -569,18 +634,17 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, return Err(Error::other("QUALIFY is not supported")); } - let RelationWithColumns(from, columns) = self.try_from_tables_with_joins( - from - )?; - let RelationWithColumns(relation, columns) = self.try_from_select_items_selection_and_group_by( - &columns.filter_map(|i| Some(i.split_last().ok()?.0)), - projection, - selection, - group_by, - from, - having, - distinct - )?; + let RelationWithColumns(from, columns) = self.try_from_tables_with_joins(from)?; + let RelationWithColumns(relation, columns) = self + .try_from_select_items_selection_and_group_by( + &columns.filter_map(|i| Some(i.split_last().ok()?.0)), + projection, + selection, + group_by, + from, + having, + distinct, + )?; Ok(RelationWithColumns::new(relation, columns)) } @@ -642,9 +706,9 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, }); // Add OFFSET let relation_builder: Result> = - offset.iter().fold(relation_builder, |builder, offset| { - Ok(builder?.offset(self.try_from_offset(offset)?)) - }); + offset.iter().fold(relation_builder, |builder, offset| { + Ok(builder?.offset(self.try_from_offset(offset)?)) + }); // Build a relation with ORDER BY and LIMIT Ok(Arc::new(relation_builder?.try_build()?)) } @@ -675,9 +739,15 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, } } -impl<'a, T:QueryToRelationTranslator + Copy + Clone> Visitor<'a, Result>> for TryIntoRelationVisitor<'a, T> { +impl<'a, T: QueryToRelationTranslator + Copy + Clone> Visitor<'a, Result>> + for TryIntoRelationVisitor<'a, T> +{ fn dependencies(&self, acceptor: &'a ast::Query) -> Dependencies<'a, ast::Query> { - let TryIntoRelationVisitor{relations, query_names, translator} = self; + let TryIntoRelationVisitor { + relations, + query_names, + translator, + } = self; let mut dependencies = acceptor.dependencies(); // Add subqueries from the body dependencies.extend( @@ -745,12 +815,18 @@ impl<'a> TryFrom> for Relation { let query_names = query.accept(IntoQueryNamesVisitor); // Visit for conversion query - .accept(TryIntoRelationVisitor::new(relations, query_names, PostgreSqlTranslator)) + .accept(TryIntoRelationVisitor::new( + relations, + query_names, + PostgreSqlTranslator, + )) .map(|r| r.as_ref().clone()) } } -impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(QueryWithRelations<'a>, T)> for Relation { +impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(QueryWithRelations<'a>, T)> + for Relation +{ type Error = Error; fn try_from(value: (QueryWithRelations<'a>, T)) -> result::Result { @@ -760,28 +836,32 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(QueryWithRelation let query_names = query.accept(IntoQueryNamesVisitor); // Visit for conversion query - .accept(TryIntoRelationVisitor::new(relations, query_names, translator)) + .accept(TryIntoRelationVisitor::new( + relations, + query_names, + translator, + )) .map(|r| r.as_ref().clone()) } } /// It creates a new hierarchy with Identifier for which the last part of their -/// path is not ambiguous. The new hierarchy will contain one-element paths +/// path is not ambiguous. The new hierarchy will contain one-element paths fn last(columns: &Hierarchy) -> Hierarchy { columns - .iter() - .filter_map(|(path, _)|{ - let path_last = path.last().unwrap().clone(); - columns - .get(&[path_last.clone()]) - .and_then( |t| Some((path_last, t.clone())) ) - }) - .collect() + .iter() + .filter_map(|(path, _)| { + let path_last = path.last().unwrap().clone(); + columns + .get(&[path_last.clone()]) + .and_then(|t| Some((path_last, t.clone()))) + }) + .collect() } /// Returns the identifier value. If it is quoted it returns its value /// as it is whereas if unquoted it returns the lowercase value. -/// Used to create relations field's name. +/// Used to create relations field's name. fn lower_case_unquoted_ident(ident: &ast::Ident) -> String { if let Some(_) = ident.quote_style { ident.value.clone() @@ -790,7 +870,6 @@ fn lower_case_unquoted_ident(ident: &ast::Ident) -> String { } } - /// A simple SQL query parser with dialect pub fn parse_with_dialect(query: &str, dialect: D) -> Result { let mut tokenizer = Tokenizer::new(&dialect, query); @@ -816,8 +895,8 @@ mod tests { builder::Ready, data_type::{DataType, DataTyped, Variant}, display::Dot, + io::{postgresql, Database}, relation::schema::Schema, - io::{Database, postgresql} }; #[test] @@ -1451,11 +1530,7 @@ mod tests { let query_str = "SELECT 3*d, COUNT(*) AS my_count FROM table_1 GROUP BY 3*d;"; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); println!("relation = {relation}"); assert_eq!( @@ -1474,7 +1549,6 @@ mod tests { .map(ToString::to_string); } - #[test] fn test_order_by() { let mut database = postgresql::test_database(); @@ -1483,11 +1557,7 @@ mod tests { SELECT * FROM user_table u ORDER BY u.city, u.id "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1501,11 +1571,7 @@ mod tests { SELECT * FROM order_table o JOIN user_table u ON (o.id=u.id) ORDER BY city "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1519,11 +1585,7 @@ mod tests { SELECT * FROM order_table o JOIN user_table u ON (o.id=u.id) ORDER BY o.id "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1537,11 +1599,7 @@ mod tests { SELECT city, SUM(o.id) FROM order_table o JOIN user_table u ON (o.id=u.id) GROUP BY city ORDER BY city "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1550,16 +1608,12 @@ mod tests { .unwrap() .iter() .map(ToString::to_string); - + let query_str = r#" SELECT city AS mycity, SUM(o.id) AS mysum FROM order_table o JOIN user_table u ON (o.id=u.id) GROUP BY mycity ORDER BY mycity, mysum "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1573,11 +1627,7 @@ mod tests { SELECT city AS date FROM order_table o JOIN user_table u ON (o.id=u.id) GROUP BY u.city ORDER BY date "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1600,11 +1650,7 @@ mod tests { ORDER BY x, t2.y, t2.z "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1619,11 +1665,7 @@ mod tests { SELECT * FROM my_tab WHERE id > 50; "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1638,11 +1680,7 @@ mod tests { SELECT * FROM my_tab WHERE id > 50; "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); @@ -1657,11 +1695,7 @@ mod tests { SELECT * FROM my_tab WHERE user_id > 50; "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); // id becomes an ambiguous column since is present in both tables assert!(relation.schema().field("id").is_err()); relation.display_dot().unwrap(); @@ -1679,11 +1713,7 @@ mod tests { SELECT * FROM my_tab WHERE user_id > 50; "#; let query = parse(query_str).unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); println!("relation = {relation}"); let query: &str = &ast::Query::from(&relation).to_string(); @@ -1712,7 +1742,7 @@ mod tests { let relation = Relation::try_from(QueryWithRelations::new( &query, &Hierarchy::from([(["schema", "table_1"], Arc::new(table_1))]), - )) + )) .unwrap(); relation.display_dot().unwrap(); println!("relation = {relation}"); @@ -1734,19 +1764,21 @@ mod tests { vec![ ("a", DataType::integer_interval(0, 10)), ("b", DataType::float_interval(20., 50.)), - ].into_iter() - .collect::() + ] + .into_iter() + .collect::(), ) .size(100) .build(); - let table_2: Relation = Relation::table() + let table_2: Relation = Relation::table() .name("table_2") .schema( vec![ ("a", DataType::integer_interval(-5, 5)), ("c", DataType::float()), - ].into_iter() - .collect::() + ] + .into_iter() + .collect::(), ) .size(100) .build(); @@ -1757,11 +1789,7 @@ mod tests { // INNER JOIN let query = parse("SELECT * FROM table_1 INNER JOIN table_2 USING (a)").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { @@ -1772,11 +1800,7 @@ mod tests { // LEFT JOIN let query = parse("SELECT * FROM table_1 LEFT JOIN table_2 USING (a)").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { @@ -1787,31 +1811,32 @@ mod tests { // RIGHT JOIN let query = parse("SELECT * FROM table_1 RIGHT JOIN table_2 USING (a)").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { assert_eq!(s[0], Arc::new(DataType::integer_interval(-5, 5))); - assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!( + s[1], + Arc::new(DataType::optional(DataType::float_interval(20., 50.))) + ); assert_eq!(s[2], Arc::new(DataType::float())); } // FULL JOIN let query = parse("SELECT * FROM table_1 FULL JOIN table_2 USING (a)").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { - assert_eq!(s[0], Arc::new(DataType::optional(DataType::integer_interval(-5, 10)))); - assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!( + s[0], + Arc::new(DataType::optional(DataType::integer_interval(-5, 10))) + ); + assert_eq!( + s[1], + Arc::new(DataType::optional(DataType::float_interval(20., 50.))) + ); assert_eq!(s[2], Arc::new(DataType::optional(DataType::float()))); } } @@ -1826,20 +1851,22 @@ mod tests { ("a", DataType::integer_interval(0, 10)), ("b", DataType::float_interval(20., 50.)), ("d", DataType::float_interval(-10., 50.)), - ].into_iter() - .collect::() + ] + .into_iter() + .collect::(), ) .size(100) .build(); - let table_2: Relation = Relation::table() + let table_2: Relation = Relation::table() .name("table_2") .schema( vec![ ("a", DataType::integer_interval(-5, 5)), ("c", DataType::float()), ("d", DataType::float_interval(10., 100.)), - ].into_iter() - .collect::() + ] + .into_iter() + .collect::(), ) .size(100) .build(); @@ -1850,11 +1877,7 @@ mod tests { // INNER JOIN let query = parse("SELECT * FROM table_1 NATURAL INNER JOIN table_2").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { @@ -1866,11 +1889,7 @@ mod tests { // LEFT JOIN let query = parse("SELECT * FROM table_1 NATURAL LEFT JOIN table_2").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { @@ -1882,34 +1901,37 @@ mod tests { // RIGHT JOIN let query = parse("SELECT * FROM table_1 NATURAL RIGHT JOIN table_2").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { assert_eq!(s[0], Arc::new(DataType::integer_interval(-5, 5))); assert_eq!(s[1], Arc::new(DataType::float_interval(10., 100.))); - assert_eq!(s[2], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!( + s[2], + Arc::new(DataType::optional(DataType::float_interval(20., 50.))) + ); assert_eq!(s[3], Arc::new(DataType::float())); - } // FULL JOIN let query = parse("SELECT * FROM table_1 NATURAL FULL JOIN table_2").unwrap(); - let relation = Relation::try_from(QueryWithRelations::new( - &query, - &relations, - )) - .unwrap(); + let relation = Relation::try_from(QueryWithRelations::new(&query, &relations)).unwrap(); relation.display_dot().unwrap(); assert!(matches!(relation.data_type(), DataType::Struct(_))); if let DataType::Struct(s) = relation.data_type() { - assert_eq!(s[0], Arc::new(DataType::optional(DataType::integer_interval(-5, 10)))); - assert_eq!(s[1], Arc::new(DataType::optional(DataType::float_interval(-10., 100.)))); - assert_eq!(s[2], Arc::new(DataType::optional(DataType::float_interval(20., 50.)))); + assert_eq!( + s[0], + Arc::new(DataType::optional(DataType::integer_interval(-5, 10))) + ); + assert_eq!( + s[1], + Arc::new(DataType::optional(DataType::float_interval(-10., 100.))) + ); + assert_eq!( + s[2], + Arc::new(DataType::optional(DataType::float_interval(20., 50.))) + ); assert_eq!(s[3], Arc::new(DataType::optional(DataType::float()))); } } From 4bb4bbf6e4c16a4bc2270a6d627b55e4ec168d46 Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 16:20:32 +0200 Subject: [PATCH 2/8] wip --- src/data_type/function.rs | 10 +- src/data_type/mod.rs | 6 +- src/dialect_translation/bigquery.rs | 4 +- src/dialect_translation/mod.rs | 16 ++-- src/dialect_translation/mssql.rs | 7 +- src/dialect_translation/postgresql.rs | 30 ++---- src/dialect_translation/sqlite.rs | 18 +--- src/differential_privacy/aggregates.rs | 7 +- src/differential_privacy/group_by.rs | 2 +- src/differential_privacy/mod.rs | 6 +- src/display/dot.rs | 4 +- src/display/mod.rs | 6 +- src/expr/bijection.rs | 4 +- src/expr/dot.rs | 3 +- src/expr/implementation.rs | 2 +- src/expr/rewriting.rs | 6 +- src/expr/sql.rs | 2 +- src/io/mod.rs | 2 +- src/io/postgresql.rs | 4 +- src/privacy_unit_tracking/mod.rs | 6 +- src/relation/builder.rs | 7 +- src/relation/mod.rs | 6 +- src/relation/rewriting.rs | 9 +- src/relation/schema.rs | 6 +- src/relation/sql.rs | 7 +- src/rewriting/dot.rs | 14 ++- src/rewriting/mod.rs | 2 +- src/rewriting/rewriting_rule.rs | 46 ++++----- src/sampling_adjustment/mod.rs | 14 ++- src/sql/expr.rs | 128 ++++++++++++------------- src/sql/mod.rs | 11 +-- src/sql/relation.rs | 18 ++-- src/sql/visitor.rs | 4 +- src/synthetic_data/mod.rs | 8 +- tests/integration.rs | 4 +- 35 files changed, 197 insertions(+), 232 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 31c5d208..821e42ed 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -708,7 +708,7 @@ impl Function for Optional { .map(|dt| DataType::optional(dt)), set => self.0.super_image(&set), } - .or_else(|err| Ok(self.co_domain())) + .or_else(|_err| Ok(self.co_domain())) } fn value(&self, arg: &Value) -> Result { @@ -719,7 +719,7 @@ impl Function for Optional { }, arg => self.0.value(arg), } - .or_else(|err| Ok(Value::none())) + .or_else(|_err| Ok(Value::none())) } } @@ -1427,11 +1427,11 @@ pub fn md5() -> impl Function { ) } -pub fn random(mut rng: Mutex) -> impl Function { +pub fn random(rng: Mutex) -> impl Function { Unimplemented::new( DataType::unit(), DataType::float_interval(0., 1.), - Arc::new(Mutex::new(RefCell::new(move |v| { + Arc::new(Mutex::new(RefCell::new(move |_v| { rng.lock().unwrap().borrow_mut().gen::().into() }))), ) @@ -3155,7 +3155,7 @@ mod tests { assert_eq!( optional_greatest .super_image(&DataType::optional( - (DataType::float_interval(0., 1.) & DataType::float_interval(-5., 2.)) + DataType::float_interval(0., 1.) & DataType::float_interval(-5., 2.) )) .unwrap(), optional_greatest diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index f682936d..5f0b0d4d 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -1724,7 +1724,7 @@ impl Variant for List { fn try_empty(&self) -> Result { Ok(Self::new( - self.data_type().deref().try_empty()?.into(), + self.data_type().try_empty()?.into(), 0.into(), )) } @@ -1844,7 +1844,7 @@ impl Variant for Set { fn try_empty(&self) -> Result { Ok(Self::new( - self.data_type().deref().try_empty()?.into(), + self.data_type().try_empty()?.into(), 0.into(), )) } @@ -1957,7 +1957,7 @@ impl Variant for Array { fn try_empty(&self) -> Result { Ok(Self::new( - self.data_type().deref().try_empty()?.into(), + self.data_type().try_empty()?.into(), Arc::new([0 as usize]), )) } diff --git a/src/dialect_translation/bigquery.rs b/src/dialect_translation/bigquery.rs index bfe23005..0744dad7 100644 --- a/src/dialect_translation/bigquery.rs +++ b/src/dialect_translation/bigquery.rs @@ -113,11 +113,11 @@ mod tests { use super::*; use crate::{ builder::{Ready, With}, - data_type::{DataType, Value as _}, + data_type::{DataType}, dialect_translation::RelationWithTranslator, expr::Expr, namer, - relation::{schema::Schema, Relation, Variant as _}, + relation::{schema::Schema, Relation}, }; use std::sync::Arc; diff --git a/src/dialect_translation/mod.rs b/src/dialect_translation/mod.rs index 7f5dac20..e2199f93 100644 --- a/src/dialect_translation/mod.rs +++ b/src/dialect_translation/mod.rs @@ -2,26 +2,24 @@ //! A specific Dialect is a struct holding: //! - a method to provide a sqlparser::Dialect for the parsing //! - methods varying from dialect to dialect regarding the conversion from AST to Expr+Relation and vice-versa -use std::{iter::once, ops::Deref}; + use sqlparser::{ ast, - dialect::{BigQueryDialect, Dialect, PostgreSqlDialect}, + dialect::{Dialect}, }; use crate::{ - data_type::function::cast, expr::{self, Function}, - relation::{self, sql::FromRelationVisitor}, - visitor::Acceptor, - WithContext, WithoutContext, + relation::sql::FromRelationVisitor, + visitor::Acceptor, WithoutContext, }; use crate::{ data_type::DataTyped, expr::Identifier, hierarchy::Hierarchy, relation::{Join, JoinOperator, Table, Variant}, - sql::{self, parse, parse_with_dialect, Error, Result}, + sql::Result, DataType, Relation, }; @@ -275,7 +273,7 @@ macro_rules! relation_to_query_tranlator_trait_constructor { materialized: None, } } - fn join_projection(&self, join: &Join) -> Vec { + fn join_projection(&self, _join: &Join) -> Vec { vec![ast::SelectItem::Wildcard( ast::WildcardAdditionalOptions::default(), )] @@ -813,7 +811,7 @@ pub trait QueryToRelationTranslator { ) -> Result { match func_arg_expr { ast::FunctionArgExpr::Expr(e) => self.try_expr(e, context), - ast::FunctionArgExpr::QualifiedWildcard(o) => todo!(), + ast::FunctionArgExpr::QualifiedWildcard(_o) => todo!(), ast::FunctionArgExpr::Wildcard => todo!(), } } diff --git a/src/dialect_translation/mssql.rs b/src/dialect_translation/mssql.rs index 28bb23cb..234431af 100644 --- a/src/dialect_translation/mssql.rs +++ b/src/dialect_translation/mssql.rs @@ -1,15 +1,14 @@ use crate::{ data_type::{DataType, DataTyped as _}, - expr::{self, Function as _}, + expr::{self}, hierarchy::Hierarchy, - relation::{sql::FromRelationVisitor, Relation, Table, Variant as _}, - visitor::Acceptor, + relation::{Table, Variant as _}, WithoutContext, }; use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator, Result}; use sqlparser::{ - ast::{self, CharacterLength}, + ast::{self}, dialect::MsSqlDialect, }; #[derive(Clone, Copy)] diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index c7e7310a..669719bc 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -1,20 +1,13 @@ -use std::sync::Arc; - -use crate::{ - expr, - hierarchy::Hierarchy, - relation::sql::FromRelationVisitor, - sql::{parse_with_dialect, query_names::IntoQueryNamesVisitor}, - visitor::Acceptor, - Relation, -}; + + +use crate::expr; use super::{ - function_builder, QueryToRelationTranslator, RelationToQueryTranslator, RelationWithTranslator, + function_builder, QueryToRelationTranslator, RelationToQueryTranslator, }; use sqlparser::{ast, dialect::PostgreSqlDialect}; -use crate::sql::{Error, Result}; + #[derive(Clone, Copy)] pub struct PostgreSqlTranslator; @@ -103,18 +96,15 @@ impl QueryToRelationTranslator for PostgreSqlTranslator { #[cfg(test)] mod tests { - use sqlparser::dialect; + use super::*; use crate::{ - builder::{Ready, With}, - data_type::{DataType, Value as _}, - display::Dot, - expr::Expr, + builder::Ready, + data_type::DataType, io::{postgresql, Database as _}, - namer, - relation::{schema::Schema, Relation, TableBuilder}, - sql::{parse, relation::QueryWithRelations}, + relation::{schema::Schema, Relation}, + sql::relation::QueryWithRelations, }; use std::sync::Arc; diff --git a/src/dialect_translation/sqlite.rs b/src/dialect_translation/sqlite.rs index dcee05e7..6c138fc6 100644 --- a/src/dialect_translation/sqlite.rs +++ b/src/dialect_translation/sqlite.rs @@ -1,7 +1,7 @@ -use crate::{relation::sql::FromRelationVisitor, visitor::Acceptor, Relation}; + use super::RelationToQueryTranslator; -use sqlparser::{ast, dialect::SQLiteDialect}; + #[derive(Clone, Copy)] pub struct SQLiteTranslator; @@ -9,15 +9,7 @@ impl RelationToQueryTranslator for SQLiteTranslator {} #[cfg(test)] mod tests { - use super::*; - use crate::{ - builder::{Ready, With}, - data_type::{DataType, Value as _}, - display::Dot, - expr::Expr, - namer, - relation::{schema::Schema, Relation}, - sql::parse, - }; - use std::sync::Arc; + + + } diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index f81fc589..039305a7 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -8,8 +8,7 @@ use crate::{ AggregateColumn, Column, Expr, Identifier, }, privacy_unit_tracking::PupRelation, - relation::{field::Field, Map, Reduce, Relation, Variant}, - DataType, Ready, + relation::{Map, Reduce, Relation, Variant}, Ready, }; use std::{cmp, collections::HashMap, ops::Deref}; @@ -520,7 +519,7 @@ impl Reduce { .build() } else { builder - .group_by_iter(self.group_by().clone().to_vec()) + .group_by_iter(self.group_by().to_vec()) .with_iter(aggs) .build() } @@ -538,7 +537,7 @@ mod tests { io::{postgresql, Database}, privacy_unit_tracking::PrivacyUnit, privacy_unit_tracking::{PrivacyUnitTracking, Strategy}, - relation::{Constraint, Schema, Variant as _}, + relation::{Schema, Variant as _}, sql::parse, Relation, }; diff --git a/src/differential_privacy/group_by.rs b/src/differential_privacy/group_by.rs index f409ac4d..f3d6b4f5 100644 --- a/src/differential_privacy/group_by.rs +++ b/src/differential_privacy/group_by.rs @@ -204,7 +204,7 @@ mod tests { use crate::{ ast, builder::With, - data_type::{DataType, DataTyped, Integer, Variant}, + data_type::{DataType, DataTyped, Variant}, display::Dot, expr::AggregateColumn, io::{postgresql, Database}, diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 6e9dcb82..03878117 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -11,8 +11,8 @@ pub mod group_by; use crate::{ builder::With, expr, - privacy_unit_tracking::{self, privacy_unit, PupRelation}, - relation::{rewriting, Constraint, Reduce, Relation, Variant}, + privacy_unit_tracking::{self, PupRelation}, + relation::{rewriting, Reduce, Relation, Variant}, Ready, }; use std::{error, fmt, ops::Deref, result}; @@ -223,7 +223,7 @@ mod tests { let (dp_relation, dp_event) = reduce.differentially_private(¶meters).unwrap().into(); dp_relation.display_dot().unwrap(); - let mult: f64 = 2000. + let _mult: f64 = 2000. * DpAggregatesParameters::from_dp_parameters(parameters.clone(), 1.) .privacy_unit_multiplicity(); assert!(matches!( diff --git a/src/display/dot.rs b/src/display/dot.rs index 6d2b4847..b36df019 100644 --- a/src/display/dot.rs +++ b/src/display/dot.rs @@ -59,7 +59,7 @@ pub fn render< )?; } for n in g.nodes().iter() { - let mut colorstring; + let colorstring; indent(w)?; let id = g.node_id(n); @@ -115,7 +115,7 @@ pub fn render< } for e in g.edges().iter() { - let mut colorstring; + let colorstring; let escaped_label = &g.edge_label(e).to_dot_string(); let start_arrow = g.edge_start_arrow(e); let end_arrow = g.edge_end_arrow(e); diff --git a/src/display/mod.rs b/src/display/mod.rs index 72b70bc6..292438f0 100644 --- a/src/display/mod.rs +++ b/src/display/mod.rs @@ -7,8 +7,7 @@ pub mod colors; pub mod dot; use crate::{ - builder::{WithContext, WithoutContext}, - data_type::DataTyped, + builder::WithContext, namer, rewriting::{RelationWithRewritingRule, RelationWithRewritingRules}, DataType, Expr, Relation, Value, @@ -17,7 +16,6 @@ use std::{ fs::File, io::{Result, Write}, process::Command, - sync::Arc, }; pub trait Dot { @@ -188,7 +186,7 @@ mod tests { builder::{Ready, With}, data_type::DataType, expr::Expr, - relation::{schema::Schema, Relation}, + relation::{schema::Schema, Relation}, WithoutContext as _, }; #[test] diff --git a/src/expr/bijection.rs b/src/expr/bijection.rs index 0dbd4750..f88d0ffd 100644 --- a/src/expr/bijection.rs +++ b/src/expr/bijection.rs @@ -1,4 +1,4 @@ -use super::{function, identifier, Column, Expr, Function}; +use super::{Column, Expr, Function}; impl Expr { /// Reduce the expression modulo a bijection @@ -17,7 +17,7 @@ impl Expr { self } } - expr => self, + _expr => self, } } diff --git a/src/expr/dot.rs b/src/expr/dot.rs index 6a992aec..ecb5112d 100644 --- a/src/expr/dot.rs +++ b/src/expr/dot.rs @@ -4,7 +4,6 @@ use std::{fmt, io, string}; use super::{aggregate, function, Column, Error, Expr, Value, Visitor}; use crate::{ - builder::{WithContext as _, WithoutContext as _}, data_type::{DataType, DataTyped}, display::{self, colors}, namer, @@ -202,7 +201,7 @@ mod tests { builder::{Ready, With}, data_type::DataType, display::Dot, - relation::{schema::Schema, Relation}, + relation::{schema::Schema, Relation}, WithoutContext as _, }; use std::sync::Arc; diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 2020b170..6379dc88 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -140,7 +140,7 @@ function_implementations!( Function::CastAsDate => Arc::new(Optional::new(function::cast(DataType::date()))), Function::CastAsTime => Arc::new(Optional::new(function::cast(DataType::time()))), Function::Concat(n) => Arc::new(function::concat(n)), - Function::Random(n) => Arc::new(function::random(Mutex::new(OsRng))), //TODO change this initialization + Function::Random(_n) => Arc::new(function::random(Mutex::new(OsRng))), //TODO change this initialization Function::Coalesce => Arc::new(function::coalesce()), _ => unreachable!(), } diff --git a/src/expr/rewriting.rs b/src/expr/rewriting.rs index 04e74493..2bcfc717 100644 --- a/src/expr/rewriting.rs +++ b/src/expr/rewriting.rs @@ -1,5 +1,5 @@ use crate::{ - expr::{Expr, Variant as _}, + expr::{Expr}, namer, }; use std::f64::consts::PI; @@ -31,8 +31,8 @@ impl Expr { mod tests { use super::*; use crate::{ - builder::{With, WithoutContext}, - data_type::{function::Function as _, value::Value, DataType}, + builder::WithoutContext, + data_type::{function::Function as _, value::Value}, display::Dot, }; diff --git a/src/expr/sql.rs b/src/expr/sql.rs index 17e1b40b..c6211eec 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -1,7 +1,7 @@ //! Convert Expr into ast::Expr use crate::{ ast, - data_type::{Boolean, DataType}, + data_type::{DataType}, expr::{self, Expr}, visitor::Acceptor, }; diff --git a/src/io/mod.rs b/src/io/mod.rs index e2a95621..66708b45 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -301,7 +301,7 @@ mod tests { #[test] fn test_relation_hierarchy() -> Result<()> { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); println!("{}", database.relations()); Ok(()) } diff --git a/src/io/postgresql.rs b/src/io/postgresql.rs index 72328c90..201cefef 100644 --- a/src/io/postgresql.rs +++ b/src/io/postgresql.rs @@ -13,7 +13,7 @@ use crate::{ relation::{Table, Variant as _}, }; use std::{ - env, fmt, ops::Deref, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time, + env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time, }; use colored::Colorize; @@ -204,7 +204,7 @@ impl DatabaseTrait for Database { fn create_table(&mut self, table: &Table) -> Result { let mut connection = self.pool.get()?; - let qq = table.create(PostgreSqlTranslator).to_string(); + let _qq = table.create(PostgreSqlTranslator).to_string(); Ok(connection.execute(&table.create(PostgreSqlTranslator).to_string(), &[])? as usize) } diff --git a/src/privacy_unit_tracking/mod.rs b/src/privacy_unit_tracking/mod.rs index 0745485d..6c3fee45 100644 --- a/src/privacy_unit_tracking/mod.rs +++ b/src/privacy_unit_tracking/mod.rs @@ -337,7 +337,7 @@ impl<'a> PrivacyUnitTracking<'a> { match self.strategy { Strategy::Soft => Err(Error::not_privacy_unit_preserving(join)), Strategy::Hard => { - let name = join.name(); + let _name = join.name(); let operator = join.operator().clone(); let names = join.names(); let names = names.with(vec![ @@ -406,7 +406,7 @@ impl<'a> PrivacyUnitTracking<'a> { left: Relation, right: PupRelation, ) -> Result { - let name = join.name(); + let _name = join.name(); let operator = join.operator().clone(); let names = join.names(); let names = names.with(vec![ @@ -458,7 +458,7 @@ impl<'a> PrivacyUnitTracking<'a> { left: PupRelation, right: Relation, ) -> Result { - let name = join.name(); + let _name = join.name(); let operator = join.operator().clone(); let names = join.names(); let names = names.with(vec![ diff --git a/src/relation/builder.rs b/src/relation/builder.rs index d4a74b01..14790653 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -9,7 +9,6 @@ use super::{ use crate::{ builder::{Ready, With, WithIterator}, data_type::{Integer, Value}, - display::Dot, expr::{self, AggregateColumn, Expr, Identifier, Split}, hierarchy::Hierarchy, namer::{self, FIELD, JOIN, MAP, REDUCE, SET}, @@ -141,7 +140,7 @@ impl MapBuilder { self } - pub fn filter_iter(mut self, iter: Vec) -> Self { + pub fn filter_iter(self, iter: Vec) -> Self { let filter = iter .into_iter() .fold(Expr::val(true), |f, x| Expr::and(f, x)); @@ -153,7 +152,7 @@ impl MapBuilder { self } - pub fn order_by_iter(mut self, iter: Vec<(Expr, bool)>) -> Self { + pub fn order_by_iter(self, iter: Vec<(Expr, bool)>) -> Self { iter.into_iter().fold(self, |w, (x, b)| w.order_by(x, b)) } @@ -1064,7 +1063,7 @@ impl Ready for ValuesBuilder { #[cfg(test)] mod tests { use super::*; - use crate::{data_type::DataTyped, display::Dot, expr::aggregate::Aggregate, DataType}; + use crate::{data_type::DataTyped, display::Dot, DataType}; #[test] fn test_map_building() { diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 92be3d59..0b6fcc15 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -1154,7 +1154,7 @@ impl Set { fn schema( names: Vec, operator: &SetOperator, - quantifier: &SetQuantifier, + _quantifier: &SetQuantifier, left: &Relation, right: &Relation, ) -> Schema { @@ -1183,7 +1183,7 @@ impl Set { /// Compute the size of the join fn size( operator: &SetOperator, - quantifier: &SetQuantifier, + _quantifier: &SetQuantifier, left: &Relation, right: &Relation, ) -> Integer { @@ -1647,7 +1647,7 @@ impl Ready for ValuesBuilder { #[cfg(test)] mod tests { use super::{schema::Schema, *}; - use crate::ast; + use crate::{builder::With, data_type::DataType, display::Dot}; #[test] diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 4ec7df52..878d779e 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -5,11 +5,10 @@ use super::{Join, Map, Reduce, Relation, Set, Table, Values, Variant as _}; use crate::{ builder::{Ready, With, WithIterator}, data_type::{self, function::Function, DataType, DataTyped, Variant as _}, - display::Dot, expr::{self, aggregate, Aggregate, Expr, Identifier, Value}, hierarchy::Hierarchy, io, - namer::{self, name_from_content}, + namer::{self}, relation::{self, LEFT_INPUT_NAME, RIGHT_INPUT_NAME}, }; use std::{ @@ -430,7 +429,7 @@ impl Relation { /// Compute L2 norms of the vectors formed by the group values for each entities pub fn l2_norms(self, entities: &str, groups: &[&str], values: &[&str]) -> Self { let mut entities_groups = vec![entities]; - entities_groups.extend(groups.clone()); + entities_groups.extend(groups); let names = values .iter() .map(|v| format!("_NORM_{}", v)) @@ -1537,7 +1536,7 @@ mod tests { #[test] fn test_poisson_sampling() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations = database.relations(); let proba = 0.5; @@ -1646,7 +1645,7 @@ mod tests { #[ignore] // Too fragile #[test] fn test_sampling_query() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations = database.relations(); // relation with reduce diff --git a/src/relation/schema.rs b/src/relation/schema.rs index eaa751cb..61402f09 100644 --- a/src/relation/schema.rs +++ b/src/relation/schema.rs @@ -8,8 +8,8 @@ use std::{ use super::{field::Field, Error, Result}; use crate::{ builder::{Ready, With}, - data_type::{self, DataType, DataTyped}, - expr::{identifier::Identifier, Expr}, + data_type::{DataType, DataTyped}, + expr::{identifier::Identifier}, }; /// A struct holding Fields as in https://github.com/apache/arrow-datafusion/blob/5b23180cf75ea7155d7c35a40f224ce4d5ad7fb8/datafusion/src/logical_plan/dfschema.rs#L36 @@ -260,7 +260,7 @@ impl Ready for Builder { #[cfg(test)] mod tests { use super::*; - use crate::data_type::{DataType, Variant}; + use crate::data_type::{DataType}; use std::panic::catch_unwind; #[test] diff --git a/src/relation/sql.rs b/src/relation/sql.rs index dde57baf..3460ae9f 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -1,18 +1,17 @@ //! Methods to convert Relations to ast::Query -use serde::de::value; + use super::{ - Error, Join, JoinOperator, Map, OrderBy, Reduce, Relation, Result, Set, SetOperator, + Join, Map, OrderBy, Reduce, Relation, Set, SetOperator, SetQuantifier, Table, Values, Variant as _, Visitor, }; use crate::{ ast, - data_type::{DataType, DataTyped}, dialect_translation::{postgresql::PostgreSqlTranslator, RelationToQueryTranslator}, expr::{identifier::Identifier, Expr}, visitor::Acceptor, }; -use std::{collections::HashSet, convert::TryFrom, iter::Iterator, ops::Deref}; +use std::{collections::HashSet, iter::Iterator, ops::Deref}; /// A simple Relation -> ast::Query conversion Visitor using CTE #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] diff --git a/src/rewriting/dot.rs b/src/rewriting/dot.rs index 48b0c73b..52024b7d 100644 --- a/src/rewriting/dot.rs +++ b/src/rewriting/dot.rs @@ -1,13 +1,12 @@ -use std::{fmt, io, iter, string}; +use std::{io, iter}; use itertools::Itertools; use super::{ - rewriting_rule, Property, RelationWithRewritingRule, RelationWithRewritingRules, RewritingRule, + Property, RelationWithRewritingRule, RelationWithRewritingRules, RewritingRule, }; use crate::{ display::{self, colors}, - expr::{rewriting, Reduce}, namer, relation::{Relation, Variant}, visitor::Acceptor, @@ -80,14 +79,14 @@ impl<'a> dot::Labeller<'a, Node<'a>, Edge<'a>> for RelationWithRewritingRules<'a } } - fn edge_label(&'a self, edge: &Edge<'a>) -> dot::LabelText<'a> { + fn edge_label(&'a self, _edge: &Edge<'a>) -> dot::LabelText<'a> { dot::LabelText::LabelStr("".into()) } fn edge_style(&'a self, edge: &Edge<'a>) -> dot::Style { match edge { - Edge::RelationInput(r, i) => dot::Style::None, - Edge::RelationRewritingRule(r, rr) => dot::Style::Dotted, + Edge::RelationInput(_r, _i) => dot::Style::None, + Edge::RelationRewritingRule(_r, _rr) => dot::Style::Dotted, } } } @@ -156,11 +155,10 @@ impl<'a> RelationWithRewritingRule<'a> { #[cfg(test)] mod tests { - use itertools::Itertools; + use super::*; use crate::{ - ast, builder::With, display::Dot, io::{postgresql, Database}, diff --git a/src/rewriting/mod.rs b/src/rewriting/mod.rs index 900742e4..c79a3ae1 100644 --- a/src/rewriting/mod.rs +++ b/src/rewriting/mod.rs @@ -145,7 +145,7 @@ mod tests { #[test] fn test_rewrite() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations = database.relations(); for (p, r) in relations.iter() { diff --git a/src/rewriting/rewriting_rule.rs b/src/rewriting/rewriting_rule.rs index d1a87f39..f1650544 100644 --- a/src/rewriting/rewriting_rule.rs +++ b/src/rewriting/rewriting_rule.rs @@ -618,7 +618,7 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { rewriting_rules } - fn map(&self, map: &'a Map, input: Arc>) -> Vec { + fn map(&self, _map: &'a Map, _input: Arc>) -> Vec { let mut rewriting_rules = vec![ RewritingRule::new(vec![Property::Public], Property::Public, Parameters::None), RewritingRule::new( @@ -650,7 +650,7 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { fn reduce( &self, reduce: &'a Reduce, - input: Arc>, + _input: Arc>, ) -> Vec { let mut rewriting_rules = vec![ RewritingRule::new(vec![Property::Public], Property::Public, Parameters::None), @@ -708,9 +708,9 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { fn join( &self, - join: &'a Join, - left: Arc>, - right: Arc>, + _join: &'a Join, + _left: Arc>, + _right: Arc>, ) -> Vec { let mut rewriting_rules = vec![ RewritingRule::new( @@ -782,9 +782,9 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { fn set( &self, - set: &'a Set, - left: Arc>, - right: Arc>, + _set: &'a Set, + _left: Arc>, + _right: Arc>, ) -> Vec { let mut rewriting_rules = vec![ RewritingRule::new( @@ -816,7 +816,7 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { rewriting_rules } - fn values(&self, values: &'a Values) -> Vec { + fn values(&self, _values: &'a Values) -> Vec { let mut rewriting_rules = vec![RewritingRule::new( vec![], Property::Public, @@ -837,13 +837,13 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { pub struct RewritingRulesEliminator; impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { - fn table(&self, table: &'a Table, rewriting_rules: &'a [RewritingRule]) -> Vec { + fn table(&self, _table: &'a Table, rewriting_rules: &'a [RewritingRule]) -> Vec { rewriting_rules.into_iter().cloned().collect() } fn map( &self, - map: &'a Map, + _map: &'a Map, rewriting_rules: &'a [RewritingRule], input: Arc>, ) -> Vec { @@ -861,7 +861,7 @@ impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { fn reduce( &self, - reduce: &'a Reduce, + _reduce: &'a Reduce, rewriting_rules: &'a [RewritingRule], input: Arc>, ) -> Vec { @@ -879,7 +879,7 @@ impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { fn join( &self, - join: &'a Join, + _join: &'a Join, rewriting_rules: &'a [RewritingRule], left: Arc>, right: Arc>, @@ -906,7 +906,7 @@ impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { fn set( &self, - set: &'a Set, + _set: &'a Set, rewriting_rules: &'a [RewritingRule], left: Arc>, right: Arc>, @@ -933,7 +933,7 @@ impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { fn values( &self, - values: &'a Values, + _values: &'a Values, rewriting_rules: &'a [RewritingRule], ) -> Vec { rewriting_rules.into_iter().cloned().collect() @@ -944,13 +944,13 @@ impl<'a> MapRewritingRulesVisitor<'a> for RewritingRulesEliminator { pub struct RewritingRulesSelector; impl<'a> SelectRewritingRuleVisitor<'a> for RewritingRulesSelector { - fn table(&self, table: &'a Table, rewriting_rules: &'a [RewritingRule]) -> Vec { + fn table(&self, _table: &'a Table, rewriting_rules: &'a [RewritingRule]) -> Vec { rewriting_rules.into_iter().cloned().collect() } fn map( &self, - map: &'a Map, + _map: &'a Map, rewriting_rules: &'a [RewritingRule], input: &RelationWithRewritingRule<'a>, ) -> Vec { @@ -963,7 +963,7 @@ impl<'a> SelectRewritingRuleVisitor<'a> for RewritingRulesSelector { fn reduce( &self, - reduce: &'a Reduce, + _reduce: &'a Reduce, rewriting_rules: &'a [RewritingRule], input: &RelationWithRewritingRule<'a>, ) -> Vec { @@ -976,7 +976,7 @@ impl<'a> SelectRewritingRuleVisitor<'a> for RewritingRulesSelector { fn join( &self, - join: &'a Join, + _join: &'a Join, rewriting_rules: &'a [RewritingRule], left: &RelationWithRewritingRule<'a>, right: &RelationWithRewritingRule<'a>, @@ -993,7 +993,7 @@ impl<'a> SelectRewritingRuleVisitor<'a> for RewritingRulesSelector { fn set( &self, - set: &'a Set, + _set: &'a Set, rewriting_rules: &'a [RewritingRule], left: &RelationWithRewritingRule<'a>, right: &RelationWithRewritingRule<'a>, @@ -1010,7 +1010,7 @@ impl<'a> SelectRewritingRuleVisitor<'a> for RewritingRulesSelector { fn values( &self, - values: &'a Values, + _values: &'a Values, rewriting_rules: &'a [RewritingRule], ) -> Vec { rewriting_rules.into_iter().cloned().collect() @@ -1266,7 +1266,7 @@ impl<'a> RewriteVisitor<'a> for Rewriter<'a> { fn set( &self, set: &'a Set, - rewriting_rule: &'a RewritingRule, + _rewriting_rule: &'a RewritingRule, rewritten_left: RelationWithDpEvent, rewritten_right: RelationWithDpEvent, ) -> RelationWithDpEvent { @@ -1282,7 +1282,7 @@ impl<'a> RewriteVisitor<'a> for Rewriter<'a> { (relation, dp_event_left.compose(dp_event_right)).into() } - fn values(&self, values: &'a Values, rewriting_rule: &'a RewritingRule) -> RelationWithDpEvent { + fn values(&self, values: &'a Values, _rewriting_rule: &'a RewritingRule) -> RelationWithDpEvent { (Arc::new(values.clone().into()), DpEvent::no_op()).into() } } diff --git a/src/sampling_adjustment/mod.rs b/src/sampling_adjustment/mod.rs index dcef5e5c..02a7a13a 100644 --- a/src/sampling_adjustment/mod.rs +++ b/src/sampling_adjustment/mod.rs @@ -326,14 +326,14 @@ impl<'a, F: Fn(&Table) -> RelationWithWeight> Visitor<'a, RelationWithWeight> fn set( &self, - set: &'a Set, - left: RelationWithWeight, - right: RelationWithWeight, + _set: &'a Set, + _left: RelationWithWeight, + _right: RelationWithWeight, ) -> RelationWithWeight { todo!() } - fn values(&self, values: &'a Values) -> RelationWithWeight { + fn values(&self, _values: &'a Values) -> RelationWithWeight { todo!() } } @@ -415,7 +415,7 @@ impl<'a, F: Fn(&Table) -> Relation> Visitor<'a, Relation> for TableSamplerVisito .build() } - fn set(&self, set: &'a Set, left: Relation, right: Relation) -> Relation { + fn set(&self, _set: &'a Set, _left: Relation, _right: Relation) -> Relation { todo!() } @@ -536,12 +536,10 @@ mod tests { ast, display::Dot, io::{postgresql, Database}, - namer, - sql::parse, }; use colored::Colorize; - use itertools::Itertools; + #[cfg(feature = "tested_sampling_adjustment")] #[test] diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 6d1a1d97..574f0530 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -43,7 +43,7 @@ impl<'a> Acceptor<'a> for ast::Expr { match self { ast::Expr::Identifier(_) => Dependencies::empty(), ast::Expr::CompoundIdentifier(_) => Dependencies::empty(), - ast::Expr::JsonAccess { value, path } => Dependencies::from([value.as_ref()]), + ast::Expr::JsonAccess { value, path: _ } => Dependencies::from([value.as_ref()]), ast::Expr::CompositeAccess { expr, key: _ } => Dependencies::from([expr.as_ref()]), ast::Expr::IsFalse(expr) => Dependencies::from([expr.as_ref()]), ast::Expr::IsNotFalse(expr) => Dependencies::from([expr.as_ref()]), @@ -132,7 +132,7 @@ impl<'a> Acceptor<'a> for ast::Expr { expr, substring_from, substring_for, - special, + special: _, } => vec![Some(expr), substring_from.as_ref(), substring_for.as_ref()] .iter() .filter_map(|expr| expr.map(AsRef::as_ref)) @@ -167,7 +167,7 @@ impl<'a> Acceptor<'a> for ast::Expr { data_type: _, value: _, } => Dependencies::empty(), - ast::Expr::MapAccess { column, keys } => Dependencies::from([column.as_ref()]), + ast::Expr::MapAccess { column, keys: _ } => Dependencies::from([column.as_ref()]), ast::Expr::Function(function) => match &function.args { ast::FunctionArguments::None => Dependencies::empty(), ast::FunctionArguments::Subquery(_) => Dependencies::empty(), @@ -216,25 +216,25 @@ impl<'a> Acceptor<'a> for ast::Expr { ast::Expr::Array(_) => Dependencies::empty(), ast::Expr::Interval(_) => Dependencies::empty(), ast::Expr::MatchAgainst { - columns, - match_value, - opt_search_modifier, + columns: _, + match_value: _, + opt_search_modifier: _, } => Dependencies::empty(), - ast::Expr::IntroducedString { introducer, value } => Dependencies::empty(), + ast::Expr::IntroducedString { introducer: _, value: _ } => Dependencies::empty(), ast::Expr::RLike { - negated, - expr, - pattern, - regexp, + negated: _, + expr: _, + pattern: _, + regexp: _, } => todo!(), - ast::Expr::Struct { values, fields } => todo!(), - ast::Expr::Named { expr, name } => todo!(), + ast::Expr::Struct { values: _, fields: _ } => todo!(), + ast::Expr::Named { expr: _, name: _ } => todo!(), ast::Expr::Convert { - expr, - data_type, - charset, - target_before_value, - styles, + expr: _, + data_type: _, + charset: _, + target_before_value: _, + styles: _, } => todo!(), ast::Expr::Wildcard => todo!(), ast::Expr::QualifiedWildcard(_) => todo!(), @@ -291,8 +291,8 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { match acceptor { ast::Expr::Identifier(ident) => self.identifier(ident), ast::Expr::CompoundIdentifier(idents) => self.compound_identifier(idents), - ast::Expr::JsonAccess { value, path } => todo!(), - ast::Expr::CompositeAccess { expr, key } => todo!(), + ast::Expr::JsonAccess { value: _, path: _ } => todo!(), + ast::Expr::CompositeAccess { expr: _, key: _ } => todo!(), ast::Expr::IsFalse(expr) => self.is( self.cast(dependencies.get(expr).clone(), &ast::DataType::Boolean), Some(false), @@ -340,14 +340,14 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { } } ast::Expr::InSubquery { - expr, - subquery, - negated, + expr: _, + subquery: _, + negated: _, } => todo!(), ast::Expr::InUnnest { - expr, - array_expr, - negated, + expr: _, + array_expr: _, + negated: _, } => todo!(), ast::Expr::Between { expr, @@ -418,22 +418,22 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { } } ast::Expr::SimilarTo { - negated, - expr, - pattern, - escape_char, + negated: _, + expr: _, + pattern: _, + escape_char: _, } => todo!(), ast::Expr::AnyOp { - left, + left: _, compare_op: _, - right, + right: _, } => { todo!() } ast::Expr::AllOp { - left, + left: _, compare_op: _, - right, + right: _, } => { todo!() } @@ -445,8 +445,8 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { kind: _, } => self.cast(dependencies.get(expr).clone(), data_type), ast::Expr::AtTimeZone { - timestamp, - time_zone, + timestamp: _, + time_zone: _, } => todo!(), ast::Expr::Extract { field, expr } => { self.extract(field, dependencies.get(expr).clone()) @@ -461,7 +461,7 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { expr, substring_from, substring_for, - special, + special: _, } => self.substring( dependencies.get(expr).clone(), substring_from @@ -480,7 +480,7 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { let trim_what = match (trim_what, trim_characters) { (None, None) => None, (Some(x), None) => Some(x.as_ref()), - (None, Some(v)) => todo!(), + (None, Some(_v)) => todo!(), _ => todo!(), }; self.trim( @@ -490,16 +490,16 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { ) } ast::Expr::Overlay { - expr, - overlay_what, - overlay_from, - overlay_for, + expr: _, + overlay_what: _, + overlay_from: _, + overlay_for: _, } => todo!(), - ast::Expr::Collate { expr, collation } => todo!(), + ast::Expr::Collate { expr: _, collation: _ } => todo!(), ast::Expr::Nested(expr) => dependencies.get(expr).clone(), ast::Expr::Value(value) => self.value(value), - ast::Expr::TypedString { data_type, value } => todo!(), - ast::Expr::MapAccess { column, keys } => todo!(), + ast::Expr::TypedString { data_type: _, value: _ } => todo!(), + ast::Expr::MapAccess { column: _, keys: _ } => todo!(), ast::Expr::Function(function) => self.function(function, { let mut result = vec![]; let function_args = match &function.args { @@ -512,7 +512,7 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { ast::FunctionArg::Named { name, arg, - operator, + operator: _, } => FunctionArg::Named { name: name.clone(), arg: match arg { @@ -551,35 +551,35 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { .collect(), else_result.clone().map(|x| dependencies.get(&*x).clone()), ), - ast::Expr::Exists { subquery, negated } => todo!(), + ast::Expr::Exists { subquery: _, negated: _ } => todo!(), ast::Expr::Subquery(_) => todo!(), ast::Expr::GroupingSets(_) => todo!(), ast::Expr::Cube(_) => todo!(), ast::Expr::Rollup(_) => todo!(), ast::Expr::Tuple(_) => todo!(), - ast::Expr::ArrayIndex { obj, indexes } => todo!(), + ast::Expr::ArrayIndex { obj: _, indexes: _ } => todo!(), ast::Expr::Array(_) => todo!(), ast::Expr::Interval(_) => todo!(), ast::Expr::MatchAgainst { - columns, - match_value, - opt_search_modifier, + columns: _, + match_value: _, + opt_search_modifier: _, } => todo!(), - ast::Expr::IntroducedString { introducer, value } => todo!(), + ast::Expr::IntroducedString { introducer: _, value: _ } => todo!(), ast::Expr::RLike { - negated, - expr, - pattern, - regexp, + negated: _, + expr: _, + pattern: _, + regexp: _, } => todo!(), - ast::Expr::Struct { values, fields } => todo!(), - ast::Expr::Named { expr, name } => todo!(), + ast::Expr::Struct { values: _, fields: _ } => todo!(), + ast::Expr::Named { expr: _, name: _ } => todo!(), ast::Expr::Convert { - expr, - data_type, - charset, - target_before_value, - styles, + expr: _, + data_type: _, + charset: _, + target_before_value: _, + styles: _, } => todo!(), ast::Expr::Wildcard => todo!(), ast::Expr::QualifiedWildcard(_) => todo!(), @@ -784,7 +784,7 @@ impl From<&Vec> for Identifier { } impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { - fn qualified_wildcard(&self, idents: &'a Vec) -> Result { + fn qualified_wildcard(&self, _idents: &'a Vec) -> Result { todo!() } diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 13dfc1bb..0c8ec9d7 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -10,7 +10,7 @@ pub mod relation; pub mod visitor; pub mod writer; -use crate::{ast, relation::Variant as _}; +use crate::ast; // I would put here the abstact AST Visitor. // Then in expr.rs module we write an implementation of the abstract visitor for Qrlew expr @@ -112,12 +112,11 @@ mod tests { builder::With, display::Dot, io::{postgresql, Database}, - relation::Relation, - DataType, + relation::{Relation, Variant as _,} }; - use colored::Colorize; - use itertools::Itertools; - use sqlparser::dialect::BigQueryDialect; + + + #[test] fn test_display() { diff --git a/src/sql/relation.rs b/src/sql/relation.rs index 43531547..bedfb053 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -12,7 +12,6 @@ use crate::{ builder::{Ready, With, WithIterator, WithoutContext}, dialect::{Dialect, GenericDialect}, dialect_translation::{postgresql::PostgreSqlTranslator, QueryToRelationTranslator}, - display::Dot, expr::{Expr, Identifier, Reduce, Split}, hierarchy::{Hierarchy, Path}, namer::{self, FIELD}, @@ -25,10 +24,9 @@ use crate::{ types::And, visitor::{Acceptor, Dependencies, Visited}, }; -use dot::Id; + use itertools::Itertools; use std::{ - collections::HashMap, convert::TryFrom, iter::{once, Iterator}, ops::Deref, @@ -143,7 +141,7 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, let VisitedQueryRelations { relations, visited, - translator, + translator: _, } = self; // Process the table_factor @@ -605,9 +603,9 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> VisitedQueryRelations<'a, having, named_window, qualify, - window_before_qualify, - value_table_mode, - connect_by, + window_before_qualify: _, + value_table_mode: _, + connect_by: _, } = select; if top.is_some() { return Err(Error::other("TOP is not supported")); @@ -744,9 +742,9 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> Visitor<'a, Result Dependencies<'a, ast::Query> { let TryIntoRelationVisitor { - relations, + relations: _, query_names, - translator, + translator: _, } = self; let mut dependencies = acceptor.dependencies(); // Add subqueries from the body @@ -1300,7 +1298,7 @@ mod tests { .schema(schema_1.clone()) .size(100) .build(); - let relation = Relation::try_from(QueryWithRelations::new( + let _relation = Relation::try_from(QueryWithRelations::new( &query, &Hierarchy::from([(["schema", "table_1"], Arc::new(table_1))]), )); diff --git a/src/sql/visitor.rs b/src/sql/visitor.rs index 00c1ed49..0cddf941 100644 --- a/src/sql/visitor.rs +++ b/src/sql/visitor.rs @@ -4,7 +4,7 @@ use crate::{ ast, visitor::{self, Acceptor, Dependencies, Visited}, }; -use itertools::Itertools; + use std::iter::Iterator; /// A type to hold queries and relations with their aliases @@ -75,7 +75,7 @@ fn queries_from_set_expr<'a>(set_expr: &'a ast::SetExpr) -> Vec<&'a ast::Query> .flat_map(|table_with_joins| TableWithJoins(table_with_joins).queries()) .collect(), ast::SetExpr::SetOperation { .. } => vec![], - ast::SetExpr::Values(values) => todo!(), + ast::SetExpr::Values(_values) => todo!(), _ => todo!(), // Not implemented } } diff --git a/src/synthetic_data/mod.rs b/src/synthetic_data/mod.rs index 1abcdaad..c06fd5ed 100644 --- a/src/synthetic_data/mod.rs +++ b/src/synthetic_data/mod.rs @@ -1,10 +1,10 @@ use crate::{ - builder::{Ready, With, WithIterator}, - expr::{AggregateColumn, Expr, Identifier}, + builder::{Ready}, + expr::{Identifier}, hierarchy::Hierarchy, - relation::{Join, Map, Reduce, Relation, Table, Values, Variant as _}, + relation::{Relation, Table, Variant as _}, }; -use std::{error, fmt, ops::Deref, result, sync::Arc}; +use std::{error, fmt, ops::Deref, result}; pub const SYNTHETIC_PREFIX: &str = "_SYNTHETIC_"; diff --git a/tests/integration.rs b/tests/integration.rs index 413dd920..f19deec4 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -11,12 +11,12 @@ use qrlew::io::sqlite; use qrlew::{ ast, dialect_translation::{ - postgresql::PostgreSqlTranslator, QueryToRelationTranslator, RelationToQueryTranslator, RelationWithTranslator + RelationToQueryTranslator, RelationWithTranslator }, expr, io::{postgresql, Database}, relation::Variant as _, - sql::{parse, parse_with_dialect, relation::QueryWithRelations}, + sql::{parse}, Relation, With, }; From ce715fc9cf6c1a029b1901d27945cb8df194dfaf Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 17:29:46 +0200 Subject: [PATCH 3/8] fmt --- src/data_type/mod.rs | 10 ++----- src/dialect_translation/bigquery.rs | 2 +- src/dialect_translation/mod.rs | 25 +++++++--------- src/dialect_translation/mssql.rs | 8 ++--- src/dialect_translation/postgresql.rs | 8 +---- src/dialect_translation/sqlite.rs | 8 +---- src/differential_privacy/aggregates.rs | 3 +- src/display/mod.rs | 3 +- src/expr/dot.rs | 3 +- src/expr/rewriting.rs | 5 +--- src/expr/sql.rs | 2 +- src/io/postgresql.rs | 4 +-- src/relation/mod.rs | 2 +- src/relation/schema.rs | 4 +-- src/relation/sql.rs | 5 ++-- src/rewriting/dot.rs | 5 +--- src/rewriting/rewriting_rule.rs | 12 ++++++-- src/sampling_adjustment/mod.rs | 1 - src/sql/expr.rs | 41 +++++++++++++++++++++----- src/sql/mod.rs | 5 +--- src/synthetic_data/mod.rs | 4 +-- 21 files changed, 81 insertions(+), 79 deletions(-) diff --git a/src/data_type/mod.rs b/src/data_type/mod.rs index 5f0b0d4d..75d490ff 100644 --- a/src/data_type/mod.rs +++ b/src/data_type/mod.rs @@ -1723,10 +1723,7 @@ impl Variant for List { } fn try_empty(&self) -> Result { - Ok(Self::new( - self.data_type().try_empty()?.into(), - 0.into(), - )) + Ok(Self::new(self.data_type().try_empty()?.into(), 0.into())) } } @@ -1843,10 +1840,7 @@ impl Variant for Set { } fn try_empty(&self) -> Result { - Ok(Self::new( - self.data_type().try_empty()?.into(), - 0.into(), - )) + Ok(Self::new(self.data_type().try_empty()?.into(), 0.into())) } } diff --git a/src/dialect_translation/bigquery.rs b/src/dialect_translation/bigquery.rs index 0744dad7..96373f36 100644 --- a/src/dialect_translation/bigquery.rs +++ b/src/dialect_translation/bigquery.rs @@ -113,7 +113,7 @@ mod tests { use super::*; use crate::{ builder::{Ready, With}, - data_type::{DataType}, + data_type::DataType, dialect_translation::RelationWithTranslator, expr::Expr, namer, diff --git a/src/dialect_translation/mod.rs b/src/dialect_translation/mod.rs index 096379d6..2ccf0a47 100644 --- a/src/dialect_translation/mod.rs +++ b/src/dialect_translation/mod.rs @@ -3,17 +3,8 @@ //! - a method to provide a sqlparser::Dialect for the parsing //! - methods varying from dialect to dialect regarding the conversion from AST to Expr+Relation and vice-versa +use sqlparser::{ast, dialect::Dialect}; -use sqlparser::{ - ast, - dialect::{Dialect}, -}; - -use crate::{ - expr::{self, Function}, - relation::sql::FromRelationVisitor, - visitor::Acceptor, WithoutContext, -}; use crate::{ data_type::DataTyped, expr::Identifier, @@ -22,6 +13,12 @@ use crate::{ sql::Result, DataType, Relation, }; +use crate::{ + expr::{self, Function}, + relation::sql::FromRelationVisitor, + visitor::Acceptor, + WithoutContext, +}; use paste::paste; @@ -787,12 +784,12 @@ pub trait QueryToRelationTranslator { context: &Hierarchy, ) -> Result> { match args { - ast::FunctionArguments::None - | ast::FunctionArguments::Subquery(_) => Ok(vec![]), - ast::FunctionArguments::List(arg_list) => arg_list.args + ast::FunctionArguments::None | ast::FunctionArguments::Subquery(_) => Ok(vec![]), + ast::FunctionArguments::List(arg_list) => arg_list + .args .iter() .map(|func_arg| match func_arg { - ast::FunctionArg::Named {arg, .. } | ast::FunctionArg::Unnamed(arg) => { + ast::FunctionArg::Named { arg, .. } | ast::FunctionArg::Unnamed(arg) => { self.try_function_arg_expr(arg, context) } }) diff --git a/src/dialect_translation/mssql.rs b/src/dialect_translation/mssql.rs index 86e4e5ef..51861402 100644 --- a/src/dialect_translation/mssql.rs +++ b/src/dialect_translation/mssql.rs @@ -292,8 +292,7 @@ impl QueryToRelationTranslator for MsSqlTranslator { ) -> Result { // need to check func.args: let args = match &func.args { - ast::FunctionArguments::None - | ast::FunctionArguments::Subquery(_) => vec![], + ast::FunctionArguments::None | ast::FunctionArguments::Subquery(_) => vec![], ast::FunctionArguments::List(l) => l.args.iter().collect(), }; // We expect 2 args @@ -361,8 +360,9 @@ fn extract_hashbyte_expression_if_valid(func_arg: &ast::FunctionArg) -> Option
match e { ast::Expr::Function(f) => { let arg_vec = match &f.args { - ast::FunctionArguments::None - | ast::FunctionArguments::Subquery(_) => vec![], + ast::FunctionArguments::None | ast::FunctionArguments::Subquery(_) => { + vec![] + } ast::FunctionArguments::List(func_args) => func_args.args.iter().collect(), }; if (f.name == expected_f_name) && (arg_vec[0] == &expected_first_arg) { diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 669719bc..63caa55c 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -1,13 +1,8 @@ - - use crate::expr; -use super::{ - function_builder, QueryToRelationTranslator, RelationToQueryTranslator, -}; +use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; use sqlparser::{ast, dialect::PostgreSqlDialect}; - #[derive(Clone, Copy)] pub struct PostgreSqlTranslator; @@ -96,7 +91,6 @@ impl QueryToRelationTranslator for PostgreSqlTranslator { #[cfg(test)] mod tests { - use super::*; use crate::{ diff --git a/src/dialect_translation/sqlite.rs b/src/dialect_translation/sqlite.rs index 6c138fc6..b81cb9a0 100644 --- a/src/dialect_translation/sqlite.rs +++ b/src/dialect_translation/sqlite.rs @@ -1,5 +1,3 @@ - - use super::RelationToQueryTranslator; #[derive(Clone, Copy)] @@ -8,8 +6,4 @@ pub struct SQLiteTranslator; impl RelationToQueryTranslator for SQLiteTranslator {} #[cfg(test)] -mod tests { - - - -} +mod tests {} diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 039305a7..00535fde 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -8,7 +8,8 @@ use crate::{ AggregateColumn, Column, Expr, Identifier, }, privacy_unit_tracking::PupRelation, - relation::{Map, Reduce, Relation, Variant}, Ready, + relation::{Map, Reduce, Relation, Variant}, + Ready, }; use std::{cmp, collections::HashMap, ops::Deref}; diff --git a/src/display/mod.rs b/src/display/mod.rs index 292438f0..0c512986 100644 --- a/src/display/mod.rs +++ b/src/display/mod.rs @@ -186,7 +186,8 @@ mod tests { builder::{Ready, With}, data_type::DataType, expr::Expr, - relation::{schema::Schema, Relation}, WithoutContext as _, + relation::{schema::Schema, Relation}, + WithoutContext as _, }; #[test] diff --git a/src/expr/dot.rs b/src/expr/dot.rs index ecb5112d..856e4d7e 100644 --- a/src/expr/dot.rs +++ b/src/expr/dot.rs @@ -201,7 +201,8 @@ mod tests { builder::{Ready, With}, data_type::DataType, display::Dot, - relation::{schema::Schema, Relation}, WithoutContext as _, + relation::{schema::Schema, Relation}, + WithoutContext as _, }; use std::sync::Arc; diff --git a/src/expr/rewriting.rs b/src/expr/rewriting.rs index 2bcfc717..a4827a08 100644 --- a/src/expr/rewriting.rs +++ b/src/expr/rewriting.rs @@ -1,7 +1,4 @@ -use crate::{ - expr::{Expr}, - namer, -}; +use crate::{expr::Expr, namer}; use std::f64::consts::PI; impl Expr { diff --git a/src/expr/sql.rs b/src/expr/sql.rs index c6211eec..a7586354 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -1,7 +1,7 @@ //! Convert Expr into ast::Expr use crate::{ ast, - data_type::{DataType}, + data_type::DataType, expr::{self, Expr}, visitor::Acceptor, }; diff --git a/src/io/postgresql.rs b/src/io/postgresql.rs index 201cefef..b480db56 100644 --- a/src/io/postgresql.rs +++ b/src/io/postgresql.rs @@ -12,9 +12,7 @@ use crate::{ namer, relation::{Table, Variant as _}, }; -use std::{ - env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time, -}; +use std::{env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time}; use colored::Colorize; use postgres::{ diff --git a/src/relation/mod.rs b/src/relation/mod.rs index 0b6fcc15..78f51371 100644 --- a/src/relation/mod.rs +++ b/src/relation/mod.rs @@ -1647,7 +1647,7 @@ impl Ready for ValuesBuilder { #[cfg(test)] mod tests { use super::{schema::Schema, *}; - + use crate::{builder::With, data_type::DataType, display::Dot}; #[test] diff --git a/src/relation/schema.rs b/src/relation/schema.rs index 61402f09..629cc5e1 100644 --- a/src/relation/schema.rs +++ b/src/relation/schema.rs @@ -9,7 +9,7 @@ use super::{field::Field, Error, Result}; use crate::{ builder::{Ready, With}, data_type::{DataType, DataTyped}, - expr::{identifier::Identifier}, + expr::identifier::Identifier, }; /// A struct holding Fields as in https://github.com/apache/arrow-datafusion/blob/5b23180cf75ea7155d7c35a40f224ce4d5ad7fb8/datafusion/src/logical_plan/dfschema.rs#L36 @@ -260,7 +260,7 @@ impl Ready for Builder { #[cfg(test)] mod tests { use super::*; - use crate::data_type::{DataType}; + use crate::data_type::DataType; use std::panic::catch_unwind; #[test] diff --git a/src/relation/sql.rs b/src/relation/sql.rs index 3460ae9f..0dcea561 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -1,9 +1,8 @@ //! Methods to convert Relations to ast::Query - use super::{ - Join, Map, OrderBy, Reduce, Relation, Set, SetOperator, - SetQuantifier, Table, Values, Variant as _, Visitor, + Join, Map, OrderBy, Reduce, Relation, Set, SetOperator, SetQuantifier, Table, Values, + Variant as _, Visitor, }; use crate::{ ast, diff --git a/src/rewriting/dot.rs b/src/rewriting/dot.rs index 52024b7d..7ac87a31 100644 --- a/src/rewriting/dot.rs +++ b/src/rewriting/dot.rs @@ -2,9 +2,7 @@ use std::{io, iter}; use itertools::Itertools; -use super::{ - Property, RelationWithRewritingRule, RelationWithRewritingRules, RewritingRule, -}; +use super::{Property, RelationWithRewritingRule, RelationWithRewritingRules, RewritingRule}; use crate::{ display::{self, colors}, namer, @@ -155,7 +153,6 @@ impl<'a> RelationWithRewritingRule<'a> { #[cfg(test)] mod tests { - use super::*; use crate::{ diff --git a/src/rewriting/rewriting_rule.rs b/src/rewriting/rewriting_rule.rs index f1650544..6139d7fb 100644 --- a/src/rewriting/rewriting_rule.rs +++ b/src/rewriting/rewriting_rule.rs @@ -618,7 +618,11 @@ impl<'a> SetRewritingRulesVisitor<'a> for RewritingRulesSetter<'a> { rewriting_rules } - fn map(&self, _map: &'a Map, _input: Arc>) -> Vec { + fn map( + &self, + _map: &'a Map, + _input: Arc>, + ) -> Vec { let mut rewriting_rules = vec![ RewritingRule::new(vec![Property::Public], Property::Public, Parameters::None), RewritingRule::new( @@ -1282,7 +1286,11 @@ impl<'a> RewriteVisitor<'a> for Rewriter<'a> { (relation, dp_event_left.compose(dp_event_right)).into() } - fn values(&self, values: &'a Values, _rewriting_rule: &'a RewritingRule) -> RelationWithDpEvent { + fn values( + &self, + values: &'a Values, + _rewriting_rule: &'a RewritingRule, + ) -> RelationWithDpEvent { (Arc::new(values.clone().into()), DpEvent::no_op()).into() } } diff --git a/src/sampling_adjustment/mod.rs b/src/sampling_adjustment/mod.rs index 02a7a13a..6dfc0f98 100644 --- a/src/sampling_adjustment/mod.rs +++ b/src/sampling_adjustment/mod.rs @@ -539,7 +539,6 @@ mod tests { }; use colored::Colorize; - #[cfg(feature = "tested_sampling_adjustment")] #[test] diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 47bae59b..5e00de98 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -220,14 +220,20 @@ impl<'a> Acceptor<'a> for ast::Expr { match_value: _, opt_search_modifier: _, } => Dependencies::empty(), - ast::Expr::IntroducedString { introducer: _, value: _ } => Dependencies::empty(), + ast::Expr::IntroducedString { + introducer: _, + value: _, + } => Dependencies::empty(), ast::Expr::RLike { negated: _, expr: _, pattern: _, regexp: _, } => todo!(), - ast::Expr::Struct { values: _, fields: _ } => todo!(), + ast::Expr::Struct { + values: _, + fields: _, + } => todo!(), ast::Expr::Named { expr: _, name: _ } => todo!(), ast::Expr::Convert { expr: _, @@ -495,10 +501,16 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { overlay_from: _, overlay_for: _, } => todo!(), - ast::Expr::Collate { expr: _, collation: _ } => todo!(), + ast::Expr::Collate { + expr: _, + collation: _, + } => todo!(), ast::Expr::Nested(expr) => dependencies.get(expr).clone(), ast::Expr::Value(value) => self.value(value), - ast::Expr::TypedString { data_type: _, value: _ } => todo!(), + ast::Expr::TypedString { + data_type: _, + value: _, + } => todo!(), ast::Expr::MapAccess { column: _, keys: _ } => todo!(), ast::Expr::Function(function) => self.function(function, { let mut result = vec![]; @@ -551,7 +563,10 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { .collect(), else_result.clone().map(|x| dependencies.get(&*x).clone()), ), - ast::Expr::Exists { subquery: _, negated: _ } => todo!(), + ast::Expr::Exists { + subquery: _, + negated: _, + } => todo!(), ast::Expr::Subquery(_) => todo!(), ast::Expr::GroupingSets(_) => todo!(), ast::Expr::Cube(_) => todo!(), @@ -565,14 +580,20 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { match_value: _, opt_search_modifier: _, } => todo!(), - ast::Expr::IntroducedString { introducer: _, value: _ } => todo!(), + ast::Expr::IntroducedString { + introducer: _, + value: _, + } => todo!(), ast::Expr::RLike { negated: _, expr: _, pattern: _, regexp: _, } => todo!(), - ast::Expr::Struct { values: _, fields: _ } => todo!(), + ast::Expr::Struct { + values: _, + fields: _, + } => todo!(), ast::Expr::Named { expr: _, name: _ } => todo!(), ast::Expr::Convert { expr: _, @@ -925,7 +946,11 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { let flat_args = flat_args?; let function_name: &str = &function.name.0.iter().join(".").to_lowercase(); let distinct: bool = match &function.args { - ast::FunctionArguments::List(func_arg_list) if func_arg_list.duplicate_treatment == Some(ast::DuplicateTreatment::Distinct) => true, + ast::FunctionArguments::List(func_arg_list) + if func_arg_list.duplicate_treatment == Some(ast::DuplicateTreatment::Distinct) => + { + true + } _ => false, }; Ok(match function_name { diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 0c8ec9d7..beb03e8c 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -112,11 +112,8 @@ mod tests { builder::With, display::Dot, io::{postgresql, Database}, - relation::{Relation, Variant as _,} + relation::{Relation, Variant as _}, }; - - - #[test] fn test_display() { diff --git a/src/synthetic_data/mod.rs b/src/synthetic_data/mod.rs index c06fd5ed..8b36d98a 100644 --- a/src/synthetic_data/mod.rs +++ b/src/synthetic_data/mod.rs @@ -1,6 +1,6 @@ use crate::{ - builder::{Ready}, - expr::{Identifier}, + builder::Ready, + expr::Identifier, hierarchy::Hierarchy, relation::{Relation, Table, Variant as _}, }; From 0a22cb01d7413bb2324c0ecedceda03d7ca4703f Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 18:09:18 +0200 Subject: [PATCH 4/8] cargo test ok --- src/data_type/function.rs | 14 ++++---- src/data_type/generator.rs | 8 ++--- src/data_type/injection.rs | 4 +-- src/data_type/intervals.rs | 12 +++---- src/dialect_translation/mod.rs | 15 -------- src/dialect_translation/postgresql.rs | 7 ++-- src/differential_privacy/aggregates.rs | 11 +----- src/display/mod.rs | 29 +++++++++------- src/expr/bijection.rs | 2 +- src/expr/dsl.rs | 1 + src/io/bigquery.rs | 47 +++++++++++++------------- src/io/mssql.rs | 31 ++++++++--------- src/io/postgresql.rs | 6 ++-- src/io/sqlite.rs | 2 +- 14 files changed, 81 insertions(+), 108 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 821e42ed..b2012223 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -2731,7 +2731,7 @@ mod tests { super::{value::Value, Struct}, *, }; - use chrono::{self, NaiveDate, NaiveDateTime, NaiveTime}; + use chrono::{self, NaiveDate, DateTime, NaiveTime}; #[test] fn test_argument_conversion() { @@ -3841,12 +3841,12 @@ mod tests { // im(struct{0: float(-∞, 10], 1: int[2 100]}) = float(-∞, 10] let set: DataType = DataType::structured_from_data_types([ DataType::date_time_interval( - NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap(), - NaiveDateTime::from_timestamp_opt(1862921288, 111110).unwrap(), + DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), + DateTime::from_timestamp(1862921288, 111110).unwrap().naive_utc(), ), DataType::date_time_interval( - NaiveDateTime::from_timestamp_opt(1362921288, 0).unwrap(), - NaiveDateTime::from_timestamp_opt(2062921288, 111110).unwrap(), + DateTime::from_timestamp(1362921288, 0).unwrap().naive_utc(), + DateTime::from_timestamp(2062921288, 111110).unwrap().naive_utc(), ), ]); let im = fun.super_image(&set).unwrap(); @@ -3854,8 +3854,8 @@ mod tests { assert_eq!( im, DataType::date_time_interval( - NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap(), - NaiveDateTime::from_timestamp_opt(2062921288, 111110).unwrap() + DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), + DateTime::from_timestamp(2062921288, 111110).unwrap().naive_utc() ), ); } diff --git a/src/data_type/generator.rs b/src/data_type/generator.rs index 7014c53f..722ccfcf 100644 --- a/src/data_type/generator.rs +++ b/src/data_type/generator.rs @@ -225,8 +225,8 @@ mod tests { chrono::NaiveDateTime::generate_between( &mut rng, &[ - chrono::NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap(), - chrono::NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap() + chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), + chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc() ] ) ); @@ -235,8 +235,8 @@ mod tests { chrono::NaiveDateTime::generate_between( &mut rng, &[ - chrono::NaiveDateTime::from_timestamp_opt(1662921288, 0).unwrap(), - chrono::NaiveDateTime::from_timestamp_opt(1693921288, 0).unwrap() + chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), + chrono::DateTime::from_timestamp(1693921288, 0).unwrap().naive_utc() ] ) ); diff --git a/src/data_type/injection.rs b/src/data_type/injection.rs index af360465..bb19d9d7 100644 --- a/src/data_type/injection.rs +++ b/src/data_type/injection.rs @@ -689,7 +689,7 @@ impl Injection for Base { &self, arg: &::Element, ) -> Result<::Element> { - self.value_map(|arg| arg.and_hms(0, 0, 0), arg) + self.value_map(|arg| arg.and_hms_opt(0, 0, 0).unwrap(), arg) } } @@ -764,7 +764,7 @@ impl Injection for Base { self.value_map_option( |arg| { let date = arg.date(); - if *arg == date.and_hms(0, 0, 0) { + if *arg == date.and_hms_opt(0, 0, 0).unwrap() { Some(date) } else { None diff --git a/src/data_type/intervals.rs b/src/data_type/intervals.rs index 05bb8cd5..b4149c6c 100644 --- a/src/data_type/intervals.rs +++ b/src/data_type/intervals.rs @@ -109,10 +109,10 @@ impl Bound for chrono::NaiveTime { "time".to_string() } fn min() -> Self { - chrono::NaiveTime::from_num_seconds_from_midnight(0, 0) + chrono::NaiveTime::from_num_seconds_from_midnight_opt(0, 0).unwrap() } fn max() -> Self { - chrono::NaiveTime::from_num_seconds_from_midnight(86399, 1_999_999_999) + chrono::NaiveTime::from_num_seconds_from_midnight_opt(86399, 1_999_999_999).unwrap() } fn hash(&self, state: &mut H) { hash::Hash::hash(self, state) @@ -1135,12 +1135,12 @@ mod tests { } let dates: Intervals = [ [ - NaiveDate::from_ymd(2022, 12, 1), - NaiveDate::from_ymd(2022, 12, 25), + NaiveDate::from_ymd_opt(2022, 12, 1).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 25).unwrap(), ], [ - NaiveDate::from_ymd(1980, 12, 1), - NaiveDate::from_ymd(1980, 12, 25), + NaiveDate::from_ymd_opt(1980, 12, 1).unwrap(), + NaiveDate::from_ymd_opt(1980, 12, 25).unwrap(), ], ] .into_iter() diff --git a/src/dialect_translation/mod.rs b/src/dialect_translation/mod.rs index 2ccf0a47..eb019e37 100644 --- a/src/dialect_translation/mod.rs +++ b/src/dialect_translation/mod.rs @@ -715,21 +715,6 @@ macro_rules! relation_to_query_tranlator_trait_constructor { relation_to_query_tranlator_trait_constructor!(); -/// Constructors for creating functions that convert AST functions with -/// a single args to annequivalent sarus functions -macro_rules! try_unary_function_constructor { - ($( $enum:ident ),*) => { - paste! { - $( - fn [](&self, arg: &ast::Function, context: &Hierarchy) -> Result { - let converted = self.try_function_args(vec![arg.clone()], context)?; - Ok(expr::Expr::[<$enum:snake>](converted[0])) - } - )* - } - } -} - /// Build Sarus Relation from dialect specific AST pub trait QueryToRelationTranslator { type D: Dialect; diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 63caa55c..3c7e6372 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -94,13 +94,10 @@ mod tests { use super::*; use crate::{ - builder::Ready, - data_type::DataType, - io::{postgresql, Database as _}, - relation::{schema::Schema, Relation}, - sql::relation::QueryWithRelations, + builder::Ready, data_type::DataType, dialect_translation::RelationWithTranslator, hierarchy::Hierarchy, io::{postgresql, Database as _}, relation::{schema::Schema, Relation}, sql::{parse_with_dialect, relation::QueryWithRelations} }; use std::sync::Arc; + use crate::sql::Result; fn assert_same_query_str(query_1: &str, query_2: &str) { let a_no_whitespace: String = query_1.chars().filter(|c| !c.is_whitespace()).collect(); diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 00535fde..3ee0b547 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -531,16 +531,7 @@ impl Reduce { mod tests { use super::*; use crate::{ - ast, - builder::With, - data_type::Variant, - display::Dot, - io::{postgresql, Database}, - privacy_unit_tracking::PrivacyUnit, - privacy_unit_tracking::{PrivacyUnitTracking, Strategy}, - relation::{Schema, Variant as _}, - sql::parse, - Relation, + ast, builder::With, data_type::Variant, display::Dot, io::{postgresql, Database}, privacy_unit_tracking::{PrivacyUnit, PrivacyUnitTracking, Strategy}, relation::{Schema, Variant as _}, sql::parse, DataType, Relation }; use std::{ops::Deref, sync::Arc}; diff --git a/src/display/mod.rs b/src/display/mod.rs index 0c512986..aadc49ee 100644 --- a/src/display/mod.rs +++ b/src/display/mod.rs @@ -37,18 +37,19 @@ const HTML_STYLE: &str = r##" "##; -const HTML_DARK_STYLE: &str = r##" -"##; +// not used +// const HTML_DARK_STYLE: &str = r##" +// "##; const HTML_BODY: &str = r##" @@ -181,10 +182,12 @@ pub mod macos { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::{ builder::{Ready, With}, - data_type::DataType, + data_type::{DataType, DataTyped as _}, expr::Expr, relation::{schema::Schema, Relation}, WithoutContext as _, diff --git a/src/expr/bijection.rs b/src/expr/bijection.rs index f88d0ffd..2c2565c7 100644 --- a/src/expr/bijection.rs +++ b/src/expr/bijection.rs @@ -59,7 +59,7 @@ impl Expr { #[cfg(test)] mod tests { - use identifier::Identifier; + use crate::expr::identifier::Identifier; use super::*; diff --git a/src/expr/dsl.rs b/src/expr/dsl.rs index 855d3089..129a8238 100644 --- a/src/expr/dsl.rs +++ b/src/expr/dsl.rs @@ -3,6 +3,7 @@ // https://veykril.github.io/tlborm/introduction.html // https://stackoverflow.com/questions/36721733/is-there-a-way-to-pattern-match-infix-operations-with-precedence-in-rust-macros // Macro DSL for exprs +#![allow(unused)] macro_rules! expr { // Process functions (@expf [$($f:tt)*][$([$($x:tt)*])*]) => {$($f)*($(expr!(@exp+ $($x)*)),*)}; diff --git a/src/io/bigquery.rs b/src/io/bigquery.rs index 114bb5b4..bd06b119 100644 --- a/src/io/bigquery.rs +++ b/src/io/bigquery.rs @@ -31,17 +31,16 @@ use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED}; use crate::{ data_type::{ - self, generator::Generator, - value::{self, Value, Variant}, - DataTyped, List, + value::{self, Value}, + DataTyped, }, namer, relation::{Constraint, Schema, Table, TableBuilder, Variant as _}, DataType, Ready as _, }; use rand::{rngs::StdRng, SeedableRng}; -use std::{env, fmt, process::Command, result, str::FromStr, sync::Arc, sync::Mutex, thread, time}; +use std::{fmt, process::Command, result, sync::Arc, sync::Mutex, thread, time}; const DB: &str = "qrlew-bigquery-test"; const PORT: u16 = 9050; @@ -147,24 +146,24 @@ pub static BQ_CLIENT: Mutex> = Mutex::new(None); pub static BIGQUERY_CONTAINER: Mutex = Mutex::new(false); impl Database { - fn db() -> String { - env::var("BIGQUERY_DB").unwrap_or(DB.into()) - } + // fn db() -> String { + // env::var("BIGQUERY_DB").unwrap_or(DB.into()) + // } - fn port() -> u16 { - match env::var("BIGQUERY_PORT") { - Ok(port) => u16::from_str(&port).unwrap_or(PORT), - Err(_) => PORT, - } - } + // fn port() -> u16 { + // match env::var("BIGQUERY_PORT") { + // Ok(port) => u16::from_str(&port).unwrap_or(PORT), + // Err(_) => PORT, + // } + // } - fn project_id() -> String { - env::var("BIGQUERY_PROJECT_ID").unwrap_or(PROJECT_ID.into()) - } + // fn project_id() -> String { + // env::var("BIGQUERY_PROJECT_ID").unwrap_or(PROJECT_ID.into()) + // } fn check_client(client: &Client) -> Result<()> { let rt = tokio::runtime::Runtime::new()?; - let res = rt.block_on(async_query("SELECT 1", &client, None))?; + let _res = rt.block_on(async_query("SELECT 1", &client, None))?; Ok(()) } @@ -387,7 +386,7 @@ async fn build_client(auth_uri: String, tmp_file_credentials: &NamedTempFile) -> } pub async fn async_row_query(query_str: &str, client: &Client) -> ResultSet { - let mut rs = client + let rs = client .job() .query(PROJECT_ID, QueryRequest::new(query_str)) .await @@ -465,7 +464,7 @@ impl DatabaseTrait for Database { } fn create_table(&mut self, table: &Table) -> Result { - let mut rt = tokio::runtime::Runtime::new()?; + let rt = tokio::runtime::Runtime::new()?; let bq_table: BQTable = table.clone().try_into()?; rt.block_on(self.client.table().create(bq_table))?; @@ -473,7 +472,7 @@ impl DatabaseTrait for Database { } fn insert_data(&mut self, table: &Table) -> Result<()> { - let mut rt = tokio::runtime::Runtime::new()?; + let rt = tokio::runtime::Runtime::new()?; let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED); let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize); @@ -551,7 +550,7 @@ async fn async_query( use_query_cache: None, format_options: None, }; - let mut rs = client.job().query(PROJECT_ID, query_request).await?; + let rs = client.job().query(PROJECT_ID, query_request).await?; let query_response = rs.query_response(); let schema = &query_response.schema; if let Some(table_schema) = schema { @@ -720,8 +719,8 @@ impl TryFrom<(Option, field_type::FieldType)> for SqlValue { let seconds = timestamp as i64; // Whole seconds part let nanoseconds = ((timestamp - seconds as f64) * 1_000_000_000.0) as u32; // Fractional part in nanoseconds let datetime = - chrono::NaiveDateTime::from_timestamp_opt(seconds, nanoseconds).unwrap(); - value::Value::date_time(datetime).try_into() + chrono::DateTime::from_timestamp(seconds, nanoseconds).unwrap(); + value::Value::date_time(datetime.naive_utc()).try_into() } field_type::FieldType::Date => value::Value::date( chrono::NaiveDate::parse_from_str(&val_as_str[..], "%Y-%m-%d")?, @@ -782,7 +781,7 @@ impl TryFrom for field_type::FieldType { DataType::Time(_) => Ok(field_type::FieldType::Time), DataType::DateTime(_) => Ok(field_type::FieldType::Datetime), DataType::Duration(_) => todo!(), - DataType::Id(i) => Ok(field_type::FieldType::String), + DataType::Id(_) => Ok(field_type::FieldType::String), DataType::Function(_) => todo!(), DataType::Any => todo!(), } diff --git a/src/io/mssql.rs b/src/io/mssql.rs index 3a15fb15..347aa5a7 100644 --- a/src/io/mssql.rs +++ b/src/io/mssql.rs @@ -5,8 +5,8 @@ use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED}; use crate::{ data_type::{ generator::Generator, - value::{self, Value, Variant}, - DataTyped, List, + value::{self, Value}, + DataTyped, }, namer, relation::{Schema, Table, TableBuilder, Variant as _}, @@ -17,11 +17,11 @@ use rand::{rngs::StdRng, SeedableRng}; use sqlx::{ self, mssql::{ - self, Mssql, MssqlArguments, MssqlConnectOptions, MssqlPoolOptions, MssqlQueryResult, + self, Mssql, MssqlArguments, MssqlPoolOptions, MssqlQueryResult, MssqlRow, MssqlValueRef, }, query::Query, - Connection, Decode, Encode, MssqlConnection, MssqlPool, Pool, Row, Type, TypeInfo, + Decode, Encode, MssqlPool, Pool, Row, Type, TypeInfo, ValueRef as _, }; use std::{ @@ -44,8 +44,7 @@ impl From for Error { pub struct Database { name: String, tables: Vec, - pool: MssqlPool, - drop: bool, + pool: MssqlPool } pub static MSSQL_POOL: Mutex>> = Mutex::new(None); @@ -53,9 +52,9 @@ pub static MSSQL_POOL: Mutex>> = Mutex::new(None); pub static MSSQL_CONTAINER: Mutex = Mutex::new(false); impl Database { - fn db() -> String { - env::var("MSSQL_DB").unwrap_or(DB.into()) - } + // fn db() -> String { + // env::var("MSSQL_DB").unwrap_or(DB.into()) + // } fn port() -> u16 { match env::var("MSSQL_PORT") { @@ -307,16 +306,14 @@ impl DatabaseTrait for Database { Database { name, tables: vec![], - pool, - drop: false, + pool } .with_tables(tables_to_be_created) } else { Ok(Database { name, tables, - pool, - drop: false, + pool }) } } @@ -555,9 +552,9 @@ impl Encode<'_, mssql::Mssql> for SqlValue { fn produces(&self) -> Option<::TypeInfo> { match self { - SqlValue::Boolean(b) => Some(>::type_info()), - SqlValue::Integer(i) => Some(>::type_info()), - SqlValue::Float(f) => Some(>::type_info()), + SqlValue::Boolean(_) => Some(>::type_info()), + SqlValue::Integer(_) => Some(>::type_info()), + SqlValue::Float(_) => Some(>::type_info()), SqlValue::Text(t) => >::produces(t.deref()), SqlValue::Optional(o) => { let value = o.clone().map(|v| v.as_ref().clone()); @@ -577,7 +574,7 @@ impl Type for SqlValue { >::type_info() } - fn compatible(ty: &::TypeInfo) -> bool { + fn compatible(_ty: &::TypeInfo) -> bool { true } } diff --git a/src/io/postgresql.rs b/src/io/postgresql.rs index b480db56..5fcc8f12 100644 --- a/src/io/postgresql.rs +++ b/src/io/postgresql.rs @@ -49,9 +49,9 @@ pub static POSTGRES_POOL: Mutex>>> pub static POSTGRES_CONTAINER: Mutex = Mutex::new(false); impl Database { - fn db() -> String { - env::var("POSTGRES_DB").unwrap_or(DB.into()) - } + // fn db() -> String { + // env::var("POSTGRES_DB").unwrap_or(DB.into()) + // } fn port() -> usize { match env::var("POSTGRES_PORT") { diff --git a/src/io/sqlite.rs b/src/io/sqlite.rs index 3081892b..62a5ccfd 100644 --- a/src/io/sqlite.rs +++ b/src/io/sqlite.rs @@ -122,7 +122,7 @@ impl FromSql for Value { rusqlite::types::ValueRef::Integer(i) => Value::integer(i), rusqlite::types::ValueRef::Real(f) => Value::float(f), rusqlite::types::ValueRef::Text(s) => Value::text(String::from_utf8_lossy(s)), - rusqlite::types::ValueRef::Blob(b) => Value::bytes(b.clone()), + rusqlite::types::ValueRef::Blob(b) => Value::bytes(b), }) } } From 0bd26258ebf49513203d016021cfc50f7115ab91 Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 18:09:49 +0200 Subject: [PATCH 5/8] cargo fmt --- src/data_type/function.rs | 14 ++++++++++---- src/data_type/generator.rs | 16 ++++++++++++---- src/dialect_translation/postgresql.rs | 10 ++++++++-- src/differential_privacy/aggregates.rs | 10 +++++++++- src/io/bigquery.rs | 3 +-- src/io/mssql.rs | 16 +++++----------- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/data_type/function.rs b/src/data_type/function.rs index b2012223..ffe399f7 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -2731,7 +2731,7 @@ mod tests { super::{value::Value, Struct}, *, }; - use chrono::{self, NaiveDate, DateTime, NaiveTime}; + use chrono::{self, DateTime, NaiveDate, NaiveTime}; #[test] fn test_argument_conversion() { @@ -3842,11 +3842,15 @@ mod tests { let set: DataType = DataType::structured_from_data_types([ DataType::date_time_interval( DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), - DateTime::from_timestamp(1862921288, 111110).unwrap().naive_utc(), + DateTime::from_timestamp(1862921288, 111110) + .unwrap() + .naive_utc(), ), DataType::date_time_interval( DateTime::from_timestamp(1362921288, 0).unwrap().naive_utc(), - DateTime::from_timestamp(2062921288, 111110).unwrap().naive_utc(), + DateTime::from_timestamp(2062921288, 111110) + .unwrap() + .naive_utc(), ), ]); let im = fun.super_image(&set).unwrap(); @@ -3855,7 +3859,9 @@ mod tests { im, DataType::date_time_interval( DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), - DateTime::from_timestamp(2062921288, 111110).unwrap().naive_utc() + DateTime::from_timestamp(2062921288, 111110) + .unwrap() + .naive_utc() ), ); } diff --git a/src/data_type/generator.rs b/src/data_type/generator.rs index 722ccfcf..0044441e 100644 --- a/src/data_type/generator.rs +++ b/src/data_type/generator.rs @@ -225,8 +225,12 @@ mod tests { chrono::NaiveDateTime::generate_between( &mut rng, &[ - chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), - chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc() + chrono::DateTime::from_timestamp(1662921288, 0) + .unwrap() + .naive_utc(), + chrono::DateTime::from_timestamp(1662921288, 0) + .unwrap() + .naive_utc() ] ) ); @@ -235,8 +239,12 @@ mod tests { chrono::NaiveDateTime::generate_between( &mut rng, &[ - chrono::DateTime::from_timestamp(1662921288, 0).unwrap().naive_utc(), - chrono::DateTime::from_timestamp(1693921288, 0).unwrap().naive_utc() + chrono::DateTime::from_timestamp(1662921288, 0) + .unwrap() + .naive_utc(), + chrono::DateTime::from_timestamp(1693921288, 0) + .unwrap() + .naive_utc() ] ) ); diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 3c7e6372..53cde7a2 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -93,11 +93,17 @@ impl QueryToRelationTranslator for PostgreSqlTranslator { mod tests { use super::*; + use crate::sql::Result; use crate::{ - builder::Ready, data_type::DataType, dialect_translation::RelationWithTranslator, hierarchy::Hierarchy, io::{postgresql, Database as _}, relation::{schema::Schema, Relation}, sql::{parse_with_dialect, relation::QueryWithRelations} + builder::Ready, + data_type::DataType, + dialect_translation::RelationWithTranslator, + hierarchy::Hierarchy, + io::{postgresql, Database as _}, + relation::{schema::Schema, Relation}, + sql::{parse_with_dialect, relation::QueryWithRelations}, }; use std::sync::Arc; - use crate::sql::Result; fn assert_same_query_str(query_1: &str, query_2: &str) { let a_no_whitespace: String = query_1.chars().filter(|c| !c.is_whitespace()).collect(); diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 3ee0b547..3a100864 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -531,7 +531,15 @@ impl Reduce { mod tests { use super::*; use crate::{ - ast, builder::With, data_type::Variant, display::Dot, io::{postgresql, Database}, privacy_unit_tracking::{PrivacyUnit, PrivacyUnitTracking, Strategy}, relation::{Schema, Variant as _}, sql::parse, DataType, Relation + ast, + builder::With, + data_type::Variant, + display::Dot, + io::{postgresql, Database}, + privacy_unit_tracking::{PrivacyUnit, PrivacyUnitTracking, Strategy}, + relation::{Schema, Variant as _}, + sql::parse, + DataType, Relation, }; use std::{ops::Deref, sync::Arc}; diff --git a/src/io/bigquery.rs b/src/io/bigquery.rs index bd06b119..9b057b35 100644 --- a/src/io/bigquery.rs +++ b/src/io/bigquery.rs @@ -718,8 +718,7 @@ impl TryFrom<(Option, field_type::FieldType)> for SqlValue { let timestamp: f64 = val_as_str.parse()?; let seconds = timestamp as i64; // Whole seconds part let nanoseconds = ((timestamp - seconds as f64) * 1_000_000_000.0) as u32; // Fractional part in nanoseconds - let datetime = - chrono::DateTime::from_timestamp(seconds, nanoseconds).unwrap(); + let datetime = chrono::DateTime::from_timestamp(seconds, nanoseconds).unwrap(); value::Value::date_time(datetime.naive_utc()).try_into() } field_type::FieldType::Date => value::Value::date( diff --git a/src/io/mssql.rs b/src/io/mssql.rs index 347aa5a7..0f26d4bf 100644 --- a/src/io/mssql.rs +++ b/src/io/mssql.rs @@ -17,12 +17,10 @@ use rand::{rngs::StdRng, SeedableRng}; use sqlx::{ self, mssql::{ - self, Mssql, MssqlArguments, MssqlPoolOptions, MssqlQueryResult, - MssqlRow, MssqlValueRef, + self, Mssql, MssqlArguments, MssqlPoolOptions, MssqlQueryResult, MssqlRow, MssqlValueRef, }, query::Query, - Decode, Encode, MssqlPool, Pool, Row, Type, TypeInfo, - ValueRef as _, + Decode, Encode, MssqlPool, Pool, Row, Type, TypeInfo, ValueRef as _, }; use std::{ env, fmt, ops::Deref, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time, @@ -44,7 +42,7 @@ impl From for Error { pub struct Database { name: String, tables: Vec
, - pool: MssqlPool + pool: MssqlPool, } pub static MSSQL_POOL: Mutex>> = Mutex::new(None); @@ -306,15 +304,11 @@ impl DatabaseTrait for Database { Database { name, tables: vec![], - pool + pool, } .with_tables(tables_to_be_created) } else { - Ok(Database { - name, - tables, - pool - }) + Ok(Database { name, tables, pool }) } } From 9549943a485861cf060d06306e677676f19d1026 Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 18:23:24 +0200 Subject: [PATCH 6/8] add unwrap to display dot --- src/relation/builder.rs | 4 ++-- src/relation/dot.rs | 12 ++++++------ src/relation/rewriting.rs | 34 +++++++++++++++++----------------- src/sql/relation.rs | 2 +- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/relation/builder.rs b/src/relation/builder.rs index 14790653..e271dbfb 100644 --- a/src/relation/builder.rs +++ b/src/relation/builder.rs @@ -1179,7 +1179,7 @@ mod tests { .on_eq("d", "x") .and(Expr::lt(Expr::col("a"), Expr::col("x"))) .build(); - join.display_dot(); + join.display_dot().unwrap(); println!("Join = {join}"); let query = &ast::Query::from(&join).to_string(); println!( @@ -1223,7 +1223,7 @@ mod tests { .left_names(vec!["a1", "b1"]) //.on_iter(vec![Expr::eq(Expr::col("a"), Expr::col("c")), Expr::eq(Expr::col("b"), Expr::col("d"))]) .build(); - join.display_dot(); + join.display_dot().unwrap(); } #[test] diff --git a/src/relation/dot.rs b/src/relation/dot.rs index 4f029c29..5a09f5ce 100644 --- a/src/relation/dot.rs +++ b/src/relation/dot.rs @@ -398,7 +398,7 @@ mod tests { .right(map_2.clone()) .build(); println!("join_2 = {}", join_2); - join_2.display_dot(); + join_2.display_dot().unwrap(); } #[test] @@ -428,7 +428,7 @@ mod tests { )) .input(table.clone()) .build(); - map.display_dot(); + map.display_dot().unwrap(); } #[test] @@ -454,7 +454,7 @@ mod tests { .left(left) .right(right) .build(); - join.display_dot(); + join.display_dot().unwrap(); } #[test] @@ -481,18 +481,18 @@ mod tests { .with_group_by_column("a") .with(Expr::sum(Expr::col("b"))) .build(); - reduce.display_dot(); + reduce.display_dot().unwrap(); } #[test] fn test_display_values() { let values: Relation = Relation::values().name("Float").values(vec![5.]).build(); - values.display_dot(); + values.display_dot().unwrap(); let values: Relation = Relation::values() .name("List_of_floats") .values(vec![Value::float(10.), Value::float(4.0)]) .build(); - values.display_dot(); + values.display_dot().unwrap(); } } diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 878d779e..04728471 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -1408,7 +1408,7 @@ mod tests { .with_group_by_column("item") .with_group_by_column("order_id") .build(); - my_relation.display_dot(); + my_relation.display_dot().unwrap(); let renamed_relation = my_relation.clone().rename_fields(|n, _| { if n == "sum_price" { @@ -1419,7 +1419,7 @@ mod tests { "unknown".to_string() } }); - renamed_relation.display_dot(); + renamed_relation.display_dot().unwrap(); } #[test] @@ -1442,7 +1442,7 @@ mod tests { ); println!("{}", filtering_expr); let filtered_relation = relation.filter(filtering_expr); - _ = filtered_relation.display_dot(); + _ = filtered_relation.display_dot().unwrap(); assert_eq!( filtered_relation .schema() @@ -1484,7 +1484,7 @@ mod tests { Expr::gt(Expr::col("a"), Expr::val(5.)), Expr::lt(Expr::col("b"), Expr::val(0.5)), )); - _ = filtered_relation.display_dot(); + _ = filtered_relation.display_dot().unwrap(); assert_eq!( filtered_relation.schema().field("a").unwrap().data_type(), DataType::float_interval(5., 10.) @@ -1519,7 +1519,7 @@ mod tests { Expr::gt(Expr::col("a"), Expr::val(5.)), Expr::lt(Expr::col("sum_d"), Expr::val(15)), )); - _ = filtered_relation.display_dot(); + _ = filtered_relation.display_dot().unwrap(); assert_eq!( filtered_relation.schema().field("a").unwrap().data_type(), DataType::float_interval(5., 10.) @@ -1742,7 +1742,7 @@ mod tests { // Without group by let unique_rel = table.unique(&["a", "b"]); println!("{}", unique_rel); - _ = unique_rel.display_dot(); + _ = unique_rel.display_dot().unwrap(); } #[test] @@ -1766,7 +1766,7 @@ mod tests { ]; let rel = table.clone().ordered_reduce(grouping_exprs, aggregates); println!("{}", rel); - _ = rel.display_dot(); + _ = rel.display_dot().unwrap(); // With group by let grouping_exprs = vec![Expr::col("c")]; @@ -1776,7 +1776,7 @@ mod tests { ]; let rel = table.ordered_reduce(grouping_exprs, aggregates); println!("{}", rel); - _ = rel.display_dot(); + _ = rel.display_dot().unwrap(); } #[test] @@ -1803,7 +1803,7 @@ mod tests { .clone() .distinct_aggregates(column, group_by, aggregates); println!("{}", distinct_rel); - _ = distinct_rel.display_dot(); + _ = distinct_rel.display_dot().unwrap(); // With group by let column = "a"; @@ -1816,7 +1816,7 @@ mod tests { .clone() .distinct_aggregates(column, group_by, aggregates); println!("{}", distinct_rel); - _ = distinct_rel.display_dot(); + _ = distinct_rel.display_dot().unwrap(); } #[test] @@ -1834,7 +1834,7 @@ mod tests { // table let rel = table.public_values_column("b").unwrap(); let rel_values: Relation = Relation::values().name("b").values([1, 2, 5]).build(); - rel.display_dot(); + rel.display_dot().unwrap(); assert_eq!(rel, rel_values); assert!(table.public_values_column("a").is_err()); @@ -1846,7 +1846,7 @@ mod tests { .with(("exp_b", Expr::exp(Expr::col("b")))) .build(); let rel = map.public_values_column("exp_b").unwrap(); - rel.display_dot(); + rel.display_dot().unwrap(); assert!(map.public_values_column("exp_a").is_err()); } @@ -1863,7 +1863,7 @@ mod tests { ) .build(); let rel = table.public_values().unwrap(); - rel.display_dot(); + rel.display_dot().unwrap(); let table: Relation = Relation::table() .name("table") @@ -1894,7 +1894,7 @@ mod tests { .input(table) .build(); let rel = map.public_values().unwrap(); - rel.display_dot(); + rel.display_dot().unwrap(); // map let table: Relation = Relation::table() @@ -1917,7 +1917,7 @@ mod tests { .input(table) .build(); let rel = map.public_values().unwrap(); - rel.display_dot(); + rel.display_dot().unwrap(); } #[test] @@ -1943,7 +1943,7 @@ mod tests { .build(); let joined_rel = table_1.clone().cross_join(table_2.clone()).unwrap(); - joined_rel.display_dot(); + joined_rel.display_dot().unwrap(); } #[test] @@ -2097,7 +2097,7 @@ mod tests { .group_by(expr!(2 * c)) .build(); let distinct_relation = relation.clone().distinct(); - distinct_relation.display_dot(); + distinct_relation.display_dot().unwrap(); assert_eq!(distinct_relation.schema(), relation.schema()); assert!(matches!(distinct_relation, Relation::Reduce(_))); if let Relation::Reduce(red) = distinct_relation { diff --git a/src/sql/relation.rs b/src/sql/relation.rs index bedfb053..f202a02b 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -1073,7 +1073,7 @@ mod tests { println!("relation = {relation}"); let q = ast::Query::from(&relation); println!("query = {q}"); - relation.display_dot(); + relation.display_dot().unwrap(); } #[test] From 527e2e0b31faa2f8e122cf9c5bc908fbfb09658d Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 18:32:55 +0200 Subject: [PATCH 7/8] tests reformatted --- examples/website.rs | 1 + src/dialect_translation/postgresql.rs | 10 ++++---- src/expr/dot.rs | 6 ++--- src/relation/rewriting.rs | 36 +++++++++++++-------------- src/sampling_adjustment/mod.rs | 10 ++------ 5 files changed, 29 insertions(+), 34 deletions(-) diff --git a/examples/website.rs b/examples/website.rs index d75234d7..8c8e7f45 100644 --- a/examples/website.rs +++ b/examples/website.rs @@ -1,3 +1,4 @@ +#[allow(unused)] fn rewrite() { use qrlew::ast::Query; use qrlew::display::Dot; diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 53cde7a2..80196b82 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -105,11 +105,11 @@ mod tests { }; use std::sync::Arc; - fn assert_same_query_str(query_1: &str, query_2: &str) { - let a_no_whitespace: String = query_1.chars().filter(|c| !c.is_whitespace()).collect(); - let b_no_whitespace: String = query_2.chars().filter(|c| !c.is_whitespace()).collect(); - assert_eq!(a_no_whitespace, b_no_whitespace); - } + // fn assert_same_query_str(query_1: &str, query_2: &str) { + // let a_no_whitespace: String = query_1.chars().filter(|c| !c.is_whitespace()).collect(); + // let b_no_whitespace: String = query_2.chars().filter(|c| !c.is_whitespace()).collect(); + // assert_eq!(a_no_whitespace, b_no_whitespace); + // } #[test] fn test_query() -> Result<()> { diff --git a/src/expr/dot.rs b/src/expr/dot.rs index 856e4d7e..8a4b86ab 100644 --- a/src/expr/dot.rs +++ b/src/expr/dot.rs @@ -275,9 +275,9 @@ mod tests { ("c", Value::float(3.)), ("d", Value::integer(4)), ]); - &expr! { a*b+d }.with(val.clone()).display_dot().unwrap(); - &expr! { d+a*b }.with(val.clone()).display_dot().unwrap(); - &expr! { (a*b+d) }.with(val).display_dot().unwrap(); + let _ = &expr! { a*b+d }.with(val.clone()).display_dot().unwrap(); + let _ = &expr! { d+a*b }.with(val.clone()).display_dot().unwrap(); + let _ = &expr! { (a*b+d) }.with(val).display_dot().unwrap(); } #[test] diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 04728471..3ff0a620 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -901,7 +901,7 @@ mod tests { use super::*; use crate::{ ast, - data_type::{value::List, DataType, DataTyped}, + data_type::{DataType, DataTyped}, display::Dot, expr::AggregateColumn, io::{postgresql, Database}, @@ -960,21 +960,21 @@ mod tests { assert!(relation.schema()[0].name() != "peid"); } - fn refacto_results(results: Vec, size: usize) -> Vec> { - let mut sorted_results: Vec> = vec![]; - for row in results { - let mut str_row = vec![]; - for i in 0..size { - str_row.push(match row[i].to_string().parse::() { - Ok(f) => ((f * 1000.).round() / 1000.).to_string(), - Err(_) => row[i].to_string(), - }) - } - sorted_results.push(str_row) - } - sorted_results.sort(); - sorted_results - } + // fn refacto_results(results: Vec, size: usize) -> Vec> { + // let mut sorted_results: Vec> = vec![]; + // for row in results { + // let mut str_row = vec![]; + // for i in 0..size { + // str_row.push(match row[i].to_string().parse::() { + // Ok(f) => ((f * 1000.).round() / 1000.).to_string(), + // Err(_) => row[i].to_string(), + // }) + // } + // sorted_results.push(str_row) + // } + // sorted_results.sort(); + // sorted_results + // } #[test] fn test_sums_by_group() { @@ -1057,12 +1057,12 @@ mod tests { } // group by and aggregates have the same argument - let mut relation = relations + let relation = relations .get(&["user_table".into()]) .unwrap() .as_ref() .clone(); - relation = relation.l1_norms("id", &vec!["age"], &vec!["age"]); + relation.l1_norms("id", &vec!["age"], &vec!["age"]); } #[test] diff --git a/src/sampling_adjustment/mod.rs b/src/sampling_adjustment/mod.rs index 6dfc0f98..0214e4ee 100644 --- a/src/sampling_adjustment/mod.rs +++ b/src/sampling_adjustment/mod.rs @@ -529,6 +529,7 @@ impl Relation { } } +#[cfg(feature = "tested_sampling_adjustment")] #[cfg(test)] mod tests { use super::*; @@ -539,8 +540,7 @@ mod tests { }; use colored::Colorize; - - #[cfg(feature = "tested_sampling_adjustment")] + #[test] fn test_uniform_poisson_sampling() { let mut database = postgresql::test_database(); @@ -611,7 +611,6 @@ mod tests { final_map.display_dot().unwrap(); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_differenciated_poisson_sampling() { let mut database = postgresql::test_database(); @@ -684,7 +683,6 @@ mod tests { final_map.display_dot().unwrap(); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_sampling_without_replacements() { let mut database = postgresql::test_database(); @@ -756,7 +754,6 @@ mod tests { final_map.display_dot().unwrap(); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_table_with_weight() { let mut database = postgresql::test_database(); @@ -798,7 +795,6 @@ mod tests { ); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_map_with_weight() { let mut database = postgresql::test_database(); @@ -838,7 +834,6 @@ mod tests { ); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_reduce_with_weight() { let mut database = postgresql::test_database(); @@ -888,7 +883,6 @@ mod tests { ); } - #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_join_with_weight() { let mut database = postgresql::test_database(); From 1a9afd14fd50e0aa311696063b8be532e0801ec7 Mon Sep 17 00:00:00 2001 From: Andi Cuko Date: Mon, 27 May 2024 18:41:01 +0200 Subject: [PATCH 8/8] all ok --- src/dialect_translation/mssql.rs | 9 +++------ src/io/bigquery.rs | 9 +++------ src/io/mssql.rs | 2 +- src/sampling_adjustment/mod.rs | 31 +++++++++++++++++-------------- src/sql/mod.rs | 2 -- 5 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/dialect_translation/mssql.rs b/src/dialect_translation/mssql.rs index 51861402..d83f0e3c 100644 --- a/src/dialect_translation/mssql.rs +++ b/src/dialect_translation/mssql.rs @@ -395,19 +395,16 @@ fn translate_data_type(dtype: DataType) -> ast::DataType { #[cfg(test)] #[cfg(feature = "mssql")] mod tests { - use sqlparser::dialect::GenericDialect; - use super::*; use crate::{ builder::{Ready, With}, - data_type::{DataType, Value as _}, + data_type::DataType, dialect_translation::RelationWithTranslator, - display::Dot, expr::Expr, io::{mssql, Database as _}, namer, - relation::{schema::Schema, Relation, Variant as _}, - sql::{parse, parse_expr, parse_with_dialect, relation::QueryWithRelations}, + relation::{schema::Schema, Relation}, + sql::parse, }; use std::sync::Arc; diff --git a/src/io/bigquery.rs b/src/io/bigquery.rs index 9b057b35..ce8b8631 100644 --- a/src/io/bigquery.rs +++ b/src/io/bigquery.rs @@ -833,18 +833,15 @@ pub fn test_database() -> Database { #[cfg(test)] mod tests { - use std::{collections::HashMap, fmt::format}; + use std::collections::HashMap; + use super::*; use gcp_bigquery_client::{ model::table_data_insert_all_request_rows::TableDataInsertAllRequestRows, table::ListOptions, }; use serde_json::json; - use crate::dialect_translation::bigquery::BigQueryTranslator; - - use super::*; - #[tokio::test] async fn test_table_list() { println!("Connecting to a mocked server"); @@ -961,7 +958,7 @@ mod tests { let timestamp = 1703273535.453880; let seconds = timestamp as i64; // Whole seconds part let nanoseconds = ((timestamp - seconds as f64) * 1_000_000_000.0) as u32; // Fractional part in nanoseconds - let datetime = chrono::NaiveDateTime::from_timestamp_opt(seconds, nanoseconds); + let datetime = chrono::DateTime::from_timestamp(seconds, nanoseconds).unwrap(); println!("Datetime: {:?}", datetime); } diff --git a/src/io/mssql.rs b/src/io/mssql.rs index 0f26d4bf..c804e09d 100644 --- a/src/io/mssql.rs +++ b/src/io/mssql.rs @@ -583,7 +583,7 @@ mod tests { use crate::{ relation::{Schema, TableBuilder}, - DataType, Ready as _, + DataType, }; use super::*; diff --git a/src/sampling_adjustment/mod.rs b/src/sampling_adjustment/mod.rs index 0214e4ee..0988bdb1 100644 --- a/src/sampling_adjustment/mod.rs +++ b/src/sampling_adjustment/mod.rs @@ -537,13 +537,16 @@ mod tests { ast, display::Dot, io::{postgresql, Database}, + namer, + sql::parse, }; use colored::Colorize; - + use itertools::Itertools as _; + #[test] fn test_uniform_poisson_sampling() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let proba = 0.1; let relations = database.relations(); let relation = Relation::try_from( @@ -595,7 +598,7 @@ mod tests { let exprs: Vec<(&str, Expr)> = join .schema() .iter() - .map(|f| (f.name().clone(), Expr::col(f.name()))) + .map(|f| (f.name(), Expr::col(f.name()))) .collect(); let final_map: Relation = Relation::map() @@ -613,7 +616,7 @@ mod tests { #[test] fn test_differenciated_poisson_sampling() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let tables_and_proba: Vec<(Vec, f64)> = vec![ (vec!["order_table".to_string()], 0.1), (vec!["item_table".to_string()], 0.5), @@ -667,7 +670,7 @@ mod tests { let exprs: Vec<(&str, Expr)> = join .schema() .iter() - .map(|f| (f.name().clone(), Expr::col(f.name()))) + .map(|f| (f.name(), Expr::col(f.name()))) .collect(); let final_map: Relation = Relation::map() @@ -685,7 +688,7 @@ mod tests { #[test] fn test_sampling_without_replacements() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let proba = 0.1; let relations = database.relations(); namer::reset(); @@ -738,7 +741,7 @@ mod tests { let exprs: Vec<(&str, Expr)> = join .schema() .iter() - .map(|f| (f.name().clone(), Expr::col(f.name()))) + .map(|f| (f.name(), Expr::col(f.name()))) .collect(); let final_map: Relation = Relation::map() @@ -1023,7 +1026,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_simple_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); let query = "SELECT COUNT(order_id), SUM(price), AVG(price) FROM item_table"; @@ -1038,7 +1041,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_join_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); let query = "SELECT COUNT(id), SUM(price), AVG(price) FROM order_table JOIN item_table ON id=order_id"; @@ -1053,7 +1056,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_reduce_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); let query = " @@ -1085,7 +1088,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_reduce_join_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); // bug with USING (col) @@ -1106,7 +1109,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_join_reduce_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); // 2 reduce after the join @@ -1135,7 +1138,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_reduce_reduce_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); let weight: f64 = 2.0; @@ -1173,7 +1176,7 @@ mod tests { #[cfg(feature = "tested_sampling_adjustment")] #[test] fn test_adjustment_reduce_reduce_join_reduce() { - let mut database = postgresql::test_database(); + let database = postgresql::test_database(); let relations: Hierarchy> = database.relations(); let weight: f64 = 2.0; diff --git a/src/sql/mod.rs b/src/sql/mod.rs index beb03e8c..38dd0a7f 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -105,8 +105,6 @@ pub use relation::{parse, parse_with_dialect}; #[cfg(test)] mod tests { use super::*; - #[cfg(feature = "sqlite")] - use crate::io::sqlite; use crate::{ ast, builder::With,