Skip to content

Commit

Permalink
Auto merge of #120752 - compiler-errors:more-relevant-bounds, r=lcnr
Browse files Browse the repository at this point in the history
Collect relevant item bounds from trait clauses for nested rigid projections

Rust currently considers trait where-clauses that bound the trait's *own* associated types to act like an item bound:

```rust
trait Foo where Self::Assoc: Bar { type Assoc; }
// acts as if:
trait Foo { type Assoc: Bar; }
```

### Background

This behavior has existed since essentially forever (i.e. before Rust 1.0), since we originally started out by literally looking at the where clauses written on the trait when assembling `SelectionCandidate::ProjectionCandidate` for projections. However, looking at the predicates of the associated type themselves was not sound, since it was unclear which predicates were *assumed* and which predicates were *implied*, and therefore this was reworked in #72788 (which added a query for the predicates we consider for `ProjectionCandidate`s), and then finally item bounds and predicates were split in #73905.

### Problem 1: GATs don't uplift bounds correctly

All the while, we've still had logic to uplift associated type bounds from a trait's where clauses. However, with the introduction of GATs, this logic was never really generalized correctly for them, since we were using simple equality to test if the self type of a trait where clause is a projection. This leads to shortcomings, such as:

```rust
trait Foo
where
    for<'a> Self::Gat<'a>: Debug,
{
    type Gat<'a>;
}

fn test<T: Foo>(x: T::Gat<'static>) {
    //~^ ERROR `<T as Foo>::Gat<'a>` doesn't implement `Debug`
    println!("{:?}", x);
}
```

### Problem 2: Nested associated type bounds are not uplifted

We also don't attempt to uplift bounds on nested associated types, something that we couldn't really support until #120584. This can be demonstrated best with an example:

```rust
trait A
    where Self::Assoc: B,
    where <Self::Assoc as B>::Assoc2: C,
{
    type Assoc; // <~ The compiler *should* treat this like it has an item bound `B<Assoc2: C>`.
}

trait B { type Assoc2; }
trait C {}

fn is_c<T: C>() {}

fn test<T: A>() {
    is_c::<<Self::Assoc as B>::Assoc2>();
    //~^ ERROR the trait bound `<<T as A>::Assoc as B>::Assoc2: C` is not satisfied
}
```

Why does this matter?

Well, generalizing this behavior bridges a gap between the associated type bounds (ATB) feature and trait where clauses. Currently, all bounds that can be stably written on associated types can also be expressed as where clauses on traits; however, with the stabilization of ATB, there are now bounds that can't be desugared in the same way. This fixes that.

## How does this PR fix things?

First, when scraping item bounds from the trait's where clauses, given a trait predicate, we'll loop of the self type of the predicate as long as it's a projection. If we find a projection whose trait ref matches, we'll uplift the bound. This allows us to uplift, for example `<Self as Trait>::Assoc: Bound` (pre-existing), but also `<<Self as Trait>::Assoc as Iterator>::Item: Bound` (new).

If that projection is a GAT, we will check if all of the GAT's *own* args are all unique late-bound vars. We then map the late-bound vars to early-bound vars from the GAT -- this allows us to uplift `for<'a, 'b> Self::Assoc<'a, 'b>: Trait` into an item bound, but we will leave `for<'a> Self::Assoc<'a, 'a>: Trait` and `Self::Assoc<'static, 'static>: Trait` alone.

### Okay, but does this *really* matter?

I consider this to be an improvement of the status quo because it makes GATs a bit less magical, and makes rigid projections a bit more expressive.
  • Loading branch information
bors committed Sep 25, 2024
2 parents 0399709 + c591475 commit 9e394f5
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 10 deletions.
245 changes: 235 additions & 10 deletions compiler/rustc_hir_analysis/src/collect/item_bounds.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_hir as hir;
use rustc_infer::traits::util;
use rustc_middle::ty::fold::shift_vars;
use rustc_middle::ty::{
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
};
use rustc_middle::{bug, span_bug};
use rustc_span::Span;
Expand Down Expand Up @@ -42,14 +43,18 @@ fn associated_type_bounds<'tcx>(
let trait_def_id = tcx.local_parent(assoc_item_def_id);
let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);

let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
match pred.kind().skip_binder() {
ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty,
ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
_ => false,
}
});
let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
let bounds_from_parent =
trait_predicates.predicates.iter().copied().filter_map(|(clause, span)| {
remap_gat_vars_and_recurse_into_nested_projections(
tcx,
filter,
item_trait_ref,
assoc_item_def_id,
span,
clause,
)
});

let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent));
debug!(
Expand All @@ -63,6 +68,226 @@ fn associated_type_bounds<'tcx>(
all_bounds
}

/// The code below is quite involved, so let me explain.
///
/// We loop here, because we also want to collect vars for nested associated items as
/// well. For example, given a clause like `Self::A::B`, we want to add that to the
/// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
/// rigid.
///
/// Secondly, regarding bound vars, when we see a where clause that mentions a GAT
/// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
/// an item bound on the GAT, where all of the GAT args are substituted with the GAT's
/// param regions, and then keep all of the other late-bound vars in the bound around.
/// We need to "compress" the binder so that it doesn't mention any of those vars that
/// were mapped to params.
fn remap_gat_vars_and_recurse_into_nested_projections<'tcx>(
tcx: TyCtxt<'tcx>,
filter: PredicateFilter,
item_trait_ref: ty::TraitRef<'tcx>,
assoc_item_def_id: LocalDefId,
span: Span,
clause: ty::Clause<'tcx>,
) -> Option<(ty::Clause<'tcx>, Span)> {
let mut clause_ty = match clause.kind().skip_binder() {
ty::ClauseKind::Trait(tr) => tr.self_ty(),
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(),
ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
_ => return None,
};

let gat_vars = loop {
if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
if alias_ty.trait_ref(tcx) == item_trait_ref
&& alias_ty.def_id == assoc_item_def_id.to_def_id()
{
// We have found the GAT in question...
// Return the vars, since we may need to remap them.
break &alias_ty.args[item_trait_ref.args.len()..];
} else {
// Only collect *self* type bounds if the filter is for self.
match filter {
PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => {
return None;
}
PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {}
}

clause_ty = alias_ty.self_ty();
continue;
}
}

return None;
};

// Special-case: No GAT vars, no mapping needed.
if gat_vars.is_empty() {
return Some((clause, span));
}

// First, check that all of the GAT args are substituted with a unique late-bound arg.
// If we find a duplicate, then it can't be mapped to the definition's params.
let mut mapping = FxIndexMap::default();
let generics = tcx.generics_of(assoc_item_def_id);
for (param, var) in std::iter::zip(&generics.own_params, gat_vars) {
let existing = match var.unpack() {
ty::GenericArgKind::Lifetime(re) => {
if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
mapping.insert(bv.var, tcx.mk_param_from_def(param))
} else {
return None;
}
}
ty::GenericArgKind::Type(ty) => {
if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
mapping.insert(bv.var, tcx.mk_param_from_def(param))
} else {
return None;
}
}
ty::GenericArgKind::Const(ct) => {
if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
mapping.insert(bv, tcx.mk_param_from_def(param))
} else {
return None;
}
}
};

if existing.is_some() {
return None;
}
}

// Finally, map all of the args in the GAT to the params we expect, and compress
// the remaining late-bound vars so that they count up from var 0.
let mut folder =
MapAndCompressBoundVars { tcx, binder: ty::INNERMOST, still_bound_vars: vec![], mapping };
let pred = clause.kind().skip_binder().fold_with(&mut folder);

Some((
ty::Binder::bind_with_vars(pred, tcx.mk_bound_variable_kinds(&folder.still_bound_vars))
.upcast(tcx),
span,
))
}

/// Given some where clause like `for<'b, 'c> <Self as Trait<'a_identity>>::Gat<'b>: Bound<'c>`,
/// the mapping will map `'b` back to the GAT's `'b_identity`. Then we need to compress the
/// remaining bound var `'c` to index 0.
///
/// This folder gives us: `for<'c> <Self as Trait<'a_identity>>::Gat<'b_identity>: Bound<'c>`,
/// which is sufficient for an item bound for `Gat`, since all of the GAT's args are identity.
struct MapAndCompressBoundVars<'tcx> {
tcx: TyCtxt<'tcx>,
/// How deep are we? Makes sure we don't touch the vars of nested binders.
binder: ty::DebruijnIndex,
/// List of bound vars that remain unsubstituted because they were not
/// mentioned in the GAT's args.
still_bound_vars: Vec<ty::BoundVariableKind>,
/// Subtle invariant: If the `GenericArg` is bound, then it should be
/// stored with the debruijn index of `INNERMOST` so it can be shifted
/// correctly during substitution.
mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
fn cx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
where
ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
{
self.binder.shift_in(1);
let out = t.super_fold_with(self);
self.binder.shift_out(1);
out
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if !ty.has_bound_vars() {
return ty;
}

if let ty::Bound(binder, old_bound) = *ty.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
mapped.expect_ty()
} else {
// If we didn't find a mapped generic, then make a new one.
// Allocate a new var idx, and insert a new bound ty.
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
let mapped = Ty::new_bound(self.tcx, ty::INNERMOST, ty::BoundTy {
var,
kind: old_bound.kind,
});
self.mapping.insert(old_bound.var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
ty.super_fold_with(self)
}
}

fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
if let ty::ReBound(binder, old_bound) = re.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
mapped.expect_region()
} else {
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
let mapped = ty::Region::new_bound(self.tcx, ty::INNERMOST, ty::BoundRegion {
var,
kind: old_bound.kind,
});
self.mapping.insert(old_bound.var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
re
}
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
if !ct.has_bound_vars() {
return ct;
}

if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
mapped.expect_const()
} else {
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Const);
let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var);
self.mapping.insert(old_var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
ct.super_fold_with(self)
}
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
}
}

/// Opaque types don't inherit bounds from their parent: for return position
/// impl trait it isn't possible to write a suitable predicate on the
/// containing function and for type-alias impl trait we don't have a backwards
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Demonstrates a mostly-theoretical inference guidance now that we turn the where
// clause on `Trait` into an item bound, given that we prefer item bounds somewhat
// greedily in trait selection.

trait Bound<T> {}
impl<T, U> Bound<T> for U {}

trait Trait
where
<<Self as Trait>::Assoc as Other>::Assoc: Bound<u32>,
{
type Assoc: Other;
}

trait Other {
type Assoc;
}

fn impls_trait<T: Bound<U>, U>() -> Vec<U> { vec![] }

fn foo<T: Trait>() {
let mut vec_u = impls_trait::<<<T as Trait>::Assoc as Other>::Assoc, _>();
vec_u.sort();
drop::<Vec<u8>>(vec_u);
//~^ ERROR mismatched types
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error[E0308]: mismatched types
--> $DIR/nested-associated-type-bound-incompleteness.rs:24:21
|
LL | drop::<Vec<u8>>(vec_u);
| --------------- ^^^^^ expected `Vec<u8>`, found `Vec<u32>`
| |
| arguments to this function are incorrect
|
= note: expected struct `Vec<u8>`
found struct `Vec<u32>`
note: function defined here
--> $SRC_DIR/core/src/mem/mod.rs:LL:COL

error: aborting due to 1 previous error

For more information about this error, try `rustc --explain E0308`.
31 changes: 31 additions & 0 deletions tests/ui/associated-type-bounds/nested-gat-projection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//@ check-pass

trait Trait
where
for<'a> Self::Gat<'a>: OtherTrait,
for<'a, 'b, 'c> <Self::Gat<'a> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>,
{
type Gat<'a>;
}

trait OtherTrait {
type OtherGat<'b>;
}

trait HigherRanked<'c> {}

fn lower_ranked<T: for<'b, 'c> OtherTrait<OtherGat<'b>: HigherRanked<'c>>>() {}

fn higher_ranked<T: Trait>()
where
for<'a> T::Gat<'a>: OtherTrait,
for<'a, 'b, 'c> <T::Gat<'a> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>,
{
}

fn test<T: Trait>() {
lower_ranked::<T::Gat<'_>>();
higher_ranked::<T>();
}

fn main() {}
28 changes: 28 additions & 0 deletions tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ check-pass
//@ revisions: current next
//@[next] compile-flags: -Znext-solver

trait Trait
where
Self::Assoc: Clone,
{
type Assoc;
}

fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
x.clone()
}

trait Trait2
where
Self::Assoc: Iterator,
<Self::Assoc as Iterator>::Item: Clone,
{
type Assoc;
}

fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
x.clone()
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ check-pass

// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
// just as it would be if it weren't a GAT but just a regular associated type.

use std::fmt::Debug;

trait Foo
where
for<'a> Self::Gat<'a>: Debug,
{
type Gat<'a>;
}

fn test<T: Foo>(x: T::Gat<'static>) {
println!("{:?}", x);
}

fn main() {}
Loading

0 comments on commit 9e394f5

Please sign in to comment.