Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wf-check user types before normalization #104746

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3661,6 +3661,7 @@ dependencies = [
name = "rustc_hir_typeck"
version = "0.1.0"
dependencies = [
"either",
"rustc_ast",
"rustc_data_structures",
"rustc_errors",
Expand Down
78 changes: 74 additions & 4 deletions compiler/rustc_borrowck/src/type_check/canonical.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::fmt;

use rustc_infer::infer::canonical::Canonical;
use rustc_infer::traits::query::NoSolution;
use rustc_infer::infer::{canonical::Canonical, InferOk};
use rustc_middle::mir::ConstraintCategory;
use rustc_middle::ty::{self, ToPredicate, TypeFoldable};
use rustc_middle::ty::{self, ToPredicate, Ty, TypeFoldable};
use rustc_span::def_id::DefId;
use rustc_span::Span;
use rustc_trait_selection::traits::query::type_op::{self, TypeOpOutput};
use rustc_trait_selection::traits::query::Fallible;
use rustc_trait_selection::traits::query::{Fallible, NoSolution};
use rustc_trait_selection::traits::{ObligationCause, ObligationCtxt};

use crate::diagnostics::{ToUniverseInfo, UniverseInfo};

Expand Down Expand Up @@ -179,4 +179,74 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
value
})
}

#[instrument(skip(self), level = "debug")]
pub(super) fn ascribe_user_type(
&mut self,
mir_ty: Ty<'tcx>,
user_ty: ty::UserType<'tcx>,
span: Span,
) {
// FIXME: Ideally MIR types are normalized, but this is not always true.
let mir_ty = self.normalize(mir_ty, Locations::All(span));

self.fully_perform_op(
Locations::All(span),
ConstraintCategory::Boring,
self.param_env.and(type_op::ascribe_user_type::AscribeUserType::new(mir_ty, user_ty)),
)
.unwrap_or_else(|err| {
span_mirbug!(
self,
span,
"ascribe_user_type `{mir_ty:?}=={user_ty:?}` failed with `{err:?}`",
);
});
}

/// *Incorrectly* skips the WF checks we normally do in `ascribe_user_type`.
///
/// FIXME(#104478, #104477): This is a hack for backward-compatibility.
#[instrument(skip(self), level = "debug")]
pub(super) fn ascribe_user_type_skip_wf(
&mut self,
mir_ty: Ty<'tcx>,
user_ty: ty::UserType<'tcx>,
span: Span,
) {
let ty::UserType::Ty(user_ty) = user_ty else { bug!() };

// A fast path for a common case with closure input/output types.
if let ty::Infer(_) = user_ty.kind() {
self.eq_types(user_ty, mir_ty, Locations::All(span), ConstraintCategory::Boring)
.unwrap();
return;
}

let mir_ty = self.normalize(mir_ty, Locations::All(span));
let cause = ObligationCause::dummy_with_span(span);
let param_env = self.param_env;
let op = |infcx: &'_ _| {
let ocx = ObligationCtxt::new_in_snapshot(infcx);
let user_ty = ocx.normalize(cause.clone(), param_env, user_ty);
ocx.eq(&cause, param_env, user_ty, mir_ty)?;
if !ocx.select_all_or_error().is_empty() {
return Err(NoSolution);
}
Ok(InferOk { value: (), obligations: vec![] })
};

self.fully_perform_op(
Locations::All(span),
ConstraintCategory::Boring,
type_op::custom::CustomTypeOp::new(op, || "ascribe_user_type_skip_wf".to_string()),
)
.unwrap_or_else(|err| {
span_mirbug!(
self,
span,
"ascribe_user_type_skip_wf `{mir_ty:?}=={user_ty:?}` failed with `{err:?}`",
);
});
}
}
127 changes: 47 additions & 80 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,60 @@
use rustc_index::vec::Idx;
use rustc_infer::infer::LateBoundRegionConversionTime;
use rustc_middle::mir::*;
use rustc_middle::ty::Ty;
use rustc_middle::ty::{self, Ty};
use rustc_span::Span;

use crate::universal_regions::UniversalRegions;

use super::{Locations, TypeChecker};

impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
/// Check explicit closure signature annotation,
/// e.g., `|x: FxHashMap<_, &'static u32>| ...`.
#[instrument(skip(self), level = "debug")]
pub(super) fn check_signature_annotation(&mut self, body: &Body<'tcx>) {
let mir_def_id = body.source.def_id().expect_local();
if !self.tcx().is_closure(mir_def_id.to_def_id()) {
return;
}
let Some(user_provided_poly_sig) =
self.tcx().typeck(mir_def_id).user_provided_sigs.get(&mir_def_id)
else {
return;
};

// Instantiate the canonicalized variables from user-provided signature
// (e.g., the `_` in the code above) with fresh variables.
// Then replace the bound items in the fn sig with fresh variables,
// so that they represent the view from "inside" the closure.
let user_provided_sig = self
.instantiate_canonical_with_fresh_inference_vars(body.span, &user_provided_poly_sig);
let user_provided_sig = self.infcx.replace_bound_vars_with_fresh_vars(
body.span,
LateBoundRegionConversionTime::FnCall,
user_provided_sig,
);

for (&user_ty, arg_decl) in user_provided_sig.inputs().iter().zip(
// In MIR, closure args begin with an implicit `self`. Skip it!
body.args_iter().skip(1).map(|local| &body.local_decls[local]),
) {
self.ascribe_user_type_skip_wf(
arg_decl.ty,
ty::UserType::Ty(user_ty),
arg_decl.source_info.span,
);
}

// If the user explicitly annotated the output type, enforce it.
let output_decl = &body.local_decls[RETURN_PLACE];
self.ascribe_user_type_skip_wf(
output_decl.ty,
ty::UserType::Ty(user_provided_sig.output()),
output_decl.source_info.span,
);
}

#[instrument(skip(self, body, universal_regions), level = "debug")]
pub(super) fn equate_inputs_and_outputs(
&mut self,
Expand All @@ -31,40 +77,6 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
debug!(?normalized_output_ty);
debug!(?normalized_input_tys);

let mir_def_id = body.source.def_id().expect_local();

// If the user explicitly annotated the input types, extract
// those.
//
// e.g., `|x: FxHashMap<_, &'static u32>| ...`
let user_provided_sig;
if !self.tcx().is_closure(mir_def_id.to_def_id()) {
user_provided_sig = None;
} else {
let typeck_results = self.tcx().typeck(mir_def_id);
user_provided_sig =
typeck_results.user_provided_sigs.get(&mir_def_id).map(|user_provided_poly_sig| {
// Instantiate the canonicalized variables from
// user-provided signature (e.g., the `_` in the code
// above) with fresh variables.
let poly_sig = self.instantiate_canonical_with_fresh_inference_vars(
body.span,
&user_provided_poly_sig,
);

// Replace the bound items in the fn sig with fresh
// variables, so that they represent the view from
// "inside" the closure.
self.infcx.replace_bound_vars_with_fresh_vars(
body.span,
LateBoundRegionConversionTime::FnCall,
poly_sig,
)
});
}

debug!(?normalized_input_tys, ?body.local_decls);

// Equate expected input tys with those in the MIR.
for (argument_index, &normalized_input_ty) in normalized_input_tys.iter().enumerate() {
if argument_index + 1 >= body.local_decls.len() {
Expand All @@ -87,28 +99,6 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
);
}

if let Some(user_provided_sig) = user_provided_sig {
for (argument_index, &user_provided_input_ty) in
user_provided_sig.inputs().iter().enumerate()
{
// In MIR, closures begin an implicit `self`, so
// argument N is stored in local N+2.
let local = Local::new(argument_index + 2);
let mir_input_ty = body.local_decls[local].ty;
let mir_input_span = body.local_decls[local].source_info.span;

// If the user explicitly annotated the input types, enforce those.
let user_provided_input_ty =
self.normalize(user_provided_input_ty, Locations::All(mir_input_span));

self.equate_normalized_input_or_output(
user_provided_input_ty,
mir_input_ty,
mir_input_span,
);
}
}

debug!(
"equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
body.yield_ty(),
Expand Down Expand Up @@ -154,29 +144,6 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
terr
);
};

// If the user explicitly annotated the output types, enforce those.
// Note that this only happens for closures.
if let Some(user_provided_sig) = user_provided_sig {
let user_provided_output_ty = user_provided_sig.output();
let user_provided_output_ty =
self.normalize(user_provided_output_ty, Locations::All(output_span));
if let Err(err) = self.eq_types(
user_provided_output_ty,
mir_output_ty,
Locations::All(output_span),
ConstraintCategory::BoringNoLocation,
) {
span_mirbug!(
self,
Location::START,
"equate_inputs_and_outputs: `{:?}=={:?}` failed with `{:?}`",
mir_output_ty,
user_provided_output_ty,
err
);
}
}
}

#[instrument(skip(self), level = "debug")]
Expand Down
76 changes: 9 additions & 67 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use rustc_middle::ty::{
use rustc_span::def_id::CRATE_DEF_ID;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::VariantIdx;
use rustc_trait_selection::traits::query::type_op;
use rustc_trait_selection::traits::query::type_op::custom::scrape_region_constraints;
use rustc_trait_selection::traits::query::type_op::custom::CustomTypeOp;
use rustc_trait_selection::traits::query::type_op::{TypeOp, TypeOpOutput};
Expand Down Expand Up @@ -197,6 +196,8 @@ pub(crate) fn type_check<'mir, 'tcx>(
}

checker.equate_inputs_and_outputs(&body, universal_regions, &normalized_inputs_and_output);
checker.check_signature_annotation(&body);

liveness::generate(
&mut checker,
body,
Expand Down Expand Up @@ -391,23 +392,14 @@ impl<'a, 'b, 'tcx> Visitor<'tcx> for TypeVerifier<'a, 'b, 'tcx> {
check_err(self, promoted_body, ty, promoted_ty);
}
} else {
if let Err(terr) = self.cx.fully_perform_op(
locations,
ConstraintCategory::Boring,
self.cx.param_env.and(type_op::ascribe_user_type::AscribeUserType::new(
constant.literal.ty(),
self.cx.ascribe_user_type(
constant.literal.ty(),
UserType::TypeOf(
uv.def.did,
UserSubsts { substs: uv.substs, user_self_ty: None },
)),
) {
span_mirbug!(
self,
constant,
"bad constant type {:?} ({:?})",
constant,
terr
);
}
),
locations.span(&self.cx.body),
);
}
} else if let Some(static_def_id) = constant.check_static_ptr(tcx) {
let unnormalized_ty = tcx.type_of(static_def_id);
Expand Down Expand Up @@ -1041,58 +1033,8 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
debug!(?self.user_type_annotations);
for user_annotation in self.user_type_annotations {
let CanonicalUserTypeAnnotation { span, ref user_ty, inferred_ty } = *user_annotation;
let inferred_ty = self.normalize(inferred_ty, Locations::All(span));
let annotation = self.instantiate_canonical_with_fresh_inference_vars(span, user_ty);
debug!(?annotation);
match annotation {
UserType::Ty(mut ty) => {
ty = self.normalize(ty, Locations::All(span));

if let Err(terr) = self.eq_types(
ty,
inferred_ty,
Locations::All(span),
ConstraintCategory::BoringNoLocation,
) {
span_mirbug!(
self,
user_annotation,
"bad user type ({:?} = {:?}): {:?}",
ty,
inferred_ty,
terr
);
}

self.prove_predicate(
ty::Binder::dummy(ty::PredicateKind::WellFormed(inferred_ty.into())),
Locations::All(span),
ConstraintCategory::TypeAnnotation,
);
}
UserType::TypeOf(def_id, user_substs) => {
if let Err(terr) = self.fully_perform_op(
Locations::All(span),
ConstraintCategory::BoringNoLocation,
self.param_env.and(type_op::ascribe_user_type::AscribeUserType::new(
inferred_ty,
def_id,
user_substs,
)),
) {
span_mirbug!(
self,
user_annotation,
"bad user type AscribeUserType({:?}, {:?} {:?}, type_of={:?}): {:?}",
inferred_ty,
def_id,
user_substs,
self.tcx().type_of(def_id),
terr,
);
}
}
}
self.ascribe_user_type(inferred_ty, annotation, span);
}
}

Expand Down
Loading