Skip to content

Commit

Permalink
Add derived causes for host effect predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Dec 10, 2024
1 parent 33c245b commit b921c6a
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 37 deletions.
89 changes: 72 additions & 17 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ impl<'tcx> ObligationCause<'tcx> {
self
}

pub fn derived_host_cause(
mut self,
parent_host_pred: ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>,
variant: impl FnOnce(DerivedHostCause<'tcx>) -> ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
self.code = variant(DerivedHostCause { parent_host_pred, parent_code: self.code }).into();
self
}

pub fn to_constraint_category(&self) -> ConstraintCategory<'tcx> {
match self.code() {
ObligationCauseCode::MatchImpl(cause, _) => cause.to_constraint_category(),
Expand Down Expand Up @@ -279,6 +288,14 @@ pub enum ObligationCauseCode<'tcx> {
/// Derived obligation for WF goals.
WellFormedDerived(DerivedCause<'tcx>),

/// Derived obligation (i.e. `where` clause) on an user-provided impl
/// or a trait alias.
ImplDerivedHost(Box<ImplDerivedHostCause<'tcx>>),

/// Derived obligation (i.e. `where` clause) on an user-provided impl
/// or a trait alias.
BuiltinDerivedHost(DerivedHostCause<'tcx>),

/// Derived obligation refined to point at a specific argument in
/// a call or method expression.
FunctionArg {
Expand Down Expand Up @@ -438,36 +455,38 @@ pub enum WellFormedLoc {
},
}

#[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
#[derive(TypeVisitable, TypeFoldable)]
pub struct ImplDerivedCause<'tcx> {
pub derived: DerivedCause<'tcx>,
/// The `DefId` of the `impl` that gave rise to the `derived` obligation.
/// If the `derived` obligation arose from a trait alias, which conceptually has a synthetic impl,
/// then this will be the `DefId` of that trait alias. Care should therefore be taken to handle
/// that exceptional case where appropriate.
pub impl_or_alias_def_id: DefId,
/// The index of the derived predicate in the parent impl's predicates.
pub impl_def_predicate_index: Option<usize>,
pub span: Span,
}

impl<'tcx> ObligationCauseCode<'tcx> {
/// Returns the base obligation, ignoring derived obligations.
pub fn peel_derives(&self) -> &Self {
let mut base_cause = self;
while let Some((parent_code, _)) = base_cause.parent() {
while let Some(parent_code) = base_cause.parent() {
base_cause = parent_code;
}
base_cause
}

pub fn parent(&self) -> Option<&Self> {
match self {
ObligationCauseCode::FunctionArg { parent_code, .. } => Some(parent_code),
ObligationCauseCode::BuiltinDerived(derived)
| ObligationCauseCode::WellFormedDerived(derived)
| ObligationCauseCode::ImplDerived(box ImplDerivedCause { derived, .. }) => {
Some(&derived.parent_code)
}
ObligationCauseCode::BuiltinDerivedHost(derived)
| ObligationCauseCode::ImplDerivedHost(box ImplDerivedHostCause { derived, .. }) => {
Some(&derived.parent_code)
}
_ => None,
}
}

/// Returns the base obligation and the base trait predicate, if any, ignoring
/// derived obligations.
pub fn peel_derives_with_predicate(&self) -> (&Self, Option<ty::PolyTraitPredicate<'tcx>>) {
let mut base_cause = self;
let mut base_trait_pred = None;
while let Some((parent_code, parent_pred)) = base_cause.parent() {
while let Some((parent_code, parent_pred)) = base_cause.parent_with_predicate() {
base_cause = parent_code;
if let Some(parent_pred) = parent_pred {
base_trait_pred = Some(parent_pred);
Expand All @@ -477,7 +496,7 @@ impl<'tcx> ObligationCauseCode<'tcx> {
(base_cause, base_trait_pred)
}

pub fn parent(&self) -> Option<(&Self, Option<ty::PolyTraitPredicate<'tcx>>)> {
pub fn parent_with_predicate(&self) -> Option<(&Self, Option<ty::PolyTraitPredicate<'tcx>>)> {
match self {
ObligationCauseCode::FunctionArg { parent_code, .. } => Some((parent_code, None)),
ObligationCauseCode::BuiltinDerived(derived)
Expand Down Expand Up @@ -555,6 +574,42 @@ pub struct DerivedCause<'tcx> {
pub parent_code: InternedObligationCauseCode<'tcx>,
}

#[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
#[derive(TypeVisitable, TypeFoldable)]
pub struct ImplDerivedCause<'tcx> {
pub derived: DerivedCause<'tcx>,
/// The `DefId` of the `impl` that gave rise to the `derived` obligation.
/// If the `derived` obligation arose from a trait alias, which conceptually has a synthetic impl,
/// then this will be the `DefId` of that trait alias. Care should therefore be taken to handle
/// that exceptional case where appropriate.
pub impl_or_alias_def_id: DefId,
/// The index of the derived predicate in the parent impl's predicates.
pub impl_def_predicate_index: Option<usize>,
pub span: Span,
}

#[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
#[derive(TypeVisitable, TypeFoldable)]
pub struct DerivedHostCause<'tcx> {
/// The trait predicate of the parent obligation that led to the
/// current obligation. Note that only trait obligations lead to
/// derived obligations, so we just store the trait predicate here
/// directly.
pub parent_host_pred: ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>,

/// The parent trait had this cause.
pub parent_code: InternedObligationCauseCode<'tcx>,
}

#[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
#[derive(TypeVisitable, TypeFoldable)]
pub struct ImplDerivedHostCause<'tcx> {
pub derived: DerivedHostCause<'tcx>,
/// The `DefId` of the `impl` that gave rise to the `derived` obligation.
pub impl_def_id: DefId,
pub span: Span,
}

#[derive(Clone, Debug, PartialEq, Eq, TypeVisitable)]
pub enum SelectionError<'tcx> {
/// The trait is not implemented.
Expand Down
22 changes: 22 additions & 0 deletions compiler/rustc_middle/src/ty/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,28 @@ impl<'tcx> UpcastFrom<TyCtxt<'tcx>, PolyProjectionPredicate<'tcx>> for Clause<'t
}
}

impl<'tcx> UpcastFrom<TyCtxt<'tcx>, ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>>
for Predicate<'tcx>
{
fn upcast_from(
from: ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>,
tcx: TyCtxt<'tcx>,
) -> Self {
from.map_bound(ty::ClauseKind::HostEffect).upcast(tcx)
}
}

impl<'tcx> UpcastFrom<TyCtxt<'tcx>, ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>>
for Clause<'tcx>
{
fn upcast_from(
from: ty::Binder<'tcx, ty::HostEffectPredicate<'tcx>>,
tcx: TyCtxt<'tcx>,
) -> Self {
from.map_bound(ty::ClauseKind::HostEffect).upcast(tcx)
}
}

impl<'tcx> UpcastFrom<TyCtxt<'tcx>, NormalizesTo<'tcx>> for Predicate<'tcx> {
fn upcast_from(from: NormalizesTo<'tcx>, tcx: TyCtxt<'tcx>) -> Self {
PredicateKind::NormalizesTo(from).upcast(tcx)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_next_trait_solver/src/solve/effect_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ where

ecx.probe_builtin_trait_candidate(BuiltinImplSource::Misc).enter(|ecx| {
ecx.add_goals(
GoalSource::Misc,
GoalSource::ImplWhereBound,
const_conditions.into_iter().map(|trait_ref| {
goal.with(
cx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
applied_do_not_recommend = true;
}
}
if let Some((parent_cause, _parent_pred)) = base_cause.parent() {
if let Some(parent_cause) = base_cause.parent() {
base_cause = parent_cause.clone();
} else {
break;
Expand All @@ -779,7 +779,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
trait_ref.skip_binder().args.type_at(1).to_opt_closure_kind()
&& !found_kind.extends(expected_kind)
{
if let Some((_, Some(parent))) = obligation.cause.code().parent() {
if let Some((_, Some(parent))) = obligation.cause.code().parent_with_predicate() {
// If we have a derived obligation, then the parent will be a `AsyncFn*` goal.
trait_ref = parent.to_poly_trait_ref();
} else if let &ObligationCauseCode::FunctionArg { arg_hir_id, .. } =
Expand Down Expand Up @@ -925,7 +925,8 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
let Some(typeck) = &self.typeck_results else {
return false;
};
let Some((ObligationCauseCode::QuestionMark, Some(y))) = obligation.cause.code().parent()
let Some((ObligationCauseCode::QuestionMark, Some(y))) =
obligation.cause.code().parent_with_predicate()
else {
return false;
};
Expand Down Expand Up @@ -1178,7 +1179,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {

let mut code = obligation.cause.code();
let mut pred = obligation.predicate.as_trait_clause();
while let Some((next_code, next_pred)) = code.parent() {
while let Some((next_code, next_pred)) = code.parent_with_predicate() {
if let Some(pred) = pred {
self.enter_forall(pred, |pred| {
diag.note(format!(
Expand Down Expand Up @@ -2093,7 +2094,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
let mut code = obligation.cause.code();
let mut trait_pred = trait_predicate;
let mut peeled = false;
while let Some((parent_code, parent_trait_pred)) = code.parent() {
while let Some((parent_code, parent_trait_pred)) = code.parent_with_predicate() {
code = parent_code;
if let Some(parent_trait_pred) = parent_trait_pred {
trait_pred = parent_trait_pred;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
{
// Suggest dereferencing the argument to a function/method call if possible
let mut real_trait_pred = trait_pred;
while let Some((parent_code, parent_trait_pred)) = code.parent() {
while let Some((parent_code, parent_trait_pred)) = code.parent_with_predicate() {
code = parent_code;
if let Some(parent_trait_pred) = parent_trait_pred {
real_trait_pred = parent_trait_pred;
Expand Down Expand Up @@ -1476,7 +1476,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
let mut span = obligation.cause.span;
let mut trait_pred = trait_pred;
let mut code = obligation.cause.code();
while let Some((c, Some(parent_trait_pred))) = code.parent() {
while let Some((c, Some(parent_trait_pred))) = code.parent_with_predicate() {
// We want the root obligation, in order to detect properly handle
// `for _ in &mut &mut vec![] {}`.
code = c;
Expand Down Expand Up @@ -3501,6 +3501,59 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
)
});
}
ObligationCauseCode::ImplDerivedHost(ref data) => {
let self_ty =
self.resolve_vars_if_possible(data.derived.parent_host_pred.self_ty());
let msg = format!(
"required for `{self_ty}` to implement `{} {}`",
data.derived.parent_host_pred.skip_binder().constness,
data.derived
.parent_host_pred
.map_bound(|pred| pred.trait_ref)
.print_only_trait_path(),
);
match tcx.hir().get_if_local(data.impl_def_id) {
Some(Node::Item(hir::Item {
kind: hir::ItemKind::Impl(hir::Impl { of_trait, self_ty, .. }),
..
})) => {
let mut spans = vec![self_ty.span];
spans.extend(of_trait.as_ref().map(|t| t.path.span));
let mut spans: MultiSpan = spans.into();
spans.push_span_label(data.span, "unsatisfied trait bound introduced here");
err.span_note(spans, msg);
}
_ => {
err.note(msg);
}
}
ensure_sufficient_stack(|| {
self.note_obligation_cause_code(
body_id,
err,
data.derived.parent_host_pred,
param_env,
&data.derived.parent_code,
obligated_types,
seen_requirements,
long_ty_file,
)
});
}
ObligationCauseCode::BuiltinDerivedHost(ref data) => {
ensure_sufficient_stack(|| {
self.note_obligation_cause_code(
body_id,
err,
data.parent_host_pred,
param_env,
&data.parent_code,
obligated_types,
seen_requirements,
long_ty_file,
)
});
}
ObligationCauseCode::WellFormedDerived(ref data) => {
let parent_trait_ref = self.resolve_vars_if_possible(data.parent_trait_pred);
let parent_predicate = parent_trait_ref;
Expand Down
Loading

0 comments on commit b921c6a

Please sign in to comment.