Skip to content

Commit

Permalink
Merge pull request #1743 from SeaQL/find_with_linked
Browse files Browse the repository at this point in the history
Find with linked (cont'd)
  • Loading branch information
tyt2y3 authored Jul 7, 2023
2 parents acca634 + 6dc19d9 commit f3fb347
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ tracing = { version = "0.1", default-features = false, features = ["attributes",
rust_decimal = { version = "1", default-features = false, optional = true }
bigdecimal = { version = "0.3", default-features = false, optional = true }
sea-orm-macros = { version = "0.12.0-rc.3", path = "sea-orm-macros", default-features = false, features = ["strum"] }
sea-query = { version = "0.29.0-rc.2", features = ["thread-safe"] }
sea-query = { version = "0.29.0-rc.2", features = ["thread-safe", "hashable-value"] }
sea-query-binder = { version = "0.4.0-rc.2", default-features = false, optional = true }
strum = { version = "0.24", default-features = false }
serde = { version = "1.0", default-features = false }
Expand Down
86 changes: 83 additions & 3 deletions src/executor/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::{
SelectB, SelectTwo, SelectTwoMany, Statement, StreamTrait, TryGetableMany,
};
use futures::{Stream, TryStreamExt};
use sea_query::SelectStatement;
use std::marker::PhantomData;
use std::pin::Pin;
use sea_query::{SelectStatement, Value};
use std::collections::HashMap;
use std::{hash::Hash, marker::PhantomData, pin::Pin};

#[cfg(feature = "with-json")]
use crate::JsonValue;
Expand Down Expand Up @@ -993,6 +993,86 @@ where
fn consolidate_query_result<L, R>(
rows: Vec<(L::Model, Option<R::Model>)>,
) -> Vec<(L::Model, Vec<R::Model>)>
where
L: EntityTrait,
R: EntityTrait,
{
// This is a strong point to consider adding a trait associated constant
// to PrimaryKeyTrait to indicate the arity
let pkcol: Vec<_> = <L::PrimaryKey as Iterable>::iter()
.map(|pk| pk.into_column())
.collect();
if pkcol.len() == 1 {
consolidate_query_result_of::<L, R, UnitPk<L>>(rows, UnitPk(pkcol[0]))
} else {
consolidate_query_result_of::<L, R, TuplePk<L>>(rows, TuplePk(pkcol))
}
}

trait ModelKey<E: EntityTrait> {
type Type: Hash + PartialEq + Eq;
fn get(&self, model: &E::Model) -> Self::Type;
}

// This could have been an array of [E::Column; <E::PrimaryKey as PrimaryKeyTrait>::ARITY]
struct UnitPk<E: EntityTrait>(E::Column);
struct TuplePk<E: EntityTrait>(Vec<E::Column>);

impl<E: EntityTrait> ModelKey<E> for UnitPk<E> {
type Type = Value;
fn get(&self, model: &E::Model) -> Self::Type {
model.get(self.0)
}
}

impl<E: EntityTrait> ModelKey<E> for TuplePk<E> {
type Type = Vec<Value>;
fn get(&self, model: &E::Model) -> Self::Type {
let mut key = Vec::new();
for col in self.0.iter() {
key.push(model.get(*col));
}
key
}
}

fn consolidate_query_result_of<L, R, KEY: ModelKey<L>>(
mut rows: Vec<(L::Model, Option<R::Model>)>,
model_key: KEY,
) -> Vec<(L::Model, Vec<R::Model>)>
where
L: EntityTrait,
R: EntityTrait,
{
let mut hashmap: HashMap<KEY::Type, Vec<R::Model>> =
rows.iter_mut().fold(HashMap::new(), |mut acc, row| {
let key = model_key.get(&row.0);
if let Some(value) = row.1.take() {
let vec: Option<&mut Vec<R::Model>> = acc.get_mut(&key);
if let Some(vec) = vec {
vec.push(value)
} else {
acc.insert(key, vec![value]);
}
}

acc
});

rows.into_iter()
.filter_map(|(l_model, _)| {
let l_pk = model_key.get(&l_model);
let r_models = hashmap.remove(&l_pk);
r_models.map(|r_models| (l_model, r_models))
})
.collect()
}

/// This is the legacy consolidate algorithm. Kept for reference
#[allow(dead_code)]
fn consolidate_query_result_of_ordered_rows<L, R>(
rows: Vec<(L::Model, Option<R::Model>)>,
) -> Vec<(L::Model, Vec<R::Model>)>
where
L: EntityTrait,
R: EntityTrait,
Expand Down
8 changes: 6 additions & 2 deletions src/query/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,16 @@ where
F: EntityTrait,
{
pub(crate) fn new(query: SelectStatement) -> Self {
Self::new_without_prepare(query)
.prepare_select()
.prepare_order_by()
}

pub(crate) fn new_without_prepare(query: SelectStatement) -> Self {
Self {
query,
entity: PhantomData,
}
.prepare_select()
.prepare_order_by()
}

fn prepare_select(mut self) -> Self {
Expand Down
46 changes: 46 additions & 0 deletions src/query/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,52 @@ where
}
select_two
}

/// Left Join with a Linked Entity and select Entity as a `Vec`.
pub fn find_with_linked<L, T>(self, l: L) -> SelectTwoMany<E, T>
where
L: Linked<FromEntity = E, ToEntity = T>,
T: EntityTrait,
{
let mut slf = self;
for (i, mut rel) in l.link().into_iter().enumerate() {
let to_tbl = Alias::new(format!("r{i}")).into_iden();
let from_tbl = if i > 0 {
Alias::new(format!("r{}", i - 1)).into_iden()
} else {
unpack_table_ref(&rel.from_tbl)
};
let table_ref = rel.to_tbl;

let mut condition = Condition::all().add(join_tbl_on_condition(
SeaRc::clone(&from_tbl),
SeaRc::clone(&to_tbl),
rel.from_col,
rel.to_col,
));
if let Some(f) = rel.on_condition.take() {
condition = condition.add(f(SeaRc::clone(&from_tbl), SeaRc::clone(&to_tbl)));
}

slf.query()
.join_as(JoinType::LeftJoin, table_ref, to_tbl, condition);
}
slf = slf.apply_alias(SelectA.as_str());
let mut select_two_many = SelectTwoMany::new_without_prepare(slf.query);
for col in <T::Column as Iterable>::iter() {
let alias = format!("{}{}", SelectB.as_str(), col.as_str());
let expr = Expr::col((
Alias::new(format!("r{}", l.link().len() - 1)).into_iden(),
col.into_iden(),
));
select_two_many.query().expr(SelectExpr {
expr: col.select_as(expr),
alias: Some(SeaRc::new(Alias::new(alias))),
window: None,
});
}
select_two_many
}
}

#[cfg(test)]
Expand Down
80 changes: 80 additions & 0 deletions tests/relational_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod common;

pub use chrono::offset::Utc;
pub use common::{bakery_chain::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
pub use rust_decimal::prelude::*;
pub use rust_decimal_macros::dec;
use sea_orm::{entity::*, query::*, DbErr, DerivePartialModel, FromQueryResult};
Expand Down Expand Up @@ -747,6 +748,85 @@ pub async fn linked() -> Result<(), DbErr> {
}]
);

let select_baker_with_customer = Baker::find()
.find_with_linked(baker::BakedForCustomer)
.order_by_asc(baker::Column::Id)
.order_by_asc(Expr::col((Alias::new("r4"), customer::Column::Id)));

assert_eq!(
select_baker_with_customer
.build(sea_orm::DatabaseBackend::MySql)
.to_string(),
[
// FIXME: This might be faulty!
"SELECT `baker`.`id` AS `A_id`,",
"`baker`.`name` AS `A_name`,",
"`baker`.`contact_details` AS `A_contact_details`,",
"`baker`.`bakery_id` AS `A_bakery_id`,",
"`r4`.`id` AS `B_id`,",
"`r4`.`name` AS `B_name`,",
"`r4`.`notes` AS `B_notes`",
"FROM `baker`",
"LEFT JOIN `cakes_bakers` AS `r0` ON `baker`.`id` = `r0`.`baker_id`",
"LEFT JOIN `cake` AS `r1` ON `r0`.`cake_id` = `r1`.`id`",
"LEFT JOIN `lineitem` AS `r2` ON `r1`.`id` = `r2`.`cake_id`",
"LEFT JOIN `order` AS `r3` ON `r2`.`order_id` = `r3`.`id`",
"LEFT JOIN `customer` AS `r4` ON `r3`.`customer_id` = `r4`.`id`",
"ORDER BY `baker`.`id` ASC, `r4`.`id` ASC"
]
.join(" ")
);

assert_eq!(
select_baker_with_customer.all(&ctx.db).await?,
[
(
baker::Model {
id: 1,
name: "Baker Bob".into(),
contact_details: serde_json::json!({
"mobile": "+61424000000",
"home": "0395555555",
"address": "12 Test St, Testville, Vic, Australia",
}),
bakery_id: Some(1),
},
vec![customer::Model {
id: 2,
name: "Kara".into(),
notes: Some("Loves all cakes".into()),
}]
),
(
baker::Model {
id: 2,
name: "Baker Bobby".into(),
contact_details: serde_json::json!({
"mobile": "+85212345678",
}),
bakery_id: Some(1),
},
vec![
customer::Model {
id: 1,
name: "Kate".into(),
notes: Some("Loves cheese cake".into()),
},
customer::Model {
id: 1,
name: "Kate".into(),
notes: Some("Loves cheese cake".into()),
},
customer::Model {
id: 2,
name: "Kara".into(),
notes: Some("Loves all cakes".into()),
},
]
),
]
);

ctx.delete().await;

Ok(())
Expand Down

0 comments on commit f3fb347

Please sign in to comment.