diff --git a/Cargo.toml b/Cargo.toml index 4244c3058..2345844cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ futures-util = { version = "^0.3" } log = { version = "^0.4", optional = true } rust_decimal = { version = "^1", optional = true } sea-orm-macros = { version = "^0.3.1", path = "sea-orm-macros", optional = true } -sea-query = { version = "^0.18.0", features = ["thread-safe"] } +sea-query = { version = "^0.18.0", git = "https://github.com/SeaQL/sea-query.git", features = ["thread-safe"] } sea-strum = { version = "^0.21", features = ["derive", "sea-orm"] } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } diff --git a/sea-orm-macros/src/derives/active_enum.rs b/sea-orm-macros/src/derives/active_enum.rs new file mode 100644 index 000000000..bafb19086 --- /dev/null +++ b/sea-orm-macros/src/derives/active_enum.rs @@ -0,0 +1,286 @@ +use heck::CamelCase; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::{punctuated::Punctuated, token::Comma, Lit, LitInt, LitStr, Meta}; + +enum Error { + InputNotEnum, + Syn(syn::Error), + TT(TokenStream), +} + +struct ActiveEnum { + ident: syn::Ident, + enum_name: String, + rs_type: TokenStream, + db_type: TokenStream, + is_string: bool, + variants: Vec, +} + +struct ActiveEnumVariant { + ident: syn::Ident, + string_value: Option, + num_value: Option, +} + +impl ActiveEnum { + fn new(input: syn::DeriveInput) -> Result { + let ident_span = input.ident.span(); + let ident = input.ident; + + let mut enum_name = ident.to_string().to_camel_case(); + let mut rs_type = Err(Error::TT(quote_spanned! { + ident_span => compile_error!("Missing macro attribute `rs_type`"); + })); + let mut db_type = Err(Error::TT(quote_spanned! { + ident_span => compile_error!("Missing macro attribute `db_type`"); + })); + for attr in input.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) { + for meta in list.iter() { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "rs_type" { + if let Lit::Str(litstr) = &nv.lit { + rs_type = syn::parse_str::(&litstr.value()) + .map_err(Error::Syn); + } + } else if name == "db_type" { + if let Lit::Str(litstr) = &nv.lit { + let s = litstr.value(); + match s.as_ref() { + "Enum" => { + db_type = Ok(quote! { + Enum(Self::name(), Self::values()) + }) + } + _ => { + db_type = syn::parse_str::(&s) + .map_err(Error::Syn); + } + } + } + } else if name == "enum_name" { + if let Lit::Str(litstr) = &nv.lit { + enum_name = litstr.value(); + } + } + } + } + } + } + } + + let variant_vec = match input.data { + syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, + _ => return Err(Error::InputNotEnum), + }; + + let mut is_string = false; + let mut is_int = false; + let mut variants = Vec::new(); + for variant in variant_vec { + let variant_span = variant.ident.span(); + let mut string_value = None; + let mut num_value = None; + for attr in variant.attrs.iter() { + if let Some(ident) = attr.path.get_ident() { + if ident != "sea_orm" { + continue; + } + } else { + continue; + } + if let Ok(list) = attr.parse_args_with(Punctuated::::parse_terminated) + { + for meta in list { + if let Meta::NameValue(nv) = meta { + if let Some(name) = nv.path.get_ident() { + if name == "string_value" { + if let Lit::Str(lit) = nv.lit { + is_string = true; + string_value = Some(lit); + } + } else if name == "num_value" { + if let Lit::Int(lit) = nv.lit { + is_int = true; + num_value = Some(lit); + } + } + } + } + } + } + } + + if is_string && is_int { + return Err(Error::TT(quote_spanned! { + ident_span => compile_error!("All enum variants should specify the same `*_value` macro attribute, either `string_value` or `num_value` but not both"); + })); + } + + if string_value.is_none() && num_value.is_none() { + return Err(Error::TT(quote_spanned! { + variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified"); + })); + } + + variants.push(ActiveEnumVariant { + ident: variant.ident, + string_value, + num_value, + }); + } + + Ok(ActiveEnum { + ident, + enum_name, + rs_type: rs_type?, + db_type: db_type?, + is_string, + variants, + }) + } + + fn expand(&self) -> syn::Result { + let expanded_impl_active_enum = self.impl_active_enum(); + + Ok(expanded_impl_active_enum) + } + + fn impl_active_enum(&self) -> TokenStream { + let Self { + ident, + enum_name, + rs_type, + db_type, + is_string, + variants, + } = self; + + let variant_idents: Vec = variants + .iter() + .map(|variant| variant.ident.clone()) + .collect(); + + let variant_values: Vec = variants + .iter() + .map(|variant| { + let variant_span = variant.ident.span(); + + if let Some(string_value) = &variant.string_value { + let string = string_value.value(); + quote! { #string } + } else if let Some(num_value) = &variant.num_value { + quote! { #num_value } + } else { + quote_spanned! { + variant_span => compile_error!("Missing macro attribute, either `string_value` or `num_value` should be specified"); + } + } + }) + .collect(); + + let val = if *is_string { + quote! { v.as_ref() } + } else { + quote! { v } + }; + + quote!( + #[automatically_derived] + impl sea_orm::ActiveEnum for #ident { + type Value = #rs_type; + + fn name() -> String { + #enum_name.to_owned() + } + + fn to_value(&self) -> Self::Value { + match self { + #( Self::#variant_idents => #variant_values, )* + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match #val { + #( #variant_values => Ok(Self::#variant_idents), )* + _ => Err(sea_orm::DbErr::Type(format!( + "unexpected value for {} enum: {}", + stringify!(#ident), + v + ))), + } + } + + fn db_type() -> sea_orm::ColumnDef { + sea_orm::ColumnType::#db_type.def() + } + } + + #[automatically_derived] + #[allow(clippy::from_over_into)] + impl Into for #ident { + fn into(self) -> sea_orm::sea_query::Value { + ::to_value(&self).into() + } + } + + #[automatically_derived] + impl sea_orm::TryGetable for #ident { + fn try_get(res: &sea_orm::QueryResult, pre: &str, col: &str) -> Result { + let value = <::Value as sea_orm::TryGetable>::try_get(res, pre, col)?; + ::try_from_value(&value).map_err(sea_orm::TryGetError::DbErr) + } + } + + #[automatically_derived] + impl sea_orm::sea_query::ValueType for #ident { + fn try_from(v: sea_orm::sea_query::Value) -> Result { + let value = <::Value as sea_orm::sea_query::ValueType>::try_from(v)?; + ::try_from_value(&value).map_err(|_| sea_orm::sea_query::ValueTypeErr) + } + + fn type_name() -> String { + <::Value as sea_orm::sea_query::ValueType>::type_name() + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + ::db_type() + .get_column_type() + .to_owned() + .into() + } + } + + #[automatically_derived] + impl sea_orm::sea_query::Nullable for #ident { + fn null() -> sea_orm::sea_query::Value { + <::Value as sea_orm::sea_query::Nullable>::null() + } + } + ) + } +} + +pub fn expand_derive_active_enum(input: syn::DeriveInput) -> syn::Result { + let ident_span = input.ident.span(); + + match ActiveEnum::new(input) { + Ok(model) => model.expand(), + Err(Error::InputNotEnum) => Ok(quote_spanned! { + ident_span => compile_error!("you can only derive ActiveEnum on enums"); + }), + Err(Error::TT(token_stream)) => Ok(token_stream), + Err(Error::Syn(e)) => Err(e), + } +} diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index 9b649922a..5f1508ef8 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -1,7 +1,7 @@ use crate::util::{escape_rust_keyword, trim_starting_raw_identifier}; use heck::CamelCase; use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote, quote_spanned}; use syn::{ parse::Error, punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, Fields, Lit, Meta, @@ -193,8 +193,8 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res primary_keys.push(quote! { #field_name }); } - let field_type = match sql_type { - Some(t) => t, + let col_type = match sql_type { + Some(t) => quote! { sea_orm::prelude::ColumnType::#t.def() }, None => { let field_type = &field.ty; let temp = quote! { #field_type } @@ -206,7 +206,7 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res } else { temp.as_str() }; - match temp { + let col_type = match temp { "char" => quote! { Char(None) }, "String" | "&str" => quote! { String(None) }, "u8" | "i8" => quote! { TinyInteger }, @@ -229,16 +229,24 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res "Decimal" => quote! { Decimal(None) }, "Vec" => quote! { Binary }, _ => { - return Err(Error::new( - field.span(), - format!("unrecognized type {}", temp), - )) + // Assumed it's ActiveEnum if none of the above type matches + quote! {} } + }; + if col_type.is_empty() { + let field_span = field.span(); + let ty = format_ident!("{}", temp); + let def = quote_spanned! { field_span => { + <#ty as ActiveEnum>::db_type() + }}; + quote! { #def } + } else { + quote! { sea_orm::prelude::ColumnType::#col_type.def() } } } }; - let mut match_row = quote! { Self::#field_name => sea_orm::prelude::ColumnType::#field_type.def() }; + let mut match_row = quote! { Self::#field_name => #col_type }; if nullable { match_row = quote! { #match_row.nullable() }; } diff --git a/sea-orm-macros/src/derives/mod.rs b/sea-orm-macros/src/derives/mod.rs index 6ba19a928..36b9f6698 100644 --- a/sea-orm-macros/src/derives/mod.rs +++ b/sea-orm-macros/src/derives/mod.rs @@ -1,3 +1,4 @@ +mod active_enum; mod active_model; mod active_model_behavior; mod column; @@ -9,6 +10,7 @@ mod model; mod primary_key; mod relation; +pub use active_enum::*; pub use active_model::*; pub use active_model_behavior::*; pub use column::*; diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index 1514895c2..00540aa4a 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -492,6 +492,41 @@ pub fn derive_active_model_behavior(input: TokenStream) -> TokenStream { } } +/// A derive macro to implement `sea_orm::ActiveEnum` trait for enums. +/// +/// # Limitations +/// +/// This derive macros can only be used on enums. +/// +/// # Macro Attributes +/// +/// All macro attributes listed below have to be annotated in the form of `#[sea_orm(attr = value)]`. +/// +/// - For enum +/// - `rs_type`: Define `ActiveEnum::Value` +/// - Possible values: `String`, `i8`, `i16`, `i32`, `i64`, `u8`, `u16`, `u32`, `u64` +/// - Note that value has to be passed as string, i.e. `rs_type = "i8"` +/// - `db_type`: Define `ColumnType` returned by `ActiveEnum::db_type()` +/// - Possible values: all available enum variants of `ColumnType`, e.g. `String(None)`, `String(Some(1))`, `Integer` +/// - Note that value has to be passed as string, i.e. `db_type = "Integer"` +/// - `enum_name`: Define `String` returned by `ActiveEnum::name()` +/// - This attribute is optional with default value being the name of enum in camel-case +/// - Note that value has to be passed as string, i.e. `db_type = "Integer"` +/// +/// - For enum variant +/// - `string_value` or `num_value`: +/// - For `string_value`, value should be passed as string, i.e. `string_value = "A"` +/// - For `num_value`, value should be passed as integer, i.e. `num_value = 1` or `num_value = 1i32` +/// - Note that only one of it can be specified, and all variants of an enum have to annotate with the same `*_value` macro attribute +#[proc_macro_derive(DeriveActiveEnum, attributes(sea_orm))] +pub fn derive_active_enum(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match derives::expand_derive_active_enum(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + /// Convert a query result into the corresponding Model. /// /// ### Usage diff --git a/src/database/statement.rs b/src/database/statement.rs index 1c970ea13..ce911295c 100644 --- a/src/database/statement.rs +++ b/src/database/statement.rs @@ -82,6 +82,15 @@ macro_rules! build_any_stmt { }; } +macro_rules! build_postgres_stmt { + ($stmt: expr, $db_backend: expr) => { + match $db_backend { + DbBackend::Postgres => $stmt.to_string(PostgresQueryBuilder), + DbBackend::MySql | DbBackend::Sqlite => unimplemented!(), + } + }; +} + macro_rules! build_query_stmt { ($stmt: ty) => { impl StatementBuilder for $stmt { @@ -114,3 +123,18 @@ build_schema_stmt!(sea_query::TableDropStatement); build_schema_stmt!(sea_query::TableAlterStatement); build_schema_stmt!(sea_query::TableRenameStatement); build_schema_stmt!(sea_query::TableTruncateStatement); + +macro_rules! build_type_stmt { + ($stmt: ty) => { + impl StatementBuilder for $stmt { + fn build(&self, db_backend: &DbBackend) -> Statement { + let stmt = build_postgres_stmt!(self, db_backend); + Statement::from_string(*db_backend, stmt) + } + } + }; +} + +build_type_stmt!(sea_query::extension::postgres::TypeAlterStatement); +build_type_stmt!(sea_query::extension::postgres::TypeCreateStatement); +build_type_stmt!(sea_query::extension::postgres::TypeDropStatement); diff --git a/src/entity/active_enum.rs b/src/entity/active_enum.rs new file mode 100644 index 000000000..292389412 --- /dev/null +++ b/src/entity/active_enum.rs @@ -0,0 +1,308 @@ +use crate::{ColumnDef, DbErr, Iterable, TryGetable}; +use sea_query::{Nullable, Value, ValueType}; + +/// A Rust representation of enum defined in database. +/// +/// # Implementations +/// +/// You can implement [ActiveEnum] manually by hand or use the derive macro [DeriveActiveEnum](sea_orm_macros::DeriveActiveEnum). +/// +/// # Examples +/// +/// Implementing it manually versus using the derive macro [DeriveActiveEnum](sea_orm_macros::DeriveActiveEnum). +/// +/// > See [DeriveActiveEnum](sea_orm_macros::DeriveActiveEnum) for the full specification of macro attributes. +/// +/// ```rust +/// use sea_orm::entity::prelude::*; +/// +/// // Using the derive macro +/// #[derive(Debug, PartialEq, EnumIter, DeriveActiveEnum)] +/// #[sea_orm( +/// rs_type = "String", +/// db_type = "String(Some(1))", +/// enum_name = "category" +/// )] +/// pub enum DeriveCategory { +/// #[sea_orm(string_value = "B")] +/// Big, +/// #[sea_orm(string_value = "S")] +/// Small, +/// } +/// +/// // Implementing it manually +/// #[derive(Debug, PartialEq, EnumIter)] +/// pub enum Category { +/// Big, +/// Small, +/// } +/// +/// impl ActiveEnum for Category { +/// // The macro attribute `rs_type` is being pasted here +/// type Value = String; +/// +/// // Will be atomically generated by `DeriveActiveEnum` +/// fn name() -> String { +/// "category".to_owned() +/// } +/// +/// // Will be atomically generated by `DeriveActiveEnum` +/// fn to_value(&self) -> Self::Value { +/// match self { +/// Self::Big => "B", +/// Self::Small => "S", +/// } +/// .to_owned() +/// } +/// +/// // Will be atomically generated by `DeriveActiveEnum` +/// fn try_from_value(v: &Self::Value) -> Result { +/// match v.as_ref() { +/// "B" => Ok(Self::Big), +/// "S" => Ok(Self::Small), +/// _ => Err(DbErr::Type(format!( +/// "unexpected value for Category enum: {}", +/// v +/// ))), +/// } +/// } +/// +/// fn db_type() -> ColumnDef { +/// // The macro attribute `db_type` is being pasted here +/// ColumnType::String(Some(1)).def() +/// } +/// } +/// ``` +/// +/// Using [ActiveEnum] on Model. +/// +/// ``` +/// use sea_orm::entity::prelude::*; +/// +/// // Define the `Category` active enum +/// #[derive(Debug, Clone, PartialEq, EnumIter, DeriveActiveEnum)] +/// #[sea_orm(rs_type = "String", db_type = "String(Some(1))")] +/// pub enum Category { +/// #[sea_orm(string_value = "B")] +/// Big, +/// #[sea_orm(string_value = "S")] +/// Small, +/// } +/// +/// #[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +/// #[sea_orm(table_name = "active_enum")] +/// pub struct Model { +/// #[sea_orm(primary_key)] +/// pub id: i32, +/// // Represents a db column using `Category` active enum +/// pub category: Category, +/// pub category_opt: Option, +/// } +/// +/// #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +/// pub enum Relation {} +/// +/// impl ActiveModelBehavior for ActiveModel {} +/// ``` +pub trait ActiveEnum: Sized + Iterable { + /// Define the Rust type that each enum variant represents. + type Value: Into + ValueType + Nullable + TryGetable; + + /// Get the name of enum + fn name() -> String; + + /// Convert enum variant into the corresponding value. + fn to_value(&self) -> Self::Value; + + /// Try to convert the corresponding value into enum variant. + fn try_from_value(v: &Self::Value) -> Result; + + /// Get the database column definition of this active enum. + fn db_type() -> ColumnDef; + + /// Convert an owned enum variant into the corresponding value. + fn into_value(self) -> Self::Value { + Self::to_value(&self) + } + + /// Get the name of all enum variants + fn values() -> Vec { + Self::iter().map(Self::into_value).collect() + } +} + +#[cfg(test)] +mod tests { + use crate as sea_orm; + use crate::{entity::prelude::*, *}; + use pretty_assertions::assert_eq; + + #[test] + fn active_enum_string() { + #[derive(Debug, PartialEq, EnumIter)] + pub enum Category { + Big, + Small, + } + + impl ActiveEnum for Category { + type Value = String; + + fn name() -> String { + "category".to_owned() + } + + fn to_value(&self) -> Self::Value { + match self { + Self::Big => "B", + Self::Small => "S", + } + .to_owned() + } + + fn try_from_value(v: &Self::Value) -> Result { + match v.as_ref() { + "B" => Ok(Self::Big), + "S" => Ok(Self::Small), + _ => Err(DbErr::Type(format!( + "unexpected value for Category enum: {}", + v + ))), + } + } + + fn db_type() -> ColumnDef { + ColumnType::String(Some(1)).def() + } + } + + #[derive(Debug, PartialEq, EnumIter, DeriveActiveEnum)] + #[sea_orm( + rs_type = "String", + db_type = "String(Some(1))", + enum_name = "category" + )] + pub enum DeriveCategory { + #[sea_orm(string_value = "B")] + Big, + #[sea_orm(string_value = "S")] + Small, + } + + assert_eq!(Category::Big.to_value(), "B".to_owned()); + assert_eq!(Category::Small.to_value(), "S".to_owned()); + assert_eq!(DeriveCategory::Big.to_value(), "B".to_owned()); + assert_eq!(DeriveCategory::Small.to_value(), "S".to_owned()); + + assert_eq!( + Category::try_from_value(&"A".to_owned()).err(), + Some(DbErr::Type( + "unexpected value for Category enum: A".to_owned() + )) + ); + assert_eq!( + Category::try_from_value(&"B".to_owned()).ok(), + Some(Category::Big) + ); + assert_eq!( + Category::try_from_value(&"S".to_owned()).ok(), + Some(Category::Small) + ); + assert_eq!( + DeriveCategory::try_from_value(&"A".to_owned()).err(), + Some(DbErr::Type( + "unexpected value for DeriveCategory enum: A".to_owned() + )) + ); + assert_eq!( + DeriveCategory::try_from_value(&"B".to_owned()).ok(), + Some(DeriveCategory::Big) + ); + assert_eq!( + DeriveCategory::try_from_value(&"S".to_owned()).ok(), + Some(DeriveCategory::Small) + ); + + assert_eq!(Category::db_type(), ColumnType::String(Some(1)).def()); + assert_eq!(DeriveCategory::db_type(), ColumnType::String(Some(1)).def()); + + assert_eq!(Category::name(), DeriveCategory::name()); + assert_eq!(Category::values(), DeriveCategory::values()); + } + + #[test] + fn active_enum_derive_signed_integers() { + macro_rules! test_int { + ($ident: ident, $rs_type: expr, $db_type: expr, $col_def: ident) => { + #[derive(Debug, PartialEq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = $rs_type, db_type = $db_type)] + pub enum $ident { + #[sea_orm(num_value = 1)] + Big, + #[sea_orm(num_value = 0)] + Small, + #[sea_orm(num_value = -10)] + Negative, + } + + assert_eq!($ident::Big.to_value(), 1); + assert_eq!($ident::Small.to_value(), 0); + assert_eq!($ident::Negative.to_value(), -10); + + assert_eq!($ident::try_from_value(&1).ok(), Some($ident::Big)); + assert_eq!($ident::try_from_value(&0).ok(), Some($ident::Small)); + assert_eq!($ident::try_from_value(&-10).ok(), Some($ident::Negative)); + assert_eq!( + $ident::try_from_value(&2).err(), + Some(DbErr::Type(format!( + "unexpected value for {} enum: 2", + stringify!($ident) + ))) + ); + + assert_eq!($ident::db_type(), ColumnType::$col_def.def()); + }; + } + + test_int!(I8, "i8", "TinyInteger", TinyInteger); + test_int!(I16, "i16", "SmallInteger", SmallInteger); + test_int!(I32, "i32", "Integer", Integer); + test_int!(I64, "i64", "BigInteger", BigInteger); + } + + #[test] + fn active_enum_derive_unsigned_integers() { + macro_rules! test_uint { + ($ident: ident, $rs_type: expr, $db_type: expr, $col_def: ident) => { + #[derive(Debug, PartialEq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = $rs_type, db_type = $db_type)] + pub enum $ident { + #[sea_orm(num_value = 1)] + Big, + #[sea_orm(num_value = 0)] + Small, + } + + assert_eq!($ident::Big.to_value(), 1); + assert_eq!($ident::Small.to_value(), 0); + + assert_eq!($ident::try_from_value(&1).ok(), Some($ident::Big)); + assert_eq!($ident::try_from_value(&0).ok(), Some($ident::Small)); + assert_eq!( + $ident::try_from_value(&2).err(), + Some(DbErr::Type(format!( + "unexpected value for {} enum: 2", + stringify!($ident) + ))) + ); + + assert_eq!($ident::db_type(), ColumnType::$col_def.def()); + }; + } + + test_uint!(U8, "u8", "TinyInteger", TinyInteger); + test_uint!(U16, "u16", "SmallInteger", SmallInteger); + test_uint!(U32, "u32", "Integer", Integer); + test_uint!(U64, "u64", "BigInteger", BigInteger); + } +} diff --git a/src/entity/column.rs b/src/entity/column.rs index bb92d46df..d5234b312 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -1,5 +1,5 @@ use crate::{EntityName, IdenStatic, Iterable}; -use sea_query::{DynIden, Expr, SeaRc, SelectStatement, SimpleExpr, Value}; +use sea_query::{Alias, BinOper, DynIden, Expr, SeaRc, SelectStatement, SimpleExpr, Value}; use std::str::FromStr; /// Defines a Column for an Entity @@ -62,6 +62,8 @@ pub enum ColumnType { Custom(String), /// A Universally Unique IDentifier that is specified in RFC 4122 Uuid, + /// `ENUM` data type with name and variants + Enum(String, Vec), } macro_rules! bind_oper { @@ -76,6 +78,25 @@ macro_rules! bind_oper { }; } +macro_rules! bind_oper_with_enum_casting { + ( $op: ident, $bin_op: ident ) => { + #[allow(missing_docs)] + fn $op(&self, v: V) -> SimpleExpr + where + V: Into, + { + let val = Expr::val(v); + let col_def = self.def(); + let col_type = col_def.get_column_type(); + let expr = match col_type.get_enum_name() { + Some(enum_name) => val.as_enum(Alias::new(enum_name)), + None => val.into(), + }; + Expr::tbl(self.entity_name(), *self).binary(BinOper::$bin_op, expr) + } + }; +} + macro_rules! bind_func_no_params { ( $func: ident ) => { /// See also SeaQuery's method with same name. @@ -128,8 +149,8 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr { (self.entity_name(), SeaRc::new(*self) as DynIden) } - bind_oper!(eq); - bind_oper!(ne); + bind_oper_with_enum_casting!(eq, Equal); + bind_oper_with_enum_casting!(ne, NotEqual); bind_oper!(gt); bind_oper!(gte); bind_oper!(lt); @@ -281,6 +302,13 @@ impl ColumnType { indexed: false, } } + + pub(crate) fn get_enum_name(&self) -> Option<&String> { + match self { + ColumnType::Enum(s, _) => Some(s), + _ => None, + } + } } impl ColumnDef { @@ -306,6 +334,11 @@ impl ColumnDef { self.indexed = true; self } + + /// Get [ColumnType] as reference + pub fn get_column_type(&self) -> &ColumnType { + &self.col_type + } } impl From for sea_query::ColumnType { @@ -331,7 +364,7 @@ impl From for sea_query::ColumnType { ColumnType::Money(s) => sea_query::ColumnType::Money(s), ColumnType::Json => sea_query::ColumnType::Json, ColumnType::JsonBinary => sea_query::ColumnType::JsonBinary, - ColumnType::Custom(s) => { + ColumnType::Custom(s) | ColumnType::Enum(s, _) => { sea_query::ColumnType::Custom(sea_query::SeaRc::new(sea_query::Alias::new(&s))) } ColumnType::Uuid => sea_query::ColumnType::Uuid, diff --git a/src/entity/mod.rs b/src/entity/mod.rs index 55f58902b..6b413bdf1 100644 --- a/src/entity/mod.rs +++ b/src/entity/mod.rs @@ -95,6 +95,7 @@ /// // to create an ActiveModel using the [ActiveModelBehavior] /// impl ActiveModelBehavior for ActiveModel {} /// ``` +mod active_enum; mod active_model; mod base_entity; mod column; @@ -106,6 +107,7 @@ pub mod prelude; mod primary_key; mod relation; +pub use active_enum::*; pub use active_model::*; pub use base_entity::*; pub use column::*; diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 67284b683..211cf853d 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,14 +1,15 @@ pub use crate::{ - error::*, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, - DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, Iden, - IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult, - Related, RelationDef, RelationTrait, Select, Value, + error::*, ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, + ColumnType, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, ForeignKeyAction, + Iden, IdenStatic, Linked, ModelTrait, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, + QueryResult, Related, RelationDef, RelationTrait, Select, Value, }; #[cfg(feature = "macros")] pub use crate::{ - DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, DerivePrimaryKey, DeriveRelation, + DeriveActiveEnum, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, + DeriveCustomColumn, DeriveEntity, DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, + DerivePrimaryKey, DeriveRelation, }; #[cfg(feature = "with-json")] diff --git a/src/error.rs b/src/error.rs index f8412c056..5289c45f8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,8 @@ pub enum DbErr { RecordNotFound(String), /// A custom error Custom(String), + /// Error occurred while parsing value as target type + Type(String), } impl std::error::Error for DbErr {} @@ -23,6 +25,7 @@ impl std::fmt::Display for DbErr { Self::Query(s) => write!(f, "Query Error: {}", s), Self::RecordNotFound(s) => write!(f, "RecordNotFound Error: {}", s), Self::Custom(s) => write!(f, "Custom Error: {}", s), + Self::Type(s) => write!(f, "Type Error: {}", s), } } } diff --git a/src/lib.rs b/src/lib.rs index 30d7cc0b7..1b517fe8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -292,9 +292,9 @@ pub use schema::*; #[cfg(feature = "macros")] pub use sea_orm_macros::{ - DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, DeriveCustomColumn, DeriveEntity, - DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, DerivePrimaryKey, DeriveRelation, - FromQueryResult, + DeriveActiveEnum, DeriveActiveModel, DeriveActiveModelBehavior, DeriveColumn, + DeriveCustomColumn, DeriveEntity, DeriveEntityModel, DeriveIntoActiveModel, DeriveModel, + DerivePrimaryKey, DeriveRelation, FromQueryResult, }; pub use sea_query; diff --git a/src/query/insert.rs b/src/query/insert.rs index f862ccba1..7cafc44c6 100644 --- a/src/query/insert.rs +++ b/src/query/insert.rs @@ -1,9 +1,9 @@ use crate::{ - ActiveModelTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, PrimaryKeyTrait, - QueryTrait, + ActiveModelTrait, ColumnTrait, EntityName, EntityTrait, IntoActiveModel, Iterable, + PrimaryKeyTrait, QueryTrait, }; use core::marker::PhantomData; -use sea_query::{InsertStatement, ValueTuple}; +use sea_query::{Alias, Expr, InsertStatement, ValueTuple}; /// Performs INSERT operations on a ActiveModel #[derive(Debug)] @@ -133,11 +133,18 @@ where } if av_has_val { columns.push(col); - values.push(av.into_value().unwrap()); + let val = Expr::val(av.into_value().unwrap()); + let col_def = col.def(); + let col_type = col_def.get_column_type(); + let expr = match col_type.get_enum_name() { + Some(enum_name) => val.as_enum(Alias::new(enum_name)), + None => val.into(), + }; + values.push(expr); } } self.query.columns(columns); - self.query.values_panic(values); + self.query.exprs_panic(values); self } diff --git a/src/query/select.rs b/src/query/select.rs index adfb07db6..3c9829480 100644 --- a/src/query/select.rs +++ b/src/query/select.rs @@ -2,7 +2,7 @@ use crate::{ColumnTrait, EntityTrait, Iterable, QueryFilter, QueryOrder, QuerySe use core::fmt::Debug; use core::marker::PhantomData; pub use sea_query::JoinType; -use sea_query::{DynIden, IntoColumnRef, SeaRc, SelectStatement, SimpleExpr}; +use sea_query::{Alias, DynIden, Expr, IntoColumnRef, SeaRc, SelectStatement, SimpleExpr}; /// Defines a structure to perform select operations #[derive(Clone, Debug)] @@ -114,13 +114,24 @@ where } fn prepare_select(mut self) -> Self { - self.query.columns(self.column_list()); + self.query.exprs(self.column_list()); self } - fn column_list(&self) -> Vec<(DynIden, E::Column)> { + fn column_list(&self) -> Vec { let table = SeaRc::new(E::default()) as DynIden; - E::Column::iter().map(|col| (table.clone(), col)).collect() + let text_type = SeaRc::new(Alias::new("text")) as DynIden; + E::Column::iter() + .map(|col| { + let expr = Expr::tbl(table.clone(), col); + let col_def = col.def(); + let col_type = col_def.get_column_type(); + match col_type.get_enum_name() { + Some(_) => expr.as_enum(text_type.clone()), + None => expr.into(), + } + }) + .collect() } fn prepare_from(mut self) -> Self { diff --git a/src/query/update.rs b/src/query/update.rs index 177d9d9e1..f12d3b950 100644 --- a/src/query/update.rs +++ b/src/query/update.rs @@ -3,7 +3,7 @@ use crate::{ QueryTrait, }; use core::marker::PhantomData; -use sea_query::{IntoIden, SimpleExpr, UpdateStatement}; +use sea_query::{Alias, Expr, IntoIden, SimpleExpr, UpdateStatement}; /// Defines a structure to perform UPDATE query operations on a ActiveModel #[derive(Clone, Debug)] @@ -109,7 +109,14 @@ where } let av = self.model.get(col); if av.is_set() { - self.query.value(col, av.unwrap()); + let val = Expr::val(av.into_value().unwrap()); + let col_def = col.def(); + let col_type = col_def.get_column_type(); + let expr = match col_type.get_enum_name() { + Some(enum_name) => val.as_enum(Alias::new(enum_name)), + None => val.into(), + }; + self.query.value_expr(col, expr); } } self diff --git a/src/schema/entity.rs b/src/schema/entity.rs index 51d5994ef..c24a1385d 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -1,20 +1,58 @@ use crate::{ - unpack_table_ref, ColumnTrait, EntityTrait, Identity, Iterable, PrimaryKeyToColumn, - PrimaryKeyTrait, RelationTrait, Schema, + unpack_table_ref, ColumnTrait, ColumnType, DbBackend, EntityTrait, Identity, Iterable, + PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema, +}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement, }; -use sea_query::{ColumnDef, ForeignKeyCreateStatement, Iden, Index, TableCreateStatement}; impl Schema { + /// Creates Postgres enums from an Entity. See [TypeCreateStatement] for more details + pub fn create_enum_from_entity(entity: E, db_backend: DbBackend) -> Vec + where + E: EntityTrait, + { + create_enum_from_entity(entity, db_backend) + } + /// Creates a table from an Entity. See [TableCreateStatement] for more details - pub fn create_table_from_entity(entity: E) -> TableCreateStatement + pub fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement where E: EntityTrait, { - create_table_from_entity(entity) + create_table_from_entity(entity, db_backend) } } -pub(crate) fn create_table_from_entity(entity: E) -> TableCreateStatement +pub(crate) fn create_enum_from_entity(_: E, db_backend: DbBackend) -> Vec +where + E: EntityTrait, +{ + if matches!(db_backend, DbBackend::MySql | DbBackend::Sqlite) { + return Vec::new(); + } + let mut vec = Vec::new(); + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let (name, values) = match col_type { + ColumnType::Enum(s, v) => (s.as_str(), v), + _ => unreachable!(), + }; + let stmt = Type::create() + .as_enum(Alias::new(name)) + .values(values.iter().map(|val| Alias::new(val.as_str()))) + .to_owned(); + vec.push(stmt); + } + vec +} + +pub(crate) fn create_table_from_entity(entity: E, db_backend: DbBackend) -> TableCreateStatement where E: EntityTrait, { @@ -22,7 +60,17 @@ where for column in E::Column::iter() { let orm_column_def = column.def(); - let types = orm_column_def.col_type.into(); + let types = match orm_column_def.col_type { + ColumnType::Enum(s, variants) => match db_backend { + DbBackend::MySql => { + ColumnType::Custom(format!("ENUM('{}')", variants.join("', '"))) + } + DbBackend::Postgres => ColumnType::Custom(s), + DbBackend::Sqlite => ColumnType::Text, + } + .into(), + _ => orm_column_def.col_type.into(), + }; let mut column_def = ColumnDef::new_with_type(column, types); if !orm_column_def.null { column_def.not_null(); @@ -122,13 +170,14 @@ where #[cfg(test)] mod tests { - use crate::{sea_query::*, tests_cfg::*, Schema}; + use crate::{sea_query::*, tests_cfg::*, DbBackend, Schema}; use pretty_assertions::assert_eq; #[test] fn test_create_table_from_entity() { assert_eq!( - Schema::create_table_from_entity(CakeFillingPrice).to_string(MysqlQueryBuilder), + Schema::create_table_from_entity(CakeFillingPrice, DbBackend::MySql) + .to_string(MysqlQueryBuilder), Table::create() .table(CakeFillingPrice) .col( diff --git a/tests/active_enum_tests.rs b/tests/active_enum_tests.rs new file mode 100644 index 000000000..aaad419b2 --- /dev/null +++ b/tests/active_enum_tests.rs @@ -0,0 +1,92 @@ +pub mod common; + +pub use common::{features::*, setup::*, TestContext}; +use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("active_enum_tests").await; + create_tables(&ctx.db).await?; + insert_active_enum(&ctx.db).await?; + ctx.delete().await; + + Ok(()) +} + +pub async fn insert_active_enum(db: &DatabaseConnection) -> Result<(), DbErr> { + use active_enum::*; + + let am = ActiveModel { + category: Set(None), + color: Set(None), + tea: Set(None), + ..Default::default() + } + .insert(db) + .await?; + + let model = Entity::find().one(db).await?.unwrap(); + assert_eq!( + model, + Model { + id: 1, + category: None, + color: None, + tea: None, + } + ); + assert_eq!( + model, + Entity::find() + .filter(Column::Id.is_not_null()) + .filter(Column::Category.is_null()) + .filter(Column::Color.is_null()) + .filter(Column::Tea.is_null()) + .one(db) + .await? + .unwrap() + ); + + let am = ActiveModel { + category: Set(Some(Category::Big)), + color: Set(Some(Color::Black)), + tea: Set(Some(Tea::EverydayTea)), + ..am + } + .save(db) + .await?; + + let model = Entity::find().one(db).await?.unwrap(); + assert_eq!( + model, + Model { + id: 1, + category: Some(Category::Big), + color: Some(Color::Black), + tea: Some(Tea::EverydayTea), + } + ); + assert_eq!( + model, + Entity::find() + .filter(Column::Id.eq(1)) + .filter(Column::Category.eq(Category::Big)) + .filter(Column::Color.eq(Color::Black)) + .filter(Column::Tea.eq(Tea::EverydayTea)) + .one(db) + .await? + .unwrap() + ); + + let res = am.delete(db).await?; + + assert_eq!(res.rows_affected, 1); + assert_eq!(Entity::find().one(db).await?, None); + + Ok(()) +} diff --git a/tests/common/features/active_enum.rs b/tests/common/features/active_enum.rs new file mode 100644 index 000000000..5285c5d94 --- /dev/null +++ b/tests/common/features/active_enum.rs @@ -0,0 +1,43 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "active_enum")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub category: Option, + pub color: Option, + pub tea: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Debug, Clone, PartialEq, EnumIter, DeriveActiveEnum)] +#[sea_orm(rs_type = "String", db_type = "String(Some(1))")] +pub enum Category { + #[sea_orm(string_value = "B")] + Big, + #[sea_orm(string_value = "S")] + Small, +} + +#[derive(Debug, Clone, PartialEq, EnumIter, DeriveActiveEnum)] +#[sea_orm(rs_type = "i32", db_type = r#"Integer"#)] +pub enum Color { + #[sea_orm(num_value = 0)] + Black, + #[sea_orm(num_value = 1)] + White, +} + +#[derive(Debug, Clone, PartialEq, EnumIter, DeriveActiveEnum)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "tea")] +pub enum Tea { + #[sea_orm(string_value = "EverydayTea")] + EverydayTea, + #[sea_orm(string_value = "BreakfastTea")] + BreakfastTea, +} diff --git a/tests/common/features/mod.rs b/tests/common/features/mod.rs index 186256ae3..e354292bb 100644 --- a/tests/common/features/mod.rs +++ b/tests/common/features/mod.rs @@ -1,3 +1,4 @@ +pub mod active_enum; pub mod applog; pub mod byte_primary_key; pub mod metadata; @@ -5,6 +6,7 @@ pub mod repository; pub mod schema; pub mod self_join; +pub use active_enum::Entity as ActiveEnum; pub use applog::Entity as Applog; pub use byte_primary_key::Entity as BytePrimaryKey; pub use metadata::Entity as Metadata; diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index fe1c7ea3c..942776047 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -1,11 +1,11 @@ pub use super::super::bakery_chain::*; use super::*; -use crate::common::setup::{create_table, create_table_without_asserts}; +use crate::common::setup::{create_enum, create_table, create_table_without_asserts}; use sea_orm::{ error::*, sea_query, ConnectionTrait, DatabaseConnection, DbBackend, DbConn, ExecResult, }; -use sea_query::{ColumnDef, ForeignKeyCreateStatement}; +use sea_query::{extension::postgres::Type, Alias, ColumnDef, ForeignKeyCreateStatement}; pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_log_table(db).await?; @@ -13,6 +13,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_repository_table(db).await?; create_self_join_table(db).await?; create_byte_primary_key_table(db).await?; + create_active_enum_table(db).await?; Ok(()) } @@ -123,3 +124,40 @@ pub async fn create_byte_primary_key_table(db: &DbConn) -> Result Result { + let db_backend = db.get_database_backend(); + let tea_enum = Alias::new("tea"); + + let create_enum_stmts = match db_backend { + DbBackend::MySql | DbBackend::Sqlite => Vec::new(), + DbBackend::Postgres => vec![Type::create() + .as_enum(tea_enum.clone()) + .values(vec![Alias::new("EverydayTea"), Alias::new("BreakfastTea")]) + .to_owned()], + }; + + create_enum(db, &create_enum_stmts, ActiveEnum).await?; + + let mut tea_col = ColumnDef::new(active_enum::Column::Tea); + match db_backend { + DbBackend::MySql => tea_col.custom(Alias::new("ENUM('EverydayTea', 'BreakfastTea')")), + DbBackend::Sqlite => tea_col.text(), + DbBackend::Postgres => tea_col.custom(tea_enum), + }; + let create_table_stmt = sea_query::Table::create() + .table(active_enum::Entity) + .col( + ColumnDef::new(active_enum::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col(ColumnDef::new(active_enum::Column::Category).string_len(1)) + .col(ColumnDef::new(active_enum::Column::Color).integer()) + .col(&mut tea_col) + .to_owned(); + + create_table(db, &create_table_stmt, ActiveEnum).await +} diff --git a/tests/common/setup/mod.rs b/tests/common/setup/mod.rs index dfa1e29c5..fae61984b 100644 --- a/tests/common/setup/mod.rs +++ b/tests/common/setup/mod.rs @@ -1,9 +1,12 @@ +use pretty_assertions::assert_eq; use sea_orm::{ - ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, DbBackend, DbConn, DbErr, - EntityTrait, ExecResult, Schema, Statement, + ColumnTrait, ColumnType, ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, + DbBackend, DbConn, DbErr, EntityTrait, ExecResult, Iterable, Schema, Statement, +}; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, Table, TableCreateStatement, }; - -use sea_query::{Alias, Table, TableCreateStatement}; pub async fn setup(base_url: &str, db_name: &str) -> DatabaseConnection { let db = if cfg!(feature = "sqlx-mysql") { @@ -74,6 +77,51 @@ pub async fn tear_down(base_url: &str, db_name: &str) { }; } +pub async fn create_enum( + db: &DbConn, + creates: &[TypeCreateStatement], + entity: E, +) -> Result<(), DbErr> +where + E: EntityTrait, +{ + let builder = db.get_database_backend(); + if builder == DbBackend::Postgres { + for col in E::Column::iter() { + let col_def = col.def(); + let col_type = col_def.get_column_type(); + if !matches!(col_type, ColumnType::Enum(_, _)) { + continue; + } + let name = match col_type { + ColumnType::Enum(s, _) => s.as_str(), + _ => unreachable!(), + }; + let drop_type_stmt = Type::drop() + .name(Alias::new(name)) + .if_exists() + .cascade() + .to_owned(); + let stmt = builder.build(&drop_type_stmt); + db.execute(stmt).await?; + } + } + + let expect_stmts: Vec = creates.iter().map(|stmt| builder.build(stmt)).collect(); + let create_from_entity_stmts: Vec = Schema::create_enum_from_entity(entity, builder) + .iter() + .map(|stmt| builder.build(stmt)) + .collect(); + + assert_eq!(expect_stmts, create_from_entity_stmts); + + for stmt in expect_stmts { + db.execute(stmt).await.map(|_| ())?; + } + + Ok(()) +} + pub async fn create_table( db: &DbConn, create: &TableCreateStatement, @@ -84,7 +132,7 @@ where { let builder = db.get_database_backend(); assert_eq!( - builder.build(&Schema::create_table_from_entity(entity)), + builder.build(&Schema::create_table_from_entity(entity, builder)), builder.build(create) );