Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: member clauses handle struct fields & operators & recursive #2420

Merged
merged 9 commits into from
Sep 16, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 117 additions & 93 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use self::subscriptions::entity::EntityManager;
use self::subscriptions::event_message::EventMessageManager;
use self::subscriptions::model_diff::{ModelDiffRequest, StateDiffManager};
use crate::proto::types::clause::ClauseType;
use crate::proto::types::LogicalOperator;
use crate::proto::world::world_server::WorldServer;
use crate::proto::world::{
SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsResponse,
Expand Down Expand Up @@ -259,7 +260,6 @@ impl DojoWorld {
// total count of rows without limit and offset
let total_count: u32 =
sqlx::query_scalar(&count_query).fetch_optional(&self.pool).await?.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand Down Expand Up @@ -381,7 +381,6 @@ impl DojoWorld {
.fetch_optional(&self.pool)
.await?
.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand Down Expand Up @@ -525,15 +524,13 @@ impl DojoWorld {
"#,
compute_selector_from_names(namespace, model)
);

let models_result: Option<(String,)> =
sqlx::query_as(&models_query).fetch_optional(&self.pool).await?;
// we return an empty array of entities if the table is empty
if models_result.is_none() {
let models_str: Option<String> =
sqlx::query_scalar(&models_query).fetch_optional(&self.pool).await?;
if models_str.is_none() {
return Ok((Vec::new(), 0));
}

let (models_str,) = models_result.unwrap();
let models_str = models_str.unwrap();

let model_ids = models_str
.split(',')
Expand All @@ -543,8 +540,14 @@ impl DojoWorld {
let schemas =
self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

let table_name = member_clause.model;
let column_name = format!("external_{}", member_clause.member);
let model = member_clause.model.clone();
let parts: Vec<&str> = member_clause.member.split('.').collect();
let (table_name, column_name) = if parts.len() > 1 {
let nested_table = parts[..parts.len() - 1].join("$");
(format!("{model}${nested_table}"), format!("external_{}", parts.last().unwrap()))
} else {
(model, format!("external_{}", member_clause.member))
};
let (entity_query, arrays_queries, count_query) = build_sql_query(
&schemas,
table,
Expand All @@ -560,7 +563,6 @@ impl DojoWorld {
.fetch_optional(&self.pool)
.await?
.unwrap_or(0);

let db_entities = sqlx::query(&entity_query)
.bind(comparison_value.clone())
.bind(limit)
Expand All @@ -581,7 +583,7 @@ impl DojoWorld {
Ok((entities_collection, total_count))
}

async fn query_by_composite(
pub(crate) async fn query_by_composite(
&self,
table: &str,
model_relation_table: &str,
Expand All @@ -590,92 +592,17 @@ impl DojoWorld {
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// different types of clauses
let mut where_clauses = Vec::new();
let mut model_clauses: HashMap<String, Vec<(String, ComparisonOperator, String)>> =
HashMap::new();
let mut having_clauses = Vec::new();

// bind valeus for prepared statement
let mut bind_values = Vec::new();

for clause in composite.clauses {
match clause.clause_type.unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id)))
})
.collect::<Result<Vec<_>, Error>>()?;
where_clauses.push(format!("({})", ids.join(" OR ")));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(&keys)?;
where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'"));
}
ClauseType::Member(member) => {
let comparison_operator =
ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value: Primitive = member.value.unwrap().try_into()?;
let comparison_value = value.to_sql_value()?;

let column_name = format!("external_{}", member.member);

model_clauses.entry(member.model.clone()).or_default().push((
column_name,
comparison_operator,
comparison_value,
));

let (namespace, model) = member
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?;
let model_id: Felt = compute_selector_from_names(namespace, model);
having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id));
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}

let mut join_clauses = Vec::new();
for (model, clauses) in model_clauses {
let model_conditions = clauses
.into_iter()
.map(|(column, op, value)| {
bind_values.push(value);
format!("[{}].{} {} ?", model, column, op)
})
.collect::<Vec<_>>()
.join(" AND ");

join_clauses.push(format!(
"JOIN [{}] ON [{}].id = [{}].entity_id AND ({})",
model, table, model, model_conditions
));
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(" AND "))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(" AND "))
} else {
String::new()
};
let (where_clause, having_clause, join_clause, bind_values) =
build_composite_clause(table, model_relation_table, &composite)?;

let count_query = format!(
r#"
SELECT COUNT(DISTINCT [{table}].id)
FROM [{table}]
JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id
{join_clause}
{where_clause}
{having_clause}
"#
);

Expand All @@ -685,7 +612,6 @@ impl DojoWorld {
}

let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0);

if total_count == 0 {
return Ok((Vec::new(), 0));
}
Expand All @@ -705,7 +631,7 @@ impl DojoWorld {
);

let mut db_query = sqlx::query_as(&query);
for value in bind_values {
for value in &bind_values {
db_query = db_query.bind(value);
}
db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0));
Expand Down Expand Up @@ -1026,6 +952,104 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result<String, Error
Ok(keys_pattern)
}

// builds a composite clause for a query
fn build_composite_clause(
table: &str,
model_relation_table: &str,
composite: &proto::types::CompositeClause,
) -> Result<(String, String, String, Vec<String>), Error> {
let is_or = composite.operator == LogicalOperator::Or as i32;
let mut where_clauses = Vec::new();
let mut join_clauses = Vec::new();
let mut having_clauses = Vec::new();
let mut bind_values = Vec::new();

for clause in &composite.clauses {
match clause.clause_type.as_ref().unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
bind_values.push(Felt::from_bytes_be_slice(id).to_string());
"?".to_string()
})
.collect::<Vec<_>>()
.join(", ");
where_clauses.push(format!("{table}.id IN ({})", ids));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(keys)?;
bind_values.push(keys_pattern);
where_clauses.push(format!("{table}.keys REGEXP ?"));
}
ClauseType::Member(member) => {
let comparison_operator = ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value: Primitive = member.value.as_ref().unwrap().clone().try_into()?;
let comparison_value = value.to_sql_value()?;
bind_values.push(comparison_value);

let model = member.model.clone();
let parts: Vec<&str> = member.member.split('.').collect();
let (table_name, column_name) = if parts.len() > 1 {
let nested_table = parts[..parts.len() - 1].join("$");
(
format!("[{model}${nested_table}]"),
format!("external_{}", parts.last().unwrap()),
)
} else {
(format!("[{model}]"), format!("external_{}", member.member))
};

let (namespace, model) = member
.model
.split_once('-')
.ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?;
let model_id = compute_selector_from_names(namespace, model);
join_clauses.push(format!(
"LEFT JOIN {table_name} ON [{table}].id = {table_name}.entity_id"
));
where_clauses.push(format!("{table_name}.{column_name} {comparison_operator} ?"));
having_clauses.push(format!(
"INSTR(group_concat({model_relation_table}.model_id), '{:#x}') > 0",
model_id
));
}
ClauseType::Composite(nested_composite) => {
let (nested_where, nested_having, nested_join, nested_values) =
self.build_composite_clause(table, model_relation_table, nested_composite)?;
where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE ")));
if !nested_having.is_empty() {
having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string());
}
join_clauses.extend(
nested_join
.split_whitespace()
.filter(|&s| s.starts_with("LEFT"))
.map(String::from),
);
bind_values.extend(nested_values);
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " }))
} else {
String::new()
};

Ok((where_clause, having_clause, join_clause, bind_values))
}

type ServiceResult<T> = Result<Response<T>, Status>;
type SubscribeModelsResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeModelsResponse, Status>> + Send>>;
Expand Down
Loading