From c2ddac36fe3314cc508eec391684ebc990365349 Mon Sep 17 00:00:00 2001 From: Ikrk Date: Fri, 2 Feb 2024 16:46:02 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Modularized=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../client/src/fuzzer/snapshot_generator.rs | 167 +++++++++++------- 1 file changed, 99 insertions(+), 68 deletions(-) diff --git a/crates/client/src/fuzzer/snapshot_generator.rs b/crates/client/src/fuzzer/snapshot_generator.rs index 84b000f4..d358138c 100644 --- a/crates/client/src/fuzzer/snapshot_generator.rs +++ b/crates/client/src/fuzzer/snapshot_generator.rs @@ -6,7 +6,7 @@ use std::{error::Error, fs::File, io::Read}; use anchor_lang::anchor_syn::{AccountField, Ty}; use cargo_metadata::camino::Utf8PathBuf; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use syn::parse::{Error as ParseError, Result as ParseResult}; use syn::spanned::Spanned; @@ -40,86 +40,102 @@ pub fn generate_snapshots_code(code_path: Vec<(String, Utf8PathBuf)>) -> Result< .content .ok_or("the content of program module is missing")?; - let mut ix_ctx_pairs = Vec::new(); - for item in items { - // Iterate through items in program module and find functions with the Context<_> parameter. Save the function name and the Context's inner type. - if let syn::Item::Fn(func) = item { - let func_name = &func.sig.ident; - let first_param_type = if let Some(param) = func.sig.inputs.into_iter().next() { - let mut ty = None::; - if let syn::FnArg::Typed(t) = param { - if let syn::Type::Path(tp) = *t.ty.clone() { - if let Some(seg) = tp.path.segments.into_iter().next() { - if let PathArguments::AngleBracketed(arg) = seg.arguments { - ty = arg.args.first().cloned(); - } - } - } - } - ty - } else { - None - }; + let ix_ctx_pairs = get_ix_ctx_pairs(&items)?; - let first_param_type = first_param_type.ok_or(format!( - "The function {} does not have the Context parameter and is malformed.", - func_name - ))?; + let (structs, impls) = get_snapshot_structs_and_impls(code, &ix_ctx_pairs)?; - ix_ctx_pairs.push((func_name.clone(), first_param_type)); - } - } - - // Find definition of each Context struct and create new struct with fields wrapped in Option<_> - let mut structs = String::new(); - let mut desers = String::new(); - let parse_result = syn::parse_file(code).map_err(|e| e.to_string())?; - for pair in ix_ctx_pairs { - let mut ty = None; - if let GenericArgument::Type(syn::Type::Path(tp)) = &pair.1 { - ty = tp.path.get_ident().cloned(); - // TODO add support for types with fully qualified path such as ix::Initialize - } - let ty = ty.ok_or(format!("malformed parameters of {} instruction", pair.0))?; - - // recursively find the context struct and create a new version with wrapped fields into Option - if let Some(ctx) = get_ctx_struct(&parse_result.items, &ty) { - let fields_parsed = if let Fields::Named(f) = ctx.fields.clone() { - let field_deser: ParseResult> = - f.named.iter().map(parse_account_field).collect(); - field_deser - } else { - Err(ParseError::new( - ctx.fields.span(), - "Context struct parse errror.", - )) - } - .map_err(|e| e.to_string())?; - - let wrapped_struct = wrap_fields_in_option(ctx, &fields_parsed).unwrap(); - let deser_code = deserialize_ctx_struct_anchor(ctx, &fields_parsed) - .map_err(|e| e.to_string())?; - // let deser_code = deserialize_ctx_struct(ctx).unwrap(); - structs = format!("{}{}", structs, wrapped_struct.into_token_stream()); - desers = format!("{}{}", desers, deser_code.into_token_stream()); - } else { - return Err(format!("The Context struct {} was not found", ty)); - } - } let use_statements = quote! { use trdelnik_client::anchor_lang::{prelude::*, self}; use trdelnik_client::anchor_lang::solana_program::instruction::AccountMeta; use trdelnik_client::fuzzing::{get_account_infos_option, FuzzingError}; } .into_token_stream(); - Ok(format!("{}{}{}", use_statements, structs, desers)) + Ok(format!("{}{}{}", use_statements, structs, impls)) }); code.into_iter().collect() } +/// Creates new snapshot structs with fields wrapped in Option<_> if approriate and the +/// respective implementations with snapshot deserialization methods +fn get_snapshot_structs_and_impls( + code: &str, + ix_ctx_pairs: &[(Ident, GenericArgument)], +) -> Result<(String, String), String> { + let mut structs = String::new(); + let mut impls = String::new(); + let parse_result = syn::parse_file(code).map_err(|e| e.to_string())?; + for pair in ix_ctx_pairs { + let mut ty = None; + if let GenericArgument::Type(syn::Type::Path(tp)) = &pair.1 { + ty = tp.path.get_ident().cloned(); + // TODO add support for types with fully qualified path such as ix::Initialize + } + let ty = ty.ok_or(format!("malformed parameters of {} instruction", pair.0))?; + + // recursively find the context struct and create a new version with wrapped fields into Option + if let Some(ctx) = find_ctx_struct(&parse_result.items, &ty) { + let fields_parsed = if let Fields::Named(f) = ctx.fields.clone() { + let field_deser: ParseResult> = + f.named.iter().map(parse_account_field).collect(); + field_deser + } else { + Err(ParseError::new( + ctx.fields.span(), + "Context struct parse errror.", + )) + } + .map_err(|e| e.to_string())?; + + let wrapped_struct = wrap_fields_in_option(ctx, &fields_parsed).unwrap(); + let deser_code = + deserialize_ctx_struct_anchor(ctx, &fields_parsed).map_err(|e| e.to_string())?; + // let deser_code = deserialize_ctx_struct(ctx).unwrap(); + structs = format!("{}{}", structs, wrapped_struct.into_token_stream()); + impls = format!("{}{}", impls, deser_code.into_token_stream()); + } else { + return Err(format!("The Context struct {} was not found", ty)); + } + } + + Ok((structs, impls)) +} + +/// Iterates through items and finds functions with the Context<_> parameter. Returns pairs with the function name and the Context's inner type. +fn get_ix_ctx_pairs(items: &[Item]) -> Result, String> { + let mut ix_ctx_pairs = Vec::new(); + for item in items { + if let syn::Item::Fn(func) = item { + let func_name = &func.sig.ident; + let first_param_type = if let Some(param) = func.sig.inputs.iter().next() { + let mut ty = None::; + if let syn::FnArg::Typed(t) = param { + if let syn::Type::Path(tp) = *t.ty.clone() { + if let Some(seg) = tp.path.segments.iter().next() { + if let PathArguments::AngleBracketed(arg) = &seg.arguments { + ty = arg.args.first().cloned(); + } + } + } + } + ty + } else { + None + }; + + let first_param_type = first_param_type.ok_or(format!( + "The function {} does not have the Context parameter and is malformed.", + func_name + ))?; + + ix_ctx_pairs.push((func_name.clone(), first_param_type)); + } + } + Ok(ix_ctx_pairs) +} + /// Recursively find a struct with a given `name` -fn get_ctx_struct<'a>(items: &'a Vec, name: &'a syn::Ident) -> Option<&'a ItemStruct> { +fn find_ctx_struct<'a>(items: &'a Vec, name: &'a syn::Ident) -> Option<&'a ItemStruct> { for item in items { if let Item::Struct(struct_item) = item { if struct_item.ident == *name { @@ -132,7 +148,7 @@ fn get_ctx_struct<'a>(items: &'a Vec, name: &'a syn::Ident) -> Option for item in items { if let Item::Mod(mod_item) = item { if let Some((_, items)) = &mod_item.content { - let r = get_ctx_struct(items, name); + let r = find_ctx_struct(items, name); if r.is_some() { return r; } @@ -143,6 +159,9 @@ fn get_ctx_struct<'a>(items: &'a Vec, name: &'a syn::Ident) -> Option None } +/// Determines if an Account should be wrapped into the `Option` type. +/// The function returns true if the account has the init or close constraints set +/// and is not already wrapped into the `Option` type. fn is_optional(parsed_field: &AccountField) -> bool { let is_optional = match parsed_field { AccountField::Field(field) => field.is_optional, @@ -156,6 +175,10 @@ fn is_optional(parsed_field: &AccountField) -> bool { (constraints.init.is_some() || constraints.is_close()) && !is_optional } +/// Determines if an Accout should be deserialized as optional. +/// The function returns true if the account has the init or close constraints set +/// or if it is explicitly optional (it was wrapped into the `Option` type already +/// in the definition of it's corresponding context structure). fn deserialize_as_option(parsed_field: &AccountField) -> bool { let is_optional = match parsed_field { AccountField::Field(field) => field.is_optional, @@ -211,6 +234,7 @@ fn wrap_fields_in_option( Ok(generated_struct.to_token_stream()) } +/// Generates code to deserialize the snapshot structs. fn deserialize_ctx_struct_anchor( snapshot_struct: &ItemStruct, parsed_fields: &[AccountField], @@ -266,6 +290,7 @@ fn deserialize_ctx_struct_anchor( Ok(generated_deser_impl.to_token_stream()) } +/// Get the identifier (name) of the passed sysvar type. fn sysvar_to_ident(sysvar: &anchor_lang::anchor_syn::SysvarTy) -> String { let str = match sysvar { anchor_lang::anchor_syn::SysvarTy::Clock => "Clock", @@ -282,6 +307,9 @@ fn sysvar_to_ident(sysvar: &anchor_lang::anchor_syn::SysvarTy) -> String { str.into() } +/// Converts passed account type to token streams. The function returns a pair of streams where the first +/// variable in the pair is the type itself and the second is a fully qualified function to deserialize +/// the given type. pub fn ty_to_tokens(ty: &anchor_lang::anchor_syn::Ty) -> Option<(TokenStream, TokenStream)> { let (return_type, deser_method) = match ty { Ty::AccountInfo | Ty::UncheckedAccount => return None, @@ -342,6 +370,7 @@ pub fn ty_to_tokens(ty: &anchor_lang::anchor_syn::Ty) -> Option<(TokenStream, To Some((return_type, deser_method)) } +/// Generates the code necessary to deserialize an account fn deserialize_account_tokens( name: &syn::Ident, is_optional: bool, @@ -369,6 +398,7 @@ fn deserialize_account_tokens( } } +/// Generates the code used with raw accounts as AccountInfo fn acc_info_tokens(name: &syn::Ident) -> TokenStream { quote! { let #name = accounts_iter @@ -377,6 +407,7 @@ fn acc_info_tokens(name: &syn::Ident) -> TokenStream { } } +/// Checks if the program attribute is present fn has_program_attribute(attrs: &Vec) -> bool { for attr in attrs { if attr.path.is_ident("program") {