diff --git a/model/src/crud/create.rs b/model/src/crud/create.rs index 6315463..9292580 100644 --- a/model/src/crud/create.rs +++ b/model/src/crud/create.rs @@ -3,20 +3,26 @@ use sqlx::{postgres::PgRow, FromRow, PgConnection}; use crate::Error; use crate::{crud::util::build_query, FieldValue, Model}; -#[derive(Clone, Debug)] +use super::util::build_query_as; + +#[derive(Debug)] pub struct Create<'a, T: Model> { - value: &'a T, + value: &'a mut T, + idempotent: bool, } impl<'a, T> Create<'a, T> where T: Model + for<'b> FromRow<'b, PgRow> + Unpin + Sized + Send, { - pub(crate) fn new(value: &'a T) -> Self { - Self { value } + pub(crate) fn new(value: &'a mut T) -> Self { + Self { + value, + idempotent: false, + } } - pub async fn execute(&self, executor: &mut PgConnection) -> Result<(), Error> { + pub async fn execute(&mut self, executor: &mut PgConnection) -> Result<(), Error> { let table_name = T::table_name(); let fields = self.value.fields()?; @@ -33,13 +39,30 @@ where .collect::>() .join(", "); + let var_bindings: Vec = fields.into_iter().map(|(_, value)| value).collect(); + + if self.idempotent { + let primary_key_index = format!("{}_pkey", table_name); + + let statement = format!( + "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ON CONSTRAINT {} DO NOTHING RETURNING *", + table_name, columns, placeholder_values, primary_key_index + ); + + let created: T = build_query_as(&statement, var_bindings) + .fetch_one(executor) + .await?; + + *self.value = created; + + return Ok(()); + } + let statement = format!( "INSERT INTO {} ({}) VALUES ({})", table_name, columns, placeholder_values ); - let var_bindings: Vec = fields.into_iter().map(|(_, value)| value).collect(); - build_query(&statement, var_bindings) .execute(executor) .await?; diff --git a/model/src/crud/mod.rs b/model/src/crud/mod.rs index 22d56e0..fbf80cd 100644 --- a/model/src/crud/mod.rs +++ b/model/src/crud/mod.rs @@ -46,7 +46,7 @@ where DeleteAssociation::new(self, relation_name, associated_id) } - fn create<'a>(&'a self) -> Create<'a, Self> { + fn create<'a>(&'a mut self) -> Create<'a, Self> { Create::new(self) } diff --git a/model/src/sort_by/grammar.lalrpop b/model/src/sort_by/grammar.lalrpop index c8f2ef6..1f48f34 100644 --- a/model/src/sort_by/grammar.lalrpop +++ b/model/src/sort_by/grammar.lalrpop @@ -1,11 +1,6 @@ -use std::str::FromStr; -use uuid::Uuid; -use rust_decimal::Decimal; -use chrono::{DateTime, NaiveDate}; use lalrpop_util::ParseError; -use crate::filter::ast::{Expr, LogicOp, CompOp, Var}; -use crate::filter::util::apply_string_escapes; -use crate::{FieldType, FieldValue, ModelDef}; +use crate::filter::ast::Var; +use crate::{ModelDef}; grammar(model_def: &ModelDef); diff --git a/model/tests/crud.rs b/model/tests/crud.rs index 50d8d67..146565a 100644 --- a/model/tests/crud.rs +++ b/model/tests/crud.rs @@ -54,7 +54,7 @@ async fn test_create() { let mut records = records.into_iter(); - let record = records.next().unwrap(); + let mut record = records.next().unwrap(); record.create().execute(&mut tx).await.unwrap(); @@ -83,7 +83,7 @@ async fn test_upsert() { let mut records = records.into_iter(); - let record = records.next().unwrap(); + let mut record = records.next().unwrap(); record.create().execute(&mut tx).await.unwrap(); diff --git a/model/tests/pagination.rs b/model/tests/pagination.rs index f53bb46..f0d1a81 100644 --- a/model/tests/pagination.rs +++ b/model/tests/pagination.rs @@ -34,7 +34,7 @@ async fn setup_tables(tx: &mut Transaction<'_, Postgres>) { async fn insert_records(tx: &mut Transaction<'_, Postgres>) { let records = read_records(); - for record in records.into_iter() { + for mut record in records.into_iter() { record.create().execute(tx).await.unwrap(); } }