From 968abf3020768500a91cce89107f33f787905032 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 26 Feb 2024 20:42:09 +0000 Subject: [PATCH] Remove a_is_expected from combine relations --- .../src/type_check/relate_tys.rs | 1 - compiler/rustc_infer/src/infer/at.rs | 91 ++++++------------- .../rustc_infer/src/infer/opaque_types.rs | 8 +- .../rustc_infer/src/infer/relate/combine.rs | 19 ++-- compiler/rustc_infer/src/infer/relate/glb.rs | 18 ++-- .../src/infer/relate/higher_ranked.rs | 8 +- compiler/rustc_infer/src/infer/relate/lub.rs | 18 ++-- .../src/infer/relate/type_relating.rs | 36 +++----- .../error_reporting/type_err_ctxt_ext.rs | 9 +- 9 files changed, 74 insertions(+), 134 deletions(-) diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs index a17178385f63b..5892a305f5ba0 100644 --- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs +++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs @@ -120,7 +120,6 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> { fn relate_opaques(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> { let infcx = self.type_checker.infcx; debug_assert!(!infcx.next_trait_solver()); - let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) }; // `handle_opaque_type` cannot handle subtyping, so to support subtyping // we instead eagerly generalize here. This is a bit of a mess but will go // away once we're using the new solver. diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index cc09a09468869..59e05cbb60148 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -49,7 +49,6 @@ pub struct At<'a, 'tcx> { pub struct Trace<'a, 'tcx> { at: At<'a, 'tcx>, - a_is_expected: bool, trace: TypeTrace<'tcx>, } @@ -106,23 +105,6 @@ pub trait ToTrace<'tcx>: Relate<'tcx> + Copy { } impl<'a, 'tcx> At<'a, 'tcx> { - /// Makes `a <: b`, where `a` may or may not be expected. - /// - /// See [`At::trace_exp`] and [`Trace::sub`] for a version of - /// this method that only requires `T: Relate<'tcx>` - pub fn sub_exp( - self, - define_opaque_types: DefineOpaqueTypes, - a_is_expected: bool, - a: T, - b: T, - ) -> InferResult<'tcx, ()> - where - T: ToTrace<'tcx>, - { - self.trace_exp(a_is_expected, a, b).sub(define_opaque_types, a, b) - } - /// Makes `actual <: expected`. For example, if type-checking a /// call like `foo(x)`, where `foo: fn(i32)`, you might have /// `sup(i32, x)`, since the "expected" type is the type that @@ -139,7 +121,7 @@ impl<'a, 'tcx> At<'a, 'tcx> { where T: ToTrace<'tcx>, { - self.sub_exp(define_opaque_types, false, actual, expected) + self.trace(expected, actual).sup(define_opaque_types, expected, actual) } /// Makes `expected <: actual`. @@ -155,24 +137,7 @@ impl<'a, 'tcx> At<'a, 'tcx> { where T: ToTrace<'tcx>, { - self.sub_exp(define_opaque_types, true, expected, actual) - } - - /// Makes `expected <: actual`. - /// - /// See [`At::trace_exp`] and [`Trace::eq`] for a version of - /// this method that only requires `T: Relate<'tcx>` - pub fn eq_exp( - self, - define_opaque_types: DefineOpaqueTypes, - a_is_expected: bool, - a: T, - b: T, - ) -> InferResult<'tcx, ()> - where - T: ToTrace<'tcx>, - { - self.trace_exp(a_is_expected, a, b).eq(define_opaque_types, a, b) + self.trace(expected, actual).sub(define_opaque_types, expected, actual) } /// Makes `expected <: actual`. @@ -261,48 +226,50 @@ impl<'a, 'tcx> At<'a, 'tcx> { where T: ToTrace<'tcx>, { - self.trace_exp(true, expected, actual) + let trace = ToTrace::to_trace(self.cause, true, expected, actual); + Trace { at: self, trace } } +} - /// Like `trace`, but the expected value is determined by the - /// boolean argument (if true, then the first argument `a` is the - /// "expected" value). - pub fn trace_exp(self, a_is_expected: bool, a: T, b: T) -> Trace<'a, 'tcx> +impl<'a, 'tcx> Trace<'a, 'tcx> { + /// Makes `a <: b`. + #[instrument(skip(self), level = "debug")] + pub fn sub(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()> where - T: ToTrace<'tcx>, + T: Relate<'tcx>, { - let trace = ToTrace::to_trace(self.cause, a_is_expected, a, b); - Trace { at: self, trace, a_is_expected } + let Trace { at, trace } = self; + let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types); + fields + .sub() + .relate(a, b) + .map(move |_| InferOk { value: (), obligations: fields.obligations }) } -} -impl<'a, 'tcx> Trace<'a, 'tcx> { - /// Makes `a <: b` where `a` may or may not be expected (if - /// `a_is_expected` is true, then `a` is expected). + /// Makes `a :> b`. #[instrument(skip(self), level = "debug")] - pub fn sub(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()> + pub fn sup(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()> where T: Relate<'tcx>, { - let Trace { at, trace, a_is_expected } = self; + let Trace { at, trace } = self; let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types); fields - .sub(a_is_expected) + .sup() .relate(a, b) .map(move |_| InferOk { value: (), obligations: fields.obligations }) } - /// Makes `a == b`; the expectation is set by the call to - /// `trace()`. + /// Makes `a == b`. #[instrument(skip(self), level = "debug")] pub fn eq(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()> where T: Relate<'tcx>, { - let Trace { at, trace, a_is_expected } = self; + let Trace { at, trace } = self; let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types); fields - .equate(StructurallyRelateAliases::No, a_is_expected) + .equate(StructurallyRelateAliases::No) .relate(a, b) .map(move |_| InferOk { value: (), obligations: fields.obligations }) } @@ -314,11 +281,11 @@ impl<'a, 'tcx> Trace<'a, 'tcx> { where T: Relate<'tcx>, { - let Trace { at, trace, a_is_expected } = self; + let Trace { at, trace } = self; debug_assert!(at.infcx.next_trait_solver()); let mut fields = at.infcx.combine_fields(trace, at.param_env, DefineOpaqueTypes::No); fields - .equate(StructurallyRelateAliases::Yes, a_is_expected) + .equate(StructurallyRelateAliases::Yes) .relate(a, b) .map(move |_| InferOk { value: (), obligations: fields.obligations }) } @@ -328,10 +295,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> { where T: Relate<'tcx>, { - let Trace { at, trace, a_is_expected } = self; + let Trace { at, trace } = self; let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types); fields - .lub(a_is_expected) + .lub() .relate(a, b) .map(move |t| InferOk { value: t, obligations: fields.obligations }) } @@ -341,10 +308,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> { where T: Relate<'tcx>, { - let Trace { at, trace, a_is_expected } = self; + let Trace { at, trace } = self; let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types); fields - .glb(a_is_expected) + .glb() .relate(a, b) .map(move |t| InferOk { value: t, obligations: fields.obligations }) } diff --git a/compiler/rustc_infer/src/infer/opaque_types.rs b/compiler/rustc_infer/src/infer/opaque_types.rs index d381c77ec666e..07245643ef59c 100644 --- a/compiler/rustc_infer/src/infer/opaque_types.rs +++ b/compiler/rustc_infer/src/infer/opaque_types.rs @@ -522,13 +522,7 @@ impl<'tcx> InferCtxt<'tcx> { ) -> InferResult<'tcx, ()> { let mut obligations = Vec::new(); - self.insert_hidden_type( - opaque_type_key, - &cause, - param_env, - hidden_ty, - &mut obligations, - )?; + self.insert_hidden_type(opaque_type_key, &cause, param_env, hidden_ty, &mut obligations)?; self.add_item_bounds_for_hidden_type( opaque_type_key.def_id.to_def_id(), diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index 749c50b57b540..099b7ff7c04b0 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -321,21 +321,24 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { pub fn equate<'a>( &'a mut self, structurally_relate_aliases: StructurallyRelateAliases, - a_is_expected: bool, ) -> TypeRelating<'a, 'infcx, 'tcx> { - TypeRelating::new(self, a_is_expected, structurally_relate_aliases, ty::Invariant) + TypeRelating::new(self, structurally_relate_aliases, ty::Invariant) } - pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> { - TypeRelating::new(self, a_is_expected, StructurallyRelateAliases::No, ty::Covariant) + pub fn sub<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, StructurallyRelateAliases::No, ty::Covariant) } - pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> { - Lub::new(self, a_is_expected) + pub fn sup<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> { + TypeRelating::new(self, StructurallyRelateAliases::No, ty::Contravariant) } - pub fn glb<'a>(&'a mut self, a_is_expected: bool) -> Glb<'a, 'infcx, 'tcx> { - Glb::new(self, a_is_expected) + pub fn lub<'a>(&'a mut self) -> Lub<'a, 'infcx, 'tcx> { + Lub::new(self) + } + + pub fn glb<'a>(&'a mut self) -> Glb<'a, 'infcx, 'tcx> { + Glb::new(self) } pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) { diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs index 9b77e6888b2c8..f6796861b12e0 100644 --- a/compiler/rustc_infer/src/infer/relate/glb.rs +++ b/compiler/rustc_infer/src/infer/relate/glb.rs @@ -13,15 +13,11 @@ use crate::traits::{ObligationCause, PredicateObligations}; /// "Greatest lower bound" (common subtype) pub struct Glb<'combine, 'infcx, 'tcx> { fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, } impl<'combine, 'infcx, 'tcx> Glb<'combine, 'infcx, 'tcx> { - pub fn new( - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, - ) -> Glb<'combine, 'infcx, 'tcx> { - Glb { fields, a_is_expected } + pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Glb<'combine, 'infcx, 'tcx> { + Glb { fields } } } @@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> { } fn a_is_expected(&self) -> bool { - self.a_is_expected + true } fn relate_with_variance>( @@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> { b: T, ) -> RelateResult<'tcx, T> { match variance { - ty::Invariant => { - self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b) - } + ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b), ty::Covariant => self.relate(a, b), // FIXME(#41044) -- not correct, need test ty::Bivariant => Ok(a), - ty::Contravariant => self.fields.lub(self.a_is_expected).relate(a, b), + ty::Contravariant => self.fields.lub().relate(a, b), } } @@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx, } fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> { - let mut sub = self.fields.sub(self.a_is_expected); + let mut sub = self.fields.sub(); sub.relate(v, a)?; sub.relate(v, b)?; Ok(()) diff --git a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs index 90be80f67b4d0..c94cbb0db030b 100644 --- a/compiler/rustc_infer/src/infer/relate/higher_ranked.rs +++ b/compiler/rustc_infer/src/infer/relate/higher_ranked.rs @@ -49,7 +49,13 @@ impl<'a, 'tcx> CombineFields<'a, 'tcx> { debug!("b_prime={:?}", sup_prime); // Compare types now that bound regions have been replaced. - let result = self.sub(sub_is_expected).relate(sub_prime, sup_prime); + // Reorder the inputs so that the expected is passed first. + let result = if sub_is_expected { + self.sub().relate(sub_prime, sup_prime) + } else { + self.sup().relate(sup_prime, sub_prime) + }; + if result.is_ok() { debug!("OK result={result:?}"); } diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs index db04e3231d6a4..3d9cfe7bf05b5 100644 --- a/compiler/rustc_infer/src/infer/relate/lub.rs +++ b/compiler/rustc_infer/src/infer/relate/lub.rs @@ -13,15 +13,11 @@ use rustc_span::Span; /// "Least upper bound" (common supertype) pub struct Lub<'combine, 'infcx, 'tcx> { fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, } impl<'combine, 'infcx, 'tcx> Lub<'combine, 'infcx, 'tcx> { - pub fn new( - fields: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, - ) -> Lub<'combine, 'infcx, 'tcx> { - Lub { fields, a_is_expected } + pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Lub<'combine, 'infcx, 'tcx> { + Lub { fields } } } @@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> { } fn a_is_expected(&self) -> bool { - self.a_is_expected + true } fn relate_with_variance>( @@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> { b: T, ) -> RelateResult<'tcx, T> { match variance { - ty::Invariant => { - self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b) - } + ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b), ty::Covariant => self.relate(a, b), // FIXME(#41044) -- not correct, need test ty::Bivariant => Ok(a), - ty::Contravariant => self.fields.glb(self.a_is_expected).relate(a, b), + ty::Contravariant => self.fields.glb().relate(a, b), } } @@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx, } fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> { - let mut sub = self.fields.sub(self.a_is_expected); + let mut sub = self.fields.sub(); sub.relate(a, v)?; sub.relate(b, v)?; Ok(()) diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs index ddc4bf9a514bc..7464b52572498 100644 --- a/compiler/rustc_infer/src/infer/relate/type_relating.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -12,7 +12,6 @@ use rustc_span::Span; /// Enforce that `a` is equal to or a subtype of `b`. pub struct TypeRelating<'combine, 'a, 'tcx> { fields: &'combine mut CombineFields<'a, 'tcx>, - a_is_expected: bool, structurally_relate_aliases: StructurallyRelateAliases, ambient_variance: ty::Variance, } @@ -20,11 +19,10 @@ pub struct TypeRelating<'combine, 'a, 'tcx> { impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> { pub fn new( f: &'combine mut CombineFields<'infcx, 'tcx>, - a_is_expected: bool, structurally_relate_aliases: StructurallyRelateAliases, ambient_variance: ty::Variance, ) -> TypeRelating<'combine, 'infcx, 'tcx> { - TypeRelating { fields: f, a_is_expected, structurally_relate_aliases, ambient_variance } + TypeRelating { fields: f, structurally_relate_aliases, ambient_variance } } } @@ -38,7 +36,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { } fn a_is_expected(&self) -> bool { - self.a_is_expected + true } fn relate_with_variance>( @@ -79,7 +77,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { self.fields.trace.cause.clone(), self.fields.param_env, ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { - a_is_expected: self.a_is_expected, + a_is_expected: true, a, b, })), @@ -93,7 +91,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { self.fields.trace.cause.clone(), self.fields.param_env, ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate { - a_is_expected: !self.a_is_expected, + a_is_expected: false, a: b, b: a, })), @@ -109,18 +107,12 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { } (&ty::Infer(TyVar(a_vid)), _) => { - infcx.instantiate_ty_var( - self, - self.a_is_expected, - a_vid, - self.ambient_variance, - b, - )?; + infcx.instantiate_ty_var(self, true, a_vid, self.ambient_variance, b)?; } (_, &ty::Infer(TyVar(b_vid))) => { infcx.instantiate_ty_var( self, - !self.a_is_expected, + false, b_vid, self.ambient_variance.xform(ty::Contravariant), a, @@ -147,13 +139,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { { self.fields.obligations.extend( infcx - .handle_opaque_type( - a, - b, - self.a_is_expected, - &self.fields.trace.cause, - self.param_env(), - )? + .handle_opaque_type(a, b, true, &self.fields.trace.cause, self.param_env())? .obligations, ); } @@ -239,14 +225,14 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> { } else { match self.ambient_variance { ty::Covariant => { - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; + self.fields.higher_ranked_sub(a, b, true)?; } ty::Contravariant => { - self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?; + self.fields.higher_ranked_sub(b, a, false)?; } ty::Invariant => { - self.fields.higher_ranked_sub(a, b, self.a_is_expected)?; - self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?; + self.fields.higher_ranked_sub(a, b, true)?; + self.fields.higher_ranked_sub(b, a, false)?; } ty::Bivariant => { unreachable!("Expected bivariance to be handled in relate_with_variance") diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs index 41855d5fb4b6b..5cd0b56c311bf 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs @@ -1541,12 +1541,9 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> { // since the normalization is just done to improve the error message. let _ = ocx.select_where_possible(); - if let Err(new_err) = ocx.eq( - &obligation.cause, - obligation.param_env, - expected, - actual, - ) { + if let Err(new_err) = + ocx.eq(&obligation.cause, obligation.param_env, expected, actual) + { (Some((data, is_normalized_term_expected, normalized_term, data.term)), new_err) } else { (None, error.err)