From 59b6f2d9f2fe40b0176b64bf295562d0beaabc3e Mon Sep 17 00:00:00 2001 From: hkalbasi Date: Thu, 6 Apr 2023 16:14:38 +0330 Subject: [PATCH] Compute closure captures --- crates/hir-def/src/body/lower.rs | 101 ++- crates/hir-def/src/body/pretty.rs | 12 +- crates/hir-def/src/hir.rs | 25 + crates/hir-ty/src/chalk_ext.rs | 13 +- crates/hir-ty/src/consteval/tests.rs | 75 ++ crates/hir-ty/src/db.rs | 11 +- crates/hir-ty/src/display.rs | 77 +- crates/hir-ty/src/infer.rs | 35 +- crates/hir-ty/src/infer/closure.rs | 722 +++++++++++++++++- crates/hir-ty/src/infer/coerce.rs | 6 +- crates/hir-ty/src/infer/expr.rs | 107 ++- crates/hir-ty/src/infer/mutability.rs | 14 +- crates/hir-ty/src/layout.rs | 15 +- crates/hir-ty/src/layout/tests.rs | 26 +- crates/hir-ty/src/layout/tests/closure.rs | 175 +++++ crates/hir-ty/src/mir.rs | 156 +++- crates/hir-ty/src/mir/borrowck.rs | 48 +- crates/hir-ty/src/mir/eval.rs | 172 +++-- crates/hir-ty/src/mir/lower.rs | 367 ++++++--- .../hir-ty/src/mir/lower/pattern_matching.rs | 17 +- crates/hir-ty/src/mir/pretty.rs | 49 +- crates/hir-ty/src/tests/coercion.rs | 12 +- crates/hir-ty/src/tests/macros.rs | 4 +- crates/hir-ty/src/tests/patterns.rs | 12 +- crates/hir-ty/src/tests/regression.rs | 29 +- crates/hir-ty/src/tests/simple.rs | 208 ++++- crates/hir-ty/src/tests/traits.rs | 74 +- crates/hir-ty/src/traits.rs | 14 +- crates/hir-ty/src/utils.rs | 23 +- crates/hir/src/lib.rs | 72 +- .../src/handlers/generate_function.rs | 3 +- .../src/handlers/mutability_errors.rs | 82 ++ .../replace_filter_map_next_with_find_map.rs | 2 +- .../src/handlers/type_mismatch.rs | 19 +- crates/ide/src/hover/tests.rs | 2 +- crates/ide/src/inlay_hints.rs | 16 +- crates/ide/src/inlay_hints/bind_pat.rs | 105 ++- crates/ide/src/signature_help.rs | 18 + crates/ide/src/static_index.rs | 1 + crates/rust-analyzer/src/config.rs | 27 + docs/user/generated_config.adoc | 5 + editors/code/package.json | 17 + 42 files changed, 2536 insertions(+), 432 deletions(-) create mode 100644 crates/hir-ty/src/layout/tests/closure.rs diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index e6bea706d838b..688c9e86bbb61 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -28,9 +28,9 @@ use crate::{ data::adt::StructKind, db::DefDatabase, hir::{ - dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId, - Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField, - Statement, + dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Expr, + ExprId, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, + RecordLitField, Statement, }, item_scope::BuiltinShadowMode, lang_item::LangItem, @@ -67,6 +67,7 @@ pub(super) fn lower( is_lowering_assignee_expr: false, is_lowering_generator: false, label_ribs: Vec::new(), + current_binding_owner: None, } .collect(params, body, is_async_fn) } @@ -92,6 +93,7 @@ struct ExprCollector<'a> { // resolution label_ribs: Vec, + current_binding_owner: Option, } #[derive(Clone, Debug)] @@ -261,11 +263,16 @@ impl ExprCollector<'_> { } Some(ast::BlockModifier::Const(_)) => { self.with_label_rib(RibKind::Constant, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Const { - id, - statements, - tail, - }) + this.collect_as_a_binding_owner_bad( + |this| { + this.collect_block_(e, |id, statements, tail| Expr::Const { + id, + statements, + tail, + }) + }, + syntax_ptr, + ) }) } None => self.collect_block(e), @@ -461,6 +468,8 @@ impl ExprCollector<'_> { } } ast::Expr::ClosureExpr(e) => self.with_label_rib(RibKind::Closure, |this| { + let (result_expr_id, prev_binding_owner) = + this.initialize_binding_owner(syntax_ptr); let mut args = Vec::new(); let mut arg_types = Vec::new(); if let Some(pl) = e.param_list() { @@ -494,17 +503,19 @@ impl ExprCollector<'_> { ClosureKind::Closure }; this.is_lowering_generator = prev_is_lowering_generator; - - this.alloc_expr( - Expr::Closure { - args: args.into(), - arg_types: arg_types.into(), - ret_type, - body, - closure_kind, - }, - syntax_ptr, - ) + let capture_by = + if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; + this.is_lowering_generator = prev_is_lowering_generator; + this.current_binding_owner = prev_binding_owner; + this.body.exprs[result_expr_id] = Expr::Closure { + args: args.into(), + arg_types: arg_types.into(), + ret_type, + body, + closure_kind, + capture_by, + }; + result_expr_id }), ast::Expr::BinExpr(e) => { let op = e.op_kind(); @@ -545,7 +556,15 @@ impl ExprCollector<'_> { ArrayExprKind::Repeat { initializer, repeat } => { let initializer = self.collect_expr_opt(initializer); let repeat = self.with_label_rib(RibKind::Constant, |this| { - this.collect_expr_opt(repeat) + if let Some(repeat) = repeat { + let syntax_ptr = AstPtr::new(&repeat); + this.collect_as_a_binding_owner_bad( + |this| this.collect_expr(repeat), + syntax_ptr, + ) + } else { + this.missing_expr() + } }); self.alloc_expr( Expr::Array(Array::Repeat { initializer, repeat }), @@ -592,6 +611,32 @@ impl ExprCollector<'_> { }) } + fn initialize_binding_owner( + &mut self, + syntax_ptr: AstPtr, + ) -> (ExprId, Option) { + let result_expr_id = self.alloc_expr(Expr::Missing, syntax_ptr); + let prev_binding_owner = self.current_binding_owner.take(); + self.current_binding_owner = Some(result_expr_id); + (result_expr_id, prev_binding_owner) + } + + /// FIXME: This function is bad. It will produce a dangling `Missing` expr which wastes memory. Currently + /// it is used only for const blocks and repeat expressions, which are also hacky and ideally should have + /// their own body. Don't add more usage for this function so that we can remove this function after + /// separating those bodies. + fn collect_as_a_binding_owner_bad( + &mut self, + job: impl FnOnce(&mut ExprCollector<'_>) -> ExprId, + syntax_ptr: AstPtr, + ) -> ExprId { + let (id, prev_owner) = self.initialize_binding_owner(syntax_ptr); + let tmp = job(self); + self.body.exprs[id] = mem::replace(&mut self.body.exprs[tmp], Expr::Missing); + self.current_binding_owner = prev_owner; + id + } + /// Desugar `try { ; }` into `': { ; ::std::ops::Try::from_output() }`, /// `try { ; }` into `': { ; ::std::ops::Try::from_output(()) }` /// and save the `` to use it as a break target for desugaring of the `?` operator. @@ -1112,8 +1157,13 @@ impl ExprCollector<'_> { } ast::Pat::ConstBlockPat(const_block_pat) => { if let Some(block) = const_block_pat.block_expr() { - let expr_id = - self.with_label_rib(RibKind::Constant, |this| this.collect_block(block)); + let expr_id = self.with_label_rib(RibKind::Constant, |this| { + let syntax_ptr = AstPtr::new(&block.clone().into()); + this.collect_as_a_binding_owner_bad( + |this| this.collect_block(block), + syntax_ptr, + ) + }); Pat::ConstBlock(expr_id) } else { Pat::Missing @@ -1272,7 +1322,12 @@ impl ExprCollector<'_> { } fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId { - self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() }) + self.body.bindings.alloc(Binding { + name, + mode, + definitions: SmallVec::new(), + owner: self.current_binding_owner, + }) } fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs index c539224073914..c3bd99b94876a 100644 --- a/crates/hir-def/src/body/pretty.rs +++ b/crates/hir-def/src/body/pretty.rs @@ -5,7 +5,9 @@ use std::fmt::{self, Write}; use syntax::ast::HasName; use crate::{ - hir::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement}, + hir::{ + Array, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Literal, Movability, Statement, + }, pretty::{print_generic_args, print_path, print_type_ref}, type_ref::TypeRef, }; @@ -360,7 +362,7 @@ impl<'a> Printer<'a> { self.print_expr(*index); w!(self, "]"); } - Expr::Closure { args, arg_types, ret_type, body, closure_kind } => { + Expr::Closure { args, arg_types, ret_type, body, closure_kind, capture_by } => { match closure_kind { ClosureKind::Generator(Movability::Static) => { w!(self, "static "); @@ -370,6 +372,12 @@ impl<'a> Printer<'a> { } _ => (), } + match capture_by { + CaptureBy::Value => { + w!(self, "move "); + } + CaptureBy::Ref => (), + } w!(self, "|"); for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() { if i != 0 { diff --git a/crates/hir-def/src/hir.rs b/crates/hir-def/src/hir.rs index 8321ba1da6f06..8709ad0e99b8d 100644 --- a/crates/hir-def/src/hir.rs +++ b/crates/hir-def/src/hir.rs @@ -275,6 +275,7 @@ pub enum Expr { ret_type: Option>, body: ExprId, closure_kind: ClosureKind, + capture_by: CaptureBy, }, Tuple { exprs: Box<[ExprId]>, @@ -292,6 +293,14 @@ pub enum ClosureKind { Async, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CaptureBy { + /// `move |x| y + x`. + Value, + /// `move` keyword was not specified. + Ref, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Movability { Static, @@ -484,6 +493,22 @@ pub struct Binding { pub name: Name, pub mode: BindingAnnotation, pub definitions: SmallVec<[PatId; 1]>, + /// Id of the closure/generator that owns this binding. If it is owned by the + /// top level expression, this field would be `None`. + pub owner: Option, +} + +impl Binding { + pub fn is_upvar(&self, relative_to: ExprId) -> bool { + match self.owner { + Some(x) => { + // We assign expression ids in a way that outer closures will recieve + // a lower id + x.into_raw() < relative_to.into_raw() + } + None => true, + } + } } #[derive(Debug, Clone, Eq, PartialEq)] diff --git a/crates/hir-ty/src/chalk_ext.rs b/crates/hir-ty/src/chalk_ext.rs index 2141894922f7b..d6a5612485651 100644 --- a/crates/hir-ty/src/chalk_ext.rs +++ b/crates/hir-ty/src/chalk_ext.rs @@ -12,8 +12,9 @@ use hir_def::{ use crate::{ db::HirDatabase, from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, from_placeholder_idx, to_chalk_trait_id, utils::generics, AdtId, AliasEq, AliasTy, Binders, - CallableDefId, CallableSig, DynTy, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy, - QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, TypeFlags, WhereClause, + CallableDefId, CallableSig, ClosureId, DynTy, FnPointer, ImplTraitId, Interner, Lifetime, + ProjectionTy, QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, TypeFlags, + WhereClause, }; pub trait TyExt { @@ -28,6 +29,7 @@ pub trait TyExt { fn as_adt(&self) -> Option<(hir_def::AdtId, &Substitution)>; fn as_builtin(&self) -> Option; fn as_tuple(&self) -> Option<&Substitution>; + fn as_closure(&self) -> Option; fn as_fn_def(&self, db: &dyn HirDatabase) -> Option; fn as_reference(&self) -> Option<(&Ty, Lifetime, Mutability)>; fn as_reference_or_ptr(&self) -> Option<(&Ty, Rawness, Mutability)>; @@ -128,6 +130,13 @@ impl TyExt for Ty { } } + fn as_closure(&self) -> Option { + match self.kind(Interner) { + TyKind::Closure(id, _) => Some(*id), + _ => None, + } + } + fn as_fn_def(&self, db: &dyn HirDatabase) -> Option { match self.callable_def(db) { Some(CallableDefId::FunctionId(func)) => Some(func), diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index a0efc7541e3f5..d987f41c70647 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -1105,6 +1105,81 @@ fn try_block() { ); } +#[test] +fn closures() { + check_number( + r#" + //- minicore: fn, copy + const GOAL: i32 = { + let y = 5; + let c = |x| x + y; + c(2) + }; + "#, + 7, + ); + check_number( + r#" + //- minicore: fn, copy + const GOAL: i32 = { + let y = 5; + let c = |(a, b): &(i32, i32)| *a + *b + y; + c(&(2, 3)) + }; + "#, + 10, + ); + check_number( + r#" + //- minicore: fn, copy + const GOAL: i32 = { + let mut y = 5; + let c = |x| { + y = y + x; + }; + c(2); + c(3); + y + }; + "#, + 10, + ); + check_number( + r#" + //- minicore: fn, copy + struct X(i32); + impl X { + fn mult(&mut self, n: i32) { + self.0 = self.0 * n + } + } + const GOAL: i32 = { + let x = X(1); + let c = || { + x.mult(2); + || { + x.mult(3); + || { + || { + x.mult(4); + || { + x.mult(x.0); + || { + x.0 + } + } + } + } + } + }; + let r = c()()()()()(); + r + x.0 + }; + "#, + 24 * 24 * 2, + ); +} + #[test] fn or_pattern() { check_number( diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index e11770e12313a..3a8fb665c44f5 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -19,9 +19,9 @@ use crate::{ consteval::ConstEvalError, method_resolution::{InherentImpls, TraitImpls, TyFingerprint}, mir::{BorrowckResult, MirBody, MirLowerError}, - Binders, CallableDefId, Const, FnDefId, GenericArg, ImplTraitId, InferenceResult, Interner, - PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, Substitution, TraitRef, Ty, TyDefId, - ValueTyDefId, + Binders, CallableDefId, ClosureId, Const, FnDefId, GenericArg, ImplTraitId, InferenceResult, + Interner, PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, Substitution, TraitRef, Ty, + TyDefId, ValueTyDefId, }; use hir_expand::name::Name; @@ -38,8 +38,11 @@ pub trait HirDatabase: DefDatabase + Upcast { #[salsa::cycle(crate::mir::mir_body_recover)] fn mir_body(&self, def: DefWithBodyId) -> Result, MirLowerError>; + #[salsa::invoke(crate::mir::mir_body_for_closure_query)] + fn mir_body_for_closure(&self, def: ClosureId) -> Result, MirLowerError>; + #[salsa::invoke(crate::mir::borrowck_query)] - fn borrowck(&self, def: DefWithBodyId) -> Result, MirLowerError>; + fn borrowck(&self, def: DefWithBodyId) -> Result, MirLowerError>; #[salsa::invoke(crate::lower::ty_query)] #[salsa::cycle(crate::lower::ty_recover)] diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs index f1a649157c354..f892a81519708 100644 --- a/crates/hir-ty/src/display.rs +++ b/crates/hir-ty/src/display.rs @@ -23,6 +23,7 @@ use hir_expand::{hygiene::Hygiene, name::Name}; use intern::{Internable, Interned}; use itertools::Itertools; use smallvec::SmallVec; +use stdx::never; use crate::{ db::HirDatabase, @@ -64,6 +65,7 @@ pub struct HirFormatter<'a> { curr_size: usize, pub(crate) max_size: Option, omit_verbose_types: bool, + closure_style: ClosureStyle, display_target: DisplayTarget, } @@ -87,6 +89,7 @@ pub trait HirDisplay { max_size: Option, omit_verbose_types: bool, display_target: DisplayTarget, + closure_style: ClosureStyle, ) -> HirDisplayWrapper<'a, Self> where Self: Sized, @@ -95,7 +98,14 @@ pub trait HirDisplay { !matches!(display_target, DisplayTarget::SourceCode { .. }), "HirDisplayWrapper cannot fail with DisplaySourceCodeError, use HirDisplay::hir_fmt directly instead" ); - HirDisplayWrapper { db, t: self, max_size, omit_verbose_types, display_target } + HirDisplayWrapper { + db, + t: self, + max_size, + omit_verbose_types, + display_target, + closure_style, + } } /// Returns a `Display`able type that is human-readable. @@ -109,6 +119,7 @@ pub trait HirDisplay { t: self, max_size: None, omit_verbose_types: false, + closure_style: ClosureStyle::ImplFn, display_target: DisplayTarget::Diagnostics, } } @@ -128,6 +139,7 @@ pub trait HirDisplay { t: self, max_size, omit_verbose_types: true, + closure_style: ClosureStyle::ImplFn, display_target: DisplayTarget::Diagnostics, } } @@ -147,6 +159,7 @@ pub trait HirDisplay { curr_size: 0, max_size: None, omit_verbose_types: false, + closure_style: ClosureStyle::ImplFn, display_target: DisplayTarget::SourceCode { module_id }, }) { Ok(()) => {} @@ -166,6 +179,7 @@ pub trait HirDisplay { t: self, max_size: None, omit_verbose_types: false, + closure_style: ClosureStyle::ImplFn, display_target: DisplayTarget::Test, } } @@ -253,7 +267,6 @@ impl DisplayTarget { pub enum DisplaySourceCodeError { PathNotFound, UnknownType, - Closure, Generator, } @@ -274,9 +287,23 @@ pub struct HirDisplayWrapper<'a, T> { t: &'a T, max_size: Option, omit_verbose_types: bool, + closure_style: ClosureStyle, display_target: DisplayTarget, } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ClosureStyle { + /// `impl FnX(i32, i32) -> i32`, where `FnX` is the most special trait between `Fn`, `FnMut`, `FnOnce` that the + /// closure implements. This is the default. + ImplFn, + /// `|i32, i32| -> i32` + RANotation, + /// `{closure#14825}`, useful for some diagnostics (like type mismatch) and internal usage. + ClosureWithId, + /// `…`, which is the `TYPE_HINT_TRUNCATION` + Hide, +} + impl HirDisplayWrapper<'_, T> { pub fn write_to(&self, f: &mut F) -> Result<(), HirDisplayError> { self.t.hir_fmt(&mut HirFormatter { @@ -287,8 +314,14 @@ impl HirDisplayWrapper<'_, T> { max_size: self.max_size, omit_verbose_types: self.omit_verbose_types, display_target: self.display_target, + closure_style: self.closure_style, }) } + + pub fn with_closure_style(mut self, c: ClosureStyle) -> Self { + self.closure_style = c; + self + } } impl<'a, T> fmt::Display for HirDisplayWrapper<'a, T> @@ -919,26 +952,42 @@ impl HirDisplay for Ty { } } } - TyKind::Closure(.., substs) => { - if f.display_target.is_source_code() { - return Err(HirDisplayError::DisplaySourceCodeError( - DisplaySourceCodeError::Closure, - )); + TyKind::Closure(id, substs) => { + if f.display_target.is_source_code() && f.closure_style != ClosureStyle::ImplFn { + never!("Only `impl Fn` is valid for displaying closures in source code"); + } + match f.closure_style { + ClosureStyle::Hide => return write!(f, "{TYPE_HINT_TRUNCATION}"), + ClosureStyle::ClosureWithId => { + return write!(f, "{{closure#{:?}}}", id.0.as_u32()) + } + _ => (), } let sig = substs.at(Interner, 0).assert_ty_ref(Interner).callable_sig(db); if let Some(sig) = sig { + let (def, _) = db.lookup_intern_closure((*id).into()); + let infer = db.infer(def); + let (_, kind) = infer.closure_info(id); + match f.closure_style { + ClosureStyle::ImplFn => write!(f, "impl {kind:?}(")?, + ClosureStyle::RANotation => write!(f, "|")?, + _ => unreachable!(), + } if sig.params().is_empty() { - write!(f, "||")?; } else if f.should_truncate() { - write!(f, "|{TYPE_HINT_TRUNCATION}|")?; + write!(f, "{TYPE_HINT_TRUNCATION}")?; } else { - write!(f, "|")?; f.write_joined(sig.params(), ", ")?; - write!(f, "|")?; }; - - write!(f, " -> ")?; - sig.ret().hir_fmt(f)?; + match f.closure_style { + ClosureStyle::ImplFn => write!(f, ")")?, + ClosureStyle::RANotation => write!(f, "|")?, + _ => unreachable!(), + } + if f.closure_style == ClosureStyle::RANotation || !sig.ret().is_unit() { + write!(f, " -> ")?; + sig.ret().hir_fmt(f)?; + } } else { write!(f, "{{closure}}")?; } diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index c34b24bee82f4..b4da1a308d3bf 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -39,9 +39,9 @@ use stdx::{always, never}; use crate::{ db::HirDatabase, fold_tys, infer::coerce::CoerceMany, lower::ImplTraitLoweringMode, - static_lifetime, to_assoc_type_id, AliasEq, AliasTy, DomainGoal, GenericArg, Goal, ImplTraitId, - InEnvironment, Interner, ProjectionTy, RpitId, Substitution, TraitRef, Ty, TyBuilder, TyExt, - TyKind, + static_lifetime, to_assoc_type_id, traits::FnTrait, AliasEq, AliasTy, ClosureId, DomainGoal, + GenericArg, Goal, ImplTraitId, InEnvironment, Interner, ProjectionTy, RpitId, Substitution, + TraitRef, Ty, TyBuilder, TyExt, TyKind, }; // This lint has a false positive here. See the link below for details. @@ -52,6 +52,8 @@ pub use coerce::could_coerce; #[allow(unreachable_pub)] pub use unify::could_unify; +pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy}; + pub(crate) mod unify; mod path; mod expr; @@ -103,6 +105,8 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc Mutability { + let (AutoBorrow::Ref(m) | AutoBorrow::RawPtr(m)) = self; + m + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum PointerCast { /// Go from a fn-item type to a fn-pointer type. @@ -373,6 +384,9 @@ pub struct InferenceResult { pub pat_adjustments: FxHashMap>, pub pat_binding_modes: FxHashMap, pub expr_adjustments: FxHashMap>, + pub(crate) closure_info: FxHashMap, FnTrait)>, + // FIXME: remove this field + pub mutated_bindings_in_closure: FxHashSet, } impl InferenceResult { @@ -409,6 +423,9 @@ impl InferenceResult { _ => None, }) } + pub(crate) fn closure_info(&self, closure: &ClosureId) -> &(Vec, FnTrait) { + self.closure_info.get(closure).unwrap() + } } impl Index for InferenceResult { @@ -460,6 +477,14 @@ pub(crate) struct InferenceContext<'a> { resume_yield_tys: Option<(Ty, Ty)>, diverges: Diverges, breakables: Vec, + + // fields related to closure capture + current_captures: Vec, + current_closure: Option, + /// Stores the list of closure ids that need to be analyzed before this closure. See the + /// comment on `InferenceContext::sort_closures` + closure_dependecies: FxHashMap>, + deferred_closures: FxHashMap, ExprId)>>, } #[derive(Clone, Debug)] @@ -527,6 +552,10 @@ impl<'a> InferenceContext<'a> { resolver, diverges: Diverges::Maybe, breakables: Vec::new(), + current_captures: vec![], + current_closure: None, + deferred_closures: FxHashMap::default(), + closure_dependecies: FxHashMap::default(), } } diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 916f29466f373..e994546356b2c 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -1,12 +1,29 @@ //! Inference of closure parameter types based on the closure's expected type. -use chalk_ir::{cast::Cast, AliasEq, AliasTy, FnSubst, WhereClause}; -use hir_def::{hir::ExprId, HasModule}; +use std::{cmp, collections::HashMap, convert::Infallible, mem}; + +use chalk_ir::{cast::Cast, AliasEq, AliasTy, FnSubst, Mutability, TyKind, WhereClause}; +use hir_def::{ + hir::{ + Array, BinaryOp, BindingAnnotation, BindingId, CaptureBy, Expr, ExprId, Pat, PatId, + Statement, UnaryOp, + }, + lang_item::LangItem, + resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, + FieldId, HasModule, VariantId, +}; +use hir_expand::name; +use rustc_hash::FxHashMap; use smallvec::SmallVec; +use stdx::never; use crate::{ - to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty, - TyExt, TyKind, + mir::{BorrowKind, ProjectionElem}, + static_lifetime, to_chalk_trait_id, + traits::FnTrait, + utils::{self, pattern_matching_dereference_count}, + Adjust, Adjustment, Canonical, CanonicalVarKinds, ChalkTraitId, ClosureId, DynTy, FnPointer, + FnSig, InEnvironment, Interner, Substitution, Ty, TyBuilder, TyExt, }; use super::{Expectation, InferenceContext}; @@ -86,3 +103,700 @@ impl InferenceContext<'_> { None } } + +// The below functions handle capture and closure kind (Fn, FnMut, ..) + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct HirPlace { + pub(crate) local: BindingId, + pub(crate) projections: Vec>, +} +impl HirPlace { + fn ty(&self, ctx: &mut InferenceContext<'_>) -> Ty { + let mut ty = ctx.table.resolve_completely(ctx.result[self.local].clone()); + for p in &self.projections { + ty = p.projected_ty(ty, ctx.db, |_, _| { + unreachable!("Closure field only happens in MIR"); + }); + } + ty.clone() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) enum CaptureKind { + ByRef(BorrowKind), + ByValue, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct CapturedItem { + pub(crate) place: HirPlace, + pub(crate) kind: CaptureKind, + pub(crate) ty: Ty, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct CapturedItemWithoutTy { + pub(crate) place: HirPlace, + pub(crate) kind: CaptureKind, +} + +impl CapturedItemWithoutTy { + fn with_ty(self, ctx: &mut InferenceContext<'_>) -> CapturedItem { + let ty = self.place.ty(ctx).clone(); + let ty = match &self.kind { + CaptureKind::ByValue => ty, + CaptureKind::ByRef(bk) => { + let m = match bk { + BorrowKind::Mut { .. } => Mutability::Mut, + _ => Mutability::Not, + }; + TyKind::Ref(m, static_lifetime(), ty).intern(Interner) + } + }; + CapturedItem { place: self.place, kind: self.kind, ty } + } +} + +impl InferenceContext<'_> { + fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option { + let r = self.place_of_expr_without_adjust(tgt_expr)?; + let default = vec![]; + let adjustments = self.result.expr_adjustments.get(&tgt_expr).unwrap_or(&default); + apply_adjusts_to_place(r, adjustments) + } + + fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option { + match &self.body[tgt_expr] { + Expr::Path(p) => { + let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr); + if let Some(r) = resolver.resolve_path_in_value_ns(self.db.upcast(), p) { + if let ResolveValueResult::ValueNs(v) = r { + if let ValueNs::LocalBinding(b) = v { + return Some(HirPlace { local: b, projections: vec![] }); + } + } + } + } + Expr::Field { expr, name } => { + let mut place = self.place_of_expr(*expr)?; + if let TyKind::Tuple(..) = self.expr_ty(*expr).kind(Interner) { + let index = name.as_tuple_index()?; + place.projections.push(ProjectionElem::TupleOrClosureField(index)) + } else { + let field = self.result.field_resolution(tgt_expr)?; + place.projections.push(ProjectionElem::Field(field)); + } + return Some(place); + } + _ => (), + } + None + } + + fn push_capture(&mut self, capture: CapturedItemWithoutTy) { + self.current_captures.push(capture); + } + + fn ref_expr(&mut self, expr: ExprId) { + if let Some(place) = self.place_of_expr(expr) { + self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared)); + } + self.walk_expr(expr); + } + + fn add_capture(&mut self, place: HirPlace, kind: CaptureKind) { + if self.is_upvar(&place) { + self.push_capture(CapturedItemWithoutTy { place, kind }); + } + } + + fn mutate_expr(&mut self, expr: ExprId) { + if let Some(place) = self.place_of_expr(expr) { + self.add_capture( + place, + CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false }), + ); + } + self.walk_expr(expr); + } + + fn consume_expr(&mut self, expr: ExprId) { + if let Some(place) = self.place_of_expr(expr) { + self.consume_place(place); + } + self.walk_expr(expr); + } + + fn consume_place(&mut self, place: HirPlace) { + if self.is_upvar(&place) { + let ty = place.ty(self).clone(); + let kind = if self.is_ty_copy(ty) { + CaptureKind::ByRef(BorrowKind::Shared) + } else { + CaptureKind::ByValue + }; + self.push_capture(CapturedItemWithoutTy { place, kind }); + } + } + + fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment]) { + if let Some((last, rest)) = adjustment.split_last() { + match last.kind { + Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => { + self.walk_expr_with_adjust(tgt_expr, rest) + } + Adjust::Deref(Some(m)) => match m.0 { + Some(m) => { + self.ref_capture_with_adjusts(m, tgt_expr, rest); + } + None => unreachable!(), + }, + Adjust::Borrow(b) => { + self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest); + } + } + } else { + self.walk_expr_without_adjust(tgt_expr); + } + } + + fn ref_capture_with_adjusts(&mut self, m: Mutability, tgt_expr: ExprId, rest: &[Adjustment]) { + let capture_kind = match m { + Mutability::Mut => { + CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false }) + } + Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared), + }; + if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) { + if let Some(place) = apply_adjusts_to_place(place, rest) { + if self.is_upvar(&place) { + self.push_capture(CapturedItemWithoutTy { place, kind: capture_kind }); + } + } + } + self.walk_expr_with_adjust(tgt_expr, rest); + } + + fn walk_expr(&mut self, tgt_expr: ExprId) { + if let Some(x) = self.result.expr_adjustments.get_mut(&tgt_expr) { + // FIXME: this take is completely unneeded, and just is here to make borrow checker + // happy. Remove it if you can. + let x_taken = mem::take(x); + self.walk_expr_with_adjust(tgt_expr, &x_taken); + *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken; + } else { + self.walk_expr_without_adjust(tgt_expr); + } + } + + fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) { + match &self.body[tgt_expr] { + Expr::If { condition, then_branch, else_branch } => { + self.consume_expr(*condition); + self.consume_expr(*then_branch); + if let &Some(expr) = else_branch { + self.consume_expr(expr); + } + } + Expr::Async { statements, tail, .. } + | Expr::Const { statements, tail, .. } + | Expr::Unsafe { statements, tail, .. } + | Expr::Block { statements, tail, .. } => { + for s in statements.iter() { + match s { + Statement::Let { pat, type_ref: _, initializer, else_branch } => { + if let Some(else_branch) = else_branch { + self.consume_expr(*else_branch); + if let Some(initializer) = initializer { + self.consume_expr(*initializer); + } + return; + } + if let Some(initializer) = initializer { + self.walk_expr(*initializer); + if let Some(place) = self.place_of_expr(*initializer) { + let ty = self.expr_ty(*initializer); + self.consume_with_pat( + place, + ty, + BindingAnnotation::Unannotated, + *pat, + ); + } + } + } + Statement::Expr { expr, has_semi: _ } => { + self.consume_expr(*expr); + } + } + } + if let Some(tail) = tail { + self.consume_expr(*tail); + } + } + Expr::While { condition, body, label: _ } + | Expr::For { iterable: condition, pat: _, body, label: _ } => { + self.consume_expr(*condition); + self.consume_expr(*body); + } + Expr::Call { callee, args, is_assignee_expr: _ } => { + self.consume_expr(*callee); + self.consume_exprs(args.iter().copied()); + } + Expr::MethodCall { receiver, args, .. } => { + self.consume_expr(*receiver); + self.consume_exprs(args.iter().copied()); + } + Expr::Match { expr, arms } => { + self.consume_expr(*expr); + for arm in arms.iter() { + self.consume_expr(arm.expr); + } + } + Expr::Break { expr, label: _ } + | Expr::Return { expr } + | Expr::Yield { expr } + | Expr::Yeet { expr } => { + if let &Some(expr) = expr { + self.consume_expr(expr); + } + } + Expr::RecordLit { fields, spread, .. } => { + if let &Some(expr) = spread { + self.consume_expr(expr); + } + self.consume_exprs(fields.iter().map(|x| x.expr)); + } + Expr::Field { expr, name: _ } => self.select_from_expr(*expr), + Expr::UnaryOp { expr, op: UnaryOp::Deref } => { + if let Some((f, _)) = self.result.method_resolution(tgt_expr) { + let mutability = 'b: { + if let Some(deref_trait) = + self.resolve_lang_item(LangItem::DerefMut).and_then(|x| x.as_trait()) + { + if let Some(deref_fn) = + self.db.trait_data(deref_trait).method_by_name(&name![deref_mut]) + { + break 'b deref_fn == f; + } + } + false + }; + if mutability { + self.mutate_expr(*expr); + } else { + self.ref_expr(*expr); + } + } else { + self.select_from_expr(*expr); + } + } + Expr::UnaryOp { expr, op: _ } + | Expr::Array(Array::Repeat { initializer: expr, repeat: _ }) + | Expr::Await { expr } + | Expr::Loop { body: expr, label: _ } + | Expr::Let { pat: _, expr } + | Expr::Box { expr } + | Expr::Cast { expr, type_ref: _ } => { + self.consume_expr(*expr); + } + Expr::Ref { expr, rawness: _, mutability } => match mutability { + hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr), + hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr), + }, + Expr::BinaryOp { lhs, rhs, op } => { + let Some(op) = op else { + return; + }; + if matches!(op, BinaryOp::Assignment { .. }) { + self.mutate_expr(*lhs); + self.consume_expr(*rhs); + return; + } + self.consume_expr(*lhs); + self.consume_expr(*rhs); + } + Expr::Range { lhs, rhs, range_type: _ } => { + if let &Some(expr) = lhs { + self.consume_expr(expr); + } + if let &Some(expr) = rhs { + self.consume_expr(expr); + } + } + Expr::Index { base, index } => { + self.select_from_expr(*base); + self.consume_expr(*index); + } + Expr::Closure { .. } => { + let ty = self.expr_ty(tgt_expr); + let TyKind::Closure(id, _) = ty.kind(Interner) else { + never!("closure type is always closure"); + return; + }; + let (captures, _) = + self.result.closure_info.get(id).expect( + "We sort closures, so we should always have data for inner closures", + ); + let mut cc = mem::take(&mut self.current_captures); + cc.extend( + captures + .iter() + .filter(|x| self.is_upvar(&x.place)) + .map(|x| CapturedItemWithoutTy { place: x.place.clone(), kind: x.kind }), + ); + self.current_captures = cc; + } + Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) + | Expr::Tuple { exprs, is_assignee_expr: _ } => { + self.consume_exprs(exprs.iter().copied()) + } + Expr::Missing + | Expr::Continue { .. } + | Expr::Path(_) + | Expr::Literal(_) + | Expr::Underscore => (), + } + } + + fn expr_ty(&mut self, expr: ExprId) -> Ty { + self.infer_expr_no_expect(expr) + } + + fn is_upvar(&self, place: &HirPlace) -> bool { + let b = &self.body[place.local]; + if let Some(c) = self.current_closure { + let (_, root) = self.db.lookup_intern_closure(c.into()); + return b.is_upvar(root); + } + false + } + + fn is_ty_copy(&self, ty: Ty) -> bool { + if let TyKind::Closure(id, _) = ty.kind(Interner) { + // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We + // should probably let chalk know which closures are copy, but I don't know how doing it + // without creating query cycles. + return self.result.closure_info.get(id).map(|x| x.1 == FnTrait::Fn).unwrap_or(true); + } + let crate_id = self.owner.module(self.db.upcast()).krate(); + let Some(copy_trait) = self.db.lang_item(crate_id, LangItem::Copy).and_then(|x| x.as_trait()) else { + return false; + }; + let trait_ref = TyBuilder::trait_ref(self.db, copy_trait).push(ty).build(); + let env = self.db.trait_environment_for_body(self.owner); + let goal = Canonical { + value: InEnvironment::new(&env.env, trait_ref.cast(Interner)), + binders: CanonicalVarKinds::empty(Interner), + }; + self.db.trait_solve(crate_id, None, goal).is_some() + } + + fn select_from_expr(&mut self, expr: ExprId) { + self.walk_expr(expr); + } + + fn adjust_for_move_closure(&mut self) { + for capture in &mut self.current_captures { + if let Some(first_deref) = + capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref) + { + capture.place.projections.truncate(first_deref); + } + capture.kind = CaptureKind::ByValue; + } + } + + fn minimize_captures(&mut self) { + self.current_captures.sort_by_key(|x| x.place.projections.len()); + let mut hash_map = HashMap::::new(); + let result = mem::take(&mut self.current_captures); + for item in result { + let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] }; + let mut it = item.place.projections.iter(); + let prev_index = loop { + if let Some(k) = hash_map.get(&lookup_place) { + break Some(*k); + } + match it.next() { + Some(x) => lookup_place.projections.push(x.clone()), + None => break None, + } + }; + match prev_index { + Some(p) => { + self.current_captures[p].kind = + cmp::max(item.kind, self.current_captures[p].kind); + } + None => { + hash_map.insert(item.place.clone(), self.current_captures.len()); + self.current_captures.push(item); + } + } + } + } + + fn consume_with_pat( + &mut self, + mut place: HirPlace, + mut ty: Ty, + mut bm: BindingAnnotation, + pat: PatId, + ) { + match &self.body[pat] { + Pat::Missing | Pat::Wild => (), + Pat::Tuple { args, ellipsis } => { + pattern_matching_dereference(&mut ty, &mut bm, &mut place); + let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); + let subst = match ty.kind(Interner) { + TyKind::Tuple(_, s) => s, + _ => return, + }; + let fields = subst.iter(Interner).map(|x| x.assert_ty_ref(Interner)).enumerate(); + let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); + for (arg, (i, ty)) in it { + let mut p = place.clone(); + p.projections.push(ProjectionElem::TupleOrClosureField(i)); + self.consume_with_pat(p, ty.clone(), bm, *arg); + } + } + Pat::Or(pats) => { + for pat in pats.iter() { + self.consume_with_pat(place.clone(), ty.clone(), bm, *pat); + } + } + Pat::Record { args, .. } => { + pattern_matching_dereference(&mut ty, &mut bm, &mut place); + let subst = match ty.kind(Interner) { + TyKind::Adt(_, s) => s, + _ => return, + }; + let Some(variant) = self.result.variant_resolution_for_pat(pat) else { + return; + }; + match variant { + VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { + self.consume_place(place) + } + VariantId::StructId(s) => { + let vd = &*self.db.struct_data(s).variant_data; + let field_types = self.db.field_types(variant); + for field_pat in args.iter() { + let arg = field_pat.pat; + let Some(local_id) = vd.field(&field_pat.name) else { + continue; + }; + let mut p = place.clone(); + p.projections.push(ProjectionElem::Field(FieldId { + parent: variant.into(), + local_id, + })); + self.consume_with_pat( + p, + field_types[local_id].clone().substitute(Interner, subst), + bm, + arg, + ); + } + } + } + } + Pat::Range { .. } + | Pat::Slice { .. } + | Pat::ConstBlock(_) + | Pat::Path(_) + | Pat::Lit(_) => self.consume_place(place), + Pat::Bind { id, subpat: _ } => { + let mode = self.body.bindings[*id].mode; + if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { + bm = mode; + } + let capture_kind = match bm { + BindingAnnotation::Unannotated | BindingAnnotation::Mutable => { + self.consume_place(place); + return; + } + BindingAnnotation::Ref => BorrowKind::Shared, + BindingAnnotation::RefMut => BorrowKind::Mut { allow_two_phase_borrow: false }, + }; + self.add_capture(place, CaptureKind::ByRef(capture_kind)); + } + Pat::TupleStruct { path: _, args, ellipsis } => { + pattern_matching_dereference(&mut ty, &mut bm, &mut place); + let subst = match ty.kind(Interner) { + TyKind::Adt(_, s) => s, + _ => return, + }; + let Some(variant) = self.result.variant_resolution_for_pat(pat) else { + return; + }; + match variant { + VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { + self.consume_place(place) + } + VariantId::StructId(s) => { + let vd = &*self.db.struct_data(s).variant_data; + let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); + let fields = vd.fields().iter(); + let it = + al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); + let field_types = self.db.field_types(variant); + for (arg, (i, _)) in it { + let mut p = place.clone(); + p.projections.push(ProjectionElem::Field(FieldId { + parent: variant.into(), + local_id: i, + })); + self.consume_with_pat( + p, + field_types[i].clone().substitute(Interner, subst), + bm, + *arg, + ); + } + } + } + } + Pat::Ref { pat, mutability: _ } => { + if let Some((inner, _, _)) = ty.as_reference() { + ty = inner.clone(); + place.projections.push(ProjectionElem::Deref); + self.consume_with_pat(place, ty, bm, *pat) + } + } + Pat::Box { .. } => (), // not supported + } + } + + fn consume_exprs(&mut self, exprs: impl Iterator) { + for expr in exprs { + self.consume_expr(expr); + } + } + + fn closure_kind(&self) -> FnTrait { + let mut r = FnTrait::Fn; + for x in &self.current_captures { + r = cmp::min( + r, + match &x.kind { + CaptureKind::ByRef(BorrowKind::Unique | BorrowKind::Mut { .. }) => { + FnTrait::FnMut + } + CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn, + CaptureKind::ByValue => FnTrait::FnOnce, + }, + ) + } + r + } + + fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait { + let (_, root) = self.db.lookup_intern_closure(closure.into()); + self.current_closure = Some(closure); + let Expr::Closure { body, capture_by, .. } = &self.body[root] else { + unreachable!("Closure expression id is always closure"); + }; + self.consume_expr(*body); + for item in &self.current_captures { + if matches!(item.kind, CaptureKind::ByRef(BorrowKind::Mut { .. })) { + // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in + // MIR. I didn't do that due duplicate diagnostics. + self.result.mutated_bindings_in_closure.insert(item.place.local); + } + } + // closure_kind should be done before adjust_for_move_closure + let closure_kind = self.closure_kind(); + match capture_by { + CaptureBy::Value => self.adjust_for_move_closure(), + CaptureBy::Ref => (), + } + self.minimize_captures(); + let result = mem::take(&mut self.current_captures); + let captures = result.into_iter().map(|x| x.with_ty(self)).collect::>(); + self.result.closure_info.insert(closure, (captures, closure_kind)); + closure_kind + } + + pub(crate) fn infer_closures(&mut self) { + let deferred_closures = self.sort_closures(); + for (closure, exprs) in deferred_closures.into_iter().rev() { + self.current_captures = vec![]; + let kind = self.analyze_closure(closure); + + for (derefed_callee, callee_ty, params, expr) in exprs { + if let &Expr::Call { callee, .. } = &self.body[expr] { + let mut adjustments = + self.result.expr_adjustments.remove(&callee).unwrap_or_default(); + self.write_fn_trait_method_resolution( + kind, + &derefed_callee, + &mut adjustments, + &callee_ty, + ¶ms, + expr, + ); + self.result.expr_adjustments.insert(callee, adjustments); + } + } + } + } + + /// We want to analyze some closures before others, to have a correct analysis: + /// * We should analyze nested closures before the parent, since the parent should capture some of + /// the things that its children captures. + /// * If a closure calls another closure, we need to analyze the callee, to find out how we should + /// capture it (e.g. by move for FnOnce) + /// + /// These dependencies are collected in the main inference. We do a topological sort in this function. It + /// will consume the `deferred_closures` field and return its content in a sorted vector. + fn sort_closures(&mut self) -> Vec<(ClosureId, Vec<(Ty, Ty, Vec, ExprId)>)> { + let mut deferred_closures = mem::take(&mut self.deferred_closures); + let mut dependents_count: FxHashMap = + deferred_closures.keys().map(|x| (*x, 0)).collect(); + for (_, deps) in &self.closure_dependecies { + for dep in deps { + *dependents_count.entry(*dep).or_default() += 1; + } + } + let mut queue: Vec<_> = + deferred_closures.keys().copied().filter(|x| dependents_count[x] == 0).collect(); + let mut result = vec![]; + while let Some(x) = queue.pop() { + if let Some(d) = deferred_closures.remove(&x) { + result.push((x, d)); + } + for dep in self.closure_dependecies.get(&x).into_iter().flat_map(|x| x.iter()) { + let cnt = dependents_count.get_mut(dep).unwrap(); + *cnt -= 1; + if *cnt == 0 { + queue.push(*dep); + } + } + } + result + } +} + +fn apply_adjusts_to_place(mut r: HirPlace, adjustments: &[Adjustment]) -> Option { + for adj in adjustments { + match &adj.kind { + Adjust::Deref(None) => { + r.projections.push(ProjectionElem::Deref); + } + _ => return None, + } + } + Some(r) +} + +fn pattern_matching_dereference( + cond_ty: &mut Ty, + binding_mode: &mut BindingAnnotation, + cond_place: &mut HirPlace, +) { + let cnt = pattern_matching_dereference_count(cond_ty, binding_mode); + cond_place.projections.extend((0..cnt).map(|_| ProjectionElem::Deref)); +} diff --git a/crates/hir-ty/src/infer/coerce.rs b/crates/hir-ty/src/infer/coerce.rs index f2e1ab269c71c..2249d84edbfc9 100644 --- a/crates/hir-ty/src/infer/coerce.rs +++ b/crates/hir-ty/src/infer/coerce.rs @@ -7,7 +7,7 @@ use std::{iter, sync::Arc}; -use chalk_ir::{cast::Cast, BoundVar, Goal, Mutability, TyVariableKind}; +use chalk_ir::{cast::Cast, BoundVar, Goal, Mutability, TyKind, TyVariableKind}; use hir_def::{ hir::ExprId, lang_item::{LangItem, LangItemTarget}, @@ -22,7 +22,7 @@ use crate::{ TypeError, TypeMismatch, }, static_lifetime, Canonical, DomainGoal, FnPointer, FnSig, Guidance, InEnvironment, Interner, - Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, + Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, }; use super::unify::InferenceTable; @@ -111,6 +111,8 @@ impl CoerceMany { // pointers to have a chance at getting a match. See // https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916 let sig = match (self.merged_ty().kind(Interner), expr_ty.kind(Interner)) { + (TyKind::FnDef(x, _), TyKind::FnDef(y, _)) if x == y => None, + (TyKind::Closure(x, _), TyKind::Closure(y, _)) if x == y => None, (TyKind::FnDef(..) | TyKind::Closure(..), TyKind::FnDef(..) | TyKind::Closure(..)) => { // FIXME: we're ignoring safety here. To be more correct, if we have one FnDef and one Closure, // we should be coercing the closure to a fn pointer of the safety of the FnDef diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index d6a205e6086f5..64f37c761e963 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -221,7 +221,7 @@ impl<'a> InferenceContext<'a> { self.diverges = Diverges::Maybe; TyBuilder::unit() } - Expr::Closure { body, args, ret_type, arg_types, closure_kind } => { + Expr::Closure { body, args, ret_type, arg_types, closure_kind, capture_by: _ } => { assert_eq!(args.len(), arg_types.len()); let mut sig_tys = Vec::with_capacity(arg_types.len() + 1); @@ -256,7 +256,7 @@ impl<'a> InferenceContext<'a> { }) .intern(Interner); - let (ty, resume_yield_tys) = match closure_kind { + let (id, ty, resume_yield_tys) = match closure_kind { ClosureKind::Generator(_) => { // FIXME: report error when there are more than 1 parameter. let resume_ty = match sig_tys.first() { @@ -276,7 +276,7 @@ impl<'a> InferenceContext<'a> { let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into(); let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner); - (generator_ty, Some((resume_ty, yield_ty))) + (None, generator_ty, Some((resume_ty, yield_ty))) } ClosureKind::Closure | ClosureKind::Async => { let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into(); @@ -285,8 +285,11 @@ impl<'a> InferenceContext<'a> { Substitution::from1(Interner, sig_ty.clone()), ) .intern(Interner); - - (closure_ty, None) + self.deferred_closures.entry(closure_id).or_default(); + if let Some(c) = self.current_closure { + self.closure_dependecies.entry(c).or_default().push(closure_id); + } + (Some(closure_id), closure_ty, None) } }; @@ -302,6 +305,7 @@ impl<'a> InferenceContext<'a> { // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); + let prev_closure = mem::replace(&mut self.current_closure, id); let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); let prev_ret_coercion = mem::replace(&mut self.return_coercion, Some(CoerceMany::new(ret_ty))); @@ -315,6 +319,7 @@ impl<'a> InferenceContext<'a> { self.diverges = prev_diverges; self.return_ty = prev_ret_ty; self.return_coercion = prev_ret_coercion; + self.current_closure = prev_closure; self.resume_yield_tys = prev_resume_yield_tys; ty @@ -340,43 +345,28 @@ impl<'a> InferenceContext<'a> { let (param_tys, ret_ty) = match res { Some((func, params, ret_ty)) => { let mut adjustments = auto_deref_adjust_steps(&derefs); - if let Some(fn_x) = func { - match fn_x { - FnTrait::FnOnce => (), - FnTrait::FnMut => { - if !matches!( - derefed_callee.kind(Interner), - TyKind::Ref(Mutability::Mut, _, _) - ) { - adjustments.push(Adjustment::borrow( - Mutability::Mut, - derefed_callee.clone(), - )); - } - } - FnTrait::Fn => { - if !matches!( - derefed_callee.kind(Interner), - TyKind::Ref(Mutability::Not, _, _) - ) { - adjustments.push(Adjustment::borrow( - Mutability::Not, - derefed_callee.clone(), - )); - } - } - } - let trait_ = fn_x - .get_id(self.db, self.table.trait_env.krate) - .expect("We just used it"); - let trait_data = self.db.trait_data(trait_); - if let Some(func) = trait_data.method_by_name(&fn_x.method_name()) { - let subst = TyBuilder::subst_for_def(self.db, trait_, None) - .push(callee_ty.clone()) - .push(TyBuilder::tuple_with(params.iter().cloned())) - .build(); - self.write_method_resolution(tgt_expr, func, subst) + if let TyKind::Closure(c, _) = + self.table.resolve_completely(callee_ty.clone()).kind(Interner) + { + if let Some(par) = self.current_closure { + self.closure_dependecies.entry(par).or_default().push(*c); } + self.deferred_closures.entry(*c).or_default().push(( + derefed_callee.clone(), + callee_ty.clone(), + params.clone(), + tgt_expr, + )); + } + if let Some(fn_x) = func { + self.write_fn_trait_method_resolution( + fn_x, + &derefed_callee, + &mut adjustments, + &callee_ty, + ¶ms, + tgt_expr, + ); } self.write_expr_adj(*callee, adjustments); (params, ret_ty) @@ -906,6 +896,41 @@ impl<'a> InferenceContext<'a> { TyKind::OpaqueType(opaque_ty_id, Substitution::from1(Interner, inner_ty)).intern(Interner) } + pub(crate) fn write_fn_trait_method_resolution( + &mut self, + fn_x: FnTrait, + derefed_callee: &Ty, + adjustments: &mut Vec, + callee_ty: &Ty, + params: &Vec, + tgt_expr: ExprId, + ) { + match fn_x { + FnTrait::FnOnce => (), + FnTrait::FnMut => { + if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Mut, _, _)) { + adjustments.push(Adjustment::borrow(Mutability::Mut, derefed_callee.clone())); + } + } + FnTrait::Fn => { + if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) { + adjustments.push(Adjustment::borrow(Mutability::Not, derefed_callee.clone())); + } + } + } + let Some(trait_) = fn_x.get_id(self.db, self.table.trait_env.krate) else { + return; + }; + let trait_data = self.db.trait_data(trait_); + if let Some(func) = trait_data.method_by_name(&fn_x.method_name()) { + let subst = TyBuilder::subst_for_def(self.db, trait_, None) + .push(callee_ty.clone()) + .push(TyBuilder::tuple_with(params.iter().cloned())) + .build(); + self.write_method_resolution(tgt_expr, func, subst.clone()); + } + } + fn infer_expr_array( &mut self, array: &Array, diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs index 52f3563421522..f344b0610c99c 100644 --- a/crates/hir-ty/src/infer/mutability.rs +++ b/crates/hir-ty/src/infer/mutability.rs @@ -3,7 +3,7 @@ use chalk_ir::Mutability; use hir_def::{ - hir::{Array, BindingAnnotation, Expr, ExprId, PatId, Statement, UnaryOp}, + hir::{Array, BinaryOp, BindingAnnotation, Expr, ExprId, PatId, Statement, UnaryOp}, lang_item::LangItem, }; use hir_expand::name; @@ -80,6 +80,9 @@ impl<'a> InferenceContext<'a> { self.infer_mut_expr(*expr, m); for arm in arms.iter() { self.infer_mut_expr(arm.expr, Mutability::Not); + if let Some(g) = arm.guard { + self.infer_mut_expr(g, Mutability::Not); + } } } Expr::Yield { expr } @@ -158,14 +161,19 @@ impl<'a> InferenceContext<'a> { let mutability = lower_to_chalk_mutability(*mutability); self.infer_mut_expr(*expr, mutability); } + Expr::BinaryOp { lhs, rhs, op: Some(BinaryOp::Assignment { .. }) } => { + self.infer_mut_expr(*lhs, Mutability::Mut); + self.infer_mut_expr(*rhs, Mutability::Not); + } Expr::Array(Array::Repeat { initializer: lhs, repeat: rhs }) | Expr::BinaryOp { lhs, rhs, op: _ } | Expr::Range { lhs: Some(lhs), rhs: Some(rhs), range_type: _ } => { self.infer_mut_expr(*lhs, Mutability::Not); self.infer_mut_expr(*rhs, Mutability::Not); } - // not implemented - Expr::Closure { .. } => (), + Expr::Closure { body, .. } => { + self.infer_mut_expr(*body, Mutability::Not); + } Expr::Tuple { exprs, is_assignee_expr: _ } | Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) => { self.infer_mut_not_expr_iter(exprs.iter().copied()); diff --git a/crates/hir-ty/src/layout.rs b/crates/hir-ty/src/layout.rs index b95bb01fcefeb..277998b617844 100644 --- a/crates/hir-ty/src/layout.rs +++ b/crates/hir-ty/src/layout.rs @@ -229,7 +229,20 @@ pub fn layout_of_ty(db: &dyn HirDatabase, ty: &Ty, krate: CrateId) -> Result { + TyKind::Closure(c, _) => { + let (def, _) = db.lookup_intern_closure((*c).into()); + let infer = db.infer(def); + let (captures, _) = infer.closure_info(c); + let fields = captures + .iter() + .map(|x| layout_of_ty(db, &x.ty, krate)) + .collect::, _>>()?; + let fields = fields.iter().collect::>(); + let fields = fields.iter().collect::>(); + cx.univariant(dl, &fields, &ReprOptions::default(), StructKind::AlwaysSized) + .ok_or(LayoutError::Unknown)? + } + TyKind::Generator(_, _) | TyKind::GeneratorWitness(_, _) => { return Err(LayoutError::NotImplemented) } TyKind::AssociatedType(_, _) diff --git a/crates/hir-ty/src/layout/tests.rs b/crates/hir-ty/src/layout/tests.rs index a8971fde3c21e..43ace8ff0326c 100644 --- a/crates/hir-ty/src/layout/tests.rs +++ b/crates/hir-ty/src/layout/tests.rs @@ -11,6 +11,8 @@ use crate::{db::HirDatabase, test_db::TestDB, Interner, Substitution}; use super::layout_of_ty; +mod closure; + fn current_machine_data_layout() -> String { project_model::target_data_layout::get(None, None, &HashMap::default()).unwrap() } @@ -81,8 +83,8 @@ fn check_size_and_align(ra_fixture: &str, minicore: &str, size: u64, align: u64) #[track_caller] fn check_size_and_align_expr(ra_fixture: &str, minicore: &str, size: u64, align: u64) { let l = eval_expr(ra_fixture, minicore).unwrap(); - assert_eq!(l.size.bytes(), size); - assert_eq!(l.align.abi.bytes(), align); + assert_eq!(l.size.bytes(), size, "size mismatch"); + assert_eq!(l.align.abi.bytes(), align, "align mismatch"); } #[track_caller] @@ -118,13 +120,31 @@ macro_rules! size_and_align { }; } +#[macro_export] macro_rules! size_and_align_expr { + (minicore: $($x:tt),*; stmts: [$($s:tt)*] $($t:tt)*) => { + { + #[allow(dead_code)] + #[allow(unused_must_use)] + #[allow(path_statements)] + { + $($s)* + let val = { $($t)* }; + $crate::layout::tests::check_size_and_align_expr( + &format!("{{ {} let val = {{ {} }}; val }}", stringify!($($s)*), stringify!($($t)*)), + &format!("//- minicore: {}\n", stringify!($($x),*)), + ::std::mem::size_of_val(&val) as u64, + ::std::mem::align_of_val(&val) as u64, + ); + } + } + }; ($($t:tt)*) => { { #[allow(dead_code)] { let val = { $($t)* }; - check_size_and_align_expr( + $crate::layout::tests::check_size_and_align_expr( stringify!($($t)*), "", ::std::mem::size_of_val(&val) as u64, diff --git a/crates/hir-ty/src/layout/tests/closure.rs b/crates/hir-ty/src/layout/tests/closure.rs new file mode 100644 index 0000000000000..31b6765a7a2b2 --- /dev/null +++ b/crates/hir-ty/src/layout/tests/closure.rs @@ -0,0 +1,175 @@ +use crate::size_and_align_expr; + +#[test] +fn zero_capture_simple() { + size_and_align_expr! { + |x: i32| x + 2 + } +} + +#[test] +fn move_simple() { + size_and_align_expr! { + minicore: copy; + stmts: [] + let y: i32 = 5; + move |x: i32| { + x + y + } + } +} + +#[test] +fn ref_simple() { + size_and_align_expr! { + minicore: copy; + stmts: [ + let y: i32 = 5; + ] + |x: i32| { + x + y + } + } + size_and_align_expr! { + minicore: copy; + stmts: [ + let mut y: i32 = 5; + ] + |x: i32| { + y = y + x; + y + } + } + size_and_align_expr! { + minicore: copy; + stmts: [ + struct X(i32, i64); + let x: X = X(2, 6); + ] + || { + x + } + } +} + +#[test] +fn ref_then_mut_then_move() { + size_and_align_expr! { + minicore: copy; + stmts: [ + struct X(i32, i64); + let mut x: X = X(2, 6); + ] + || { + &x; + &mut x; + x; + } + } +} + +#[test] +fn nested_closures() { + size_and_align_expr! { + || { + || { + || { + let x = 2; + move || { + move || { + x + } + } + } + } + } + } +} + +#[test] +fn capture_specific_fields() { + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + move |x: i64| { + y.0 + x + (y.2 .0 as i64) + } + } + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + move |x: i64| { + let _ = &y; + y.0 + x + (y.2 .0 as i64) + } + } + size_and_align_expr! { + minicore: copy; + stmts: [ + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + ] + let y = &y; + move |x: i64| { + y.0 + x + (y.2 .0 as i64) + } + } + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + move |x: i64| { + let X(a, _, (b, _)) = y; + a + x + (b as i64) + } + } + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y = &&X(2, 5, (7, 3)); + move |x: i64| { + let X(a, _, (b, _)) = y; + *a + x + (*b as i64) + } + } + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + move |x: i64| { + match y { + X(a, _, (b, _)) => a + x + (b as i64), + } + } + } + size_and_align_expr! { + struct X(i64, i32, (u8, i128)); + let y: X = X(2, 5, (7, 3)); + move |x: i64| { + let X(a @ 2, _, (b, _)) = y else { return 5 }; + a + x + (b as i64) + } + } +} + +#[test] +fn ellipsis_pattern() { + size_and_align_expr! { + struct X(i8, u16, i32, u64, i128, u8); + let y: X = X(1, 2, 3, 4, 5, 6); + move |_: i64| { + let X(_a, .., _b, _c) = y; + } + } + size_and_align_expr! { + struct X { a: i32, b: u8, c: i128} + let y: X = X { a: 1, b: 2, c: 3 }; + move |_: i64| { + let X { a, b, .. } = y; + _ = (a, b); + } + } + size_and_align_expr! { + let y: (&&&(i8, u16, i32, u64, i128, u8), u16, i32, u64, i128, u8) = (&&&(1, 2, 3, 4, 5, 6), 2, 3, 4, 5, 6); + move |_: i64| { + let ((_a, .., _b, _c), .., _e, _f) = y; + } + } +} diff --git a/crates/hir-ty/src/mir.rs b/crates/hir-ty/src/mir.rs index ab5f199a0383b..2fa1bf2b7e645 100644 --- a/crates/hir-ty/src/mir.rs +++ b/crates/hir-ty/src/mir.rs @@ -3,7 +3,8 @@ use std::{fmt::Display, iter}; use crate::{ - infer::PointerCast, Const, ConstScalar, InferenceResult, Interner, MemoryMap, Substitution, Ty, + db::HirDatabase, infer::PointerCast, ClosureId, Const, ConstScalar, InferenceResult, Interner, + MemoryMap, Substitution, Ty, TyKind, }; use chalk_ir::Mutability; use hir_def::{ @@ -19,9 +20,11 @@ mod pretty; pub use borrowck::{borrowck_query, BorrowckResult, MutabilityReason}; pub use eval::{interpret_mir, pad16, Evaluator, MirEvalError}; -pub use lower::{lower_to_mir, mir_body_query, mir_body_recover, MirLowerError}; +pub use lower::{ + lower_to_mir, mir_body_for_closure_query, mir_body_query, mir_body_recover, MirLowerError, +}; use smallvec::{smallvec, SmallVec}; -use stdx::impl_from; +use stdx::{impl_from, never}; use super::consteval::{intern_const_scalar, try_const_usize}; @@ -89,11 +92,12 @@ impl Operand { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ProjectionElem { Deref, Field(FieldId), - TupleField(usize), + // FIXME: get rid of this, and use FieldId for tuples and closures + TupleOrClosureField(usize), Index(V), ConstantIndex { offset: u64, min_length: u64, from_end: bool }, Subslice { from: u64, to: u64, from_end: bool }, @@ -101,6 +105,63 @@ pub enum ProjectionElem { OpaqueCast(T), } +impl ProjectionElem { + pub fn projected_ty( + &self, + base: Ty, + db: &dyn HirDatabase, + closure_field: impl FnOnce(ClosureId, usize) -> Ty, + ) -> Ty { + match self { + ProjectionElem::Deref => match &base.data(Interner).kind { + TyKind::Raw(_, inner) | TyKind::Ref(_, _, inner) => inner.clone(), + _ => { + never!("Overloaded deref is not a projection"); + return TyKind::Error.intern(Interner); + } + }, + ProjectionElem::Field(f) => match &base.data(Interner).kind { + TyKind::Adt(_, subst) => { + db.field_types(f.parent)[f.local_id].clone().substitute(Interner, subst) + } + _ => { + never!("Only adt has field"); + return TyKind::Error.intern(Interner); + } + }, + ProjectionElem::TupleOrClosureField(f) => match &base.data(Interner).kind { + TyKind::Tuple(_, subst) => subst + .as_slice(Interner) + .get(*f) + .map(|x| x.assert_ty_ref(Interner)) + .cloned() + .unwrap_or_else(|| { + never!("Out of bound tuple field"); + TyKind::Error.intern(Interner) + }), + TyKind::Closure(id, _) => closure_field(*id, *f), + _ => { + never!("Only tuple or closure has tuple or closure field"); + return TyKind::Error.intern(Interner); + } + }, + ProjectionElem::Index(_) => match &base.data(Interner).kind { + TyKind::Array(inner, _) | TyKind::Slice(inner) => inner.clone(), + _ => { + never!("Overloaded index is not a projection"); + return TyKind::Error.intern(Interner); + } + }, + ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } + | ProjectionElem::OpaqueCast(_) => { + never!("We don't emit these yet"); + return TyKind::Error.intern(Interner); + } + } + } +} + type PlaceElem = ProjectionElem; #[derive(Debug, Clone, PartialEq, Eq)] @@ -123,7 +184,7 @@ pub enum AggregateKind { Tuple(Ty), Adt(VariantId, Substitution), Union(UnionId, FieldId), - //Closure(LocalDefId, SubstsRef), + Closure(Ty), //Generator(LocalDefId, SubstsRef, Movability), } @@ -418,7 +479,7 @@ pub enum Terminator { }, } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] pub enum BorrowKind { /// Data must be immutable and is aliasable. Shared, @@ -847,6 +908,87 @@ pub struct MirBody { pub arg_count: usize, pub binding_locals: ArenaMap, pub param_locals: Vec, + /// This field stores the closures directly owned by this body. It is used + /// in traversing every mir body. + pub closures: Vec, +} + +impl MirBody { + fn walk_places(&mut self, mut f: impl FnMut(&mut Place)) { + fn for_operand(op: &mut Operand, f: &mut impl FnMut(&mut Place)) { + match op { + Operand::Copy(p) | Operand::Move(p) => { + f(p); + } + Operand::Constant(_) => (), + } + } + for (_, block) in self.basic_blocks.iter_mut() { + for statement in &mut block.statements { + match &mut statement.kind { + StatementKind::Assign(p, r) => { + f(p); + match r { + Rvalue::ShallowInitBox(o, _) + | Rvalue::UnaryOp(_, o) + | Rvalue::Cast(_, o, _) + | Rvalue::Use(o) => for_operand(o, &mut f), + Rvalue::CopyForDeref(p) + | Rvalue::Discriminant(p) + | Rvalue::Len(p) + | Rvalue::Ref(_, p) => f(p), + Rvalue::CheckedBinaryOp(_, o1, o2) => { + for_operand(o1, &mut f); + for_operand(o2, &mut f); + } + Rvalue::Aggregate(_, ops) => { + for op in ops { + for_operand(op, &mut f); + } + } + } + } + StatementKind::Deinit(p) => f(p), + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Nop => (), + } + } + match &mut block.terminator { + Some(x) => match x { + Terminator::SwitchInt { discr, .. } => for_operand(discr, &mut f), + Terminator::FalseEdge { .. } + | Terminator::FalseUnwind { .. } + | Terminator::Goto { .. } + | Terminator::Resume + | Terminator::GeneratorDrop + | Terminator::Abort + | Terminator::Return + | Terminator::Unreachable => (), + Terminator::Drop { place, .. } => { + f(place); + } + Terminator::DropAndReplace { place, value, .. } => { + f(place); + for_operand(value, &mut f); + } + Terminator::Call { func, args, destination, .. } => { + for_operand(func, &mut f); + args.iter_mut().for_each(|x| for_operand(x, &mut f)); + f(destination); + } + Terminator::Assert { cond, .. } => { + for_operand(cond, &mut f); + } + Terminator::Yield { value, resume_arg, .. } => { + for_operand(value, &mut f); + f(resume_arg); + } + }, + None => (), + } + } + } } fn const_as_usize(c: &Const) -> usize { diff --git a/crates/hir-ty/src/mir/borrowck.rs b/crates/hir-ty/src/mir/borrowck.rs index c8729af86a9ea..5b2bca955f593 100644 --- a/crates/hir-ty/src/mir/borrowck.rs +++ b/crates/hir-ty/src/mir/borrowck.rs @@ -3,13 +3,13 @@ // Currently it is an ad-hoc implementation, only useful for mutability analysis. Feel free to remove all of these // if needed for implementing a proper borrow checker. -use std::sync::Arc; +use std::{iter, sync::Arc}; use hir_def::DefWithBodyId; use la_arena::ArenaMap; use stdx::never; -use crate::db::HirDatabase; +use crate::{db::HirDatabase, ClosureId}; use super::{ BasicBlockId, BorrowKind, LocalId, MirBody, MirLowerError, MirSpan, Place, ProjectionElem, @@ -29,14 +29,48 @@ pub struct BorrowckResult { pub mutability_of_locals: ArenaMap, } +fn all_mir_bodies( + db: &dyn HirDatabase, + def: DefWithBodyId, +) -> Box, MirLowerError>> + '_> { + fn for_closure( + db: &dyn HirDatabase, + c: ClosureId, + ) -> Box, MirLowerError>> + '_> { + match db.mir_body_for_closure(c) { + Ok(body) => { + let closures = body.closures.clone(); + Box::new( + iter::once(Ok(body)) + .chain(closures.into_iter().flat_map(|x| for_closure(db, x))), + ) + } + Err(e) => Box::new(iter::once(Err(e))), + } + } + match db.mir_body(def) { + Ok(body) => { + let closures = body.closures.clone(); + Box::new( + iter::once(Ok(body)).chain(closures.into_iter().flat_map(|x| for_closure(db, x))), + ) + } + Err(e) => Box::new(iter::once(Err(e))), + } +} + pub fn borrowck_query( db: &dyn HirDatabase, def: DefWithBodyId, -) -> Result, MirLowerError> { +) -> Result, MirLowerError> { let _p = profile::span("borrowck_query"); - let body = db.mir_body(def)?; - let r = BorrowckResult { mutability_of_locals: mutability_of_locals(&body), mir_body: body }; - Ok(Arc::new(r)) + let r = all_mir_bodies(db, def) + .map(|body| { + let body = body?; + Ok(BorrowckResult { mutability_of_locals: mutability_of_locals(&body), mir_body: body }) + }) + .collect::, MirLowerError>>()?; + Ok(r.into()) } fn is_place_direct(lvalue: &Place) -> bool { @@ -60,7 +94,7 @@ fn place_case(lvalue: &Place) -> ProjectionCase { ProjectionElem::ConstantIndex { .. } | ProjectionElem::Subslice { .. } | ProjectionElem::Field(_) - | ProjectionElem::TupleField(_) + | ProjectionElem::TupleOrClosureField(_) | ProjectionElem::Index(_) => { is_part_of = true; } diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 84b59b5eeb91a..8c911b7f7b778 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -25,8 +25,8 @@ use crate::{ mapping::from_chalk, method_resolution::{is_dyn_method, lookup_impl_method}, traits::FnTrait, - CallableDefId, Const, ConstScalar, FnDefId, GenericArgData, Interner, MemoryMap, Substitution, - TraitEnvironment, Ty, TyBuilder, TyExt, + CallableDefId, ClosureId, Const, ConstScalar, FnDefId, GenericArgData, Interner, MemoryMap, + Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, }; use super::{ @@ -92,6 +92,7 @@ pub struct Evaluator<'a> { enum Address { Stack(usize), Heap(usize), + Invalid(usize), } use Address::*; @@ -169,8 +170,10 @@ impl Address { fn from_usize(x: usize) -> Self { if x > usize::MAX / 2 { Stack(x - usize::MAX / 2) + } else if x > usize::MAX / 4 { + Heap(x - usize::MAX / 4) } else { - Heap(x) + Invalid(x) } } @@ -181,7 +184,8 @@ impl Address { fn to_usize(&self) -> usize { let as_num = match self { Stack(x) => *x + usize::MAX / 2, - Heap(x) => *x, + Heap(x) => *x + usize::MAX / 4, + Invalid(x) => *x, }; as_num } @@ -190,6 +194,7 @@ impl Address { match self { Stack(x) => Stack(f(*x)), Heap(x) => Heap(f(*x)), + Invalid(x) => Invalid(f(*x)), } } @@ -209,6 +214,7 @@ pub enum MirEvalError { UndefinedBehavior(&'static str), Panic(String), MirLowerError(FunctionId, MirLowerError), + MirLowerErrorForClosure(ClosureId, MirLowerError), TypeIsUnsized(Ty, &'static str), NotSupported(String), InvalidConst(Const), @@ -238,6 +244,9 @@ impl std::fmt::Debug for MirEvalError { Self::MirLowerError(arg0, arg1) => { f.debug_tuple("MirLowerError").field(arg0).field(arg1).finish() } + Self::MirLowerErrorForClosure(arg0, arg1) => { + f.debug_tuple("MirLowerError").field(arg0).field(arg1).finish() + } Self::InvalidVTableId(arg0) => f.debug_tuple("InvalidVTableId").field(arg0).finish(), Self::NotSupported(arg0) => f.debug_tuple("NotSupported").field(arg0).finish(), Self::InvalidConst(arg0) => { @@ -355,16 +364,15 @@ impl Evaluator<'_> { self.ty_filler(&locals.body.locals[p.local].ty, locals.subst, locals.body.owner)?; let mut metadata = None; // locals are always sized for proj in &p.projection { + let prev_ty = ty.clone(); + ty = proj.projected_ty(ty, self.db, |c, f| { + let (def, _) = self.db.lookup_intern_closure(c.into()); + let infer = self.db.infer(def); + let (captures, _) = infer.closure_info(&c); + captures.get(f).expect("broken closure field").ty.clone() + }); match proj { ProjectionElem::Deref => { - ty = match &ty.data(Interner).kind { - TyKind::Raw(_, inner) | TyKind::Ref(_, _, inner) => inner.clone(), - _ => { - return Err(MirEvalError::TypeError( - "Overloaded deref in MIR is disallowed", - )) - } - }; metadata = if self.size_of(&ty, locals)?.is_none() { Some(Interval { addr: addr.offset(self.ptr_size()), size: self.ptr_size() }) } else { @@ -377,78 +385,41 @@ impl Evaluator<'_> { let offset = from_bytes!(usize, self.read_memory(locals.ptr[*op], self.ptr_size())?); metadata = None; // Result of index is always sized - match &ty.data(Interner).kind { - TyKind::Ref(_, _, inner) => match &inner.data(Interner).kind { - TyKind::Slice(inner) => { - ty = inner.clone(); - let ty_size = self.size_of_sized( - &ty, - locals, - "slice inner type should be sized", - )?; - let value = self.read_memory(addr, self.ptr_size() * 2)?; - addr = Address::from_bytes(&value[0..8])?.offset(ty_size * offset); - } - x => not_supported!("MIR index for ref type {x:?}"), - }, - TyKind::Array(inner, _) | TyKind::Slice(inner) => { - ty = inner.clone(); - let ty_size = self.size_of_sized( - &ty, - locals, - "array inner type should be sized", - )?; - addr = addr.offset(ty_size * offset); + let ty_size = + self.size_of_sized(&ty, locals, "array inner type should be sized")?; + addr = addr.offset(ty_size * offset); + } + &ProjectionElem::TupleOrClosureField(f) => { + let layout = self.layout(&prev_ty)?; + let offset = layout.fields.offset(f).bytes_usize(); + addr = addr.offset(offset); + metadata = None; // tuple field is always sized + } + ProjectionElem::Field(f) => { + let layout = self.layout(&prev_ty)?; + let variant_layout = match &layout.variants { + Variants::Single { .. } => &layout, + Variants::Multiple { variants, .. } => { + &variants[match f.parent { + hir_def::VariantId::EnumVariantId(x) => { + RustcEnumVariantIdx(x.local_id) + } + _ => { + return Err(MirEvalError::TypeError( + "Multivariant layout only happens for enums", + )) + } + }] } - x => not_supported!("MIR index for type {x:?}"), - } + }; + let offset = variant_layout + .fields + .offset(u32::from(f.local_id.into_raw()) as usize) + .bytes_usize(); + addr = addr.offset(offset); + // FIXME: support structs with unsized fields + metadata = None; } - &ProjectionElem::TupleField(f) => match &ty.data(Interner).kind { - TyKind::Tuple(_, subst) => { - let layout = self.layout(&ty)?; - ty = subst - .as_slice(Interner) - .get(f) - .ok_or(MirEvalError::TypeError("not enough tuple fields"))? - .assert_ty_ref(Interner) - .clone(); - let offset = layout.fields.offset(f).bytes_usize(); - addr = addr.offset(offset); - metadata = None; // tuple field is always sized - } - _ => return Err(MirEvalError::TypeError("Only tuple has tuple fields")), - }, - ProjectionElem::Field(f) => match &ty.data(Interner).kind { - TyKind::Adt(adt, subst) => { - let layout = self.layout_adt(adt.0, subst.clone())?; - let variant_layout = match &layout.variants { - Variants::Single { .. } => &layout, - Variants::Multiple { variants, .. } => { - &variants[match f.parent { - hir_def::VariantId::EnumVariantId(x) => { - RustcEnumVariantIdx(x.local_id) - } - _ => { - return Err(MirEvalError::TypeError( - "Multivariant layout only happens for enums", - )) - } - }] - } - }; - ty = self.db.field_types(f.parent)[f.local_id] - .clone() - .substitute(Interner, subst); - let offset = variant_layout - .fields - .offset(u32::from(f.local_id.into_raw()) as usize) - .bytes_usize(); - addr = addr.offset(offset); - // FIXME: support structs with unsized fields - metadata = None; - } - _ => return Err(MirEvalError::TypeError("Only adt has fields")), - }, ProjectionElem::ConstantIndex { .. } => { not_supported!("constant index") } @@ -845,6 +816,15 @@ impl Evaluator<'_> { values.iter().copied(), )?) } + AggregateKind::Closure(ty) => { + let layout = self.layout(&ty)?; + Owned(self.make_by_layout( + layout.size.bytes_usize(), + &layout, + None, + values.iter().copied(), + )?) + } } } Rvalue::Cast(kind, operand, target_ty) => match kind { @@ -1065,6 +1045,9 @@ impl Evaluator<'_> { let (mem, pos) = match addr { Stack(x) => (&self.stack, x), Heap(x) => (&self.heap, x), + Invalid(_) => { + return Err(MirEvalError::UndefinedBehavior("read invalid memory address")) + } }; mem.get(pos..pos + size).ok_or(MirEvalError::UndefinedBehavior("out of bound memory read")) } @@ -1073,6 +1056,9 @@ impl Evaluator<'_> { let (mem, pos) = match addr { Stack(x) => (&mut self.stack, x), Heap(x) => (&mut self.heap, x), + Invalid(_) => { + return Err(MirEvalError::UndefinedBehavior("write invalid memory address")) + } }; mem.get_mut(pos..pos + r.len()) .ok_or(MirEvalError::UndefinedBehavior("out of bound memory write"))? @@ -1394,6 +1380,25 @@ impl Evaluator<'_> { Ok(()) } + fn exec_closure( + &mut self, + closure: ClosureId, + closure_data: Interval, + generic_args: &Substitution, + destination: Interval, + args: &[IntervalAndTy], + ) -> Result<()> { + let mir_body = self + .db + .mir_body_for_closure(closure) + .map_err(|x| MirEvalError::MirLowerErrorForClosure(closure, x))?; + let arg_bytes = iter::once(Ok(closure_data.get(self)?.to_owned())) + .chain(args.iter().map(|x| Ok(x.get(&self)?.to_owned()))) + .collect::>>()?; + let bytes = self.interpret_mir(&mir_body, arg_bytes.into_iter(), generic_args.clone())?; + destination.write_from_bytes(self, &bytes) + } + fn exec_fn_def( &mut self, def: FnDefId, @@ -1546,6 +1551,9 @@ impl Evaluator<'_> { TyKind::Function(_) => { self.exec_fn_pointer(func_data, destination, &args[1..], locals)?; } + TyKind::Closure(closure, subst) => { + self.exec_closure(*closure, func_data, subst, destination, &args[1..])?; + } x => not_supported!("Call FnTrait methods with type {x:?}"), } Ok(()) diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 7f3fdf343a2d8..78a2d90f7fb01 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -21,9 +21,16 @@ use la_arena::ArenaMap; use rustc_hash::FxHashMap; use crate::{ - consteval::ConstEvalError, db::HirDatabase, display::HirDisplay, infer::TypeMismatch, - inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, static_lifetime, - utils::generics, Adjust, Adjustment, AutoBorrow, CallableDefId, TyBuilder, TyExt, + consteval::ConstEvalError, + db::HirDatabase, + display::HirDisplay, + infer::{CaptureKind, CapturedItem, TypeMismatch}, + inhabitedness::is_ty_uninhabited_from, + layout::layout_of_ty, + mapping::ToChalk, + static_lifetime, + utils::generics, + Adjust, Adjustment, AutoBorrow, CallableDefId, TyBuilder, TyExt, }; use super::*; @@ -74,10 +81,12 @@ pub enum MirLowerError { BreakWithoutLoop, Loop, /// Something that should never happen and is definitely a bug, but we don't want to panic if it happened - ImplementationError(&'static str), + ImplementationError(String), LangItemNotFound(LangItem), MutatingRvalue, UnresolvedLabel, + UnresolvedUpvar(Place), + UnaccessableLocal, } macro_rules! not_supported { @@ -88,8 +97,8 @@ macro_rules! not_supported { macro_rules! implementation_error { ($x: expr) => {{ - ::stdx::never!("MIR lower implementation bug: {}", $x); - return Err(MirLowerError::ImplementationError($x)); + ::stdx::never!("MIR lower implementation bug: {}", format!($x)); + return Err(MirLowerError::ImplementationError(format!($x))); }}; } @@ -116,7 +125,44 @@ impl MirLowerError { type Result = std::result::Result; -impl MirLowerCtx<'_> { +impl<'ctx> MirLowerCtx<'ctx> { + fn new( + db: &'ctx dyn HirDatabase, + owner: DefWithBodyId, + body: &'ctx Body, + infer: &'ctx InferenceResult, + ) -> Self { + let mut basic_blocks = Arena::new(); + let start_block = basic_blocks.alloc(BasicBlock { + statements: vec![], + terminator: None, + is_cleanup: false, + }); + let locals = Arena::new(); + let binding_locals: ArenaMap = ArenaMap::new(); + let mir = MirBody { + basic_blocks, + locals, + start_block, + binding_locals, + param_locals: vec![], + owner, + arg_count: body.params.len(), + closures: vec![], + }; + let ctx = MirLowerCtx { + result: mir, + db, + infer, + body, + owner, + current_loop_blocks: None, + labeled_loop_blocks: Default::default(), + discr_temp: None, + }; + ctx + } + fn temp(&mut self, ty: Ty) -> Result { if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) { implementation_error!("unsized temporaries"); @@ -268,7 +314,7 @@ impl MirLowerCtx<'_> { self.push_assignment( current, place, - Operand::Copy(self.result.binding_locals[pat_id].into()).into(), + Operand::Copy(self.binding_local(pat_id)?.into()).into(), expr_id.into(), ); Ok(Some(current)) @@ -823,7 +869,51 @@ impl MirLowerCtx<'_> { ); Ok(Some(current)) }, - Expr::Closure { .. } => not_supported!("closure"), + Expr::Closure { .. } => { + let ty = self.expr_ty(expr_id); + let TyKind::Closure(id, _) = ty.kind(Interner) else { + not_supported!("closure with non closure type"); + }; + self.result.closures.push(*id); + let (captures, _) = self.infer.closure_info(id); + let mut operands = vec![]; + for capture in captures.iter() { + let p = Place { + local: self.binding_local(capture.place.local)?, + projection: capture.place.projections.clone().into_iter().map(|x| { + match x { + ProjectionElem::Deref => ProjectionElem::Deref, + ProjectionElem::Field(x) => ProjectionElem::Field(x), + ProjectionElem::TupleOrClosureField(x) => ProjectionElem::TupleOrClosureField(x), + ProjectionElem::ConstantIndex { offset, min_length, from_end } => ProjectionElem::ConstantIndex { offset, min_length, from_end }, + ProjectionElem::Subslice { from, to, from_end } => ProjectionElem::Subslice { from, to, from_end }, + ProjectionElem::OpaqueCast(x) => ProjectionElem::OpaqueCast(x), + ProjectionElem::Index(x) => match x { }, + } + }).collect(), + }; + match &capture.kind { + CaptureKind::ByRef(bk) => { + let tmp: Place = self.temp(capture.ty.clone())?.into(); + self.push_assignment( + current, + tmp.clone(), + Rvalue::Ref(bk.clone(), p), + expr_id.into(), + ); + operands.push(Operand::Move(tmp)); + }, + CaptureKind::ByValue => operands.push(Operand::Move(p)), + } + } + self.push_assignment( + current, + place, + Rvalue::Aggregate(AggregateKind::Closure(ty), operands), + expr_id.into(), + ); + Ok(Some(current)) + }, Expr::Tuple { exprs, is_assignee_expr: _ } => { let Some(values) = exprs .iter() @@ -893,7 +983,7 @@ impl MirLowerCtx<'_> { let index = name .as_tuple_index() .ok_or(MirLowerError::TypeError("named field on tuple"))?; - place.projection.push(ProjectionElem::TupleField(index)) + place.projection.push(ProjectionElem::TupleOrClosureField(index)) } else { let field = self.infer.field_resolution(expr_id).ok_or(MirLowerError::UnresolvedField)?; @@ -1126,8 +1216,9 @@ impl MirLowerCtx<'_> { }; self.set_goto(prev_block, begin); f(self, begin)?; - let my = mem::replace(&mut self.current_loop_blocks, prev) - .ok_or(MirLowerError::ImplementationError("current_loop_blocks is corrupt"))?; + let my = mem::replace(&mut self.current_loop_blocks, prev).ok_or( + MirLowerError::ImplementationError("current_loop_blocks is corrupt".to_string()), + )?; if let Some(prev) = prev_label { self.labeled_loop_blocks.insert(label.unwrap(), prev); } @@ -1159,7 +1250,9 @@ impl MirLowerCtx<'_> { let r = match self .current_loop_blocks .as_mut() - .ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))? + .ok_or(MirLowerError::ImplementationError( + "Current loop access out of loop".to_string(), + ))? .end { Some(x) => x, @@ -1167,7 +1260,9 @@ impl MirLowerCtx<'_> { let s = self.new_basic_block(); self.current_loop_blocks .as_mut() - .ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))? + .ok_or(MirLowerError::ImplementationError( + "Current loop access out of loop".to_string(), + ))? .end = Some(s); s } @@ -1181,7 +1276,7 @@ impl MirLowerCtx<'_> { /// This function push `StorageLive` statement for the binding, and applies changes to add `StorageDead` in /// the appropriated places. - fn push_storage_live(&mut self, b: BindingId, current: BasicBlockId) { + fn push_storage_live(&mut self, b: BindingId, current: BasicBlockId) -> Result<()> { // Current implementation is wrong. It adds no `StorageDead` at the end of scope, and before each break // and continue. It just add a `StorageDead` before the `StorageLive`, which is not wrong, but unneeeded in // the proper implementation. Due this limitation, implementing a borrow checker on top of this mir will falsely @@ -1206,9 +1301,10 @@ impl MirLowerCtx<'_> { .copied() .map(MirSpan::PatId) .unwrap_or(MirSpan::Unknown); - let l = self.result.binding_locals[b]; + let l = self.binding_local(b)?; self.push_statement(current, StatementKind::StorageDead(l).with_span(span)); self.push_statement(current, StatementKind::StorageLive(l).with_span(span)); + Ok(()) } fn resolve_lang_item(&self, item: LangItem) -> Result { @@ -1256,9 +1352,15 @@ impl MirLowerCtx<'_> { } } } else { + let mut err = None; self.body.walk_bindings_in_pat(*pat, |b| { - self.push_storage_live(b, current); + if let Err(e) = self.push_storage_live(b, current) { + err = Some(e); + } }); + if let Some(e) = err { + return Err(e); + } } } hir_def::hir::Statement::Expr { expr, has_semi: _ } => { @@ -1274,6 +1376,67 @@ impl MirLowerCtx<'_> { None => Ok(Some(current)), } } + + fn lower_params_and_bindings( + &mut self, + params: impl Iterator + Clone, + pick_binding: impl Fn(BindingId) -> bool, + ) -> Result { + let base_param_count = self.result.param_locals.len(); + self.result.param_locals.extend(params.clone().map(|(x, ty)| { + let local_id = self.result.locals.alloc(Local { ty }); + if let Pat::Bind { id, subpat: None } = self.body[x] { + if matches!( + self.body.bindings[id].mode, + BindingAnnotation::Unannotated | BindingAnnotation::Mutable + ) { + self.result.binding_locals.insert(id, local_id); + } + } + local_id + })); + // and then rest of bindings + for (id, _) in self.body.bindings.iter() { + if !pick_binding(id) { + continue; + } + if !self.result.binding_locals.contains_idx(id) { + self.result + .binding_locals + .insert(id, self.result.locals.alloc(Local { ty: self.infer[id].clone() })); + } + } + let mut current = self.result.start_block; + for ((param, _), local) in + params.zip(self.result.param_locals.clone().into_iter().skip(base_param_count)) + { + if let Pat::Bind { id, .. } = self.body[param] { + if local == self.binding_local(id)? { + continue; + } + } + let r = self.pattern_match( + current, + None, + local.into(), + self.result.locals[local].ty.clone(), + param, + BindingAnnotation::Unannotated, + )?; + if let Some(b) = r.1 { + self.set_terminator(b, Terminator::Unreachable); + } + current = r.0; + } + Ok(current) + } + + fn binding_local(&self, b: BindingId) -> Result { + match self.result.binding_locals.get(b) { + Some(x) => Ok(*x), + None => Err(MirLowerError::UnaccessableLocal), + } + } } fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result { @@ -1297,6 +1460,87 @@ fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result { }) } +pub fn mir_body_for_closure_query( + db: &dyn HirDatabase, + closure: ClosureId, +) -> Result> { + let (owner, expr) = db.lookup_intern_closure(closure.into()); + let body = db.body(owner); + let infer = db.infer(owner); + let Expr::Closure { args, body: root, .. } = &body[expr] else { + implementation_error!("closure expression is not closure"); + }; + let TyKind::Closure(_, substs) = &infer[expr].kind(Interner) else { + implementation_error!("closure expression is not closure"); + }; + let (captures, _) = infer.closure_info(&closure); + let mut ctx = MirLowerCtx::new(db, owner, &body, &infer); + ctx.result.arg_count = args.len() + 1; + // 0 is return local + ctx.result.locals.alloc(Local { ty: infer[*root].clone() }); + ctx.result.locals.alloc(Local { ty: infer[expr].clone() }); + let Some(sig) = substs.at(Interner, 0).assert_ty_ref(Interner).callable_sig(db) else { + implementation_error!("closure has not callable sig"); + }; + let current = ctx.lower_params_and_bindings( + args.iter().zip(sig.params().iter()).map(|(x, y)| (*x, y.clone())), + |_| true, + )?; + if let Some(b) = ctx.lower_expr_to_place(*root, return_slot().into(), current)? { + ctx.set_terminator(b, Terminator::Return); + } + let mut upvar_map: FxHashMap> = FxHashMap::default(); + for (i, capture) in captures.iter().enumerate() { + let local = ctx.binding_local(capture.place.local)?; + upvar_map.entry(local).or_default().push((capture, i)); + } + let mut err = None; + let closure_local = ctx.result.locals.iter().nth(1).unwrap().0; + ctx.result.walk_places(|p| { + if let Some(x) = upvar_map.get(&p.local) { + let r = x.iter().find(|x| { + if p.projection.len() < x.0.place.projections.len() { + return false; + } + for (x, y) in p.projection.iter().zip(x.0.place.projections.iter()) { + match (x, y) { + (ProjectionElem::Deref, ProjectionElem::Deref) => (), + (ProjectionElem::Field(x), ProjectionElem::Field(y)) if x == y => (), + ( + ProjectionElem::TupleOrClosureField(x), + ProjectionElem::TupleOrClosureField(y), + ) if x == y => (), + _ => return false, + } + } + true + }); + match r { + Some(x) => { + p.local = closure_local; + let prev_projs = + mem::replace(&mut p.projection, vec![PlaceElem::TupleOrClosureField(x.1)]); + if x.0.kind != CaptureKind::ByValue { + p.projection.push(ProjectionElem::Deref); + } + p.projection.extend(prev_projs.into_iter().skip(x.0.place.projections.len())); + } + None => err = Some(p.clone()), + } + } + }); + ctx.result.binding_locals = ctx + .result + .binding_locals + .into_iter() + .filter(|x| ctx.body[x.0].owner == Some(expr)) + .collect(); + if let Some(err) = err { + return Err(MirLowerError::UnresolvedUpvar(err)); + } + Ok(Arc::new(ctx.result)) +} + pub fn mir_body_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Result> { let _p = profile::span("mir_body_query").detail(|| match def { DefWithBodyId::FunctionId(it) => db.function_data(it).name.to_string(), @@ -1334,86 +1578,29 @@ pub fn lower_to_mir( if let Some((_, x)) = infer.type_mismatches().next() { return Err(MirLowerError::TypeMismatch(x.clone())); } - let mut basic_blocks = Arena::new(); - let start_block = - basic_blocks.alloc(BasicBlock { statements: vec![], terminator: None, is_cleanup: false }); - let mut locals = Arena::new(); + let mut ctx = MirLowerCtx::new(db, owner, body, infer); // 0 is return local - locals.alloc(Local { ty: infer[root_expr].clone() }); - let mut binding_locals: ArenaMap = ArenaMap::new(); + ctx.result.locals.alloc(Local { ty: infer[root_expr].clone() }); + let binding_picker = |b: BindingId| { + if root_expr == body.body_expr { + body[b].owner.is_none() + } else { + body[b].owner == Some(root_expr) + } + }; // 1 to param_len is for params - let param_locals: Vec = if let DefWithBodyId::FunctionId(fid) = owner { + let current = if let DefWithBodyId::FunctionId(fid) = owner { let substs = TyBuilder::placeholder_subst(db, fid); let callable_sig = db.callable_item_signature(fid.into()).substitute(Interner, &substs); - body.params - .iter() - .zip(callable_sig.params().iter()) - .map(|(&x, ty)| { - let local_id = locals.alloc(Local { ty: ty.clone() }); - if let Pat::Bind { id, subpat: None } = body[x] { - if matches!( - body.bindings[id].mode, - BindingAnnotation::Unannotated | BindingAnnotation::Mutable - ) { - binding_locals.insert(id, local_id); - } - } - local_id - }) - .collect() + ctx.lower_params_and_bindings( + body.params.iter().zip(callable_sig.params().iter()).map(|(x, y)| (*x, y.clone())), + binding_picker, + )? } else { - if !body.params.is_empty() { - return Err(MirLowerError::TypeError("Unexpected parameter for non function body")); - } - vec![] + ctx.lower_params_and_bindings([].into_iter(), binding_picker)? }; - // and then rest of bindings - for (id, _) in body.bindings.iter() { - if !binding_locals.contains_idx(id) { - binding_locals.insert(id, locals.alloc(Local { ty: infer[id].clone() })); - } - } - let mir = MirBody { - basic_blocks, - locals, - start_block, - binding_locals, - param_locals, - owner, - arg_count: body.params.len(), - }; - let mut ctx = MirLowerCtx { - result: mir, - db, - infer, - body, - owner, - current_loop_blocks: None, - labeled_loop_blocks: Default::default(), - discr_temp: None, - }; - let mut current = start_block; - for (¶m, local) in body.params.iter().zip(ctx.result.param_locals.clone().into_iter()) { - if let Pat::Bind { id, .. } = body[param] { - if local == ctx.result.binding_locals[id] { - continue; - } - } - let r = ctx.pattern_match( - current, - None, - local.into(), - ctx.result.locals[local].ty.clone(), - param, - BindingAnnotation::Unannotated, - )?; - if let Some(b) = r.1 { - ctx.set_terminator(b, Terminator::Unreachable); - } - current = r.0; - } if let Some(b) = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)? { - ctx.result.basic_blocks[b].terminator = Some(Terminator::Return); + ctx.set_terminator(b, Terminator::Return); } Ok(ctx.result) } diff --git a/crates/hir-ty/src/mir/lower/pattern_matching.rs b/crates/hir-ty/src/mir/lower/pattern_matching.rs index c3ced82aab7bd..12a77715e9c67 100644 --- a/crates/hir-ty/src/mir/lower/pattern_matching.rs +++ b/crates/hir-ty/src/mir/lower/pattern_matching.rs @@ -1,5 +1,7 @@ //! MIR lowering for patterns +use crate::utils::pattern_matching_dereference_count; + use super::*; macro_rules! not_supported { @@ -52,7 +54,7 @@ impl MirLowerCtx<'_> { args, *ellipsis, subst.iter(Interner).enumerate().map(|(i, x)| { - (PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone()) + (PlaceElem::TupleOrClosureField(i), x.assert_ty_ref(Interner).clone()) }), &cond_place, binding_mode, @@ -142,7 +144,7 @@ impl MirLowerCtx<'_> { if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { binding_mode = mode; } - self.push_storage_live(*id, current); + self.push_storage_live(*id, current)?; self.push_assignment( current, target_place.into(), @@ -387,13 +389,6 @@ fn pattern_matching_dereference( binding_mode: &mut BindingAnnotation, cond_place: &mut Place, ) { - while let Some((ty, _, mu)) = cond_ty.as_reference() { - if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref { - *binding_mode = BindingAnnotation::RefMut; - } else { - *binding_mode = BindingAnnotation::Ref; - } - *cond_ty = ty.clone(); - cond_place.projection.push(ProjectionElem::Deref); - } + let cnt = pattern_matching_dereference_count(cond_ty, binding_mode); + cond_place.projection.extend((0..cnt).map(|_| ProjectionElem::Deref)); } diff --git a/crates/hir-ty/src/mir/pretty.rs b/crates/hir-ty/src/mir/pretty.rs index d6253b378b2c0..3e1f2ecef1b5b 100644 --- a/crates/hir-ty/src/mir/pretty.rs +++ b/crates/hir-ty/src/mir/pretty.rs @@ -1,6 +1,9 @@ //! A pretty-printer for MIR. -use std::fmt::{Debug, Display, Write}; +use std::{ + fmt::{Debug, Display, Write}, + mem, +}; use hir_def::{body::Body, hir::BindingId}; use hir_expand::name::Name; @@ -20,7 +23,7 @@ impl MirBody { pub fn pretty_print(&self, db: &dyn HirDatabase) -> String { let hir_body = db.body(self.owner); let mut ctx = MirPrettyCtx::new(self, &hir_body, db); - ctx.for_body(); + ctx.for_body(ctx.body.owner); ctx.result } @@ -42,7 +45,7 @@ struct MirPrettyCtx<'a> { hir_body: &'a Body, db: &'a dyn HirDatabase, result: String, - ident: String, + indent: String, local_to_binding: ArenaMap, } @@ -88,22 +91,43 @@ impl Display for LocalName { } impl<'a> MirPrettyCtx<'a> { - fn for_body(&mut self) { - wln!(self, "// {:?}", self.body.owner); + fn for_body(&mut self, name: impl Debug) { + wln!(self, "// {:?}", name); self.with_block(|this| { this.locals(); wln!(this); this.blocks(); }); + for &closure in &self.body.closures { + let body = match self.db.mir_body_for_closure(closure) { + Ok(x) => x, + Err(e) => { + wln!(self, "// error in {closure:?}: {e:?}"); + continue; + } + }; + let result = mem::take(&mut self.result); + let indent = mem::take(&mut self.indent); + let mut ctx = MirPrettyCtx { + body: &body, + local_to_binding: body.binding_locals.iter().map(|(x, y)| (*y, x)).collect(), + result, + indent, + ..*self + }; + ctx.for_body(closure); + self.result = ctx.result; + self.indent = ctx.indent; + } } fn with_block(&mut self, f: impl FnOnce(&mut MirPrettyCtx<'_>)) { - self.ident += " "; + self.indent += " "; wln!(self, "{{"); f(self); for _ in 0..4 { self.result.pop(); - self.ident.pop(); + self.indent.pop(); } wln!(self, "}}"); } @@ -114,7 +138,7 @@ impl<'a> MirPrettyCtx<'a> { body, db, result: String::new(), - ident: String::new(), + indent: String::new(), local_to_binding, hir_body, } @@ -122,7 +146,7 @@ impl<'a> MirPrettyCtx<'a> { fn write_line(&mut self) { self.result.push('\n'); - self.result += &self.ident; + self.result += &self.indent; } fn write(&mut self, line: &str) { @@ -247,7 +271,7 @@ impl<'a> MirPrettyCtx<'a> { } } } - ProjectionElem::TupleField(x) => { + ProjectionElem::TupleOrClosureField(x) => { f(this, local, head); w!(this, ".{}", x); } @@ -302,6 +326,11 @@ impl<'a> MirPrettyCtx<'a> { self.operand_list(x); w!(self, ")"); } + Rvalue::Aggregate(AggregateKind::Closure(_), x) => { + w!(self, "Closure("); + self.operand_list(x); + w!(self, ")"); + } Rvalue::Aggregate(AggregateKind::Union(_, _), x) => { w!(self, "Union("); self.operand_list(x); diff --git a/crates/hir-ty/src/tests/coercion.rs b/crates/hir-ty/src/tests/coercion.rs index c165a7cb161e2..16e5ef85d09d5 100644 --- a/crates/hir-ty/src/tests/coercion.rs +++ b/crates/hir-ty/src/tests/coercion.rs @@ -575,7 +575,7 @@ fn two_closures_lub() { fn foo(c: i32) { let add = |a: i32, b: i32| a + b; let sub = |a, b| a - b; - //^^^^^^^^^^^^ |i32, i32| -> i32 + //^^^^^^^^^^^^ impl Fn(i32, i32) -> i32 if c > 42 { add } else { sub }; //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ fn(i32, i32) -> i32 } @@ -875,6 +875,16 @@ fn test() { fn adjust_index() { check_no_mismatches( r" +//- minicore: index, slice, coerce_unsized +fn test() { + let x = [1, 2, 3]; + x[2] = 6; + // ^ adjustments: Borrow(Ref(Mut)) +} + ", + ); + check_no_mismatches( + r" //- minicore: index struct Struct; impl core::ops::Index for Struct { diff --git a/crates/hir-ty/src/tests/macros.rs b/crates/hir-ty/src/tests/macros.rs index 8b75ec842a4f6..d45edf730ad5f 100644 --- a/crates/hir-ty/src/tests/macros.rs +++ b/crates/hir-ty/src/tests/macros.rs @@ -198,7 +198,7 @@ fn expr_macro_def_expanded_in_various_places() { 100..119 'for _ ...!() {}': () 104..105 '_': {unknown} 117..119 '{}': () - 124..134 '|| spam!()': || -> isize + 124..134 '|| spam!()': impl Fn() -> isize 140..156 'while ...!() {}': () 154..156 '{}': () 161..174 'break spam!()': ! @@ -279,7 +279,7 @@ fn expr_macro_rules_expanded_in_various_places() { 114..133 'for _ ...!() {}': () 118..119 '_': {unknown} 131..133 '{}': () - 138..148 '|| spam!()': || -> isize + 138..148 '|| spam!()': impl Fn() -> isize 154..170 'while ...!() {}': () 168..170 '{}': () 175..188 'break spam!()': ! diff --git a/crates/hir-ty/src/tests/patterns.rs b/crates/hir-ty/src/tests/patterns.rs index be67329fee441..c8c31bdea5cbf 100644 --- a/crates/hir-ty/src/tests/patterns.rs +++ b/crates/hir-ty/src/tests/patterns.rs @@ -70,8 +70,8 @@ fn infer_pattern() { 228..233 '&true': &bool 229..233 'true': bool 234..236 '{}': () - 246..252 'lambda': |u64, u64, i32| -> i32 - 255..287 '|a: u6...b; c }': |u64, u64, i32| -> i32 + 246..252 'lambda': impl Fn(u64, u64, i32) -> i32 + 255..287 '|a: u6...b; c }': impl Fn(u64, u64, i32) -> i32 256..257 'a': u64 264..265 'b': u64 267..268 'c': i32 @@ -677,25 +677,25 @@ fn test() { 51..58 'loop {}': ! 56..58 '{}': () 72..171 '{ ... x); }': () - 78..81 'foo': fn foo<&(i32, &str), i32, |&(i32, &str)| -> i32>(&(i32, &str), |&(i32, &str)| -> i32) -> i32 + 78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32 78..105 'foo(&(...y)| x)': i32 82..91 '&(1, "a")': &(i32, &str) 83..91 '(1, "a")': (i32, &str) 84..85 '1': i32 87..90 '"a"': &str - 93..104 '|&(x, y)| x': |&(i32, &str)| -> i32 + 93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32 94..101 '&(x, y)': &(i32, &str) 95..101 '(x, y)': (i32, &str) 96..97 'x': i32 99..100 'y': &str 103..104 'x': i32 - 142..145 'foo': fn foo<&(i32, &str), &i32, |&(i32, &str)| -> &i32>(&(i32, &str), |&(i32, &str)| -> &i32) -> &i32 + 142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32 142..168 'foo(&(...y)| x)': &i32 146..155 '&(1, "a")': &(i32, &str) 147..155 '(1, "a")': (i32, &str) 148..149 '1': i32 151..154 '"a"': &str - 157..167 '|(x, y)| x': |&(i32, &str)| -> &i32 + 157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32 158..164 '(x, y)': (i32, &str) 159..160 'x': &i32 162..163 'y': &&str diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index 74c7e3efc058f..d78d6eba7fb37 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -805,19 +805,19 @@ fn issue_4966() { 225..229 'iter': T 244..246 '{}': Vec 258..402 '{ ...r(); }': () - 268..273 'inner': Map<|&f64| -> f64> - 276..300 'Map { ... 0.0 }': Map<|&f64| -> f64> - 285..298 '|_: &f64| 0.0': |&f64| -> f64 + 268..273 'inner': Map f64> + 276..300 'Map { ... 0.0 }': Map f64> + 285..298 '|_: &f64| 0.0': impl Fn(&f64) -> f64 286..287 '_': &f64 295..298 '0.0': f64 - 311..317 'repeat': Repeat f64>> - 320..345 'Repeat...nner }': Repeat f64>> - 338..343 'inner': Map<|&f64| -> f64> - 356..359 'vec': Vec f64>>>> - 362..371 'from_iter': fn from_iter f64>>>, Repeat f64>>>(Repeat f64>>) -> Vec f64>>>> - 362..379 'from_i...epeat)': Vec f64>>>> - 372..378 'repeat': Repeat f64>> - 386..389 'vec': Vec f64>>>> + 311..317 'repeat': Repeat f64>> + 320..345 'Repeat...nner }': Repeat f64>> + 338..343 'inner': Map f64> + 356..359 'vec': Vec f64>>>> + 362..371 'from_iter': fn from_iter f64>>>, Repeat f64>>>(Repeat f64>>) -> Vec f64>>>> + 362..379 'from_i...epeat)': Vec f64>>>> + 372..378 'repeat': Repeat f64>> + 386..389 'vec': Vec f64>>>> 386..399 'vec.foo_bar()': {unknown} "#]], ); @@ -852,7 +852,7 @@ fn main() { 123..126 'S()': S 132..133 's': S 132..144 's.g(|_x| {})': () - 136..143 '|_x| {}': |&i32| -> () + 136..143 '|_x| {}': impl Fn(&i32) 137..139 '_x': &i32 141..143 '{}': () 150..151 's': S @@ -1759,13 +1759,14 @@ const C: usize = 2 + 2; #[test] fn regression_14456() { - check_no_mismatches( + check_types( r#" //- minicore: future async fn x() {} fn f() { let fut = x(); - let t = [0u8; 2 + 2]; + let t = [0u8; { let a = 2 + 2; a }]; + //^ [u8; 4] } "#, ); diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 0c037a39ec1f5..2b14a28a67221 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -1906,8 +1906,8 @@ fn closure_return() { "#, expect![[r#" 16..58 '{ ...; }; }': u32 - 26..27 'x': || -> usize - 30..55 '|| -> ...n 1; }': || -> usize + 26..27 'x': impl Fn() -> usize + 30..55 '|| -> ...n 1; }': impl Fn() -> usize 42..55 '{ return 1; }': usize 44..52 'return 1': ! 51..52 '1': usize @@ -1925,8 +1925,8 @@ fn closure_return_unit() { "#, expect![[r#" 16..47 '{ ...; }; }': u32 - 26..27 'x': || -> () - 30..44 '|| { return; }': || -> () + 26..27 'x': impl Fn() + 30..44 '|| { return; }': impl Fn() 33..44 '{ return; }': () 35..41 'return': ! "#]], @@ -1943,8 +1943,8 @@ fn closure_return_inferred() { "#, expect![[r#" 16..46 '{ ..." }; }': u32 - 26..27 'x': || -> &str - 30..43 '|| { "test" }': || -> &str + 26..27 'x': impl Fn() -> &str + 30..43 '|| { "test" }': impl Fn() -> &str 33..43 '{ "test" }': &str 35..41 '"test"': &str "#]], @@ -2050,7 +2050,7 @@ fn fn_pointer_return() { 47..120 '{ ...hod; }': () 57..63 'vtable': Vtable 66..90 'Vtable...| {} }': Vtable - 83..88 '|| {}': || -> () + 83..88 '|| {}': impl Fn() 86..88 '{}': () 100..101 'm': fn() 104..110 'vtable': Vtable @@ -2142,9 +2142,9 @@ fn main() { 149..151 'Ok': Ok<(), ()>(()) -> Result<(), ()> 149..155 'Ok(())': Result<(), ()> 152..154 '()': () - 167..171 'test': fn test<(), (), || -> impl Future>, impl Future>>(|| -> impl Future>) + 167..171 'test': fn test<(), (), impl Fn() -> impl Future>, impl Future>>(impl Fn() -> impl Future>) 167..228 'test(|... })': () - 172..227 '|| asy... }': || -> impl Future> + 172..227 '|| asy... }': impl Fn() -> impl Future> 175..227 'async ... }': impl Future> 191..205 'return Err(())': ! 198..201 'Err': Err<(), ()>(()) -> Result<(), ()> @@ -2270,8 +2270,8 @@ fn infer_labelled_break_with_val() { "#, expect![[r#" 9..335 '{ ... }; }': () - 19..21 '_x': || -> bool - 24..332 '|| 'ou... }': || -> bool + 19..21 '_x': impl Fn() -> bool + 24..332 '|| 'ou... }': impl Fn() -> bool 27..332 ''outer... }': bool 40..332 '{ ... }': () 54..59 'inner': i8 @@ -2695,6 +2695,179 @@ impl B for Astruct {} ) } +#[test] +fn capture_kinds_simple() { + check_types( + r#" +struct S; + +impl S { + fn read(&self) -> &S { self } + fn write(&mut self) -> &mut S { self } + fn consume(self) -> S { self } +} + +fn f() { + let x = S; + let c1 = || x.read(); + //^^ impl Fn() -> &S + let c2 = || x.write(); + //^^ impl FnMut() -> &mut S + let c3 = || x.consume(); + //^^ impl FnOnce() -> S + let c3 = || x.consume().consume().consume(); + //^^ impl FnOnce() -> S + let c3 = || x.consume().write().read(); + //^^ impl FnOnce() -> &S + let x = &mut x; + let c1 = || x.write(); + //^^ impl FnMut() -> &mut S + let x = S; + let c1 = || { let ref t = x; t }; + //^^ impl Fn() -> &S + let c2 = || { let ref mut t = x; t }; + //^^ impl FnMut() -> &mut S + let c3 = || { let t = x; t }; + //^^ impl FnOnce() -> S +} + "#, + ) +} + +#[test] +fn capture_kinds_closure() { + check_types( + r#" +//- minicore: copy, fn +fn f() { + let mut x = 2; + x = 5; + let mut c1 = || { x = 3; x }; + //^^^^^^ impl FnMut() -> i32 + let mut c2 = || { c1() }; + //^^^^^^ impl FnMut() -> i32 + let mut c1 = || { x }; + //^^^^^^ impl Fn() -> i32 + let mut c2 = || { c1() }; + //^^^^^^ impl Fn() -> i32 + struct X; + let x = X; + let mut c1 = || { x }; + //^^^^^^ impl FnOnce() -> X + let mut c2 = || { c1() }; + //^^^^^^ impl FnOnce() -> X +} + "#, + ); +} + +#[test] +fn capture_kinds_overloaded_deref() { + check_types( + r#" +//- minicore: fn, deref_mut +use core::ops::{Deref, DerefMut}; + +struct Foo; +impl Deref for Foo { + type Target = (i32, u8); + fn deref(&self) -> &(i32, u8) { + &(5, 2) + } +} +impl DerefMut for Foo { + fn deref_mut(&mut self) -> &mut (i32, u8) { + &mut (5, 2) + } +} +fn test() { + let mut x = Foo; + let c1 = || *x; + //^^ impl Fn() -> (i32, u8) + let c2 = || { *x = (2, 5); }; + //^^ impl FnMut() + let c3 = || { x.1 }; + //^^ impl Fn() -> u8 + let c4 = || { x.1 = 6; }; + //^^ impl FnMut() +} + "#, + ); +} + +#[test] +fn capture_kinds_with_copy_types() { + check_types( + r#" +//- minicore: copy, clone, derive +#[derive(Clone, Copy)] +struct Copy; +struct NotCopy; +#[derive(Clone, Copy)] +struct Generic(T); + +trait Tr { + type Assoc; +} + +impl Tr for Copy { + type Assoc = NotCopy; +} + +#[derive(Clone, Copy)] +struct AssocGeneric(T::Assoc); + +fn f() { + let a = Copy; + let b = NotCopy; + let c = Generic(Copy); + let d = Generic(NotCopy); + let e: AssocGeneric = AssocGeneric(NotCopy); + let c1 = || a; + //^^ impl Fn() -> Copy + let c2 = || b; + //^^ impl FnOnce() -> NotCopy + let c3 = || c; + //^^ impl Fn() -> Generic + let c3 = || d; + //^^ impl FnOnce() -> Generic + let c3 = || e; + //^^ impl FnOnce() -> AssocGeneric +} + "#, + ) +} + +#[test] +fn derive_macro_should_work_for_associated_type() { + check_types( + r#" +//- minicore: copy, clone, derive +#[derive(Clone)] +struct X; +#[derive(Clone)] +struct Y; + +trait Tr { + type Assoc; +} + +impl Tr for X { + type Assoc = Y; +} + +#[derive(Clone)] +struct AssocGeneric(T::Assoc); + +fn f() { + let e: AssocGeneric = AssocGeneric(Y); + let e_clone = e.clone(); + //^^^^^^^ AssocGeneric +} + "#, + ) +} + #[test] fn cfgd_out_assoc_items() { check_types( @@ -3291,6 +3464,19 @@ fn f(t: Ark) { ); } +#[test] +fn const_dependent_on_local() { + check_types( + r#" +fn main() { + let s = 5; + let t = [2; s]; + //^ [i32; _] +} +"#, + ); +} + #[test] fn issue_14275() { // FIXME: evaluate const generic diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 3564ed4133463..fdfae44c2d13e 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -90,7 +90,7 @@ fn infer_async_closure() { async fn test() { let f = async move |x: i32| x + 42; f; -// ^ |i32| -> impl Future +// ^ impl Fn(i32) -> impl Future let a = f(4); a; // ^ impl Future @@ -99,7 +99,7 @@ async fn test() { // ^ i32 let f = async move || 42; f; -// ^ || -> impl Future +// ^ impl Fn() -> impl Future let a = f(); a; // ^ impl Future @@ -116,7 +116,7 @@ async fn test() { }; let _: Option = c().await; c; -// ^ || -> impl Future> +// ^ impl Fn() -> impl Future> } "#, ); @@ -550,7 +550,7 @@ fn test() -> u64 { 53..54 'a': S 57..58 'S': S(fn(u32) -> u64) -> S 57..74 'S(|i| ...s u64)': S - 59..73 '|i| 2*i as u64': |u32| -> u64 + 59..73 '|i| 2*i as u64': impl Fn(u32) -> u64 60..61 'i': u32 63..64 '2': u64 63..73 '2*i as u64': u64 @@ -1333,9 +1333,9 @@ fn foo() -> (impl FnOnce(&str, T), impl Trait) { } "#, expect![[r#" - 134..165 '{ ...(C)) }': (|&str, T| -> (), Bar) - 140..163 '(|inpu...ar(C))': (|&str, T| -> (), Bar) - 141..154 '|input, t| {}': |&str, T| -> () + 134..165 '{ ...(C)) }': (impl Fn(&str, T), Bar) + 140..163 '(|inpu...ar(C))': (impl Fn(&str, T), Bar) + 141..154 '|input, t| {}': impl Fn(&str, T) 142..147 'input': &str 149..150 't': T 152..154 '{}': () @@ -1506,8 +1506,8 @@ fn main() { 71..105 '{ ...()); }': () 77..78 'f': fn f(&dyn Fn(S)) 77..102 'f(&|nu...foo())': () - 79..101 '&|numb....foo()': &|S| -> () - 80..101 '|numbe....foo()': |S| -> () + 79..101 '&|numb....foo()': &impl Fn(S) + 80..101 '|numbe....foo()': impl Fn(S) 81..87 'number': S 89..95 'number': S 89..101 'number.foo()': () @@ -1912,13 +1912,13 @@ fn test() { 131..132 'f': F 151..153 '{}': Lazy 251..497 '{ ...o(); }': () - 261..266 'lazy1': Lazy Foo> - 283..292 'Lazy::new': fn new Foo>(|| -> Foo) -> Lazy Foo> - 283..300 'Lazy::...| Foo)': Lazy Foo> - 293..299 '|| Foo': || -> Foo + 261..266 'lazy1': Lazy Foo> + 283..292 'Lazy::new': fn new Foo>(impl Fn() -> Foo) -> Lazy Foo> + 283..300 'Lazy::...| Foo)': Lazy Foo> + 293..299 '|| Foo': impl Fn() -> Foo 296..299 'Foo': Foo 310..312 'r1': usize - 315..320 'lazy1': Lazy Foo> + 315..320 'lazy1': Lazy Foo> 315..326 'lazy1.foo()': usize 368..383 'make_foo_fn_ptr': fn() -> Foo 399..410 'make_foo_fn': fn make_foo_fn() -> Foo @@ -1963,20 +1963,20 @@ fn test() { 163..167 '1u32': u32 174..175 'x': Option 174..190 'x.map(...v + 1)': Option - 180..189 '|v| v + 1': |u32| -> u32 + 180..189 '|v| v + 1': impl Fn(u32) -> u32 181..182 'v': u32 184..185 'v': u32 184..189 'v + 1': u32 188..189 '1': u32 196..197 'x': Option 196..212 'x.map(... 1u64)': Option - 202..211 '|_v| 1u64': |u32| -> u64 + 202..211 '|_v| 1u64': impl Fn(u32) -> u64 203..205 '_v': u32 207..211 '1u64': u64 222..223 'y': Option 239..240 'x': Option 239..252 'x.map(|_v| 1)': Option - 245..251 '|_v| 1': |u32| -> i64 + 245..251 '|_v| 1': impl Fn(u32) -> i64 246..248 '_v': u32 250..251 '1': i64 "#]], @@ -2005,11 +2005,11 @@ fn test u64>(f: F) { //^^^^ u64 let g = |v| v + 1; //^^^^^ u64 - //^^^^^^^^^ |u64| -> u64 + //^^^^^^^^^ impl Fn(u64) -> u64 g(1u64); //^^^^^^^ u64 let h = |v| 1u128 + v; - //^^^^^^^^^^^^^ |u128| -> u128 + //^^^^^^^^^^^^^ impl Fn(u128) -> u128 }"#, ); } @@ -2062,17 +2062,17 @@ fn test() { 312..314 '{}': () 330..489 '{ ... S); }': () 340..342 'x1': u64 - 345..349 'foo1': fn foo1 u64>(S, |S| -> u64) -> u64 + 345..349 'foo1': fn foo1 u64>(S, impl Fn(S) -> u64) -> u64 345..368 'foo1(S...hod())': u64 350..351 'S': S - 353..367 '|s| s.method()': |S| -> u64 + 353..367 '|s| s.method()': impl Fn(S) -> u64 354..355 's': S 357..358 's': S 357..367 's.method()': u64 378..380 'x2': u64 - 383..387 'foo2': fn foo2 u64>(|S| -> u64, S) -> u64 + 383..387 'foo2': fn foo2 u64>(impl Fn(S) -> u64, S) -> u64 383..406 'foo2(|...(), S)': u64 - 388..402 '|s| s.method()': |S| -> u64 + 388..402 '|s| s.method()': impl Fn(S) -> u64 389..390 's': S 392..393 's': S 392..402 's.method()': u64 @@ -2081,14 +2081,14 @@ fn test() { 421..422 'S': S 421..446 'S.foo1...hod())': u64 428..429 'S': S - 431..445 '|s| s.method()': |S| -> u64 + 431..445 '|s| s.method()': impl Fn(S) -> u64 432..433 's': S 435..436 's': S 435..445 's.method()': u64 456..458 'x4': u64 461..462 'S': S 461..486 'S.foo2...(), S)': u64 - 468..482 '|s| s.method()': |S| -> u64 + 468..482 '|s| s.method()': impl Fn(S) -> u64 469..470 's': S 472..473 's': S 472..482 's.method()': u64 @@ -2562,9 +2562,9 @@ fn main() { 72..74 '_v': F 117..120 '{ }': () 132..163 '{ ... }); }': () - 138..148 'f::<(), _>': fn f<(), |&()| -> ()>(|&()| -> ()) + 138..148 'f::<(), _>': fn f<(), impl Fn(&())>(impl Fn(&())) 138..160 'f::<()... z; })': () - 149..159 '|z| { z; }': |&()| -> () + 149..159 '|z| { z; }': impl Fn(&()) 150..151 'z': &() 153..159 '{ z; }': () 155..156 'z': &() @@ -2721,9 +2721,9 @@ fn main() { 983..998 'Vec::::new': fn new() -> Vec 983..1000 'Vec::<...:new()': Vec 983..1012 'Vec::<...iter()': IntoIter - 983..1075 'Vec::<...one })': FilterMap, |i32| -> Option> + 983..1075 'Vec::<...one })': FilterMap, impl Fn(i32) -> Option> 983..1101 'Vec::<... y; })': () - 1029..1074 '|x| if...None }': |i32| -> Option + 1029..1074 '|x| if...None }': impl Fn(i32) -> Option 1030..1031 'x': i32 1033..1074 'if x >...None }': Option 1036..1037 'x': i32 @@ -2736,7 +2736,7 @@ fn main() { 1049..1057 'x as u32': u32 1066..1074 '{ None }': Option 1068..1072 'None': Option - 1090..1100 '|y| { y; }': |u32| -> () + 1090..1100 '|y| { y; }': impl Fn(u32) 1091..1092 'y': u32 1094..1100 '{ y; }': () 1096..1097 'y': u32 @@ -2979,13 +2979,13 @@ fn foo() { 52..126 '{ ...)(s) }': () 62..63 's': Option 66..78 'Option::None': Option - 88..89 'f': |Option| -> () - 92..111 '|x: Op...2>| {}': |Option| -> () + 88..89 'f': impl Fn(Option) + 92..111 '|x: Op...2>| {}': impl Fn(Option) 93..94 'x': Option 109..111 '{}': () 117..124 '(&f)(s)': () - 118..120 '&f': &|Option| -> () - 119..120 'f': |Option| -> () + 118..120 '&f': &impl Fn(Option) + 119..120 'f': impl Fn(Option) 122..123 's': Option "#]], ); @@ -3072,15 +3072,15 @@ fn foo() { 228..229 's': Option 232..236 'None': Option 246..247 'f': Box)> - 281..294 'box (|ps| {})': Box<|&Option| -> ()> - 286..293 '|ps| {}': |&Option| -> () + 281..294 'box (|ps| {})': Box)> + 286..293 '|ps| {}': impl Fn(&Option) 287..289 'ps': &Option 291..293 '{}': () 300..301 'f': Box)> 300..305 'f(&s)': () 302..304 '&s': &Option 303..304 's': Option - 281..294: expected Box)>, got Box<|&Option| -> ()> + 281..294: expected Box)>, got Box)> "#]], ); } diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index deb6ce56773d3..8bc38aca4722b 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -4,7 +4,7 @@ use std::{env::var, sync::Arc}; use chalk_ir::GoalData; use chalk_recursive::Cache; -use chalk_solve::{logging_db::LoggingRustIrDatabase, Solver}; +use chalk_solve::{logging_db::LoggingRustIrDatabase, rust_ir, Solver}; use base_db::CrateId; use hir_def::{ @@ -177,8 +177,10 @@ fn is_chalk_print() -> bool { std::env::var("CHALK_PRINT").is_ok() } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum FnTrait { + // Warning: Order is important. If something implements `x` it should also implement + // `y` if `y <= x`. FnOnce, FnMut, Fn, @@ -193,6 +195,14 @@ impl FnTrait { } } + pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind { + match self { + FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, + FnTrait::FnMut => rust_ir::ClosureKind::FnMut, + FnTrait::Fn => rust_ir::ClosureKind::Fn, + } + } + pub fn method_name(self) -> Name { match self { FnTrait::FnOnce => name!(call_once), diff --git a/crates/hir-ty/src/utils.rs b/crates/hir-ty/src/utils.rs index 9d97ab84a56ab..3b2a726688d01 100644 --- a/crates/hir-ty/src/utils.rs +++ b/crates/hir-ty/src/utils.rs @@ -4,7 +4,7 @@ use std::iter; use base_db::CrateId; -use chalk_ir::{cast::Cast, fold::Shift, BoundVar, DebruijnIndex}; +use chalk_ir::{cast::Cast, fold::Shift, BoundVar, DebruijnIndex, Mutability}; use either::Either; use hir_def::{ db::DefDatabase, @@ -12,6 +12,7 @@ use hir_def::{ GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget, }, + hir::BindingAnnotation, lang_item::LangItem, resolver::{HasResolver, TypeNs}, type_ref::{TraitBoundModifier, TypeRef}, @@ -24,7 +25,8 @@ use rustc_hash::FxHashSet; use smallvec::{smallvec, SmallVec}; use crate::{ - db::HirDatabase, ChalkTraitId, Interner, Substitution, TraitRef, TraitRefExt, WhereClause, + db::HirDatabase, ChalkTraitId, Interner, Substitution, TraitRef, TraitRefExt, Ty, TyExt, + WhereClause, }; pub(crate) fn fn_traits( @@ -352,3 +354,20 @@ pub fn is_fn_unsafe_to_call(db: &dyn HirDatabase, func: FunctionId) -> bool { _ => false, } } + +pub(crate) fn pattern_matching_dereference_count( + cond_ty: &mut Ty, + binding_mode: &mut BindingAnnotation, +) -> usize { + let mut r = 0; + while let Some((ty, _, mu)) = cond_ty.as_reference() { + if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref { + *binding_mode = BindingAnnotation::RefMut; + } else { + *binding_mode = BindingAnnotation::Ref; + } + *cond_ty = ty.clone(); + r += 1; + } + r +} diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index fad9f19d25359..db923cb0fe233 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -129,7 +129,7 @@ pub use { ExpandResult, HirFileId, InFile, MacroFile, Origin, }, hir_ty::{ - display::{HirDisplay, HirDisplayError, HirWrite}, + display::{ClosureStyle, HirDisplay, HirDisplayError, HirWrite}, mir::MirEvalError, PointerCast, Safety, }, @@ -1530,35 +1530,44 @@ impl DefWithBody { let hir_body = db.body(self.into()); - if let Ok(borrowck_result) = db.borrowck(self.into()) { - let mir_body = &borrowck_result.mir_body; - let mol = &borrowck_result.mutability_of_locals; - for (binding_id, _) in hir_body.bindings.iter() { - let need_mut = &mol[mir_body.binding_locals[binding_id]]; - let local = Local { parent: self.into(), binding_id }; - match (need_mut, local.is_mut(db)) { - (mir::MutabilityReason::Mut { .. }, true) - | (mir::MutabilityReason::Not, false) => (), - (mir::MutabilityReason::Mut { spans }, false) => { - for span in spans { - let span: InFile = match span { - mir::MirSpan::ExprId(e) => match source_map.expr_syntax(*e) { - Ok(s) => s.map(|x| x.into()), - Err(_) => continue, - }, - mir::MirSpan::PatId(p) => match source_map.pat_syntax(*p) { - Ok(s) => s.map(|x| match x { - Either::Left(e) => e.into(), - Either::Right(e) => e.into(), - }), - Err(_) => continue, - }, - mir::MirSpan::Unknown => continue, - }; - acc.push(NeedMut { local, span }.into()); + if let Ok(borrowck_results) = db.borrowck(self.into()) { + for borrowck_result in borrowck_results.iter() { + let mir_body = &borrowck_result.mir_body; + let mol = &borrowck_result.mutability_of_locals; + for (binding_id, _) in hir_body.bindings.iter() { + let Some(&local) = mir_body.binding_locals.get(binding_id) else { + continue; + }; + let need_mut = &mol[local]; + let local = Local { parent: self.into(), binding_id }; + match (need_mut, local.is_mut(db)) { + (mir::MutabilityReason::Mut { .. }, true) + | (mir::MutabilityReason::Not, false) => (), + (mir::MutabilityReason::Mut { spans }, false) => { + for span in spans { + let span: InFile = match span { + mir::MirSpan::ExprId(e) => match source_map.expr_syntax(*e) { + Ok(s) => s.map(|x| x.into()), + Err(_) => continue, + }, + mir::MirSpan::PatId(p) => match source_map.pat_syntax(*p) { + Ok(s) => s.map(|x| match x { + Either::Left(e) => e.into(), + Either::Right(e) => e.into(), + }), + Err(_) => continue, + }, + mir::MirSpan::Unknown => continue, + }; + acc.push(NeedMut { local, span }.into()); + } + } + (mir::MutabilityReason::Not, true) => { + if !infer.mutated_bindings_in_closure.contains(&binding_id) { + acc.push(UnusedMut { local }.into()) + } } } - (mir::MutabilityReason::Not, true) => acc.push(UnusedMut { local }.into()), } } } @@ -3383,7 +3392,12 @@ impl Type { } pub fn as_callable(&self, db: &dyn HirDatabase) -> Option { + let mut the_ty = &self.ty; let callee = match self.ty.kind(Interner) { + TyKind::Ref(_, _, ty) if ty.as_closure().is_some() => { + the_ty = ty; + Callee::Closure(ty.as_closure().unwrap()) + } TyKind::Closure(id, _) => Callee::Closure(*id), TyKind::Function(_) => Callee::FnPtr, TyKind::FnDef(..) => Callee::Def(self.ty.callable_def(db)?), @@ -3398,7 +3412,7 @@ impl Type { } }; - let sig = self.ty.callable_sig(db)?; + let sig = the_ty.callable_sig(db)?; Some(Callable { ty: self.clone(), sig, callee, is_bound_method: false }) } diff --git a/crates/ide-assists/src/handlers/generate_function.rs b/crates/ide-assists/src/handlers/generate_function.rs index 0768389281ca3..2372fe28e1974 100644 --- a/crates/ide-assists/src/handlers/generate_function.rs +++ b/crates/ide-assists/src/handlers/generate_function.rs @@ -1910,7 +1910,6 @@ fn bar(new: fn) ${0:-> _} { #[test] fn add_function_with_closure_arg() { - // FIXME: The argument in `bar` is wrong. check_assist( generate_function, r" @@ -1925,7 +1924,7 @@ fn foo() { bar(closure) } -fn bar(closure: _) { +fn bar(closure: impl Fn(i64) -> i64) { ${0:todo!()} } ", diff --git a/crates/ide-diagnostics/src/handlers/mutability_errors.rs b/crates/ide-diagnostics/src/handlers/mutability_errors.rs index 564b756402de6..8c4ca23e06e8d 100644 --- a/crates/ide-diagnostics/src/handlers/mutability_errors.rs +++ b/crates/ide-diagnostics/src/handlers/mutability_errors.rs @@ -771,6 +771,88 @@ fn fn_once(mut x: impl FnOnce(u8) -> u8) -> u8 { ); } + #[test] + fn closure() { + // FIXME: Diagnositc spans are too large + check_diagnostics( + r#" + //- minicore: copy, fn + struct X; + + impl X { + fn mutate(&mut self) {} + } + + fn f() { + let x = 5; + let closure1 = || { x = 2; }; + //^^^^^^^^^^^^^ 💡 error: cannot mutate immutable variable `x` + let _ = closure1(); + //^^^^^^^^ 💡 error: cannot mutate immutable variable `closure1` + let closure2 = || { x = x; }; + //^^^^^^^^^^^^^ 💡 error: cannot mutate immutable variable `x` + let closure3 = || { + let x = 2; + x = 5; + //^^^^^ 💡 error: cannot mutate immutable variable `x` + x + }; + let x = X; + let closure4 = || { x.mutate(); }; + //^^^^^^^^^^^^^^^^^^ 💡 error: cannot mutate immutable variable `x` + } + "#, + ); + check_diagnostics( + r#" + //- minicore: copy, fn + fn f() { + let mut x = 5; + //^^^^^ 💡 weak: variable does not need to be mutable + let mut y = 2; + y = 7; + let closure = || { + let mut z = 8; + z = 3; + let mut k = z; + //^^^^^ 💡 weak: variable does not need to be mutable + }; + } + "#, + ); + check_diagnostics( + r#" +//- minicore: copy, fn +fn f() { + let closure = || { + || { + || { + let x = 2; + || { || { x = 5; } } + //^^^^^^^^^^^^^^^^^^^^ 💡 error: cannot mutate immutable variable `x` + } + } + }; +} + "#, + ); + check_diagnostics( + r#" +//- minicore: copy, fn +fn f() { + struct X; + let mut x = X; + //^^^^^ 💡 weak: variable does not need to be mutable + let c1 = || x; + let mut x = X; + let c2 = || { x = X; x }; + let mut x = X; + let c2 = move || { x = X; }; +} + "#, + ); + } + #[test] fn respect_allow_unused_mut() { // FIXME: respect diff --git a/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs b/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs index 9b1c65983e615..738339cfa6b5c 100644 --- a/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs +++ b/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs @@ -115,7 +115,7 @@ fn foo() { r#" //- minicore: iterators fn foo() { - let m = core::iter::repeat(()) + let mut m = core::iter::repeat(()) .filter_map(|()| Some(92)); let n = m.next(); } diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 4abc25a28fbc0..c5fa1cb027e14 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -1,5 +1,5 @@ use either::Either; -use hir::{db::ExpandDatabase, HirDisplay, InFile, Type}; +use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, InFile, Type}; use ide_db::{famous_defs::FamousDefs, source_change::SourceChange}; use syntax::{ ast::{self, BlockExpr, ExprStmt}, @@ -32,8 +32,8 @@ pub(crate) fn type_mismatch(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) "type-mismatch", format!( "expected {}, found {}", - d.expected.display(ctx.sema.db), - d.actual.display(ctx.sema.db) + d.expected.display(ctx.sema.db).with_closure_style(ClosureStyle::ClosureWithId), + d.actual.display(ctx.sema.db).with_closure_style(ClosureStyle::ClosureWithId), ), display_range, ) @@ -596,6 +596,19 @@ fn test() -> String { ); } + #[test] + fn closure_mismatch_show_different_type() { + check_diagnostics( + r#" +fn f() { + let mut x = (|| 1, 2); + x = (|| 3, 4); + //^^^^ error: expected {closure#0}, found {closure#1} +} + "#, + ); + } + #[test] fn type_mismatch_on_block() { cov_mark::check!(type_mismatch_on_block); diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs index 70ec915e96745..8a58fbeb860bb 100644 --- a/crates/ide/src/hover/tests.rs +++ b/crates/ide/src/hover/tests.rs @@ -225,7 +225,7 @@ fn main() { *iter* ```rust - let mut iter: Iter>, |&mut u32, &u32, &mut u32| -> Option, u32>> + let mut iter: Iter>, impl Fn(&mut u32, &u32, &mut u32) -> Option, u32>> ``` "#]], ); diff --git a/crates/ide/src/inlay_hints.rs b/crates/ide/src/inlay_hints.rs index ac477339ec233..e6360bc6ec1c9 100644 --- a/crates/ide/src/inlay_hints.rs +++ b/crates/ide/src/inlay_hints.rs @@ -5,7 +5,8 @@ use std::{ use either::Either; use hir::{ - known, HasVisibility, HirDisplay, HirDisplayError, HirWrite, ModuleDef, ModuleDefId, Semantics, + known, ClosureStyle, HasVisibility, HirDisplay, HirDisplayError, HirWrite, ModuleDef, + ModuleDefId, Semantics, }; use ide_db::{base_db::FileRange, famous_defs::FamousDefs, RootDatabase}; use itertools::Itertools; @@ -45,6 +46,7 @@ pub struct InlayHintsConfig { pub param_names_for_lifetime_elision_hints: bool, pub hide_named_constructor_hints: bool, pub hide_closure_initialization_hints: bool, + pub closure_style: ClosureStyle, pub max_length: Option, pub closing_brace_hints_min_lines: Option, } @@ -291,6 +293,7 @@ fn label_of_ty( mut max_length: Option, ty: hir::Type, label_builder: &mut InlayHintLabelBuilder<'_>, + config: &InlayHintsConfig, ) -> Result<(), HirDisplayError> { let iter_item_type = hint_iterator(sema, famous_defs, &ty); match iter_item_type { @@ -321,11 +324,14 @@ fn label_of_ty( label_builder.write_str(LABEL_ITEM)?; label_builder.end_location_link(); label_builder.write_str(LABEL_MIDDLE2)?; - rec(sema, famous_defs, max_length, ty, label_builder)?; + rec(sema, famous_defs, max_length, ty, label_builder, config)?; label_builder.write_str(LABEL_END)?; Ok(()) } - None => ty.display_truncated(sema.db, max_length).write_to(label_builder), + None => ty + .display_truncated(sema.db, max_length) + .with_closure_style(config.closure_style) + .write_to(label_builder), } } @@ -335,7 +341,7 @@ fn label_of_ty( location: None, result: InlayHintLabel::default(), }; - let _ = rec(sema, famous_defs, config.max_length, ty, &mut label_builder); + let _ = rec(sema, famous_defs, config.max_length, ty, &mut label_builder, config); let r = label_builder.finish(); Some(r) } @@ -481,6 +487,7 @@ fn closure_has_block_body(closure: &ast::ClosureExpr) -> bool { #[cfg(test)] mod tests { use expect_test::Expect; + use hir::ClosureStyle; use itertools::Itertools; use test_utils::extract_annotations; @@ -504,6 +511,7 @@ mod tests { binding_mode_hints: false, hide_named_constructor_hints: false, hide_closure_initialization_hints: false, + closure_style: ClosureStyle::ImplFn, param_names_for_lifetime_elision_hints: false, max_length: None, closing_brace_hints_min_lines: None, diff --git a/crates/ide/src/inlay_hints/bind_pat.rs b/crates/ide/src/inlay_hints/bind_pat.rs index 6a50927333d4f..5f571d0448211 100644 --- a/crates/ide/src/inlay_hints/bind_pat.rs +++ b/crates/ide/src/inlay_hints/bind_pat.rs @@ -176,6 +176,7 @@ fn pat_is_enum_variant(db: &RootDatabase, bind_pat: &ast::IdentPat, pat_ty: &hir mod tests { // This module also contains tests for super::closure_ret + use hir::ClosureStyle; use syntax::{TextRange, TextSize}; use test_utils::extract_annotations; @@ -235,7 +236,7 @@ fn main() { let zz_ref = &zz; //^^^^^^ &Test let test = || zz; - //^^^^ || -> Test + //^^^^ impl FnOnce() -> Test }"#, ); } @@ -753,7 +754,7 @@ fn main() { let func = times2; // ^^^^ fn times2(i32) -> i32 let closure = |x: i32| x * 2; - // ^^^^^^^ |i32| -> i32 + // ^^^^^^^ impl Fn(i32) -> i32 } fn fallible() -> ControlFlow<()> { @@ -821,7 +822,7 @@ fn main() { //^^^^^^^^^ i32 let multiply = - //^^^^^^^^ |i32, i32| -> i32 + //^^^^^^^^ impl Fn(i32, i32) -> i32 | a, b| a * b //^ i32 ^ i32 @@ -830,10 +831,10 @@ fn main() { let _: i32 = multiply(1, 2); //^ a ^ b let multiply_ref = &multiply; - //^^^^^^^^^^^^ &|i32, i32| -> i32 + //^^^^^^^^^^^^ &impl Fn(i32, i32) -> i32 let return_42 = || 42; - //^^^^^^^^^ || -> i32 + //^^^^^^^^^ impl Fn() -> i32 || { 42 }; //^^ i32 }"#, @@ -857,6 +858,94 @@ fn main() { ); } + #[test] + fn closure_style() { + check_with_config( + InlayHintsConfig { type_hints: true, ..DISABLED_CONFIG }, + r#" +//- minicore: fn +fn main() { + let x = || 2; + //^ impl Fn() -> i32 + let y = |t: i32| x() + t; + //^ impl Fn(i32) -> i32 + let mut t = 5; + //^ i32 + let z = |k: i32| { t += k; }; + //^ impl FnMut(i32) + let p = (y, z); + //^ (impl Fn(i32) -> i32, impl FnMut(i32)) +} + "#, + ); + check_with_config( + InlayHintsConfig { + type_hints: true, + closure_style: ClosureStyle::RANotation, + ..DISABLED_CONFIG + }, + r#" +//- minicore: fn +fn main() { + let x = || 2; + //^ || -> i32 + let y = |t: i32| x() + t; + //^ |i32| -> i32 + let mut t = 5; + //^ i32 + let z = |k: i32| { t += k; }; + //^ |i32| -> () + let p = (y, z); + //^ (|i32| -> i32, |i32| -> ()) +} + "#, + ); + check_with_config( + InlayHintsConfig { + type_hints: true, + closure_style: ClosureStyle::ClosureWithId, + ..DISABLED_CONFIG + }, + r#" +//- minicore: fn +fn main() { + let x = || 2; + //^ {closure#0} + let y = |t: i32| x() + t; + //^ {closure#1} + let mut t = 5; + //^ i32 + let z = |k: i32| { t += k; }; + //^ {closure#2} + let p = (y, z); + //^ ({closure#1}, {closure#2}) +} + "#, + ); + check_with_config( + InlayHintsConfig { + type_hints: true, + closure_style: ClosureStyle::Hide, + ..DISABLED_CONFIG + }, + r#" +//- minicore: fn +fn main() { + let x = || 2; + //^ … + let y = |t: i32| x() + t; + //^ … + let mut t = 5; + //^ i32 + let z = |k: i32| { t += k; }; + //^ … + let p = (y, z); + //^ (…, …) +} + "#, + ); + } + #[test] fn skip_closure_type_hints() { check_with_config( @@ -871,13 +960,13 @@ fn main() { let multiple_2 = |x: i32| { x * 2 }; let multiple_2 = |x: i32| x * 2; - // ^^^^^^^^^^ |i32| -> i32 + // ^^^^^^^^^^ impl Fn(i32) -> i32 let (not) = (|x: bool| { !x }); - // ^^^ |bool| -> bool + // ^^^ impl Fn(bool) -> bool let (is_zero, _b) = (|x: usize| { x == 0 }, false); - // ^^^^^^^ |usize| -> bool + // ^^^^^^^ impl Fn(usize) -> bool // ^^ bool let plus_one = |x| { x + 1 }; diff --git a/crates/ide/src/signature_help.rs b/crates/ide/src/signature_help.rs index 4b2c139f6f455..cdee705cbfdaa 100644 --- a/crates/ide/src/signature_help.rs +++ b/crates/ide/src/signature_help.rs @@ -1227,6 +1227,24 @@ fn main() { ) } + #[test] + fn call_info_for_fn_def_over_reference() { + check( + r#" +struct S; +fn foo(s: S) -> i32 { 92 } +fn main() { + let bar = &&&&&foo; + bar($0); +} + "#, + expect![[r#" + fn foo(s: S) -> i32 + ^^^^ + "#]], + ) + } + #[test] fn call_info_for_fn_ptr() { check( diff --git a/crates/ide/src/static_index.rs b/crates/ide/src/static_index.rs index c97691b14a57f..774b07775b9dc 100644 --- a/crates/ide/src/static_index.rs +++ b/crates/ide/src/static_index.rs @@ -118,6 +118,7 @@ impl StaticIndex<'_> { adjustment_hints_hide_outside_unsafe: false, hide_named_constructor_hints: false, hide_closure_initialization_hints: false, + closure_style: hir::ClosureStyle::ImplFn, param_names_for_lifetime_elision_hints: false, binding_mode_hints: false, max_length: Some(25), diff --git a/crates/rust-analyzer/src/config.rs b/crates/rust-analyzer/src/config.rs index 06fb95515ae93..27445d1f712d8 100644 --- a/crates/rust-analyzer/src/config.rs +++ b/crates/rust-analyzer/src/config.rs @@ -338,6 +338,8 @@ config_data! { inlayHints_closingBraceHints_minLines: usize = "25", /// Whether to show inlay type hints for return types of closures. inlayHints_closureReturnTypeHints_enable: ClosureReturnTypeHintsDef = "\"never\"", + /// Closure notation in type and chaining inaly hints. + inlayHints_closureStyle: ClosureStyle = "\"impl_fn\"", /// Whether to show enum variant discriminant hints. inlayHints_discriminantHints_enable: DiscriminantHintsDef = "\"never\"", /// Whether to show inlay hints for type adjustments. @@ -1301,6 +1303,12 @@ impl Config { hide_closure_initialization_hints: self .data .inlayHints_typeHints_hideClosureInitialization, + closure_style: match self.data.inlayHints_closureStyle { + ClosureStyle::ImplFn => hir::ClosureStyle::ImplFn, + ClosureStyle::RustAnalyzer => hir::ClosureStyle::RANotation, + ClosureStyle::WithId => hir::ClosureStyle::ClosureWithId, + ClosureStyle::Hide => hir::ClosureStyle::Hide, + }, adjustment_hints: match self.data.inlayHints_expressionAdjustmentHints_enable { AdjustmentHintsDef::Always => ide::AdjustmentHints::Always, AdjustmentHintsDef::Never => match self.data.inlayHints_reborrowHints_enable { @@ -1807,6 +1815,15 @@ enum ClosureReturnTypeHintsDef { WithBlock, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +enum ClosureStyle { + ImplFn, + RustAnalyzer, + WithId, + Hide, +} + #[derive(Deserialize, Debug, Clone)] #[serde(untagged)] enum ReborrowHintsDef { @@ -2288,6 +2305,16 @@ fn field_props(field: &str, ty: &str, doc: &[&str], default: &str) -> serde_json }, ], }, + "ClosureStyle" => set! { + "type": "string", + "enum": ["impl_fn", "rust_analyzer", "with_id", "hide"], + "enumDescriptions": [ + "`impl_fn`: `impl FnMut(i32, u64) -> i8`", + "`rust_analyzer`: `|i32, u64| -> i8`", + "`with_id`: `{closure#14352}`, where that id is the unique number of the closure in r-a internals", + "`hide`: Shows `...` for every closure type", + ], + }, _ => panic!("missing entry for {ty}: {default}"), } diff --git a/docs/user/generated_config.adoc b/docs/user/generated_config.adoc index 12dfe394f1d5b..e92e6ae92cce3 100644 --- a/docs/user/generated_config.adoc +++ b/docs/user/generated_config.adoc @@ -474,6 +474,11 @@ to always show them). -- Whether to show inlay type hints for return types of closures. -- +[[rust-analyzer.inlayHints.closureStyle]]rust-analyzer.inlayHints.closureStyle (default: `"impl_fn"`):: ++ +-- +Closure notation in type and chaining inaly hints. +-- [[rust-analyzer.inlayHints.discriminantHints.enable]]rust-analyzer.inlayHints.discriminantHints.enable (default: `"never"`):: + -- diff --git a/editors/code/package.json b/editors/code/package.json index 81fa97269a9f0..087fd1296b310 100644 --- a/editors/code/package.json +++ b/editors/code/package.json @@ -1028,6 +1028,23 @@ "Only show type hints for return types of closures with blocks." ] }, + "rust-analyzer.inlayHints.closureStyle": { + "markdownDescription": "Closure notation in type and chaining inaly hints.", + "default": "impl_fn", + "type": "string", + "enum": [ + "impl_fn", + "rust_analyzer", + "with_id", + "hide" + ], + "enumDescriptions": [ + "`impl_fn`: `impl FnMut(i32, u64) -> i8`", + "`rust_analyzer`: `|i32, u64| -> i8`", + "`with_id`: `{closure#14352}`, where that id is the unique number of the closure in r-a internals", + "`hide`: Shows `...` for every closure type" + ] + }, "rust-analyzer.inlayHints.discriminantHints.enable": { "markdownDescription": "Whether to show enum variant discriminant hints.", "default": "never",