From 5bb5a6ae018f62da7b4eda34d46fbc3162a192be Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 19:58:14 +0100 Subject: [PATCH 01/12] add derive macros for weak & strong enums and structs --- sqlx-core/src/mysql/mod.rs | 2 + sqlx-core/src/mysql/protocol/type.rs | 3 + sqlx-core/src/mysql/types/mod.rs | 9 +- sqlx-core/src/postgres/types/mod.rs | 4 + sqlx-macros/src/derives.rs | 879 +++++++++++++++++++++++++-- sqlx-macros/src/lib.rs | 26 +- src/lib.rs | 7 +- src/types.rs | 6 + tests/derives.rs | 285 ++++++++- 9 files changed, 1138 insertions(+), 83 deletions(-) create mode 100644 src/types.rs diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 1f037b3fc4..91e55f1848 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -22,6 +22,8 @@ pub use error::MySqlError; pub use types::MySqlTypeInfo; +pub use protocol::TypeId; + pub use row::MySqlRow; /// An alias for [`Pool`], specialized for **MySQL**. diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 6284401b08..74841c811a 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -13,6 +13,9 @@ impl TypeId { pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY pub const TEXT: TypeId = TypeId(252); // or BLOB + // Enum + pub const ENUM: TypeId = TypeId(247); + // More Bytes pub const TINY_BLOB: TypeId = TypeId(249); pub const MEDIUM_BLOB: TypeId = TypeId(250); diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 39e8252b1d..cc31e69518 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -49,6 +49,11 @@ impl MySqlTypeInfo { char_set: def.char_set, } } + + #[doc(hidden)] + pub fn r#enum() -> Self { + Self::new(TypeId::ENUM) + } } impl Display for MySqlTypeInfo { @@ -67,6 +72,7 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB | TypeId::LONG_BLOB + | TypeId::ENUM if (self.is_binary == other.is_binary) && match other.id { TypeId::VAR_CHAR @@ -74,7 +80,8 @@ impl TypeInfo for MySqlTypeInfo { | TypeId::CHAR | TypeId::TINY_BLOB | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB => true, + | TypeId::LONG_BLOB + | TypeId::ENUM => true, _ => false, } => diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 28404d8b26..a81a47ec16 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -32,6 +32,10 @@ impl PgTypeInfo { pub fn with_oid(oid: u32) -> Self { Self { id: TypeId(oid) } } + + pub fn oid(&self) -> u32 { + self.id.0 + } } impl Display for PgTypeInfo { diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 460b05c8b5..fdba07428b 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -1,42 +1,288 @@ +use proc_macro2::Ident; use quote::quote; -use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; +use std::iter::FromIterator; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, + Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt, +}; + +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} + +struct SqlxAttributes { + transparent: bool, + postgres_oid: Option, + repr: Option, + rename: Option, +} + +fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; + + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} + } + } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) +} + +fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} + +pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; -pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) - } - } - )) + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), _ => Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", @@ -44,45 +290,558 @@ pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result syn::Result { +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + tts.extend(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { - let ident = &input.ident; - let ty = &unnamed.first().unwrap().ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { - fn decode(raw: &[u8]) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode(raw).map(Self) - } - fn decode_null() -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_null().map(Self) - } - fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) - } + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) } - )) + } } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) +} + +pub(crate) fn expand_derive_has_sql_type( + input: &DeriveInput, +) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), _ => Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", )), } } + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index ede05206b4..57c4770710 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -137,19 +137,37 @@ pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } -#[proc_macro_derive(Encode)] +#[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_encode(input) { + match derives::expand_derive_encode(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } } -#[proc_macro_derive(Decode)] +#[proc_macro_derive(Decode, attributes(sqlx))] pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); - match derives::expand_derive_decode(input) { + match derives::expand_derive_decode(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(HasSqlType, attributes(sqlx))] +pub fn derive_has_sql_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_has_sql_type(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Type, attributes(sqlx))] +pub fn derive_type(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_type(&input) { Ok(ts) => ts.into(), Err(e) => e.to_compile_error().into(), } diff --git a/src/lib.rs b/src/lib.rs index f9ccab40fb..bf4efc0946 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled"); // Modules -pub use sqlx_core::{arguments, describe, error, pool, row, types}; +pub use sqlx_core::{arguments, describe, error, pool, row}; // Types pub use sqlx_core::{ @@ -48,3 +48,8 @@ pub mod result_ext; pub mod encode; pub mod decode; + +pub mod types; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Type; diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000000..6eb3cab6c0 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,6 @@ +//! Traits linking Rust types to SQL types. + +pub use sqlx_core::types::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::HasSqlType; diff --git a/tests/derives.rs b/tests/derives.rs index 9c7aa411b4..6ec2cc902e 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -1,28 +1,59 @@ use sqlx::decode::Decode; use sqlx::encode::Encode; +use sqlx::types::{HasSqlType, TypeInfo}; +use std::fmt::Debug; -#[derive(PartialEq, Debug, Encode, Decode)] -struct Foo(i32); +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(transparent)] +struct Transparent(i32); + +#[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)] +#[repr(i32)] +#[allow(dead_code)] +enum Weak { + One, + Two, + Three, +} + +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 10101010))] +#[allow(dead_code)] +enum Strong { + One, + Two, + #[sqlx(rename = "four")] + Three, +} + +#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)] +#[sqlx(postgres(oid = 20202020))] +#[allow(dead_code)] +struct Struct { + field1: String, + field2: i64, + field3: bool, +} #[test] #[cfg(feature = "mysql")] -fn encode_mysql() { - encode_with_db::(); +fn encode_transparent_mysql() { + encode_transparent::(); } #[test] #[cfg(feature = "postgres")] -fn encode_postgres() { - encode_with_db::(); +fn encode_transparent_postgres() { + encode_transparent::(); } #[allow(dead_code)] -fn encode_with_db() +fn encode_transparent() where - Foo: Encode, + Transparent: Encode, i32: Encode, { - let example = Foo(0x1122_3344); + let example = Transparent(0x1122_3344); let mut encoded = Vec::new(); let mut encoded_orig = Vec::new(); @@ -35,26 +66,246 @@ where #[test] #[cfg(feature = "mysql")] -fn decode_mysql() { - decode_with_db::(); +fn encode_weak_enum_mysql() { + encode_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_weak_enum_postgres() { + encode_weak_enum::(); +} + +#[allow(dead_code)] +fn encode_weak_enum() +where + Weak: Encode, + i32: Encode, +{ + for example in [Weak::One, Weak::Two, Weak::Three].iter() { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(&(*example as i32), &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "mysql")] +fn encode_strong_enum_mysql() { + encode_strong_enum::(); } #[test] #[cfg(feature = "postgres")] -fn decode_postgres() { - decode_with_db::(); +fn encode_strong_enum_postgres() { + encode_strong_enum::(); } #[allow(dead_code)] -fn decode_with_db() +fn encode_strong_enum() where - Foo: Decode + Encode, + Strong: Encode, + str: Encode, { - let example = Foo(0x1122_3344); + for (example, name) in [ + (Strong::One, "One"), + (Strong::Two, "Two"), + (Strong::Three, "four"), + ] + .iter() + { + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(example, &mut encoded); + Encode::::encode(name, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); + } +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_struct_postgres() { + let field1 = "Foo".to_string(); + let field2 = 3; + let field3 = false; + + let example = Struct { + field1: field1.clone(), + field2, + field3, + }; + + let mut encoded = Vec::new(); + Encode::::encode(&example, &mut encoded); + + let string_oid = >::type_info().oid(); + let i64_oid = >::type_info().oid(); + let bool_oid = >::type_info().oid(); + + // 3 columns + assert_eq!(&[0, 0, 0, 3], &encoded[..4]); + let encoded = &encoded[4..]; + + // check field1 (string) + assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]); + assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]); + let encoded = &encoded[8 + field1.len()..]; + + // check field2 (i64) + assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field2.to_be_bytes(), &encoded[8..16]); + let encoded = &encoded[16..]; + + // check field3 (bool) + assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]); + assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]); + assert_eq!(field3, encoded[8] != 0); + let encoded = &encoded[9..]; + + assert!(encoded.is_empty()); + + let string_size = >::size_hint(&field1); + let i64_size = >::size_hint(&field2); + let bool_size = >::size_hint(&field3); + + assert_eq!( + 4 + 3 * (4 + 4) + string_size + i64_size + bool_size, + example.size_hint() + ); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_transparent_mysql() { + decode_with_db::(Transparent(0x1122_3344)); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_transparent_postgres() { + decode_with_db::(Transparent(0x1122_3344)); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_weak_enum_mysql() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} +#[test] +#[cfg(feature = "postgres")] +fn decode_weak_enum_postgres() { + decode_with_db::(Weak::One); + decode_with_db::(Weak::Two); + decode_with_db::(Weak::Three); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_strong_enum_mysql() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_strong_enum_postgres() { + decode_with_db::(Strong::One); + decode_with_db::(Strong::Two); + decode_with_db::(Strong::Three); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_struct_postgres() { + decode_with_db::(Struct { + field1: "Foo".to_string(), + field2: 3, + field3: true, + }); +} + +#[allow(dead_code)] +fn decode_with_db + Encode + PartialEq + Debug>(example: V) { let mut encoded = Vec::new(); Encode::::encode(&example, &mut encoded); - let decoded = Foo::decode(&encoded).unwrap(); + let decoded = V::decode(&encoded).unwrap(); assert_eq!(example, decoded); } + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_transparent_mysql() { + has_sql_type_transparent::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_transparent_postgres() { + has_sql_type_transparent::(); +} + +#[allow(dead_code)] +fn has_sql_type_transparent() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_weak_enum_mysql() { + has_sql_type_weak_enum::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_weak_enum_postgres() { + has_sql_type_weak_enum::(); +} + +#[allow(dead_code)] +fn has_sql_type_weak_enum() +where + DB: HasSqlType + HasSqlType, +{ + let info: DB::TypeInfo = >::type_info(); + let info_orig: DB::TypeInfo = >::type_info(); + assert!(info.compatible(&info_orig)); +} + +#[test] +#[cfg(feature = "mysql")] +fn has_sql_type_strong_enum_mysql() { + let info: sqlx::mysql::MySqlTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum())) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_strong_enum_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010))) +} + +#[test] +#[cfg(feature = "postgres")] +fn has_sql_type_struct_postgres() { + let info: sqlx::postgres::PgTypeInfo = >::type_info(); + assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020))) +} From 6cf904dcbac8bf57b6d4c3cbdaf55692ee563c73 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 19:59:09 +0100 Subject: [PATCH 02/12] format --- sqlx-macros/src/derives.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index fdba07428b..5716eb8a9e 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -5,7 +5,7 @@ use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, - Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt, + Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Stmt, Variant, }; macro_rules! assert_attribute { From 18f4d47fe2fc165842d9918d0045270a06ad17ad Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 21:05:22 +0100 Subject: [PATCH 03/12] fix db type --- tests/derives.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/derives.rs b/tests/derives.rs index 6ec2cc902e..4c22324ef1 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -121,8 +121,8 @@ where let mut encoded = Vec::new(); let mut encoded_orig = Vec::new(); - Encode::::encode(example, &mut encoded); - Encode::::encode(name, &mut encoded_orig); + Encode::::encode(example, &mut encoded); + Encode::::encode(*name, &mut encoded_orig); assert_eq!(encoded, encoded_orig); } From 55ea866dca36366f3c68af4d69a05dcf4bf60e5f Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 29 Jan 2020 21:05:52 +0100 Subject: [PATCH 04/12] move feature guard from strong_enum to struct --- sqlx-macros/src/derives.rs | 111 ++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 5716eb8a9e..8ddc6203b4 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -358,40 +358,34 @@ fn expand_derive_encode_strong_enum( let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#ident :: #id => #rename,)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#ident :: #id => #name,)); - } + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); } - - tts.extend(quote!( - impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - let val = match self { - #(#value_arms)* - }; - >::encode(val, buf) - } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - >::size_hint(val) - } - } - )); } - Ok(tts) + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) } fn expand_derive_encode_struct( @@ -579,7 +573,7 @@ fn expand_derive_decode_strong_enum( // TODO: prevent heap allocation Ok(quote!( - impl sqlx::decode::Decode for #ident where String: sqlx::decode::Decode { + impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { fn decode(buf: &[u8]) -> std::result::Result { let val = >::decode(buf)?; match val.as_str() { @@ -597,31 +591,34 @@ fn expand_derive_decode_struct( ) -> syn::Result { check_struct_attributes(input, fields)?; - let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); - let column_count = fields.len(); + if cfg!(feature = "postgres") { + let ident = &input.ident; - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); + let column_count = fields.len(); - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); - let mut reads: Vec> = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - names.push(id.clone().unwrap()); - let ty = &field.ty; - reads.push(parse_quote!( + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( if buf.len() < 8 { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); } @@ -642,10 +639,10 @@ fn expand_derive_decode_struct( let buf = &buf[8+len..]; )); - } - let reads = reads.into_iter().flatten(); + } + let reads = reads.into_iter().flatten(); - Ok(quote!( + tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { fn decode(buf: &[u8]) -> std::result::Result { if buf.len() < 4 { @@ -670,6 +667,8 @@ fn expand_derive_decode_struct( } } )) + } + Ok(tts) } pub(crate) fn expand_derive_has_sql_type( From 1d2caf76f3f3171f62afa732106efb9f48de2c55 Mon Sep 17 00:00:00 2001 From: freax13 Date: Thu, 30 Jan 2020 12:59:09 +0100 Subject: [PATCH 05/12] split derives into different files --- sqlx-macros/src/derives.rs | 846 ------------------------ sqlx-macros/src/derives/attributes.rs | 261 ++++++++ sqlx-macros/src/derives/decode.rs | 221 +++++++ sqlx-macros/src/derives/encode.rs | 208 ++++++ sqlx-macros/src/derives/has_sql_type.rs | 169 +++++ sqlx-macros/src/derives/mod.rs | 25 + 6 files changed, 884 insertions(+), 846 deletions(-) delete mode 100644 sqlx-macros/src/derives.rs create mode 100644 sqlx-macros/src/derives/attributes.rs create mode 100644 sqlx-macros/src/derives/decode.rs create mode 100644 sqlx-macros/src/derives/encode.rs create mode 100644 sqlx-macros/src/derives/has_sql_type.rs create mode 100644 sqlx-macros/src/derives/mod.rs diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs deleted file mode 100644 index 8ddc6203b4..0000000000 --- a/sqlx-macros/src/derives.rs +++ /dev/null @@ -1,846 +0,0 @@ -use proc_macro2::Ident; -use quote::quote; -use std::iter::FromIterator; -use syn::punctuated::Punctuated; -use syn::token::Comma; -use syn::{ - parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, - Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Stmt, Variant, -}; - -macro_rules! assert_attribute { - ($e:expr, $err:expr, $input:expr) => { - if !$e { - return Err(syn::Error::new_spanned($input, $err)); - } - }; -} - -struct SqlxAttributes { - transparent: bool, - postgres_oid: Option, - repr: Option, - rename: Option, -} - -fn parse_attributes(input: &[Attribute]) -> syn::Result { - let mut transparent = None; - let mut postgres_oid = None; - let mut repr = None; - let mut rename = None; - - macro_rules! fail { - ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); - }; - } - - macro_rules! try_set { - ($i:ident, $v:expr, $t:expr) => { - match $i { - None => $i = Some($v), - Some(_) => fail!($t, "duplicate attribute"), - } - }; - } - - for attr in input { - let meta = attr - .parse_meta() - .map_err(|e| syn::Error::new_spanned(attr, e))?; - match meta { - Meta::List(list) if list.path.is_ident("sqlx") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(meta) => match meta { - Meta::Path(p) if p.is_ident("transparent") => { - try_set!(transparent, true, value) - } - Meta::NameValue(MetaNameValue { - path, - lit: Lit::Str(val), - .. - }) if path.is_ident("rename") => try_set!(rename, val.value(), value), - Meta::List(list) if list.path.is_ident("postgres") => { - for value in list.nested.iter() { - match value { - NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Int(val), - .. - })) if path.is_ident("oid") => { - try_set!(postgres_oid, val.base10_parse()?, value); - } - u => fail!(u, "unexpected value"), - } - } - } - - u => fail!(u, "unexpected attribute"), - }, - u => fail!(u, "unexpected attribute"), - } - } - } - Meta::List(list) if list.path.is_ident("repr") => { - if list.nested.len() != 1 { - fail!(&list.nested, "expected one value") - } - match list.nested.first().unwrap() { - NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { - try_set!(repr, p.get_ident().unwrap().clone(), list); - } - u => fail!(u, "unexpected value"), - } - } - _ => {} - } - } - - Ok(SqlxAttributes { - transparent: transparent.unwrap_or(false), - postgres_oid, - repr, - rename, - }) -} - -fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - attributes.transparent, - "expected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - Ok(()) -} - -fn check_enum_attributes<'a>( - input: &'a DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - variant - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - variant - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); - } - - Ok(attributes) -} - -fn check_weak_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); - for variant in variants { - let attributes = parse_attributes(&variant.attrs)?; - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - variant - ); - } - Ok(attributes.repr.unwrap()) -} - -fn check_strong_enum_attributes( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_enum_attributes(input, variants)?; - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - Ok(attributes) -} - -fn check_struct_attributes<'a>( - input: &'a DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = parse_attributes(&input.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - input - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_some(), - "expected #[sqlx(postgres(oid = ..))]", - input - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - input - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - - for field in fields { - let attributes = parse_attributes(&field.attrs)?; - assert_attribute!( - !attributes.transparent, - "unexpected #[sqlx(transparent)]", - field - ); - #[cfg(feature = "postgres")] - assert_attribute!( - attributes.postgres_oid.is_none(), - "unexpected #[sqlx(postgres(oid = ..))]", - field - ); - assert_attribute!( - attributes.rename.is_none(), - "unexpected #[sqlx(rename = ..)]", - field - ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); - } - - Ok(attributes) -} - -pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result { - let args = parse_attributes(&input.attrs)?; - - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_encode_transparent(&input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match args.repr { - Some(_) => expand_derive_encode_weak_enum(input, variants), - None => expand_derive_encode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_encode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_encode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::encode::Encode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&self.0, buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&self.0, buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&self.0) - } - } - )) -} - -fn expand_derive_encode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - Ok(quote!( - impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - sqlx::encode::Encode::encode(&(*self as #repr), buf) - } - fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { - sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) - } - fn size_hint(&self) -> usize { - sqlx::encode::Encode::size_hint(&(*self as #repr)) - } - } - )) -} - -fn expand_derive_encode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#ident :: #id => #rename,)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#ident :: #id => #name,)); - } - } - - Ok(quote!( - impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { - fn encode(&self, buf: &mut std::vec::Vec) { - let val = match self { - #(#value_arms)* - }; - >::encode(val, buf) - } - fn size_hint(&self) -> usize { - let val = match self { - #(#value_arms)* - }; - >::size_hint(val) - } - } - )) -} - -fn expand_derive_encode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, &fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut writes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - writes.push(parse_quote!({ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - sqlx::encode::Encode::::encode(&self. #id, buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); - })); - } - - let mut sizes: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - let ty = &field.ty; - sizes.push( - parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), - ); - } - - tts.extend(quote!( - impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { - fn encode(&self, buf: &mut std::vec::Vec) { - buf.extend(&(#column_count as u32).to_be_bytes()); - #(#writes)* - } - fn size_hint(&self) -> usize { - 4 + #column_count * (4 + 4) + #(#sizes)+* - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_decode_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_decode_weak_enum(input, variants), - None => expand_derive_decode_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_decode_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_decode_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - generics.params.insert(0, parse_quote!(DB: sqlx::Database)); - generics - .make_where_clause() - .predicates - .push(parse_quote!(#ty: sqlx::decode::Decode)); - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - Ok(quote!( - impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { - fn decode(raw: &[u8]) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode(raw).map(Self) - } - fn decode_null() -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_null().map(Self) - } - fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { - <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) - } - } - )) -} - -fn expand_derive_decode_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, &variants)?; - - let ident = &input.ident; - let arms = variants - .iter() - .map(|v| { - let id = &v.ident; - parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) - }) - .collect::>(); - - Ok(quote!( - impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { - fn decode(raw: &[u8]) -> std::result::Result { - let val = <#repr as sqlx::decode::Decode>::decode(raw)?; - match val { - #(#arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - check_strong_enum_attributes(input, &variants)?; - - let ident = &input.ident; - - let mut value_arms = Vec::new(); - for v in variants { - let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; - if let Some(rename) = attributes.rename { - value_arms.push(quote!(#rename => Ok(#ident :: #id),)); - } else { - let name = id.to_string(); - value_arms.push(quote!(#name => Ok(#ident :: #id),)); - } - } - - // TODO: prevent heap allocation - Ok(quote!( - impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { - fn decode(buf: &[u8]) -> std::result::Result { - let val = >::decode(buf)?; - match val.as_str() { - #(#value_arms)* - _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) - } - } - } - )) -} - -fn expand_derive_decode_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - check_struct_attributes(input, fields)?; - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let ident = &input.ident; - - let column_count = fields.len(); - - // extract type generics - let generics = &input.generics; - let (_, ty_generics, _) = generics.split_for_impl(); - - // add db type for impl generics & where clause - let mut generics = generics.clone(); - let predicates = &mut generics.make_where_clause().predicates; - for field in fields { - let ty = &field.ty; - predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); - predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); - } - let (impl_generics, _, where_clause) = generics.split_for_impl(); - - let mut reads: Vec> = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { - let id = &field.ident; - names.push(id.clone().unwrap()); - let ty = &field.ty; - reads.push(parse_quote!( - if buf.len() < 8 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let raw = &buf[8..8+len]; - let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; - - let buf = &buf[8+len..]; - )); - } - let reads = reads.into_iter().flatten(); - - tts.extend(quote!( - impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { - fn decode(buf: &[u8]) -> std::result::Result { - if buf.len() < 4 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; - if column_count != #column_count { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); - } - let buf = &buf[4..]; - - #(#reads)* - - if !buf.is_empty() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); - } - - Ok(#ident { - #(#names),* - }) - } - } - )) - } - Ok(tts) -} - -pub(crate) fn expand_derive_has_sql_type( - input: &DeriveInput, -) -> syn::Result { - let attrs = parse_attributes(&input.attrs)?; - match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), - .. - }) if unnamed.len() == 1 => { - expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) - } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), - None => expand_derive_has_sql_type_strong_enum(input, variants), - }, - Data::Struct(DataStruct { - fields: Fields::Named(FieldsNamed { named, .. }), - .. - }) => expand_derive_has_sql_type_struct(input, named), - _ => Err(syn::Error::new_spanned( - input, - "expected a tuple struct with a single field", - )), - } -} - -fn expand_derive_has_sql_type_transparent( - input: &DeriveInput, - field: &Field, -) -> syn::Result { - check_transparent_attributes(input, field)?; - - let ident = &input.ident; - let ty = &field.ty; - - // extract type generics - let generics = &input.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); - - // add db type for clause - let mut generics = generics.clone(); - generics - .make_where_clause() - .predicates - .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); - let (_, _, where_clause) = generics.split_for_impl(); - - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_weak_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let repr = check_weak_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - if cfg!(feature = "postgres") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { - fn type_info() -> Self::TypeInfo { - >::type_info() - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_strong_enum( - input: &DeriveInput, - variants: &Punctuated, -) -> syn::Result { - let attributes = check_strong_enum_attributes(input, variants)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "mysql") { - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { - fn type_info() -> Self::TypeInfo { - sqlx::mysql::MySqlTypeInfo::r#enum() - } - } - )); - } - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -fn expand_derive_has_sql_type_struct( - input: &DeriveInput, - fields: &Punctuated, -) -> syn::Result { - let attributes = check_struct_attributes(input, fields)?; - - let ident = &input.ident; - let mut tts = proc_macro2::TokenStream::new(); - - if cfg!(feature = "postgres") { - let oid = attributes.postgres_oid.unwrap(); - tts.extend(quote!( - impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { - fn type_info() -> Self::TypeInfo { - sqlx::postgres::PgTypeInfo::with_oid(#oid) - } - } - )); - } - - Ok(tts) -} - -pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { - let encode_tts = expand_derive_encode(input)?; - let decode_tts = expand_derive_decode(input)?; - let has_sql_type_tts = expand_derive_has_sql_type(input)?; - - let combined = proc_macro2::TokenStream::from_iter( - encode_tts - .into_iter() - .chain(decode_tts) - .chain(has_sql_type_tts), - ); - Ok(combined) -} diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs new file mode 100644 index 0000000000..72df69910d --- /dev/null +++ b/sqlx-macros/src/derives/attributes.rs @@ -0,0 +1,261 @@ +use proc_macro2::Ident; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Variant}; + +macro_rules! assert_attribute { + ($e:expr, $err:expr, $input:expr) => { + if !$e { + return Err(syn::Error::new_spanned($input, $err)); + } + }; +} + +pub struct SqlxAttributes { + pub transparent: bool, + pub postgres_oid: Option, + pub repr: Option, + pub rename: Option, +} + +pub fn parse_attributes(input: &[Attribute]) -> syn::Result { + let mut transparent = None; + let mut postgres_oid = None; + let mut repr = None; + let mut rename = None; + + macro_rules! fail { + ($t:expr, $m:expr) => { + return Err(syn::Error::new_spanned($t, $m)); + }; + } + + macro_rules! try_set { + ($i:ident, $v:expr, $t:expr) => { + match $i { + None => $i = Some($v), + Some(_) => fail!($t, "duplicate attribute"), + } + }; + } + + for attr in input { + let meta = attr + .parse_meta() + .map_err(|e| syn::Error::new_spanned(attr, e))?; + match meta { + Meta::List(list) if list.path.is_ident("sqlx") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(meta) => match meta { + Meta::Path(p) if p.is_ident("transparent") => { + try_set!(transparent, true, value) + } + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("rename") => try_set!(rename, val.value(), value), + Meta::List(list) if list.path.is_ident("postgres") => { + for value in list.nested.iter() { + match value { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Int(val), + .. + })) if path.is_ident("oid") => { + try_set!(postgres_oid, val.base10_parse()?, value); + } + u => fail!(u, "unexpected value"), + } + } + } + + u => fail!(u, "unexpected attribute"), + }, + u => fail!(u, "unexpected attribute"), + } + } + } + Meta::List(list) if list.path.is_ident("repr") => { + if list.nested.len() != 1 { + fail!(&list.nested, "expected one value") + } + match list.nested.first().unwrap() { + NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { + try_set!(repr, p.get_ident().unwrap().clone(), list); + } + u => fail!(u, "unexpected value"), + } + } + _ => {} + } + } + + Ok(SqlxAttributes { + transparent: transparent.unwrap_or(false), + postgres_oid, + repr, + rename, + }) +} + +pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + attributes.transparent, + "expected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + Ok(()) +} + +pub fn check_enum_attributes<'a>( + input: &'a DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + variant + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + variant + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant); + } + + Ok(attributes) +} + +pub fn check_weak_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); + for variant in variants { + let attributes = parse_attributes(&variant.attrs)?; + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + variant + ); + } + Ok(attributes.repr.unwrap()) +} + +pub fn check_strong_enum_attributes( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_enum_attributes(input, variants)?; + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + Ok(attributes) +} + +pub fn check_struct_attributes<'a>( + input: &'a DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = parse_attributes(&input.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + input + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_some(), + "expected #[sqlx(postgres(oid = ..))]", + input + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + input + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); + + for field in fields { + let attributes = parse_attributes(&field.attrs)?; + assert_attribute!( + !attributes.transparent, + "unexpected #[sqlx(transparent)]", + field + ); + #[cfg(feature = "postgres")] + assert_attribute!( + attributes.postgres_oid.is_none(), + "unexpected #[sqlx(postgres(oid = ..))]", + field + ); + assert_attribute!( + attributes.rename.is_none(), + "unexpected #[sqlx(rename = ..)]", + field + ); + assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field); + } + + Ok(attributes) +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs new file mode 100644 index 0000000000..2aa779f819 --- /dev/null +++ b/sqlx-macros/src/derives/decode.rs @@ -0,0 +1,221 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use proc_macro2::Ident; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, +}; + +pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_decode_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_decode_weak_enum(input, variants), + None => expand_derive_decode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_decode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_decode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) +} + +fn expand_derive_decode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + let arms = variants + .iter() + .map(|v| { + let id = &v.ident; + parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),) + }) + .collect::>(); + + Ok(quote!( + impl sqlx::decode::Decode for #ident where #repr: sqlx::decode::Decode { + fn decode(raw: &[u8]) -> std::result::Result { + let val = <#repr as sqlx::decode::Decode>::decode(raw)?; + match val { + #(#arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#name => Ok(#ident :: #id),)); + } + } + + // TODO: prevent heap allocation + Ok(quote!( + impl sqlx::decode::Decode for #ident where std::string::String: sqlx::decode::Decode { + fn decode(buf: &[u8]) -> std::result::Result { + let val = >::decode(buf)?; + match val.as_str() { + #(#value_arms)* + _ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value"))) + } + } + } + )) +} + +fn expand_derive_decode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::decode::Decode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut reads: Vec> = Vec::new(); + let mut names: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + names.push(id.clone().unwrap()); + let ty = &field.ty; + reads.push(parse_quote!( + if buf.len() < 8 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let raw = &buf[8..8+len]; + let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; + + let buf = &buf[8+len..]; + )); + } + let reads = reads.into_iter().flatten(); + + tts.extend(quote!( + impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { + fn decode(buf: &[u8]) -> std::result::Result { + if buf.len() < 4 { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); + } + + let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize; + if column_count != #column_count { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); + } + let buf = &buf[4..]; + + #(#reads)* + + if !buf.is_empty() { + return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len())))); + } + + Ok(#ident { + #(#names),* + }) + } + } + )) + } + Ok(tts) +} diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs new file mode 100644 index 0000000000..03c91d00a4 --- /dev/null +++ b/sqlx-macros/src/derives/encode.rs @@ -0,0 +1,208 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { + let args = parse_attributes(&input.attrs)?; + + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_encode_transparent(&input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match args.repr { + Some(_) => expand_derive_encode_weak_enum(input, variants), + None => expand_derive_encode_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_encode_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_encode_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) +} + +fn expand_derive_encode_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + Ok(quote!( + impl sqlx::encode::Encode for #ident where #repr: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&(*self as #repr), buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&(*self as #repr)) + } + } + )) +} + +fn expand_derive_encode_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + check_strong_enum_attributes(input, &variants)?; + + let ident = &input.ident; + + let mut value_arms = Vec::new(); + for v in variants { + let id = &v.ident; + let attributes = parse_attributes(&v.attrs)?; + if let Some(rename) = attributes.rename { + value_arms.push(quote!(#ident :: #id => #rename,)); + } else { + let name = id.to_string(); + value_arms.push(quote!(#ident :: #id => #name,)); + } + } + + Ok(quote!( + impl sqlx::encode::Encode for #ident where str: sqlx::encode::Encode { + fn encode(&self, buf: &mut std::vec::Vec) { + let val = match self { + #(#value_arms)* + }; + >::encode(val, buf) + } + fn size_hint(&self) -> usize { + let val = match self { + #(#value_arms)* + }; + >::size_hint(val) + } + } + )) +} + +fn expand_derive_encode_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + check_struct_attributes(input, &fields)?; + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let ident = &input.ident; + + let column_count = fields.len(); + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + let predicates = &mut generics.make_where_clause().predicates; + for field in fields { + let ty = &field.ty; + predicates.push(parse_quote!(#ty: sqlx::encode::Encode)); + predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>)); + } + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut writes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + writes.push(parse_quote!({ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + sqlx::encode::Encode::::encode(&self. #id, buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); + })); + } + + let mut sizes: Vec = Vec::new(); + for field in fields { + let id = &field.ident; + let ty = &field.ty; + sizes.push( + parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), + ); + } + + tts.extend(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + buf.extend(&(#column_count as u32).to_be_bytes()); + #(#writes)* + } + fn size_hint(&self) -> usize { + 4 + #column_count * (4 + 4) + #(#sizes)+* + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/has_sql_type.rs new file mode 100644 index 0000000000..64da9ae1c5 --- /dev/null +++ b/sqlx-macros/src/derives/has_sql_type.rs @@ -0,0 +1,169 @@ +use super::attributes::{ + check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, + check_weak_enum_attributes, parse_attributes, +}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{ + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, + FieldsUnnamed, Variant, +}; + +pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result { + let attrs = parse_attributes(&input.attrs)?; + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_has_sql_type_struct(input, named), + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +fn expand_derive_has_sql_type_transparent( + input: &DeriveInput, + field: &Field, +) -> syn::Result { + check_transparent_attributes(input, field)?; + + let ident = &input.ident; + let ty = &field.ty; + + // extract type generics + let generics = &input.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + // add db type for clause + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>)); + let (_, _, where_clause) = generics.split_for_impl(); + + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_weak_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let repr = check_weak_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + if cfg!(feature = "postgres") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > { + fn type_info() -> Self::TypeInfo { + >::type_info() + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_strong_enum( + input: &DeriveInput, + variants: &Punctuated, +) -> syn::Result { + let attributes = check_strong_enum_attributes(input, variants)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "mysql") { + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::MySql { + fn type_info() -> Self::TypeInfo { + sqlx::mysql::MySqlTypeInfo::r#enum() + } + } + )); + } + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} + +fn expand_derive_has_sql_type_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let attributes = check_struct_attributes(input, fields)?; + + let ident = &input.ident; + let mut tts = proc_macro2::TokenStream::new(); + + if cfg!(feature = "postgres") { + let oid = attributes.postgres_oid.unwrap(); + tts.extend(quote!( + impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres { + fn type_info() -> Self::TypeInfo { + sqlx::postgres::PgTypeInfo::with_oid(#oid) + } + } + )); + } + + Ok(tts) +} diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs new file mode 100644 index 0000000000..28d9eee771 --- /dev/null +++ b/sqlx-macros/src/derives/mod.rs @@ -0,0 +1,25 @@ +mod attributes; +mod decode; +mod encode; +mod has_sql_type; + +pub(crate) use decode::expand_derive_decode; +pub(crate) use encode::expand_derive_encode; +pub(crate) use has_sql_type::expand_derive_has_sql_type; + +use std::iter::FromIterator; +use syn::DeriveInput; + +pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result { + let encode_tts = expand_derive_encode(input)?; + let decode_tts = expand_derive_decode(input)?; + let has_sql_type_tts = expand_derive_has_sql_type(input)?; + + let combined = proc_macro2::TokenStream::from_iter( + encode_tts + .into_iter() + .chain(decode_tts) + .chain(has_sql_type_tts), + ); + Ok(combined) +} From b2c759aed2028e6f3b0d21e4545ff0dc40722b17 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:43:37 +0100 Subject: [PATCH 06/12] move decode_struct_field and encode_struct_field to sqlx-core --- sqlx-core/src/postgres/mod.rs | 2 + sqlx-core/src/postgres/types/mod.rs | 1 + sqlx-core/src/postgres/types/struct.rs | 59 ++++++++++++++++++++++++++ sqlx-macros/src/derives/decode.rs | 27 ++---------- sqlx-macros/src/derives/encode.rs | 26 +++--------- 5 files changed, 72 insertions(+), 43 deletions(-) create mode 100644 sqlx-core/src/postgres/types/struct.rs diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 7afcee130d..43c8937950 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -5,6 +5,8 @@ pub use connection::PgConnection; pub use database::Postgres; pub use error::PgError; pub use row::PgRow; +#[doc(hidden)] +pub use types::r#struct::{decode_struct_field, encode_struct_field}; pub use types::PgTypeInfo; mod arguments; diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index a81a47ec16..7aefc2aff3 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -3,6 +3,7 @@ mod bytes; mod float; mod int; mod str; +pub mod r#struct; #[cfg(feature = "chrono")] mod chrono; diff --git a/sqlx-core/src/postgres/types/struct.rs b/sqlx-core/src/postgres/types/struct.rs new file mode 100644 index 0000000000..ee9a080e2c --- /dev/null +++ b/sqlx-core/src/postgres/types/struct.rs @@ -0,0 +1,59 @@ +use crate::decode::{Decode, DecodeError}; +use crate::encode::Encode; +use crate::postgres::protocol::TypeId; +use crate::postgres::types::PgTypeInfo; +use crate::types::HasSqlType; +use crate::Postgres; +use std::convert::TryInto; + +/// read a struct field and advance the buffer +pub fn decode_struct_field>(buf: &mut &[u8]) -> Result +where + Postgres: HasSqlType, +{ + if buf.len() < 8 { + return Err(DecodeError::Message(std::boxed::Box::new( + "Not enough data sent", + ))); + } + + let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); + if oid != >::type_info().oid() { + return Err(DecodeError::Message(std::boxed::Box::new("Invalid oid"))); + } + + let len = u32::from_be_bytes(buf[4..8].try_into().unwrap()) as usize; + + if buf.len() < 8 + len { + return Err(DecodeError::Message(std::boxed::Box::new( + "Not enough data sent", + ))); + } + + let raw = &buf[8..8 + len]; + let value = T::decode(raw)?; + + *buf = &buf[8 + len..]; + + Ok(value) +} + +pub fn encode_struct_field>(buf: &mut Vec, value: &T) +where + Postgres: HasSqlType, +{ + // write oid + let info = >::type_info(); + buf.extend(&info.oid().to_be_bytes()); + + // write zeros for length + buf.extend(&[0; 4]); + + let start = buf.len(); + value.encode(buf); + let end = buf.len(); + let size = end - start; + + // replaces zeros with actual length + buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes()); +} diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 2aa779f819..dc9322dbb9 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -161,35 +161,16 @@ fn expand_derive_decode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut reads: Vec> = Vec::new(); + let mut reads: Vec = Vec::new(); let mut names: Vec = Vec::new(); for field in fields { let id = &field.ident; names.push(id.clone().unwrap()); let ty = &field.ty; reads.push(parse_quote!( - if buf.len() < 8 { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap()); - if oid != >::type_info().oid() { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid"))); - } - - let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize; - - if buf.len() < 8 + len { - return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent"))); - } - - let raw = &buf[8..8+len]; - let #id = <#ty as sqlx::decode::Decode>::decode(raw)?; - - let buf = &buf[8+len..]; - )); + let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; + )); } - let reads = reads.into_iter().flatten(); tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { @@ -202,7 +183,7 @@ fn expand_derive_decode_struct( if column_count != #column_count { return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count"))); } - let buf = &buf[4..]; + let mut buf = &buf[4..]; #(#reads)* diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 03c91d00a4..84aa53f12c 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -6,8 +6,8 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ - parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, - FieldsUnnamed, Variant, + parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, + FieldsUnnamed, Stmt, Variant, }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { @@ -160,26 +160,12 @@ fn expand_derive_encode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut writes: Vec = Vec::new(); + let mut writes: Vec = Vec::new(); for field in fields { let id = &field.ident; - let ty = &field.ty; - writes.push(parse_quote!({ - // write oid - let info = >::type_info(); - buf.extend(&info.oid().to_be_bytes()); - - // write zeros for length - buf.extend(&[0; 4]); - - let start = buf.len(); - sqlx::encode::Encode::::encode(&self. #id, buf); - let end = buf.len(); - let size = end - start; - - // replaces zeros with actual length - buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes()); - })); + writes.push(parse_quote!( + sqlx::postgres::encode_struct_field(buf, &self. #id); + )); } let mut sizes: Vec = Vec::new(); From dd27f7e0f53f9cb9cc0d9e3eb681122947a176d6 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:51:41 +0100 Subject: [PATCH 07/12] switch from vecs to iterator chains --- sqlx-macros/src/derives/decode.rs | 14 ++++++-------- sqlx-macros/src/derives/encode.rs | 20 +++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index dc9322dbb9..fa4604ef8d 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -2,7 +2,6 @@ use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_attributes, }; -use proc_macro2::Ident; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; @@ -161,16 +160,15 @@ fn expand_derive_decode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut reads: Vec = Vec::new(); - let mut names: Vec = Vec::new(); - for field in fields { + let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; - names.push(id.clone().unwrap()); let ty = &field.ty; - reads.push(parse_quote!( + parse_quote!( let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?; - )); - } + ) + }); + + let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( impl #impl_generics sqlx::decode::Decode for #ident#ty_generics #where_clause { diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index 84aa53f12c..e5f4da45cf 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -160,22 +160,20 @@ fn expand_derive_encode_struct( } let (impl_generics, _, where_clause) = generics.split_for_impl(); - let mut writes: Vec = Vec::new(); - for field in fields { + let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; - writes.push(parse_quote!( + parse_quote!( sqlx::postgres::encode_struct_field(buf, &self. #id); - )); - } + ) + }); - let mut sizes: Vec = Vec::new(); - for field in fields { + let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; - sizes.push( - parse_quote!(<#ty as sqlx::encode::Encode>::size_hint(&self. #id)), - ); - } + parse_quote!( + <#ty as sqlx::encode::Encode>::size_hint(&self. #id) + ) + }); tts.extend(quote!( impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { From 642e98da6e6e90c0dc6cb4fcd56baf5eda3b28e3 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:54:05 +0100 Subject: [PATCH 08/12] add explanation for size_hint --- sqlx-macros/src/derives/encode.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index e5f4da45cf..fc1d32c13a 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -182,7 +182,9 @@ fn expand_derive_encode_struct( #(#writes)* } fn size_hint(&self) -> usize { - 4 + #column_count * (4 + 4) + #(#sizes)+* + 4 // oid + + #column_count * (4 + 4) // oid (int) and length (int) for each column + + #(#sizes)+* // sum of the size hints for each column } } )); From 71a962dc63044e1aa97e10e8b1714cbf6ef42f59 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 11:58:20 +0100 Subject: [PATCH 09/12] fix error messages --- sqlx-macros/src/derives/decode.rs | 15 +++++++++++++-- sqlx-macros/src/derives/encode.rs | 15 +++++++++++++-- sqlx-macros/src/derives/has_sql_type.rs | 15 +++++++++++++-- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index fa4604ef8d..23c6427d32 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -27,9 +27,20 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result expand_derive_decode_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "unit structs are not supported", )), } } diff --git a/sqlx-macros/src/derives/encode.rs b/sqlx-macros/src/derives/encode.rs index fc1d32c13a..1e0c2caa77 100644 --- a/sqlx-macros/src/derives/encode.rs +++ b/sqlx-macros/src/derives/encode.rs @@ -28,9 +28,20 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result expand_derive_encode_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "unit structs are not supported", )), } } diff --git a/sqlx-macros/src/derives/has_sql_type.rs b/sqlx-macros/src/derives/has_sql_type.rs index 64da9ae1c5..8251d5005b 100644 --- a/sqlx-macros/src/derives/has_sql_type.rs +++ b/sqlx-macros/src/derives/has_sql_type.rs @@ -27,9 +27,20 @@ pub fn expand_derive_has_sql_type(input: &DeriveInput) -> syn::Result expand_derive_has_sql_type_struct(input, named), - _ => Err(syn::Error::new_spanned( + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + Data::Struct(DataStruct { + fields: Fields::Unnamed(..), + .. + }) => Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )), + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( input, - "expected a tuple struct with a single field", + "unit structs are not supported", )), } } From 9d0927c10e0e41998c190e4c0ae5b2873cfbe1fe Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:16:00 +0100 Subject: [PATCH 10/12] add tests for postgres struct field encoding --- Cargo.toml | 4 ++++ tests/postgres-struct.rs | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 tests/postgres-struct.rs diff --git a/Cargo.toml b/Cargo.toml index 6b55257e97..0f6d1ee458 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,5 +90,9 @@ required-features = [ "mysql", "chrono", "macros" ] name = "derives" required-features = [ "macros" ] +[[test]] +name = "postgres-struct" +required-features = [ "postgres" ] + [profile.release] lto = true diff --git a/tests/postgres-struct.rs b/tests/postgres-struct.rs new file mode 100644 index 0000000000..d108343100 --- /dev/null +++ b/tests/postgres-struct.rs @@ -0,0 +1,39 @@ +use sqlx::encode::Encode; +use sqlx::postgres::{decode_struct_field, encode_struct_field}; +use sqlx::types::HasSqlType; +use sqlx::Postgres; +use std::convert::TryInto; + +#[test] +fn test_encode_field() { + let value = "Foo Bar"; + let mut raw_encoded = Vec::new(); + <&str as Encode>::encode(&value, &mut raw_encoded); + let mut field_encoded = Vec::new(); + encode_struct_field(&mut field_encoded, &value); + + // check oid + let oid = >::type_info().oid(); + let field_encoded_oid = u32::from_be_bytes(field_encoded[0..4].try_into().unwrap()); + assert_eq!(oid, field_encoded_oid); + + // check length + let field_encoded_length = u32::from_be_bytes(field_encoded[4..8].try_into().unwrap()); + assert_eq!(raw_encoded.len(), field_encoded_length as usize); + + // check data + assert_eq!(raw_encoded, &field_encoded[8..]); +} + +#[test] +fn test_decode_field() { + let value = "Foo Bar".to_string(); + + let mut buf = Vec::new(); + encode_struct_field(&mut buf, &value); + + let mut buf = buf.as_slice(); + let value_decoded: String = decode_struct_field(&mut buf).unwrap(); + assert_eq!(value_decoded, value); + assert!(buf.is_empty()); +} From 57164438f39ed5a61672142fac7e95dfd86abfda Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:20:17 +0100 Subject: [PATCH 11/12] removed unused imports --- sqlx-core/src/postgres/types/struct.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlx-core/src/postgres/types/struct.rs b/sqlx-core/src/postgres/types/struct.rs index ee9a080e2c..e07f5c389c 100644 --- a/sqlx-core/src/postgres/types/struct.rs +++ b/sqlx-core/src/postgres/types/struct.rs @@ -1,7 +1,5 @@ use crate::decode::{Decode, DecodeError}; use crate::encode::Encode; -use crate::postgres::protocol::TypeId; -use crate::postgres::types::PgTypeInfo; use crate::types::HasSqlType; use crate::Postgres; use std::convert::TryInto; From 96f7efc8c022362e129e31640ec08be03bf28586 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 10 Feb 2020 12:38:05 +0100 Subject: [PATCH 12/12] use iterator change in expand_derive_strong_enum --- sqlx-macros/src/derives/decode.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sqlx-macros/src/derives/decode.rs b/sqlx-macros/src/derives/decode.rs index 23c6427d32..93f1adaec7 100644 --- a/sqlx-macros/src/derives/decode.rs +++ b/sqlx-macros/src/derives/decode.rs @@ -118,17 +118,16 @@ fn expand_derive_decode_strong_enum( let ident = &input.ident; - let mut value_arms = Vec::new(); - for v in variants { + let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; - let attributes = parse_attributes(&v.attrs)?; + let attributes = parse_attributes(&v.attrs).unwrap(); if let Some(rename) = attributes.rename { - value_arms.push(quote!(#rename => Ok(#ident :: #id),)); + parse_quote!(#rename => Ok(#ident :: #id),) } else { let name = id.to_string(); - value_arms.push(quote!(#name => Ok(#ident :: #id),)); + parse_quote!(#name => Ok(#ident :: #id),) } - } + }); // TODO: prevent heap allocation Ok(quote!(