Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
victoria de sainte agathe committed Dec 15, 2023
1 parent 9a390ad commit 635fd18
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions src/relation/rewriting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ impl Relation {
/// This transform multiplies the values in `self` relation by their corresponding `scale_factors`.
/// `scale_factors` contains the entities scaling factors and the vectors columns
/// `self` contains the coordinates, the base and vectors columns
pub fn scale(self, entities: &str, values: Vec<&str>, scale_factors: Relation) -> Self {
pub fn scale(self, entities: &str, values: &[(&str, &str)], scale_factors: Relation) -> Self {
// TODO fix this
// Join the two relations on the entity column
let join: Relation = Relation::join()
Expand All @@ -459,14 +459,28 @@ impl Relation {
.left(self)
.right(scale_factors)
.build();
//Multiply the values by the factors
join.map_fields(|field_name, expr| {
if values.contains(&field_name) {
Expr::multiply(expr, Expr::col(format!("_SCALE_FACTOR_{}", field_name)))
} else {
expr
}
})
//TODO: Multiply the values by the factors
let fields = join.schema()
.iter()
.map(|field| (field.name(), Expr::col(field.name())))
.chain(
values.into_iter()
.map(|(name, col)| {
let field_name = join.schema().field(*col).unwrap().name();
(
*name,
Expr::multiply(
Expr::col(field_name),
Expr::col(format!("_SCALE_FACTOR_{}", field_name))
)
)
})
)
.collect::<Vec<_>>();
Relation::map()
.with_iter(fields)
.input(join)
.build()
}

/// For each coordinate, rescale the columns by 1 / greatest(1, norm_l2/C)
Expand All @@ -481,11 +495,11 @@ impl Relation {
value_clippings: Vec<(&str, &str, f64)>,
) -> Self {
let named_values = value_clippings.iter()
.map(|(s1, s2, _)| (s1.to_string(), format!("_CLIPPED_{}", s2)))
.map(|(s1, s2, _)| (format!("_CLIPPED_{}", s2), s1.to_string(), s2.to_string()))
.collect::<Vec<_>>();
// Arrange the values
let value_clippings: HashMap<&str, f64> = value_clippings.iter()
.map(|(_, s, f)| (*s, *f))
let value_clippings: HashMap<&str, (f64, &str)> = value_clippings.iter()
.map(|(s1, s2, f)| (*s2, (*f, *s1)))
.collect();
// Compute the norm
let norms = self.clone().l2_norms(
Expand All @@ -496,7 +510,7 @@ impl Relation {
// Compute the scaling factors
let scaling_factors = norms.map_fields(|field_name, expr| {
if value_clippings.contains_key(&field_name) {
let value_clipping = value_clippings[&field_name];
let (value_clipping, _) = value_clippings[&field_name];
if value_clipping == 0.0 {
Expr::val(value_clipping)
} else {
Expand All @@ -514,47 +528,18 @@ impl Relation {
});
let clipped_relation = self.clone().scale(
entities,
value_clippings.keys().cloned().collect(),
named_values.iter()
.map(|(s1, _, s2)| (s1.as_str(), s2.as_str()))
.collect::<Vec<_>>()
.as_slice(),
scaling_factors,
);
// Join the clipped relation with the `self` relation then filter fields.
// The resulting relation is composed by the fields of the `self` relation
// and their clipped values whose name is prefixed by `_CLIPPED_`
let left_names = self.fields()
.into_iter()
.map(|field| field.name())
.collect::<Vec<_>>();
let join: Relation = Relation::join()
.inner(Expr::val(true))
.on_eq(entities, entities)
.left_names(left_names.clone())
.right_names(
clipped_relation
.fields()
.into_iter()
.map(|field| format!("_CLIPPED_{}", field.name()))
.collect(),
)
.left(self.clone())
.right(clipped_relation)
.build();
let fields = left_names.into_iter()
.chain(named_values.iter().map(|(_, s)| s.as_str()))
.collect::<Vec<_>>();
let clipped_relation:Relation = Relation::map()
.with_iter(join.schema().iter().filter_map(|f| {
fields.contains(&f.name()).then_some((f.name(), Expr::col(f.name())))
}))
.input(join)
// the join has squared the size of the relation. force the size to be equal to the `self` relation size
.limit(*self.size().max().unwrap() as usize)
.build();
// Aggregate
clipped_relation.sums_by_group(
groups,
named_values.iter()
.map(|(s1, s2)| (s1.as_str(), s2.as_str()))
.collect()
.map(|(s1, s2, _)| (s2.as_str(), s1.as_str()))
.collect::<Vec<_>>()
)
}

Expand Down

0 comments on commit 635fd18

Please sign in to comment.