Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add try_from attribute for FromRow derive #1081

Merged
merged 1 commit into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions sqlx-core/src/from_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,32 @@ use crate::row::Row;
/// }
/// }
/// ```
///
/// #### `try_from`
///
/// When your struct contains a field whose type is not matched with the database type,
/// if the field type has an implementation [`TryFrom`] for the database type,
/// you can use the `try_from` attribute to convert the database type to the field type.
/// For example:
///
/// ```rust,ignore
/// #[derive(sqlx::FromRow)]
/// struct User {
/// id: i32,
/// name: String,
/// #[sqlx(try_from = "i64")]
/// bigIntInMySql: u64
/// }
/// ```
///
/// Given a query such as:
///
/// ```sql
/// SELECT id, name, bigIntInMySql FROM users;
/// ```
///
/// In MySql, `BigInt` type matches `i64`, but you can convert it to `u64` by `try_from`.
///
pub trait FromRow<'r, R: Row>: Sized {
fn from_row(row: &'r R) -> Result<Self, Error>;
}
Expand Down
8 changes: 8 additions & 0 deletions sqlx-macros/src/derives/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct SqlxChildAttributes {
pub rename: Option<String>,
pub default: bool,
pub flatten: bool,
pub try_from: Option<Ident>,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
let mut rename = None;
let mut default = false;
let mut try_from = None;
let mut flatten = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
Expand All @@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value),
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
u => fail!(u, "unexpected attribute"),
Expand All @@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
rename,
default,
flatten,
try_from,
})
}

Expand Down
55 changes: 39 additions & 16 deletions sqlx-macros/src/derives/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,45 @@ fn expand_derive_from_row_struct(
let attributes = parse_child_attributes(&field.attrs).unwrap();
let ty = &field.ty;

let expr: Expr = if attributes.flatten {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
} else {
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
};

if attributes.default {
Expand Down
81 changes: 81 additions & 0 deletions tests/mysql/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,85 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: u64,
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: Id,
}

#[derive(Debug, PartialEq)]
struct Id(i64);
impl std::convert::TryFrom<i64> for Id {
type Error = std::io::Error;
fn try_from(value: i64) -> Result<Self, Self::Error> {
Ok(Id(value))
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, Id(id.0));

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "Id", flatten)]
id: u64,
}

#[derive(Debug, PartialEq, sqlx::FromRow)]
struct Id {
id: i64,
}

impl std::convert::TryFrom<Id> for u64 {
type Error = std::io::Error;
fn try_from(value: Id) -> Result<Self, Self::Error> {
Ok(value.id as u64)
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant