diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 46d7aef85f..85a853f27d 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -71,6 +71,7 @@ pub struct SqlxChildAttributes { pub rename: Option, pub default: bool, pub flatten: bool, + pub try_from: Option, } pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { @@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result syn::Result { let mut rename = None; let mut default = false; + let mut try_from = None; let mut flatten = false; for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) { @@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result try_set!(rename, val.value(), value), + Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(val), + .. + }) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value), Meta::Path(path) if path.is_ident("default") => default = true, Meta::Path(path) if path.is_ident("flatten") => flatten = true, u => fail!(u, "unexpected attribute"), @@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result)); - parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row)) - } else { - predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); - predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); - - let id_s = attributes - .rename - .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) - .map(|s| match container_attributes.rename_all { - Some(pattern) => rename_all(&s, pattern), - None => s, - }) - .unwrap(); - parse_quote!(row.try_get(#id_s)) + let expr: Expr = match (attributes.flatten, attributes.try_from) { + (true, None) => { + predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>)); + parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row)) + } + (false, None) => { + predicates + .push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); + + let id_s = attributes + .rename + .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) + .map(|s| match container_attributes.rename_all { + Some(pattern) => rename_all(&s, pattern), + None => s, + }) + .unwrap(); + parse_quote!(row.try_get(#id_s)) + } + (true,Some(try_from)) => { + predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>)); + parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string())))) + } + (false,Some(try_from)) => { + predicates + .push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(#try_from: ::sqlx::types::Type)); + + let id_s = attributes + .rename + .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) + .map(|s| match container_attributes.rename_all { + Some(pattern) => rename_all(&s, pattern), + None => s, + }) + .unwrap(); + parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string())))) + } }; if attributes.default { diff --git a/tests/mysql/macros.rs b/tests/mysql/macros.rs index 80eb1b2e91..369a35c5e7 100644 --- a/tests/mysql/macros.rs +++ b/tests/mysql/macros.rs @@ -354,4 +354,85 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> { + #[derive(sqlx::FromRow)] + struct Record { + #[sqlx(try_from = "i64")] + id: u64, + } + + let mut conn = new::().await?; + let (mut conn, id) = with_test_row(&mut conn).await?; + + let record = sqlx::query_as::<_, Record>("select id from tweet") + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, id.0 as u64); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> { + #[derive(sqlx::FromRow)] + struct Record { + #[sqlx(try_from = "i64")] + id: Id, + } + + #[derive(Debug, PartialEq)] + struct Id(i64); + impl std::convert::TryFrom for Id { + type Error = std::io::Error; + fn try_from(value: i64) -> Result { + Ok(Id(value)) + } + } + + let mut conn = new::().await?; + let (mut conn, id) = with_test_row(&mut conn).await?; + + let record = sqlx::query_as::<_, Record>("select id from tweet") + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, Id(id.0)); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> { + #[derive(sqlx::FromRow)] + struct Record { + #[sqlx(try_from = "Id", flatten)] + id: u64, + } + + #[derive(Debug, PartialEq, sqlx::FromRow)] + struct Id { + id: i64, + } + + impl std::convert::TryFrom for u64 { + type Error = std::io::Error; + fn try_from(value: Id) -> Result { + Ok(value.id as u64) + } + } + + let mut conn = new::().await?; + let (mut conn, id) = with_test_row(&mut conn).await?; + + let record = sqlx::query_as::<_, Record>("select id from tweet") + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, id.0 as u64); + + Ok(()) +} + // we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant