Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make same_type_modulo_infer a proper TypeRelation #100691

Merged
merged 1 commit into from
Aug 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 79 additions & 53 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_hir::Node;
use rustc_middle::dep_graph::DepContext;
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::relate::{self, RelateResult, TypeRelation};
use rustc_middle::ty::{
self, error::TypeError, Binder, List, Region, Subst, Ty, TyCtxt, TypeFoldable,
TypeSuperVisitable, TypeVisitable,
Expand Down Expand Up @@ -2660,67 +2661,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
pub fn same_type_modulo_infer(&self, a: Ty<'tcx>, b: Ty<'tcx>) -> bool {
let (a, b) = self.resolve_vars_if_possible((a, b));
match (a.kind(), b.kind()) {
(&ty::Adt(def_a, substs_a), &ty::Adt(def_b, substs_b)) => {
if def_a != def_b {
return false;
}
SameTypeModuloInfer(self).relate(a, b).is_ok()
}
}

substs_a
.types()
.zip(substs_b.types())
.all(|(a, b)| self.same_type_modulo_infer(a, b))
}
(&ty::FnDef(did_a, substs_a), &ty::FnDef(did_b, substs_b)) => {
if did_a != did_b {
return false;
}
struct SameTypeModuloInfer<'a, 'tcx>(&'a InferCtxt<'a, 'tcx>);

substs_a
.types()
.zip(substs_b.types())
.all(|(a, b)| self.same_type_modulo_infer(a, b))
}
(&ty::Int(_) | &ty::Uint(_), &ty::Infer(ty::InferTy::IntVar(_)))
impl<'tcx> TypeRelation<'tcx> for SameTypeModuloInfer<'_, 'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.0.tcx
}

fn param_env(&self) -> ty::ParamEnv<'tcx> {
// Unused, only for consts which we treat as always equal
ty::ParamEnv::empty()
}

fn tag(&self) -> &'static str {
"SameTypeModuloInfer"
}

fn a_is_expected(&self) -> bool {
true
}

fn relate_with_variance<T: relate::Relate<'tcx>>(
&mut self,
_variance: ty::Variance,
_info: ty::VarianceDiagInfo<'tcx>,
a: T,
b: T,
) -> relate::RelateResult<'tcx, T> {
self.relate(a, b)
}

fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
match (a.kind(), b.kind()) {
(ty::Int(_) | ty::Uint(_), ty::Infer(ty::InferTy::IntVar(_)))
| (
&ty::Infer(ty::InferTy::IntVar(_)),
&ty::Int(_) | &ty::Uint(_) | &ty::Infer(ty::InferTy::IntVar(_)),
ty::Infer(ty::InferTy::IntVar(_)),
ty::Int(_) | ty::Uint(_) | ty::Infer(ty::InferTy::IntVar(_)),
)
| (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
| (ty::Float(_), ty::Infer(ty::InferTy::FloatVar(_)))
| (
&ty::Infer(ty::InferTy::FloatVar(_)),
&ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
ty::Infer(ty::InferTy::FloatVar(_)),
ty::Float(_) | ty::Infer(ty::InferTy::FloatVar(_)),
)
| (&ty::Infer(ty::InferTy::TyVar(_)), _)
| (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
(&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => {
mut_a == mut_b && self.same_type_modulo_infer(ty_a, ty_b)
}
(&ty::RawPtr(a), &ty::RawPtr(b)) => {
a.mutbl == b.mutbl && self.same_type_modulo_infer(a.ty, b.ty)
}
(&ty::Slice(a), &ty::Slice(b)) => self.same_type_modulo_infer(a, b),
(&ty::Array(a_ty, a_ct), &ty::Array(b_ty, b_ct)) => {
self.same_type_modulo_infer(a_ty, b_ty) && a_ct == b_ct
}
(&ty::Tuple(a), &ty::Tuple(b)) => {
if a.len() != b.len() {
return false;
}
std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
}
(&ty::FnPtr(a), &ty::FnPtr(b)) => {
let a = a.skip_binder().inputs_and_output;
let b = b.skip_binder().inputs_and_output;
if a.len() != b.len() {
return false;
}
std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
}
// FIXME(compiler-errors): This needs to be generalized more
_ => a == b,
| (ty::Infer(ty::InferTy::TyVar(_)), _)
| (_, ty::Infer(ty::InferTy::TyVar(_))) => Ok(a),
(ty::Infer(_), _) | (_, ty::Infer(_)) => Err(TypeError::Mismatch),
_ => relate::super_relate_tys(self, a, b),
}
}

fn regions(
&mut self,
a: ty::Region<'tcx>,
b: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
if (a.is_var() && b.is_free_or_static()) || (b.is_var() && a.is_free_or_static()) || a == b
{
Ok(a)
} else {
Err(TypeError::Mismatch)
}
}

fn binders<T>(
&mut self,
a: ty::Binder<'tcx, T>,
b: ty::Binder<'tcx, T>,
) -> relate::RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: relate::Relate<'tcx>,
{
Ok(ty::Binder::dummy(self.relate(a.skip_binder(), b.skip_binder())?))
}

fn consts(
&mut self,
a: ty::Const<'tcx>,
_b: ty::Const<'tcx>,
) -> relate::RelateResult<'tcx, ty::Const<'tcx>> {
// FIXME(compiler-errors): This could at least do some first-order
// relation
Ok(a)
}
}

impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,10 @@ impl<'tcx> Region<'tcx> {
_ => self.is_free(),
}
}

pub fn is_var(self) -> bool {
matches!(self.kind(), ty::ReVar(_))
}
}

/// Type utilities
Expand Down
45 changes: 45 additions & 0 deletions src/test/ui/implied-bounds/issue-100690.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// This code (probably) _should_ compile, but it currently does not because we
// are not smart enough about implied bounds.

use std::io;

fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
//~^ NOTE required by a bound in this
where
F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
//~^ NOTE required by this bound in `real_dispatch`
//~| NOTE required by a bound in `real_dispatch`
{
todo!()
}

#[derive(Debug)]
struct UIView<'a, T: 'a> {
_phantom: std::marker::PhantomData<&'a mut T>,
}

trait Handle<'a, T: 'a, V, R> {
fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
where
F: FnOnce(&mut V) -> R + Send + 'static;
}

#[derive(Debug, Clone)]
struct TUIHandle<T> {
_phantom: std::marker::PhantomData<T>,
}

impl<'a, T: 'a> Handle<'a, T, UIView<'a, T>, Result<(), io::Error>> for TUIHandle<T> {
fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
where
F: FnOnce(&mut UIView<'a, T>) -> Result<(), io::Error> + Send + 'static,
{
real_dispatch(f)
//~^ ERROR expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
//~| NOTE expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
//~| NOTE expected a closure with arguments
//~| NOTE required by a bound introduced by this call
}
}

fn main() {}
22 changes: 22 additions & 0 deletions src/test/ui/implied-bounds/issue-100690.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
error[E0277]: expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
--> $DIR/issue-100690.rs:37:23
|
LL | real_dispatch(f)
| ------------- ^ expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
| |
| required by a bound introduced by this call
|
= note: expected a closure with arguments `(&mut UIView<'a, T>,)`
found a closure with arguments `(&mut UIView<'_, T>,)`
note: required by a bound in `real_dispatch`
--> $DIR/issue-100690.rs:9:8
|
LL | fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
| ------------- required by a bound in this
...
LL | F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `real_dispatch`

error: aborting due to previous error

For more information about this error, try `rustc --explain E0277`.