Skip to content

Commit

Permalink
feat: Add try_from attribute for FromRow
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhengzhuo committed Sep 7, 2022
1 parent 18a76fb commit b8d0b1f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 16 deletions.
8 changes: 8 additions & 0 deletions sqlx-macros/src/derives/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct SqlxChildAttributes {
pub rename: Option<String>,
pub default: bool,
pub flatten: bool,
pub try_from: Option<Ident>,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
let mut rename = None;
let mut default = false;
let mut try_from = None;
let mut flatten = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
Expand All @@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value),
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
u => fail!(u, "unexpected attribute"),
Expand All @@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
rename,
default,
flatten,
try_from,
})
}

Expand Down
55 changes: 39 additions & 16 deletions sqlx-macros/src/derives/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,45 @@ fn expand_derive_from_row_struct(
let attributes = parse_child_attributes(&field.attrs).unwrap();
let ty = &field.ty;

let expr: Expr = if attributes.flatten {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
} else {
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
};

if attributes.default {
Expand Down
81 changes: 81 additions & 0 deletions tests/mysql/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,85 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: u64,
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: Id,
}

#[derive(Debug, PartialEq)]
struct Id(i64);
impl std::convert::TryFrom<i64> for Id {
type Error = std::io::Error;
fn try_from(value: i64) -> Result<Self, Self::Error> {
Ok(Id(value))
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, Id(id.0));

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "Id", flatten)]
id: u64,
}

#[derive(Debug, PartialEq, sqlx::FromRow)]
struct Id {
id: i64,
}

impl std::convert::TryFrom<Id> for u64 {
type Error = std::io::Error;
fn try_from(value: Id) -> Result<Self, Self::Error> {
Ok(value.id as u64)
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant

0 comments on commit b8d0b1f

Please sign in to comment.