Skip to content

Commit

Permalink
Pretty print AsyncFn traits too
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Nov 22, 2024
1 parent 7540306 commit f79ab27
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 17 deletions.
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/middle/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ impl<'tcx> TyCtxt<'tcx> {
}
}

/// Given a [`ty::ClosureKind`], get the [`DefId`] of its corresponding `Fn`-family
/// trait, if it is defined.
pub fn async_fn_trait_kind_to_def_id(self, kind: ty::ClosureKind) -> Option<DefId> {
let items = self.lang_items();
match kind {
ty::ClosureKind::Fn => items.async_fn_trait(),
ty::ClosureKind::FnMut => items.async_fn_mut_trait(),
ty::ClosureKind::FnOnce => items.async_fn_once_trait(),
}
}

/// Returns `true` if `id` is a `DefId` of [`Fn`], [`FnMut`] or [`FnOnce`] traits.
pub fn is_fn_trait(self, id: DefId) -> bool {
self.fn_trait_kind_from_def_id(id).is_some()
Expand Down
34 changes: 23 additions & 11 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,10 +993,8 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {

match bound_predicate.skip_binder() {
ty::ClauseKind::Trait(pred) => {
let trait_ref = bound_predicate.rebind(pred.trait_ref);

// Don't print `+ Sized`, but rather `+ ?Sized` if absent.
if tcx.is_lang_item(trait_ref.def_id(), LangItem::Sized) {
if tcx.is_lang_item(pred.def_id, LangItem::Sized) {
match pred.polarity {
ty::PredicatePolarity::Positive => {
has_sized_bound = true;
Expand Down Expand Up @@ -1040,12 +1038,15 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
// Insert parenthesis around (Fn(A, B) -> C) if the opaque ty has more than one other trait
let paren_needed = fn_traits.len() > 1 || traits.len() > 0 || !has_sized_bound;

for (bound_args, entry) in fn_traits {
for ((bound_args, is_async), entry) in fn_traits {
write!(self, "{}", if first { "" } else { " + " })?;
write!(self, "{}", if paren_needed { "(" } else { "" })?;

let trait_def_id =
tcx.fn_trait_kind_to_def_id(entry.kind).expect("expected Fn lang items");
let trait_def_id = if is_async {
tcx.async_fn_trait_kind_to_def_id(entry.kind).expect("expected AsyncFn lang items")
} else {
tcx.fn_trait_kind_to_def_id(entry.kind).expect("expected Fn lang items")
};

if let Some(return_ty) = entry.return_ty {
self.wrap_binder(&bound_args, |args, cx| {
Expand Down Expand Up @@ -1209,17 +1210,28 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
ty::PolyTraitPredicate<'tcx>,
FxIndexMap<DefId, ty::Binder<'tcx, Term<'tcx>>>,
>,
fn_traits: &mut FxIndexMap<ty::Binder<'tcx, &'tcx ty::List<Ty<'tcx>>>, OpaqueFnEntry<'tcx>>,
fn_traits: &mut FxIndexMap<
(ty::Binder<'tcx, &'tcx ty::List<Ty<'tcx>>>, bool),
OpaqueFnEntry<'tcx>,
>,
) {
let tcx = self.tcx();
let trait_def_id = trait_pred.def_id();

let fn_trait_and_async = if let Some(kind) = tcx.fn_trait_kind_from_def_id(trait_def_id) {
Some((kind, false))
} else if let Some(kind) = tcx.async_fn_trait_kind_from_def_id(trait_def_id) {
Some((kind, true))
} else {
None
};

if trait_pred.polarity() == ty::PredicatePolarity::Positive
&& let Some(kind) = tcx.fn_trait_kind_from_def_id(trait_def_id)
&& let Some((kind, is_async)) = fn_trait_and_async
&& let ty::Tuple(types) = *trait_pred.skip_binder().trait_ref.args.type_at(1).kind()
{
let entry = fn_traits
.entry(trait_pred.rebind(types))
.entry((trait_pred.rebind(types), is_async))
.or_insert_with(|| OpaqueFnEntry { kind, return_ty: None });
if kind.extends(entry.kind) {
entry.kind = kind;
Expand Down Expand Up @@ -3148,10 +3160,10 @@ define_print_and_forward_display! {

TraitRefPrintSugared<'tcx> {
if !with_reduced_queries()
&& let Some(kind) = cx.tcx().fn_trait_kind_from_def_id(self.0.def_id)
&& cx.tcx().trait_def(self.0.def_id).paren_sugar
&& let ty::Tuple(args) = self.0.args.type_at(1).kind()
{
p!(write("{}", kind.as_str()), "(");
p!(write("{}", cx.tcx().item_name(self.0.def_id)), "(");
for (i, arg) in args.iter().enumerate() {
if i > 0 {
p!(", ");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn()` is not satisfied
--> $DIR/fn-exception-target-features.rs:16:10
|
LL | test(target_feature);
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
| |
| required by a bound introduced by this call
|
Expand Down
8 changes: 4 additions & 4 deletions tests/ui/async-await/async-closures/fn-exception.stderr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn()` is not satisfied
--> $DIR/fn-exception.rs:19:10
|
LL | test(unsafety);
| ---- ^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
| ---- ^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
| |
| required by a bound introduced by this call
|
Expand All @@ -12,11 +12,11 @@ note: required by a bound in `test`
LL | fn test(f: impl async Fn()) {}
| ^^^^^^^^^^ required by this bound in `test`

error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn<()>` is not satisfied
error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn()` is not satisfied
--> $DIR/fn-exception.rs:20:10
|
LL | test(abi);
| ---- ^^^ the trait `AsyncFn<()>` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
| ---- ^^^ the trait `AsyncFn()` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
| |
| required by a bound introduced by this call
|
Expand Down
14 changes: 14 additions & 0 deletions tests/ui/async-await/async-closures/pretty-async-fn-opaque.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//@ edition: 2021

#![feature(async_closure)]

use std::ops::AsyncFnMut;

fn produce() -> impl AsyncFnMut() -> &'static str {
async || ""
}

fn main() {
let x: i32 = produce();
//~^ ERROR mismatched types
}
17 changes: 17 additions & 0 deletions tests/ui/async-await/async-closures/pretty-async-fn-opaque.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
error[E0308]: mismatched types
--> $DIR/pretty-async-fn-opaque.rs:12:18
|
LL | fn produce() -> impl AsyncFnMut() -> &'static str {
| --------------------------------- the found opaque type
...
LL | let x: i32 = produce();
| --- ^^^^^^^^^ expected `i32`, found opaque type
| |
| expected due to this
|
= note: expected type `i32`
found opaque type `impl AsyncFnMut() -> &'static str`

error: aborting due to 1 previous error

For more information about this error, try `rustc --explain E0308`.

0 comments on commit f79ab27

Please sign in to comment.