Skip to content

Commit

Permalink
Fix incorrect instantiation in constraint solving (#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Nak authored Apr 25, 2024
1 parent c80b36e commit 2fca91c
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 22 deletions.
15 changes: 8 additions & 7 deletions crates/analyzer/src/traversal/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
use std::{ops::RangeInclusive, str::FromStr};

use fe_common::{diagnostics::Label, numeric, Span};
use fe_parser::{ast as fe, ast::GenericArg, node::Node};
use num_bigint::BigInt;
use num_traits::{ToPrimitive, Zero};
use smol_str::SmolStr;

use super::borrowck;
use crate::{
builtins::{ContractTypeMethod, GlobalFunction, Intrinsic, ValueMethod},
Expand All @@ -23,13 +31,6 @@ use crate::{
},
};

use fe_common::{diagnostics::Label, numeric, Span};
use fe_parser::{ast as fe, ast::GenericArg, node::Node};
use num_bigint::BigInt;
use num_traits::{ToPrimitive, Zero};
use smol_str::SmolStr;
use std::{ops::RangeInclusive, str::FromStr};

// TODO: don't fail fatally if expected type is provided

pub fn expr_type(
Expand Down
5 changes: 3 additions & 2 deletions crates/hir-analysis/src/ty/constraint_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,14 @@ impl<'db> ConstraintSolver<'db> {
let res = if table.unify(gen_impl.ty(self.db), goal_ty).is_ok()
&& table.unify(gen_impl.trait_(self.db), goal_trait).is_ok()
{
let constraints = impl_.instantiate_identity().constraints(self.db);
let constraints = gen_impl.fold_with(&mut table).constraints(self.db);
Some(constraints.fold_with(&mut table))
} else {
None
};

table.rollback_to(snapshot);

res
})
else {
Expand All @@ -178,7 +179,7 @@ impl<'db> ConstraintSolver<'db> {
match is_goal_satisfiable(self.db, sub_goal, self.assumptions) {
GoalSatisfiability::Satisfied => {}
GoalSatisfiability::NotSatisfied(_) => {
return GoalSatisfiability::NotSatisfied(self.goal)
return GoalSatisfiability::NotSatisfied(self.goal);
}
GoalSatisfiability::InfiniteRecursion(_) => {
return GoalSatisfiability::InfiniteRecursion(self.goal)
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-analysis/src/ty/def_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ impl<'db> Visitor for DefAnalyzer<'db> {
let span = ctxt.span().unwrap();
if let Some(diag) = ty.emit_diag(self.db, span.clone().into()) {
self.diags.push(diag)
} else if let Some(diag) = ty.emit_sat_diag(self.db, self.assumptions, span.into()) {
} else if let Some(diag) = ty.emit_wf_diag(self.db, self.assumptions, span.into()) {
self.diags.push(diag)
}
}
Expand Down
37 changes: 35 additions & 2 deletions crates/hir-analysis/src/ty/ty_check/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,42 @@ use super::TyChecker;
use crate::{
ty::{
diagnostics::{BodyDiag, FuncBodyDiag, FuncBodyDiagAccumulator},
fold::{TyFoldable, TyFolder},
func_def::FuncDef,
ty_def::{TyBase, TyData, TyId},
ty_lower::lower_generic_arg_list,
visitor::{TyVisitable, TyVisitor},
},
HirAnalysisDb,
};

pub(super) struct Callable {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Callable {
func_def: FuncDef,
generic_args: Vec<TyId>,
}

impl<'db> TyVisitable<'db> for Callable {
fn visit_with<V>(&self, visitor: &mut V)
where
V: TyVisitor<'db>,
{
self.generic_args.visit_with(visitor)
}
}

impl<'db> TyFoldable<'db> for Callable {
fn super_fold_with<F>(self, folder: &mut F) -> Self
where
F: TyFolder<'db>,
{
Self {
func_def: self.func_def,
generic_args: self.generic_args.fold_with(folder),
}
}
}

impl Callable {
pub(super) fn new(
db: &dyn HirAnalysisDb,
Expand All @@ -45,10 +69,19 @@ impl Callable {
})
}

pub(super) fn ret_ty(&self, db: &dyn HirAnalysisDb) -> TyId {
pub fn ret_ty(&self, db: &dyn HirAnalysisDb) -> TyId {
self.func_def.ret_ty(db).instantiate(db, &self.generic_args)
}

pub fn ty(&self, db: &dyn HirAnalysisDb) -> TyId {
let mut ty = TyId::func(db, self.func_def);
for &arg in self.generic_args.iter() {
ty = TyId::app(db, ty, arg);
}

ty
}

pub(super) fn unify_generic_args(
&mut self,
tc: &mut TyChecker,
Expand Down
16 changes: 15 additions & 1 deletion crates/hir-analysis/src/ty/ty_check/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use hir::{
};
use rustc_hash::FxHashMap;

use super::TypedBody;
use super::{Callable, TypedBody};
use crate::{
ty::{
const_ty::{ConstTyData, ConstTyId, EvaluatedConstTy},
Expand All @@ -27,6 +27,7 @@ pub(super) struct TyCheckEnv<'db> {

pat_ty: FxHashMap<PatId, TyId>,
expr_ty: FxHashMap<ExprId, ExprProp>,
callables: FxHashMap<ExprId, Callable>,

var_env: Vec<BlockEnv>,

Expand All @@ -47,6 +48,7 @@ impl<'db> TyCheckEnv<'db> {
body,
pat_ty: FxHashMap::default(),
expr_ty: FxHashMap::default(),
callables: FxHashMap::default(),
var_env: vec![BlockEnv::new(func.scope(), 0)],
pending_vars: FxHashMap::default(),
loop_stack: Vec::new(),
Expand Down Expand Up @@ -91,6 +93,11 @@ impl<'db> TyCheckEnv<'db> {
binding.def_span(self)
}

pub(super) fn register_callable(&mut self, expr: ExprId, callable: Callable) {
if self.callables.insert(expr, callable).is_some() {
panic!("callable is already registered for the given expr")
}
}
pub(super) fn binding_name(&self, binding: LocalBinding) -> IdentId {
binding.binding_name(self)
}
Expand Down Expand Up @@ -219,10 +226,17 @@ impl<'db> TyCheckEnv<'db> {
.values_mut()
.for_each(|ty| *ty = ty.fold_with(&mut folder));

let callables = self
.callables
.into_iter()
.map(|(expr, callable)| (expr, callable.fold_with(&mut folder)))
.collect();

TypedBody {
body: Some(self.body),
pat_ty: self.pat_ty,
expr_ty: self.expr_ty,
callables,
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/hir-analysis/src/ty/ty_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ impl<'db> TyChecker<'db> {
callable.check_args(self, args, call_span.args_moved(), None);

let ret_ty = callable.ret_ty(self.db);
self.register_callable(expr, callable);
ExprProp::new(ret_ty, true)
}

Expand Down Expand Up @@ -344,6 +345,7 @@ impl<'db> TyChecker<'db> {
Some((*receiver, receiver_ty)),
);
let ret_ty = callable.ret_ty(self.db);
self.register_callable(expr, callable);
ExprProp::new(ret_ty, true)
}

Expand Down
36 changes: 28 additions & 8 deletions crates/hir-analysis/src/ty/ty_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod pat;
mod path;
mod stmt;

pub use callable::Callable;
pub use env::ExprProp;
use env::TyCheckEnv;
pub(super) use expr::TraitOps;
Expand Down Expand Up @@ -183,13 +184,18 @@ impl<'db> TyChecker<'db> {
}
}
}

fn register_callable(&mut self, expr: ExprId, callable: Callable) {
self.env.register_callable(expr, callable)
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypedBody {
body: Option<Body>,
pat_ty: FxHashMap<PatId, TyId>,
expr_ty: FxHashMap<ExprId, ExprProp>,
callables: FxHashMap<ExprId, Callable>,
}

impl TypedBody {
Expand All @@ -211,11 +217,16 @@ impl TypedBody {
.unwrap_or_else(|| TyId::invalid(db, InvalidCause::Other))
}

pub fn callable_expr(&self, expr: ExprId) -> Option<&Callable> {
self.callables.get(&expr)
}

fn empty() -> Self {
Self {
body: None,
pat_ty: FxHashMap::default(),
expr_ty: FxHashMap::default(),
callables: FxHashMap::default(),
}
}
}
Expand Down Expand Up @@ -307,7 +318,20 @@ impl<'db> TyCheckerFinalizer<'db> {
let prop = self.body.expr_prop(self.db, expr);
let span = ctxt.span().unwrap();
self.check_unknown(prop.ty, span.clone().into());
self.check_wf(prop, span.into());
if prop.binding.is_none() {
self.check_wf(prop.ty, span.into());
}
}

// We need this additional check for method call because the callable type is
// not tied to the expression type.
if let Expr::MethodCall(..) = expr_data {
if let Some(callable) = self.body.callable_expr(expr) {
let callable_ty = callable.ty(self.db);
let span = ctxt.span().unwrap().into_method_call_expr().method_name();
self.check_unknown(callable_ty, span.clone().into());
self.check_wf(callable_ty, span.into())
}
}

walk_expr(self, ctxt, expr);
Expand Down Expand Up @@ -346,17 +370,13 @@ impl<'db> TyCheckerFinalizer<'db> {
}
}

fn check_wf(&mut self, prop: ExprProp, span: DynLazySpan) {
if prop.binding.is_some() {
// WF check is already performed.
return;
}
let flags = prop.ty.flags(self.db);
fn check_wf(&mut self, ty: TyId, span: DynLazySpan) {
let flags = ty.flags(self.db);
if flags.contains(TyFlags::HAS_INVALID) || flags.contains(TyFlags::HAS_VAR) {
return;
}

if let Some(diag) = prop.ty.emit_sat_diag(self.db, self.assumptions, span) {
if let Some(diag) = ty.emit_wf_diag(self.db, self.assumptions, span) {
self.diags.push(diag.into());
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-analysis/src/ty/ty_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl TyId {
visitor.diag
}

pub(super) fn emit_sat_diag(
pub(super) fn emit_wf_diag(
self,
db: &dyn HirAnalysisDb,
assumptions: AssumptionListId,
Expand Down
31 changes: 31 additions & 0 deletions crates/hir-analysis/test_files/constraints/specialized.fe
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
trait Trait {
fn f(self) -> i32
}

struct S<T> {
t: T
}

struct S2<T> {
s: S<T>
}

impl Trait for S<i32> {
fn f(self) -> i32 {
self.t
}
}

impl<T> Trait for S2<T>
where S<T>: Trait {
fn f(self) -> i32 {
self.s.f()
}
}

fn bar() {
let t: i32 = 1
let s = S { t }
let s2 = S2 { s }
let _ = s2.f()
}
17 changes: 17 additions & 0 deletions crates/hir-analysis/tests/constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
mod test_db;
use std::path::Path;

use dir_test::{dir_test, Fixture};
use test_db::HirAnalysisTestDb;

#[dir_test(
dir: "$CARGO_MANIFEST_DIR/test_files/constraints",
glob: "*.fe"
)]
fn test_standalone(fixture: Fixture<&str>) {
let mut db = HirAnalysisTestDb::default();
let path = Path::new(fixture.path());
let file_name = path.file_name().and_then(|file| file.to_str()).unwrap();
let (top_mod, _) = db.new_stand_alone(file_name, fixture.content());
db.assert_no_diags(top_mod);
}
30 changes: 30 additions & 0 deletions crates/uitest/fixtures/ty_check/method_bound/specialized.fe
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
trait Trait {
fn f(self) -> i32
}

struct S<T> {
t: T
}

struct S2<T> {
s: S<T>
}

impl Trait for S<i32> {
fn f(self) -> i32 {
self.t
}
}

impl<T> Trait for S2<T>
where S<T>: Trait {
fn f(self) -> i32 {
self.s.f()
}
}

fn bar() {
let s = S { t: false }
let s2 = S2 { s }
let _ = s2.f()
}
12 changes: 12 additions & 0 deletions crates/uitest/fixtures/ty_check/method_bound/specialized.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
source: crates/uitest/tests/ty_check.rs
expression: diags
input_file: crates/uitest/fixtures/ty_check/method_bound/specialized.fe
---
error[6-0003]: trait bound is not satisfied
┌─ specialized.fe:29:16
29let _ = s2.f()
^ `S2<bool>` doesn't implement `Trait`

0 comments on commit 2fca91c

Please sign in to comment.