Skip to content

Commit

Permalink
More compare_impl_item simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Oct 23, 2024
1 parent 73a37a1 commit 21d95fb
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 76 deletions.
125 changes: 52 additions & 73 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ mod refine;
/// - `impl_m`: type of the method we are checking
/// - `trait_m`: the method in the trait
/// - `impl_trait_ref`: the TraitRef corresponding to the trait implementation
#[instrument(level = "debug", skip(tcx))]
pub(super) fn compare_impl_method<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: ty::AssocItem,
trait_m: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) {
debug!("compare_impl_method(impl_trait_ref={:?})", impl_trait_ref);

let _: Result<_, ErrorGuaranteed> = try {
check_method_is_structurally_compatible(tcx, impl_m, trait_m, impl_trait_ref, false)?;
compare_method_predicate_entailment(tcx, impl_m, trait_m, impl_trait_ref)?;
Expand Down Expand Up @@ -167,8 +166,6 @@ fn compare_method_predicate_entailment<'tcx>(
trait_m: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) -> Result<(), ErrorGuaranteed> {
let trait_to_impl_args = impl_trait_ref.args;

// This node-id should be used for the `body_id` field on each
// `ObligationCause` (and the `FnCtxt`).
//
Expand All @@ -183,13 +180,13 @@ fn compare_method_predicate_entailment<'tcx>(
kind: impl_m.kind,
});

// Create mapping from impl to placeholder.
let impl_to_placeholder_args = GenericArgs::identity_for_item(tcx, impl_m.def_id);

// Create mapping from trait to placeholder.
let trait_to_placeholder_args =
impl_to_placeholder_args.rebase_onto(tcx, impl_m.container_id(tcx), trait_to_impl_args);
debug!("compare_impl_method: trait_to_placeholder_args={:?}", trait_to_placeholder_args);
// Create mapping from trait method to impl method.
let trait_to_impl_args = GenericArgs::identity_for_item(tcx, impl_m.def_id).rebase_onto(
tcx,
impl_m.container_id(tcx),
impl_trait_ref.args,
);
debug!(?trait_to_impl_args);

let impl_m_predicates = tcx.predicates_of(impl_m.def_id);
let trait_m_predicates = tcx.predicates_of(trait_m.def_id);
Expand All @@ -204,28 +201,22 @@ fn compare_method_predicate_entailment<'tcx>(
let impl_predicates = tcx.predicates_of(impl_m_predicates.parent.unwrap());
let mut hybrid_preds = impl_predicates.instantiate_identity(tcx).predicates;
hybrid_preds.extend(
trait_m_predicates
.instantiate_own(tcx, trait_to_placeholder_args)
.map(|(predicate, _)| predicate),
trait_m_predicates.instantiate_own(tcx, trait_to_impl_args).map(|(predicate, _)| predicate),
);

// Construct trait parameter environment and then shift it into the placeholder viewpoint.
// The key step here is to update the caller_bounds's predicates to be
// the new hybrid bounds we computed.
let normalize_cause = traits::ObligationCause::misc(impl_m_span, impl_m_def_id);
let param_env = ty::ParamEnv::new(tcx.mk_clauses(&hybrid_preds), Reveal::UserFacing);
let param_env = traits::normalize_param_env_or_error(tcx, param_env, normalize_cause);
debug!(caller_bounds=?param_env.caller_bounds());

let infcx = &tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new_with_diagnostics(infcx);

debug!("compare_impl_method: caller_bounds={:?}", param_env.caller_bounds());

// Create obligations for each predicate declared by the impl
// definition in the context of the hybrid param-env. This makes
// sure that the impl's method's where clauses are not more
// restrictive than the trait's method (and the impl itself).
let impl_m_own_bounds = impl_m_predicates.instantiate_own(tcx, impl_to_placeholder_args);
let impl_m_own_bounds = impl_m_predicates.instantiate_own_identity();
for (predicate, span) in impl_m_own_bounds {
let normalize_cause = traits::ObligationCause::misc(span, impl_m_def_id);
let predicate = ocx.normalize(&normalize_cause, param_env, predicate);
Expand All @@ -252,7 +243,6 @@ fn compare_method_predicate_entailment<'tcx>(
// any associated types appearing in the fn arguments or return
// type.

// Compute placeholder form of impl and trait method tys.
let mut wf_tys = FxIndexSet::default();

let unnormalized_impl_sig = infcx.instantiate_binder_with_fresh_vars(
Expand All @@ -263,9 +253,9 @@ fn compare_method_predicate_entailment<'tcx>(

let norm_cause = ObligationCause::misc(impl_m_span, impl_m_def_id);
let impl_sig = ocx.normalize(&norm_cause, param_env, unnormalized_impl_sig);
debug!("compare_impl_method: impl_fty={:?}", impl_sig);
debug!(?impl_sig);

let trait_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_to_placeholder_args);
let trait_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_to_impl_args);
let trait_sig = tcx.liberate_late_bound_regions(impl_m.def_id, trait_sig);

// Next, add all inputs and output as well-formed tys. Importantly,
Expand All @@ -276,9 +266,7 @@ fn compare_method_predicate_entailment<'tcx>(
// We also have to add the normalized trait signature
// as we don't normalize during implied bounds computation.
wf_tys.extend(trait_sig.inputs_and_output.iter());
let trait_fty = Ty::new_fn_ptr(tcx, ty::Binder::dummy(trait_sig));

debug!("compare_impl_method: trait_fty={:?}", trait_fty);
debug!(?trait_sig);

// FIXME: We'd want to keep more accurate spans than "the method signature" when
// processing the comparison between the trait and impl fn, but we sadly lose them
Expand Down Expand Up @@ -451,8 +439,6 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
// just so we don't ICE during instantiation later.
check_method_is_structurally_compatible(tcx, impl_m, trait_m, impl_trait_ref, true)?;

let trait_to_impl_args = impl_trait_ref.args;

let impl_m_hir_id = tcx.local_def_id_to_hir_id(impl_m_def_id);
let return_span = tcx.hir().fn_decl_by_hir_id(impl_m_hir_id).unwrap().output.span();
let cause =
Expand All @@ -462,18 +448,18 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
kind: impl_m.kind,
});

// Create mapping from impl to placeholder.
let impl_to_placeholder_args = GenericArgs::identity_for_item(tcx, impl_m.def_id);

// Create mapping from trait to placeholder.
let trait_to_placeholder_args =
impl_to_placeholder_args.rebase_onto(tcx, impl_m.container_id(tcx), trait_to_impl_args);
// Create mapping from trait to impl (i.e. impl trait header + impl method identity args).
let trait_to_impl_args = GenericArgs::identity_for_item(tcx, impl_m.def_id).rebase_onto(
tcx,
impl_m.container_id(tcx),
impl_trait_ref.args,
);

let hybrid_preds = tcx
.predicates_of(impl_m.container_id(tcx))
.instantiate_identity(tcx)
.into_iter()
.chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_to_placeholder_args))
.chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_to_impl_args))
.map(|(clause, _)| clause);
let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds), Reveal::UserFacing);
let param_env = traits::normalize_param_env_or_error(
Expand Down Expand Up @@ -507,7 +493,7 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
.instantiate_binder_with_fresh_vars(
return_span,
infer::HigherRankedType,
tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_to_placeholder_args),
tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_to_impl_args),
)
.fold_with(&mut collector);

Expand Down Expand Up @@ -701,7 +687,7 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
// Also, we only need to account for a difference in trait and impl args,
// since we previously enforce that the trait method and impl method have the
// same generics.
let num_trait_args = trait_to_impl_args.len();
let num_trait_args = impl_trait_ref.args.len();
let num_impl_args = tcx.generics_of(impl_m.container_id(tcx)).own_params.len();
let ty = match ty.try_fold_with(&mut RemapHiddenTyRegions {
tcx,
Expand Down Expand Up @@ -1037,12 +1023,7 @@ fn check_region_bounds_on_impl_item<'tcx>(
let trait_generics = tcx.generics_of(trait_m.def_id);
let trait_params = trait_generics.own_counts().lifetimes;

debug!(
"check_region_bounds_on_impl_item: \
trait_generics={:?} \
impl_generics={:?}",
trait_generics, impl_generics
);
debug!(?trait_generics, ?impl_generics);

// Must have same number of early-bound lifetime parameters.
// Unfortunately, if the user screws up the bounds, then this
Expand Down Expand Up @@ -1706,8 +1687,7 @@ pub(super) fn compare_impl_const_raw(
let trait_const_item = tcx.associated_item(trait_const_item_def);
let impl_trait_ref =
tcx.impl_trait_ref(impl_const_item.container_id(tcx)).unwrap().instantiate_identity();

debug!("compare_impl_const(impl_trait_ref={:?})", impl_trait_ref);
debug!(?impl_trait_ref);

compare_number_of_generics(tcx, impl_const_item, trait_const_item, false)?;
compare_generic_param_kinds(tcx, impl_const_item, trait_const_item, false)?;
Expand All @@ -1718,6 +1698,7 @@ pub(super) fn compare_impl_const_raw(
/// The equivalent of [compare_method_predicate_entailment], but for associated constants
/// instead of associated functions.
// FIXME(generic_const_items): If possible extract the common parts of `compare_{type,const}_predicate_entailment`.
#[instrument(level = "debug", skip(tcx))]
fn compare_const_predicate_entailment<'tcx>(
tcx: TyCtxt<'tcx>,
impl_ct: ty::AssocItem,
Expand All @@ -1732,13 +1713,14 @@ fn compare_const_predicate_entailment<'tcx>(
// because we shouldn't really have to deal with lifetimes or
// predicates. In fact some of this should probably be put into
// shared functions because of DRY violations...
let impl_args = GenericArgs::identity_for_item(tcx, impl_ct.def_id);
let trait_to_impl_args =
impl_args.rebase_onto(tcx, impl_ct.container_id(tcx), impl_trait_ref.args);
let trait_to_impl_args = GenericArgs::identity_for_item(tcx, impl_ct.def_id).rebase_onto(
tcx,
impl_ct.container_id(tcx),
impl_trait_ref.args,
);

// Create a parameter environment that represents the implementation's
// method.
// Compute placeholder form of impl and trait const tys.
// associated const.
let impl_ty = tcx.type_of(impl_ct_def_id).instantiate_identity();

let trait_ty = tcx.type_of(trait_ct.def_id).instantiate(tcx, trait_to_impl_args);
Expand Down Expand Up @@ -1772,7 +1754,7 @@ fn compare_const_predicate_entailment<'tcx>(
let infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new_with_diagnostics(&infcx);

let impl_ct_own_bounds = impl_ct_predicates.instantiate_own(tcx, impl_args);
let impl_ct_own_bounds = impl_ct_predicates.instantiate_own_identity();
for (predicate, span) in impl_ct_own_bounds {
let cause = ObligationCause::misc(span, impl_ct_def_id);
let predicate = ocx.normalize(&cause, param_env, predicate);
Expand All @@ -1783,20 +1765,15 @@ fn compare_const_predicate_entailment<'tcx>(

// There is no "body" here, so just pass dummy id.
let impl_ty = ocx.normalize(&cause, param_env, impl_ty);

debug!("compare_const_impl: impl_ty={:?}", impl_ty);
debug!(?impl_ty);

let trait_ty = ocx.normalize(&cause, param_env, trait_ty);

debug!("compare_const_impl: trait_ty={:?}", trait_ty);
debug!(?trait_ty);

let err = ocx.sup(&cause, param_env, trait_ty, impl_ty);

if let Err(terr) = err {
debug!(
"checking associated const for compatibility: impl ty {:?}, trait ty {:?}",
impl_ty, trait_ty
);
debug!(?impl_ty, ?trait_ty);

// Locate the Span containing just the type of the offending impl
let (ty, _) = tcx.hir().expect_impl_item(impl_ct_def_id).expect_const();
Expand Down Expand Up @@ -1841,14 +1818,13 @@ fn compare_const_predicate_entailment<'tcx>(
ocx.resolve_regions_and_report_errors(impl_ct_def_id, &outlives_env)
}

#[instrument(level = "debug", skip(tcx))]
pub(super) fn compare_impl_ty<'tcx>(
tcx: TyCtxt<'tcx>,
impl_ty: ty::AssocItem,
trait_ty: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) {
debug!("compare_impl_type(impl_trait_ref={:?})", impl_trait_ref);

let _: Result<(), ErrorGuaranteed> = try {
compare_number_of_generics(tcx, impl_ty, trait_ty, false)?;
compare_generic_param_kinds(tcx, impl_ty, trait_ty, false)?;
Expand All @@ -1860,20 +1836,23 @@ pub(super) fn compare_impl_ty<'tcx>(

/// The equivalent of [compare_method_predicate_entailment], but for associated types
/// instead of associated functions.
#[instrument(level = "debug", skip(tcx))]
fn compare_type_predicate_entailment<'tcx>(
tcx: TyCtxt<'tcx>,
impl_ty: ty::AssocItem,
trait_ty: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) -> Result<(), ErrorGuaranteed> {
let impl_args = GenericArgs::identity_for_item(tcx, impl_ty.def_id);
let trait_to_impl_args =
impl_args.rebase_onto(tcx, impl_ty.container_id(tcx), impl_trait_ref.args);
let trait_to_impl_args = GenericArgs::identity_for_item(tcx, impl_ty.def_id).rebase_onto(
tcx,
impl_ty.container_id(tcx),
impl_trait_ref.args,
);

let impl_ty_predicates = tcx.predicates_of(impl_ty.def_id);
let trait_ty_predicates = tcx.predicates_of(trait_ty.def_id);

let impl_ty_own_bounds = impl_ty_predicates.instantiate_own(tcx, impl_args);
let impl_ty_own_bounds = impl_ty_predicates.instantiate_own_identity();
if impl_ty_own_bounds.len() == 0 {
// Nothing to check.
return Ok(());
Expand All @@ -1883,7 +1862,7 @@ fn compare_type_predicate_entailment<'tcx>(
// `ObligationCause` (and the `FnCtxt`). This is what
// `regionck_item` expects.
let impl_ty_def_id = impl_ty.def_id.expect_local();
debug!("compare_type_predicate_entailment: trait_to_impl_args={:?}", trait_to_impl_args);
debug!(?trait_to_impl_args);

// The predicates declared by the impl definition, the trait and the
// associated type in the trait are assumed.
Expand All @@ -1894,18 +1873,18 @@ fn compare_type_predicate_entailment<'tcx>(
.instantiate_own(tcx, trait_to_impl_args)
.map(|(predicate, _)| predicate),
);

debug!("compare_type_predicate_entailment: bounds={:?}", hybrid_preds);
debug!(?hybrid_preds);

let impl_ty_span = tcx.def_span(impl_ty_def_id);
let normalize_cause = ObligationCause::misc(impl_ty_span, impl_ty_def_id);

let param_env = ty::ParamEnv::new(tcx.mk_clauses(&hybrid_preds), Reveal::UserFacing);
let param_env = traits::normalize_param_env_or_error(tcx, param_env, normalize_cause);
debug!(caller_bounds=?param_env.caller_bounds());

let infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new_with_diagnostics(&infcx);

debug!("compare_type_predicate_entailment: caller_bounds={:?}", param_env.caller_bounds());

for (predicate, span) in impl_ty_own_bounds {
let cause = ObligationCause::misc(span, impl_ty_def_id);
let predicate = ocx.normalize(&cause, param_env, predicate);
Expand Down Expand Up @@ -2005,11 +1984,11 @@ pub(super) fn check_type_bounds<'tcx>(
.explicit_item_bounds(trait_ty.def_id)
.iter_instantiated_copied(tcx, rebased_args)
.map(|(concrete_ty_bound, span)| {
debug!("check_type_bounds: concrete_ty_bound = {:?}", concrete_ty_bound);
debug!(?concrete_ty_bound);
traits::Obligation::new(tcx, mk_cause(span), param_env, concrete_ty_bound)
})
.collect();
debug!("check_type_bounds: item_bounds={:?}", obligations);
debug!(item_bounds=?obligations);

// Normalize predicates with the assumption that the GAT may always normalize
// to its definition type. This should be the param-env we use to *prove* the
Expand All @@ -2028,7 +2007,7 @@ pub(super) fn check_type_bounds<'tcx>(
} else {
ocx.normalize(&normalize_cause, normalize_param_env, obligation.predicate)
};
debug!("compare_projection_bounds: normalized predicate = {:?}", normalized_predicate);
debug!(?normalized_predicate);
obligation.predicate = normalized_predicate;

ocx.register_obligation(obligation);
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_middle/src/ty/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ impl<'tcx> GenericPredicates<'tcx> {
EarlyBinder::bind(self.predicates).iter_instantiated_copied(tcx, args)
}

pub fn instantiate_own_identity(self) -> impl Iterator<Item = (Clause<'tcx>, Span)> {
pub fn instantiate_own_identity(
self,
) -> impl Iterator<Item = (Clause<'tcx>, Span)> + DoubleEndedIterator + ExactSizeIterator {
EarlyBinder::bind(self.predicates).iter_identity_copied()
}

Expand Down
Loading

0 comments on commit 21d95fb

Please sign in to comment.