Skip to content

Commit

Permalink
implement type_implments_trait query
Browse files Browse the repository at this point in the history
  • Loading branch information
csmoe committed May 15, 2020
1 parent a1104b4 commit 10d7da4
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/librustc_middle/hir/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ impl<'hir> Map<'hir> {
/// Given a `HirId`, returns the `BodyId` associated with it,
/// if the node is a body owner, otherwise returns `None`.
pub fn maybe_body_owned_by(&self, hir_id: HirId) -> Option<BodyId> {
if let Some(node) = self.find(hir_id) { associated_body(node) } else { None }
self.find(hir_id).map(associated_body).flatten()
}

/// Given a body owner's id, returns the `BodyId` associated with it.
Expand Down
6 changes: 6 additions & 0 deletions src/librustc_middle/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,12 @@ rustc_queries! {
desc { "evaluating trait selection obligation `{}`", goal.value }
}

query type_implements_trait(
key: (DefId, Ty<'tcx>, SubstsRef<'tcx>, ty::ParamEnv<'tcx>, )
) -> bool {
desc { "evaluating `type_implements_trait` `{:?}`", key }
}

/// Do not call this query directly: part of the `Eq` type-op
query type_op_ascribe_user_type(
goal: CanonicalTypeOpAscribeUserTypeGoal<'tcx>
Expand Down
12 changes: 12 additions & 0 deletions src/librustc_middle/ty/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,15 @@ impl Key for (Symbol, u32, u32) {
DUMMY_SP
}
}

impl<'tcx> Key for (DefId, Ty<'tcx>, SubstsRef<'tcx>, ty::ParamEnv<'tcx>) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
LOCAL_CRATE
}

fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
DUMMY_SP
}
}
37 changes: 19 additions & 18 deletions src/librustc_trait_selection/traits/error_reporting/suggestions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::Visitor;
use rustc_hir::lang_items;
use rustc_hir::{AsyncGeneratorKind, GeneratorKind, Node};
use rustc_middle::ty::TypeckTables;
use rustc_middle::ty::{
Expand Down Expand Up @@ -1785,37 +1786,37 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
span: Span,
) {
debug!(
"suggest_await_befor_try: obligation={:?}, span={:?}, trait_ref={:?}",
obligation, span, trait_ref
"suggest_await_befor_try: obligation={:?}, span={:?}, trait_ref={:?}, trait_ref_self_ty={:?}",
obligation,
span,
trait_ref,
trait_ref.self_ty()
);
let body_hir_id = obligation.cause.body_id;
let item_id = self.tcx.hir().get_parent_node(body_hir_id);

let mut is_future = false;
if let ty::Opaque(def_id, substs) = trait_ref.self_ty().kind {
let preds = self.tcx.predicates_of(def_id).instantiate(self.tcx, substs);
for p in preds.predicates {
if let Some(trait_ref) = p.to_opt_poly_trait_ref() {
if Some(trait_ref.def_id()) == self.tcx.lang_items().future_trait() {
is_future = true;
break;
}
}
}
}

if let Some(body_id) = self.tcx.hir().maybe_body_owned_by(item_id) {
let body = self.tcx.hir().body(body_id);
if let Some(hir::GeneratorKind::Async(_)) = body.generator_kind {
let future_trait = self.tcx.lang_items().future_trait().unwrap();
let future_trait =
self.tcx.require_lang_item(lang_items::FutureTraitLangItem, None);

let self_ty = self.resolve_vars_if_possible(&trait_ref.self_ty());

let impls_future = self.tcx.type_implements_trait((
future_trait,
self_ty,
ty::List::empty(),
obligation.param_env,
));

let item_def_id = self
.tcx
.associated_items(future_trait)
.in_definition_order()
.next()
.unwrap()
.def_id;
debug!("trait_ref_self_ty: {:?}", trait_ref.self_ty());
// `<T as Future>::Output`
let projection_ty = ty::ProjectionTy {
// `T`
Expand Down Expand Up @@ -1850,7 +1851,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
obligation.param_env,
);
debug!("suggest_await_befor_try: try_trait_obligation {:?}", try_obligation);
if self.predicate_may_hold(&try_obligation) && is_future {
if self.predicate_may_hold(&try_obligation) && impls_future {
if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) {
if snippet.ends_with('?') {
err.span_suggestion(
Expand Down
42 changes: 41 additions & 1 deletion src/librustc_trait_selection/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ use rustc_hir::def_id::DefId;
use rustc_middle::middle::region;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::{InternalSubsts, SubstsRef};
use rustc_middle::ty::{self, GenericParamDefKind, ToPredicate, Ty, TyCtxt, WithConstness};
use rustc_middle::ty::{
self, GenericParamDefKind, ParamEnv, ToPredicate, Ty, TyCtxt, WithConstness,
};
use rustc_span::Span;

use std::fmt::Debug;
Expand Down Expand Up @@ -523,6 +525,43 @@ fn vtable_methods<'tcx>(
}))
}

/// Check whether a `ty` implements given trait(trait_def_id).
///
/// NOTE: Always return `false` for a type which needs inference.
fn type_implements_trait<'tcx>(
tcx: TyCtxt<'tcx>,
key: (
DefId, // trait_def_id,
Ty<'tcx>, // type
SubstsRef<'tcx>,
ParamEnv<'tcx>,
),
) -> bool {
let (trait_def_id, ty, params, param_env) = key;

debug!(
"type_implements_trait: trait_def_id={:?}, type={:?}, params={:?}, param_env={:?}",
trait_def_id, ty, params, param_env
);

// Do not check on infer_types to avoid panic in evaluate_obligation.
if ty.has_infer_types() {
return false;
}

let ty = tcx.erase_regions(&ty);

let trait_ref = ty::TraitRef { def_id: trait_def_id, substs: tcx.mk_substs_trait(ty, params) };

let obligation = Obligation {
cause: ObligationCause::dummy(),
param_env,
recursion_depth: 0,
predicate: trait_ref.without_const().to_predicate(),
};
tcx.infer_ctxt().enter(|infcx| infcx.predicate_must_hold_modulo_regions(&obligation))
}

pub fn provide(providers: &mut ty::query::Providers<'_>) {
object_safety::provide(providers);
*providers = ty::query::Providers {
Expand All @@ -531,6 +570,7 @@ pub fn provide(providers: &mut ty::query::Providers<'_>) {
codegen_fulfill_obligation: codegen::codegen_fulfill_obligation,
vtable_methods,
substitute_normalize_and_test_predicates,
type_implements_trait,
..*providers
};
}
20 changes: 20 additions & 0 deletions src/test/ui/async-await/issue-61076.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
// edition:2018

use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};

struct T;

impl Future for T {
type Output = Result<(), ()>;

fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}

async fn foo() -> Result<(), ()> {
Ok(())
}
Expand All @@ -9,4 +23,10 @@ async fn bar() -> Result<(), ()> {
Ok(())
}

async fn baz() -> Result<(), ()> {
let t = T;
t?; //~ ERROR the `?` operator can only be applied to values that implement `std::ops::Try`
Ok(())
}

fn main() {}
16 changes: 14 additions & 2 deletions src/test/ui/async-await/issue-61076.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error[E0277]: the `?` operator can only be applied to values that implement `std::ops::Try`
--> $DIR/issue-61076.rs:8:5
--> $DIR/issue-61076.rs:22:5
|
LL | foo()?;
| ^^^^^^
Expand All @@ -10,6 +10,18 @@ LL | foo()?;
= help: the trait `std::ops::Try` is not implemented for `impl std::future::Future`
= note: required by `std::ops::Try::into_result`

error: aborting due to previous error
error[E0277]: the `?` operator can only be applied to values that implement `std::ops::Try`
--> $DIR/issue-61076.rs:28:5
|
LL | t?;
| ^^
| |
| the `?` operator cannot be applied to type `T`
| help: consider using `.await` here: `t.await?`
|
= help: the trait `std::ops::Try` is not implemented for `T`
= note: required by `std::ops::Try::into_result`

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0277`.
3 changes: 2 additions & 1 deletion src/test/ui/async-await/try-on-option-in-async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ async fn an_async_block() -> u32 {
let x: Option<u32> = None;
x?; //~ ERROR the `?` operator
22
}.await
}
.await
}

async fn async_closure_containing_fn() -> u32 {
Expand Down
6 changes: 3 additions & 3 deletions src/test/ui/async-await/try-on-option-in-async.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ LL | | let x: Option<u32> = None;
LL | | x?;
| | ^^ cannot use the `?` operator in an async block that returns `{integer}`
LL | | 22
LL | | }.await
LL | | }
| |_____- this function should return `Result` or `Option` to accept `?`
|
= help: the trait `std::ops::Try` is not implemented for `{integer}`
= note: required by `std::ops::Try::from_error`

error[E0277]: the `?` operator can only be used in an async closure that returns `Result` or `Option` (or another type that implements `std::ops::Try`)
--> $DIR/try-on-option-in-async.rs:16:9
--> $DIR/try-on-option-in-async.rs:17:9
|
LL | let async_closure = async || {
| __________________________________-
Expand All @@ -29,7 +29,7 @@ LL | | };
= note: required by `std::ops::Try::from_error`

error[E0277]: the `?` operator can only be used in an async function that returns `Result` or `Option` (or another type that implements `std::ops::Try`)
--> $DIR/try-on-option-in-async.rs:25:5
--> $DIR/try-on-option-in-async.rs:26:5
|
LL | async fn an_async_function() -> u32 {
| _____________________________________-
Expand Down
18 changes: 2 additions & 16 deletions src/tools/clippy/clippy_lints/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@ use rustc_hir::{
use rustc_infer::infer::TyCtxtInferExt;
use rustc_lint::{LateContext, Level, Lint, LintContext};
use rustc_middle::hir::map::Map;
use rustc_middle::traits;
use rustc_middle::ty::{self, layout::IntegerExt, subst::GenericArg, Binder, Ty, TyCtxt, TypeFoldable};
use rustc_span::hygiene::{ExpnKind, MacroKind};
use rustc_span::source_map::original_sp;
use rustc_span::symbol::{self, kw, Symbol};
use rustc_span::{BytePos, Pos, Span, DUMMY_SP};
use rustc_target::abi::Integer;
use rustc_trait_selection::traits::predicate_for_trait_def;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::query::normalize::AtExt;
use smallvec::SmallVec;

Expand Down Expand Up @@ -326,19 +323,8 @@ pub fn implements_trait<'a, 'tcx>(
trait_id: DefId,
ty_params: &[GenericArg<'tcx>],
) -> bool {
let ty = cx.tcx.erase_regions(&ty);
let obligation = predicate_for_trait_def(
cx.tcx,
cx.param_env,
traits::ObligationCause::dummy(),
trait_id,
0,
ty,
ty_params,
);
cx.tcx
.infer_ctxt()
.enter(|infcx| infcx.predicate_must_hold_modulo_regions(&obligation))
let ty_params = cx.tcx.mk_substs(ty_params.iter());
cx.tcx.type_implements_trait((trait_id, ty, ty_params, cx.param_env))
}

/// Gets the `hir::TraitRef` of the trait the given method is implemented for.
Expand Down

0 comments on commit 10d7da4

Please sign in to comment.