From 92b8b5676682fe8cc50b4145a6bd25075454d283 Mon Sep 17 00:00:00 2001 From: Ivar Scholten Date: Sat, 31 Aug 2024 00:56:16 +0200 Subject: [PATCH] fix: use Result type aliases in "Wrap return type in Result" assist This commit makes the "Wrap return type in Result" assist prefer type aliases of standard library type when the are in scope, use at least one generic parameter, and have the name "Result". The last restriction was made in an attempt to avoid false assumptions about which type the user is referring to, but that might be overly strict. We could also do something like this, in order of priority: * Use the alias named "Result". * Use any alias if only a single one is in scope, otherwise: * Use the standard library type. This is easy to add if others feel differently that is appropriate, just let me know. --- .../rust-analyzer/crates/hir/src/semantics.rs | 17 +- .../handlers/wrap_return_type_in_result.rs | 293 +++++++++++++++++- 2 files changed, 294 insertions(+), 16 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir/src/semantics.rs b/src/tools/rust-analyzer/crates/hir/src/semantics.rs index 763f53031e4c5..c78b59826c955 100644 --- a/src/tools/rust-analyzer/crates/hir/src/semantics.rs +++ b/src/tools/rust-analyzer/crates/hir/src/semantics.rs @@ -14,6 +14,7 @@ use hir_def::{ hir::Expr, lower::LowerCtx, nameres::MacroSubNs, + path::ModPath, resolver::{self, HasResolver, Resolver, TypeNs}, type_ref::Mutability, AsMacroCall, DefWithBodyId, FunctionId, MacroId, TraitId, VariantId, @@ -46,9 +47,9 @@ use crate::{ source_analyzer::{resolve_hir_path, SourceAnalyzer}, Access, Adjust, Adjustment, Adt, AutoBorrow, BindingMode, BuiltinAttr, Callable, Const, ConstParam, Crate, DeriveHelper, Enum, Field, Function, HasSource, HirFileId, Impl, InFile, - Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, ScopeDef, - Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, TypeParam, Union, - Variant, VariantDef, + ItemInNs, Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, + ScopeDef, Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, + TypeParam, Union, Variant, VariantDef, }; const CONTINUE_NO_BREAKS: ControlFlow = ControlFlow::Continue(()); @@ -1384,6 +1385,16 @@ impl<'db> SemanticsImpl<'db> { self.analyze(path.syntax())?.resolve_path(self.db, path) } + pub fn resolve_mod_path( + &self, + scope: &SyntaxNode, + path: &ModPath, + ) -> Option> { + let analyze = self.analyze(scope)?; + let items = analyze.resolver.resolve_module_path_in_items(self.db.upcast(), path); + Some(items.iter_items().map(|(item, _)| item.into())) + } + fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option { self.analyze(record_lit.syntax())?.resolve_variant(self.db, record_lit) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs index b68ed00f77210..8f0e9b4fe09d5 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs @@ -1,12 +1,14 @@ use std::iter; +use hir::HasSource; use ide_db::{ famous_defs::FamousDefs, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, }; +use itertools::Itertools; use syntax::{ - ast::{self, make, Expr}, - match_ast, ted, AstNode, + ast::{self, make, Expr, HasGenericParams}, + match_ast, ted, AstNode, ToSmolStr, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext< }; let type_ref = &ret_type.ty()?; - let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); - let result_enum = + let core_result = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?; - if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_result) { + // The return type is already wrapped in a Result cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); return None; } - let new_result_ty = - make::ext::ty_result(type_ref.clone(), make::ty_placeholder()).clone_for_update(); - let generic_args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast)?; - let last_genarg = generic_args.generic_args().last()?; - acc.add( AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite), "Wrap return type in Result", type_ref.syntax().text_range(), |edit| { + let new_result_ty = result_type(ctx, &core_result, type_ref).clone_for_update(); let body = edit.make_mut(ast::Expr::BlockExpr(body)); let mut exprs_to_wrap = Vec::new(); @@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext< } let old_result_ty = edit.make_mut(type_ref.clone()); - ted::replace(old_result_ty.syntax(), new_result_ty.syntax()); - if let Some(cap) = ctx.config.snippet_cap { - edit.add_placeholder_snippet(cap, last_genarg); + // Add a placeholder snippet at the first generic argument that doesn't equal the return type. + // This is normally the error type, but that may not be the case when we inserted a type alias. + let args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast); + let error_type_arg = args.and_then(|list| { + list.generic_args().find(|arg| match arg { + ast::GenericArg::TypeArg(_) => arg.syntax().text() != type_ref.syntax().text(), + ast::GenericArg::LifetimeArg(_) => false, + _ => true, + }) + }); + if let Some(error_type_arg) = error_type_arg { + if let Some(cap) = ctx.config.snippet_cap { + edit.add_placeholder_snippet(cap, error_type_arg); + } } }, ) } +fn result_type( + ctx: &AssistContext<'_>, + core_result: &hir::Enum, + ret_type: &ast::Type, +) -> ast::Type { + // Try to find a Result type alias in the current scope (shadowing the default). + let result_path = hir::ModPath::from_segments( + hir::PathKind::Plain, + iter::once(hir::Name::new_symbol_root(hir::sym::Result.clone())), + ); + let alias = ctx.sema.resolve_mod_path(ret_type.syntax(), &result_path).and_then(|def| { + def.filter_map(|def| match def.as_module_def()? { + hir::ModuleDef::TypeAlias(alias) => { + let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?; + (&enum_ty == core_result).then_some(alias) + } + _ => None, + }) + .find_map(|alias| { + let mut inserted_ret_type = false; + let generic_params = alias + .source(ctx.db())? + .value + .generic_param_list()? + .generic_params() + .map(|param| match param { + // Replace the very first type parameter with the functions return type. + ast::GenericParam::TypeParam(_) if !inserted_ret_type => { + inserted_ret_type = true; + ret_type.to_smolstr() + } + ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(), + _ => make::ty_placeholder().to_smolstr(), + }) + .join(", "); + + let name = alias.name(ctx.db()); + let name = name.as_str(); + Some(make::ty(&format!("{name}<{generic_params}>"))) + }) + }); + // If there is no applicable alias in scope use the default Result type. + alias.unwrap_or_else(|| make::ext::ty_result(ret_type.clone(), make::ty_placeholder())) +} + fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { match e { Expr::BreakExpr(break_expr) => { @@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result { "#, ); } + + #[test] + fn wrap_return_type_in_local_result_type() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +type Result = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result2 = core::result::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +type Result2 = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_imported_local_result_type() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::*; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::*; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_from_function_body() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i3$02 { + type Result = core::result::Result; + 0 +} +"#, + r#" +fn foo() -> Result { + type Result = core::result::Result; + Ok(0) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_already_using_alias() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +pub type Result = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_multiple_generics() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result; + +fn foo() -> i3$02 { + 0 +} +"#, + r#" +type Result = core::result::Result; + +fn foo() -> Result { + Ok(0) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result, ()>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result = core::result::Result, ()>; + +fn foo() -> Result { + Ok(0) +} + "#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result<'a, T, E> = core::result::Result, &'a ()>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result<'a, T, E> = core::result::Result, &'a ()>; + +fn foo() -> Result<'_, i32, ${0:_}> { + Ok(0) +} + "#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result, Bar>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result = core::result::Result, Bar>; + +fn foo() -> Result { + Ok(0) +} + "#, + ); + } }