diff --git a/graphql_client_cli/src/generate.rs b/graphql_client_cli/src/generate.rs index 763acad4c..dd4e2ebd2 100644 --- a/graphql_client_cli/src/generate.rs +++ b/graphql_client_cli/src/generate.rs @@ -17,6 +17,7 @@ pub(crate) struct CliCodegenParams { pub no_formatting: bool, pub module_visibility: Option, pub output_directory: Option, + pub custom_scalars_module: Option, } pub(crate) fn generate_code(params: CliCodegenParams) -> Result<()> { @@ -30,6 +31,7 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> Result<()> { query_path, schema_path, selected_operation, + custom_scalars_module, } = params; let deprecation_strategy = deprecation_strategy.as_ref().and_then(|s| s.parse().ok()); @@ -59,6 +61,17 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> Result<()> { options.set_deprecation_strategy(deprecation_strategy); } + if let Some(custom_scalars_module) = custom_scalars_module { + let custom_scalars_module = syn::parse_str(&custom_scalars_module).with_context(|| { + format!( + "Invalid custom scalars module path: {}", + custom_scalars_module + ) + })?; + + options.set_custom_scalars_module(custom_scalars_module); + } + let gen = generate_module_token_stream(query_path.clone(), &schema_path, options).unwrap(); let generated_code = gen.to_string(); diff --git a/graphql_client_cli/src/main.rs b/graphql_client_cli/src/main.rs index 138dfa10e..ab1f94ef0 100644 --- a/graphql_client_cli/src/main.rs +++ b/graphql_client_cli/src/main.rs @@ -70,6 +70,10 @@ enum Cli { /// file, with the same name and the .rs extension. #[structopt(short = "o", long = "output-directory")] output_directory: Option, + /// The module where the custom scalar definitions are located. + /// --custom-scalars-module='crate::gql::custom_scalars' + #[structopt(short = "p", long = "custom-scalars-module")] + custom_scalars_module: Option, }, } @@ -101,6 +105,7 @@ fn main() -> anyhow::Result<()> { query_path, schema_path, selected_operation, + custom_scalars_module, } => generate::generate_code(generate::CliCodegenParams { variables_derives, response_derives, @@ -111,6 +116,7 @@ fn main() -> anyhow::Result<()> { query_path, schema_path, selected_operation, + custom_scalars_module, }), } } diff --git a/graphql_client_codegen/src/codegen.rs b/graphql_client_codegen/src/codegen.rs index 1559b28ab..6fb7ae8bd 100644 --- a/graphql_client_codegen/src/codegen.rs +++ b/graphql_client_codegen/src/codegen.rs @@ -156,7 +156,11 @@ fn generate_scalar_definitions<'a, 'schema: 'a>( proc_macro2::Span::call_site(), ); - quote!(type #ident = super::#ident;) + if let Some(custom_scalars_module) = options.custom_scalars_module() { + quote!(type #ident = #custom_scalars_module::#ident;) + } else { + quote!(type #ident = super::#ident;) + } }) } diff --git a/graphql_client_codegen/src/codegen_options.rs b/graphql_client_codegen/src/codegen_options.rs index 2d1e83579..95276c15a 100644 --- a/graphql_client_codegen/src/codegen_options.rs +++ b/graphql_client_codegen/src/codegen_options.rs @@ -2,7 +2,7 @@ use crate::deprecation::DeprecationStrategy; use crate::normalization::Normalization; use proc_macro2::Ident; use std::path::{Path, PathBuf}; -use syn::Visibility; +use syn::{self, Visibility}; /// Which context is this code generation effort taking place. #[derive(Debug)] @@ -39,6 +39,8 @@ pub struct GraphQLClientCodegenOptions { schema_file: Option, /// Normalization pattern for query types and names. normalization: Normalization, + /// Custom scalar definitions module path + custom_scalars_module: Option, } impl GraphQLClientCodegenOptions { @@ -56,6 +58,7 @@ impl GraphQLClientCodegenOptions { query_file: Default::default(), schema_file: Default::default(), normalization: Normalization::None, + custom_scalars_module: Default::default(), } } @@ -174,4 +177,14 @@ impl GraphQLClientCodegenOptions { pub fn normalization(&self) -> &Normalization { &self.normalization } + + /// Get the custom scalar definitions module + pub fn custom_scalars_module(&self) -> Option<&syn::Path> { + self.custom_scalars_module.as_ref() + } + + /// Set the custom scalar definitions module + pub fn set_custom_scalars_module(&mut self, module: syn::Path) { + self.custom_scalars_module = Some(module) + } } diff --git a/graphql_query_derive/src/lib.rs b/graphql_query_derive/src/lib.rs index 78ba0337a..e7ff68b04 100644 --- a/graphql_query_derive/src/lib.rs +++ b/graphql_query_derive/src/lib.rs @@ -70,6 +70,7 @@ fn build_graphql_client_derive_options( ) -> Result { let variables_derives = attributes::extract_attr(input, "variables_derives").ok(); let response_derives = attributes::extract_attr(input, "response_derives").ok(); + let custom_scalars_module = attributes::extract_attr(input, "custom_scalars_module").ok(); let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Derive); options.set_query_file(query_path); @@ -92,6 +93,18 @@ fn build_graphql_client_derive_options( options.set_normalization(normalization); }; + // The user can give a path to a module that provides definitions for the custom scalars. + if let Some(custom_scalars_module) = custom_scalars_module { + let custom_scalars_module = syn::parse_str(&custom_scalars_module).map_err(|err| { + GeneralError(format!( + "Invalid custom scalars module path: {}. {}", + custom_scalars_module, err + )) + })?; + + options.set_custom_scalars_module(custom_scalars_module); + } + options.set_struct_ident(input.ident.clone()); options.set_module_visibility(input.vis.clone()); options.set_operation_name(input.ident.to_string());