Skip to content

Commit

Permalink
Generate an enum for @oneOf input (#450)
Browse files Browse the repository at this point in the history
* Add `is_one_of` to input

* Extract struct building from input codegen

* Generate enums for `oneOf` input

* Clippy fixes

* Empty commit for CI

---------

Co-authored-by: Surma <surma@surma.dev>
  • Loading branch information
jbourassa and surma authored May 25, 2023
1 parent c5847ce commit fcd1f34
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 53 deletions.
4 changes: 2 additions & 2 deletions graphql_client/tests/input_object_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn input_object_variables_default() {
msg: default_input_object_variables_query::Variables::default_msg(),
};

let out = serde_json::to_value(&variables).unwrap();
let out = serde_json::to_value(variables).unwrap();

let expected_default = serde_json::json!({
"msg":{"content":null,"to":{"category":null,"email":"rosa.luxemburg@example.com","name":null}}
Expand Down Expand Up @@ -130,7 +130,7 @@ pub struct RustNameQuery;
#[test]
fn rust_name_correctly_mapped() {
use rust_name_query::*;
let value = serde_json::to_value(&Variables {
let value = serde_json::to_value(Variables {
extern_: Some("hello".to_owned()),
msg: <_>::default(),
})
Expand Down
30 changes: 30 additions & 0 deletions graphql_client/tests/one_of_input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use graphql_client::*;
use serde_json::*;

#[derive(GraphQLQuery)]
#[graphql(
schema_path = "tests/one_of_input/schema.graphql",
query_path = "tests/one_of_input/query.graphql",
variables_derives = "Clone"
)]
pub struct OneOfMutation;

#[test]
fn one_of_input() {
use one_of_mutation::*;

let author = Param::Author(Author { id: 1 });
let _ = Param::Name("Mark Twain".to_string());
let _ = Param::RecursiveDirect(Box::new(author.clone()));
let _ = Param::RecursiveIndirect(Box::new(Recursive {
param: Box::new(author.clone()),
}));
let _ = Param::RequiredInts(vec![1]);
let _ = Param::OptionalInts(vec![Some(1)]);

let query = OneOfMutation::build_query(Variables { param: author });
assert_eq!(
json!({ "param": { "author":{ "id": 1 } } }),
serde_json::to_value(&query.variables).expect("json"),
);
}
3 changes: 3 additions & 0 deletions graphql_client/tests/one_of_input/query.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mutation OneOfMutation($param: Param!) {
oneOfMutation(query: $param)
}
24 changes: 24 additions & 0 deletions graphql_client/tests/one_of_input/schema.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
schema {
mutation: Mutation
}

type Mutation {
oneOfMutation(mutation: Param!): Int
}

input Param @oneOf {
author: Author
name: String
recursiveDirect: Param
recursiveIndirect: Recursive
requiredInts: [Int!]
optionalInts: [Int]
}

input Author {
id: Int!
}

input Recursive {
param: Param!
}
2 changes: 1 addition & 1 deletion graphql_client_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ fn set_env_logger() {
.init();
}

fn colored_level<'a>(style: &'a mut Style, level: Level) -> StyledValue<'a, &'static str> {
fn colored_level(style: &mut Style, level: Level) -> StyledValue<'_, &'static str> {
match level {
Level::Trace => style.set_color(Color::Magenta).value("TRACE"),
Level::Debug => style.set_color(Color::Blue).value("DEBUG"),
Expand Down
151 changes: 108 additions & 43 deletions graphql_client_codegen/src/codegen/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use super::shared::{field_rename_annotation, keyword_replace};
use crate::{
codegen_options::GraphQLClientCodegenOptions,
query::{BoundQuery, UsedTypes},
schema::input_is_recursive_without_indirection,
schema::{input_is_recursive_without_indirection, StoredInputType},
type_qualifiers::GraphqlTypeQualifier,
};
use heck::ToSnakeCase;
use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;

Expand All @@ -17,48 +18,112 @@ pub(super) fn generate_input_object_definitions(
all_used_types
.inputs(query.schema)
.map(|(_input_id, input)| {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());

let fields = input.fields.iter().map(|(field_name, field_type)| {
let safe_field_name = keyword_replace(field_name.to_snake_case());
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let optional_skip_serializing_none =
if *options.skip_serializing_none() && field_type.is_optional() {
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
} else {
None
};
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#optional_skip_serializing_none
#annotation pub #name_ident: #field_type
)
});

quote! {
#variable_derives
pub struct #struct_name{
#(#fields,)*
}
if input.is_one_of {
generate_enum(input, options, variable_derives, query)
} else {
generate_struct(input, options, variable_derives, query)
}
})
.collect()
}

fn generate_struct(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
variable_derives: &impl quote::ToTokens,
query: &BoundQuery<'_>,
) -> TokenStream {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());

let fields = input.fields.iter().map(|(field_name, field_type)| {
let safe_field_name = keyword_replace(field_name.to_snake_case());
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let optional_skip_serializing_none =
if *options.skip_serializing_none() && field_type.is_optional() {
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
} else {
None
};
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#optional_skip_serializing_none
#annotation pub #name_ident: #field_type
)
});

quote! {
#variable_derives
pub struct #struct_name{
#(#fields,)*
}
}
}

fn generate_enum(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
variable_derives: &impl quote::ToTokens,
query: &BoundQuery<'_>,
) -> TokenStream {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let enum_name = Ident::new(safe_name.as_ref(), Span::call_site());

let variants = input.fields.iter().map(|(field_name, field_type)| {
let variant_name = field_name.to_upper_camel_case();
let safe_variant_name = keyword_replace(&variant_name);

let annotation = field_rename_annotation(field_name.as_ref(), &variant_name);
let name_ident = Ident::new(safe_variant_name.as_ref(), Span::call_site());

let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());

// Add the required qualifier so that the variant's field isn't wrapped in Option
let mut qualifiers = vec![GraphqlTypeQualifier::Required];
qualifiers.extend(field_type.qualifiers.iter().cloned());

let field_type_tokens = super::decorate_type(&type_name, &qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#annotation #name_ident(#field_type)
)
});

quote! {
#variable_derives
pub enum #enum_name{
#(#variants,)*
}
}
}
9 changes: 2 additions & 7 deletions graphql_client_codegen/src/deprecation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@ pub enum DeprecationStatus {
}

/// The available deprecation strategies.
#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub enum DeprecationStrategy {
/// Allow use of deprecated items in queries, and say nothing.
Allow,
/// Fail compilation if a deprecated item is used.
Deny,
/// Allow use of deprecated items in queries, but warn about them (default).
#[default]
Warn,
}

impl Default for DeprecationStrategy {
fn default() -> Self {
DeprecationStrategy::Warn
}
}

impl std::str::FromStr for DeprecationStrategy {
type Err = ();

Expand Down
1 change: 1 addition & 0 deletions graphql_client_codegen/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ impl StoredInputFieldType {
pub(crate) struct StoredInputType {
pub(crate) name: String,
pub(crate) fields: Vec<(String, StoredInputFieldType)>,
pub(crate) is_one_of: bool,
}

/// Intermediate representation for a parsed GraphQL schema used during code generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ fn ingest_input<'doc, T>(schema: &mut Schema, input: &mut parser::InputObjectTyp
where
T: graphql_parser::query::Text<'doc>,
{
let is_one_of = input
.directives
.iter()
.any(|directive| directive.name.as_ref() == "oneOf");

let input = super::StoredInputType {
name: input.name.as_ref().into(),
fields: input
Expand All @@ -305,6 +310,7 @@ where
)
})
.collect(),
is_one_of,
};

schema.stored_inputs.push(input);
Expand Down
3 changes: 3 additions & 0 deletions graphql_client_codegen/src/schema/json_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ fn ingest_input(schema: &mut Schema, input: &mut FullType) {
let input = super::StoredInputType {
fields,
name: input.name.take().expect("Input without a name"),
// The one-of input spec is not stable yet, thus the introspection query does not have
// `isOneOf`, so this is always false.
is_one_of: false,
};

schema.stored_inputs.push(input);
Expand Down

0 comments on commit fcd1f34

Please sign in to comment.