Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretty print async fn sugar in opaques and trait bounds #132911

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/middle/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ impl<'tcx> TyCtxt<'tcx> {
}
}

/// Given a [`ty::ClosureKind`], get the [`DefId`] of its corresponding `Fn`-family
/// trait, if it is defined.
pub fn async_fn_trait_kind_to_def_id(self, kind: ty::ClosureKind) -> Option<DefId> {
let items = self.lang_items();
match kind {
ty::ClosureKind::Fn => items.async_fn_trait(),
ty::ClosureKind::FnMut => items.async_fn_mut_trait(),
ty::ClosureKind::FnOnce => items.async_fn_once_trait(),
}
}

/// Returns `true` if `id` is a `DefId` of [`Fn`], [`FnMut`] or [`FnOnce`] traits.
pub fn is_fn_trait(self, id: DefId) -> bool {
self.fn_trait_kind_from_def_id(id).is_some()
Expand Down
203 changes: 85 additions & 118 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use rustc_hir::definitions::{DefKey, DefPathDataName};
use rustc_macros::{Lift, extension};
use rustc_session::Limit;
use rustc_session::cstore::{ExternCrate, ExternCrateSource};
use rustc_span::FileNameDisplayPreference;
use rustc_span::symbol::{Ident, Symbol, kw};
use rustc_span::{FileNameDisplayPreference, sym};
use rustc_type_ir::{Upcast as _, elaborate};
use smallvec::SmallVec;

Expand All @@ -26,8 +26,8 @@ use super::*;
use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
use crate::query::{IntoQueryParam, Providers};
use crate::ty::{
ConstInt, Expr, GenericArgKind, ParamConst, ScalarInt, Term, TermKind, TypeFoldable,
TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
ConstInt, Expr, GenericArgKind, ParamConst, ScalarInt, Term, TermKind, TraitPredicate,
TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
};

macro_rules! p {
Expand Down Expand Up @@ -993,10 +993,8 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {

match bound_predicate.skip_binder() {
ty::ClauseKind::Trait(pred) => {
let trait_ref = bound_predicate.rebind(pred.trait_ref);

// Don't print `+ Sized`, but rather `+ ?Sized` if absent.
if tcx.is_lang_item(trait_ref.def_id(), LangItem::Sized) {
if tcx.is_lang_item(pred.def_id(), LangItem::Sized) {
match pred.polarity {
ty::PredicatePolarity::Positive => {
has_sized_bound = true;
Expand All @@ -1007,24 +1005,22 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
}

self.insert_trait_and_projection(
trait_ref,
pred.polarity,
bound_predicate.rebind(pred),
Copy link
Member

@fmease fmease Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a nice consequence, now the

let trait_ref = bound_predicate.rebind(pred.trait_ref);
// ...
if tcx.is_lang_item(trait_ref.def_id(), LangItem::Sized) {

a few lines above can be simplified to

// ...
if tcx.is_lang_item(pred.def_id(), LangItem::Sized) {

(GH doesn't let me comment directly on those lines)

None,
&mut traits,
&mut fn_traits,
);
}
ty::ClauseKind::Projection(pred) => {
let proj_ref = bound_predicate.rebind(pred);
let trait_ref = proj_ref.required_poly_trait_ref(tcx);

// Projection type entry -- the def-id for naming, and the ty.
let proj_ty = (proj_ref.projection_def_id(), proj_ref.term());
let proj = bound_predicate.rebind(pred);
let trait_ref = proj.map_bound(|proj| TraitPredicate {
trait_ref: proj.projection_term.trait_ref(tcx),
polarity: ty::PredicatePolarity::Positive,
});

self.insert_trait_and_projection(
trait_ref,
ty::PredicatePolarity::Positive,
Some(proj_ty),
Some((proj.projection_def_id(), proj.term())),
&mut traits,
&mut fn_traits,
);
Expand All @@ -1042,88 +1038,66 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
// Insert parenthesis around (Fn(A, B) -> C) if the opaque ty has more than one other trait
let paren_needed = fn_traits.len() > 1 || traits.len() > 0 || !has_sized_bound;

for (fn_once_trait_ref, entry) in fn_traits {
for ((bound_args, is_async), entry) in fn_traits {
write!(self, "{}", if first { "" } else { " + " })?;
write!(self, "{}", if paren_needed { "(" } else { "" })?;

self.wrap_binder(&fn_once_trait_ref, |trait_ref, cx| {
define_scoped_cx!(cx);
// Get the (single) generic ty (the args) of this FnOnce trait ref.
let generics = tcx.generics_of(trait_ref.def_id);
let own_args = generics.own_args_no_defaults(tcx, trait_ref.args);

match (entry.return_ty, own_args[0].expect_ty()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have noooooooOOoOooooooo idea wtf was going on here lmao

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, something about this part was load-bearing, cf. #133597

// We can only print `impl Fn() -> ()` if we have a tuple of args and we recorded
// a return type.
(Some(return_ty), arg_tys) if matches!(arg_tys.kind(), ty::Tuple(_)) => {
let name = if entry.fn_trait_ref.is_some() {
"Fn"
} else if entry.fn_mut_trait_ref.is_some() {
"FnMut"
} else {
"FnOnce"
};

p!(write("{}(", name));
let trait_def_id = if is_async {
tcx.async_fn_trait_kind_to_def_id(entry.kind).expect("expected AsyncFn lang items")
} else {
tcx.fn_trait_kind_to_def_id(entry.kind).expect("expected Fn lang items")
};

for (idx, ty) in arg_tys.tuple_fields().iter().enumerate() {
if idx > 0 {
p!(", ");
}
p!(print(ty));
}
if let Some(return_ty) = entry.return_ty {
self.wrap_binder(&bound_args, |args, cx| {
define_scoped_cx!(cx);
p!(write("{}", tcx.item_name(trait_def_id)));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could very easily start making this print like async Fn in the future.

p!("(");

p!(")");
if let Some(ty) = return_ty.skip_binder().as_type() {
if !ty.is_unit() {
p!(" -> ", print(return_ty));
}
for (idx, ty) in args.iter().enumerate() {
if idx > 0 {
p!(", ");
}
p!(write("{}", if paren_needed { ")" } else { "" }));

first = false;
p!(print(ty));
}
// If we got here, we can't print as a `impl Fn(A, B) -> C`. Just record the
// trait_refs we collected in the OpaqueFnEntry as normal trait refs.
_ => {
if entry.has_fn_once {
traits
.entry((fn_once_trait_ref, ty::PredicatePolarity::Positive))
.or_default()
.extend(
// Group the return ty with its def id, if we had one.
entry.return_ty.map(|ty| {
(tcx.require_lang_item(LangItem::FnOnceOutput, None), ty)
}),
);
}
if let Some(trait_ref) = entry.fn_mut_trait_ref {
traits.entry((trait_ref, ty::PredicatePolarity::Positive)).or_default();
}
if let Some(trait_ref) = entry.fn_trait_ref {
traits.entry((trait_ref, ty::PredicatePolarity::Positive)).or_default();

p!(")");
if let Some(ty) = return_ty.skip_binder().as_type() {
if !ty.is_unit() {
p!(" -> ", print(return_ty));
}
}
}
p!(write("{}", if paren_needed { ")" } else { "" }));

Ok(())
})?;
first = false;
Ok(())
})?;
} else {
// Otherwise, render this like a regular trait.
traits.insert(
bound_args.map_bound(|args| ty::TraitPredicate {
polarity: ty::PredicatePolarity::Positive,
trait_ref: ty::TraitRef::new(tcx, trait_def_id, [Ty::new_tup(tcx, args)]),
Copy link
Member

@fmease fmease Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is missing the self ty.

}),
FxIndexMap::default(),
);
}
}

// Print the rest of the trait types (that aren't Fn* family of traits)
for ((trait_ref, polarity), assoc_items) in traits {
for (trait_pred, assoc_items) in traits {
write!(self, "{}", if first { "" } else { " + " })?;

self.wrap_binder(&trait_ref, |trait_ref, cx| {
self.wrap_binder(&trait_pred, |trait_pred, cx| {
define_scoped_cx!(cx);

if polarity == ty::PredicatePolarity::Negative {
if trait_pred.polarity == ty::PredicatePolarity::Negative {
p!("!");
}
p!(print(trait_ref.print_only_trait_name()));
p!(print(trait_pred.trait_ref.print_only_trait_name()));

let generics = tcx.generics_of(trait_ref.def_id);
let own_args = generics.own_args_no_defaults(tcx, trait_ref.args);
let generics = tcx.generics_of(trait_pred.def_id());
let own_args = generics.own_args_no_defaults(tcx, trait_pred.trait_ref.args);

if !own_args.is_empty() || !assoc_items.is_empty() {
let mut first = true;
Expand Down Expand Up @@ -1230,51 +1204,48 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
/// traits map or fn_traits map, depending on if the trait is in the Fn* family of traits.
fn insert_trait_and_projection(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty disappointed with how sensitive to bound var renumbering this is... but 🤷

&mut self,
trait_ref: ty::PolyTraitRef<'tcx>,
polarity: ty::PredicatePolarity,
trait_pred: ty::PolyTraitPredicate<'tcx>,
proj_ty: Option<(DefId, ty::Binder<'tcx, Term<'tcx>>)>,
traits: &mut FxIndexMap<
(ty::PolyTraitRef<'tcx>, ty::PredicatePolarity),
ty::PolyTraitPredicate<'tcx>,
FxIndexMap<DefId, ty::Binder<'tcx, Term<'tcx>>>,
>,
fn_traits: &mut FxIndexMap<ty::PolyTraitRef<'tcx>, OpaqueFnEntry<'tcx>>,
fn_traits: &mut FxIndexMap<
(ty::Binder<'tcx, &'tcx ty::List<Ty<'tcx>>>, bool),
OpaqueFnEntry<'tcx>,
>,
) {
let trait_def_id = trait_ref.def_id();

// If our trait_ref is FnOnce or any of its children, project it onto the parent FnOnce
// super-trait ref and record it there.
// We skip negative Fn* bounds since they can't use parenthetical notation anyway.
if polarity == ty::PredicatePolarity::Positive
&& let Some(fn_once_trait) = self.tcx().lang_items().fn_once_trait()
{
// If we have a FnOnce, then insert it into
if trait_def_id == fn_once_trait {
let entry = fn_traits.entry(trait_ref).or_default();
// Optionally insert the return_ty as well.
if let Some((_, ty)) = proj_ty {
entry.return_ty = Some(ty);
}
entry.has_fn_once = true;
return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) {
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap();
let tcx = self.tcx();
let trait_def_id = trait_pred.def_id();

fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) {
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap();
let fn_trait_and_async = if let Some(kind) = tcx.fn_trait_kind_from_def_id(trait_def_id) {
Some((kind, false))
} else if let Some(kind) = tcx.async_fn_trait_kind_from_def_id(trait_def_id) {
Some((kind, true))
} else {
None
};

fn_traits.entry(super_trait_ref).or_default().fn_trait_ref = Some(trait_ref);
return;
if trait_pred.polarity() == ty::PredicatePolarity::Positive
&& let Some((kind, is_async)) = fn_trait_and_async
&& let ty::Tuple(types) = *trait_pred.skip_binder().trait_ref.args.type_at(1).kind()
{
let entry = fn_traits
.entry((trait_pred.rebind(types), is_async))
.or_insert_with(|| OpaqueFnEntry { kind, return_ty: None });
if kind.extends(entry.kind) {
entry.kind = kind;
}
if let Some((proj_def_id, proj_ty)) = proj_ty
&& tcx.item_name(proj_def_id) == sym::Output
{
entry.return_ty = Some(proj_ty);
}
return;
}

// Otherwise, just group our traits and projection types.
traits.entry((trait_ref, polarity)).or_default().extend(proj_ty);
traits.entry(trait_pred).or_default().extend(proj_ty);
}

fn pretty_print_inherent_projection(
Expand Down Expand Up @@ -3189,10 +3160,10 @@ define_print_and_forward_display! {

TraitRefPrintSugared<'tcx> {
if !with_reduced_queries()
&& let Some(kind) = cx.tcx().fn_trait_kind_from_def_id(self.0.def_id)
&& cx.tcx().trait_def(self.0.def_id).paren_sugar
&& let ty::Tuple(args) = self.0.args.type_at(1).kind()
{
p!(write("{}", kind.as_str()), "(");
p!(write("{}", cx.tcx().item_name(self.0.def_id)), "(");
for (i, arg) in args.iter().enumerate() {
if i > 0 {
p!(", ");
Expand Down Expand Up @@ -3415,11 +3386,7 @@ pub fn provide(providers: &mut Providers) {
*providers = Providers { trimmed_def_paths, ..*providers };
}

#[derive(Default)]
pub struct OpaqueFnEntry<'tcx> {
// The trait ref is already stored as a key, so just track if we have it as a real predicate
has_fn_once: bool,
fn_mut_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
fn_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
kind: ty::ClosureKind,
return_ty: Option<ty::Binder<'tcx, Term<'tcx>>>,
}
13 changes: 0 additions & 13 deletions compiler/rustc_type_ir/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,19 +684,6 @@ impl<I: Interner> ty::Binder<I, ProjectionPredicate<I>> {
self.skip_binder().projection_term.trait_def_id(cx)
}

/// Get the trait ref required for this projection to be well formed.
/// Note that for generic associated types the predicates of the associated
/// type also need to be checked.
#[inline]
pub fn required_poly_trait_ref(&self, cx: I) -> ty::Binder<I, TraitRef<I>> {
// Note: unlike with `TraitRef::to_poly_trait_ref()`,
// `self.0.trait_ref` is permitted to have escaping regions.
// This is because here `self` has a `Binder` and so does our
// return value, so we are preserving the number of binding
// levels.
self.map_bound(|predicate| predicate.projection_term.trait_ref(cx))
}

pub fn term(&self) -> ty::Binder<I, I::Term> {
self.map_bound(|predicate| predicate.term)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn()` is not satisfied
--> $DIR/fn-exception-target-features.rs:16:10
|
LL | test(target_feature);
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
| |
| required by a bound introduced by this call
|
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/async-await/async-closures/fn-exception.stderr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn()` is not satisfied
--> $DIR/fn-exception.rs:19:10
|
LL | test(unsafety);
| ---- ^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
| ---- ^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
| |
| required by a bound introduced by this call
|
Expand All @@ -12,11 +12,11 @@ note: required by a bound in `test`
LL | fn test(f: impl async Fn()) {}
| ^^^^^^^^^^ required by this bound in `test`

error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn()` is not satisfied
--> $DIR/fn-exception.rs:20:10
|
LL | test(abi);
| ---- ^^^ the trait `AsyncFn<()>` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
| ---- ^^^ the trait `AsyncFn()` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
| |
| required by a bound introduced by this call
|
Expand Down
14 changes: 14 additions & 0 deletions tests/ui/async-await/async-closures/pretty-async-fn-opaque.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//@ edition: 2021

#![feature(async_closure)]

use std::ops::AsyncFnMut;

fn produce() -> impl AsyncFnMut() -> &'static str {
async || ""
}

fn main() {
let x: i32 = produce();
//~^ ERROR mismatched types
}
Loading
Loading