Skip to content

Commit

Permalink
Use SpanlessEq for in trait_bounds lints
Browse files Browse the repository at this point in the history
  • Loading branch information
y21 committed Oct 3, 2024
1 parent db1bda3 commit d6be597
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 81 deletions.
107 changes: 49 additions & 58 deletions clippy_lints/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ use clippy_utils::source::{SpanRangeExt, snippet, snippet_with_applicability};
use clippy_utils::{SpanlessEq, SpanlessHash, is_from_proc_macro};
use core::hash::{Hash, Hasher};
use itertools::Itertools;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, IndexEntry};
use rustc_data_structures::unhash::UnhashMap;
use rustc_errors::Applicability;
use rustc_hir::def::Res;
use rustc_hir::{
GenericArg, GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
TraitBoundModifier, TraitItem, TraitRef, Ty, TyKind, WherePredicate,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_session::impl_lint_pass;
use rustc_span::{BytePos, Span};
use std::collections::hash_map::Entry;

declare_clippy_lint! {
/// ### What it does
Expand Down Expand Up @@ -153,7 +152,10 @@ impl<'tcx> LateLintPass<'tcx> for TraitBounds {
.filter_map(get_trait_info_from_bound)
.for_each(|(trait_item_res, trait_item_segments, span)| {
if let Some(self_segments) = self_bounds_map.get(&trait_item_res) {
if SpanlessEq::new(cx).eq_path_segments(self_segments, trait_item_segments) {
if SpanlessEq::new(cx)
.paths_by_resolution()
.eq_path_segments(self_segments, trait_item_segments)
{
span_lint_and_help(
cx,
TRAIT_DUPLICATION_IN_BOUNDS,
Expand Down Expand Up @@ -302,7 +304,7 @@ impl TraitBounds {
}
}

fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_>) {
fn check_trait_bound_duplication<'tcx>(cx: &LateContext<'tcx>, generics: &'_ Generics<'tcx>) {
if generics.span.from_expansion() {
return;
}
Expand All @@ -314,6 +316,7 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
// |
// collects each of these where clauses into a set keyed by generic name and comparable trait
// eg. (T, Clone)
#[expect(clippy::mutable_key_type)]
let where_predicates = generics
.predicates
.iter()
Expand Down Expand Up @@ -367,11 +370,27 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
struct ComparableTraitRef(Res, Vec<Res>);
impl Default for ComparableTraitRef {
fn default() -> Self {
Self(Res::Err, Vec::new())
struct ComparableTraitRef<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
trait_ref: &'tcx TraitRef<'tcx>,
modifier: TraitBoundModifier,
}

impl PartialEq for ComparableTraitRef<'_, '_> {
fn eq(&self, other: &Self) -> bool {
self.modifier == other.modifier
&& SpanlessEq::new(self.cx)
.paths_by_resolution()
.eq_path(self.trait_ref.path, other.trait_ref.path)
}
}
impl Eq for ComparableTraitRef<'_, '_> {}
impl Hash for ComparableTraitRef<'_, '_> {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut s = SpanlessHash::new(self.cx).paths_by_resolution();
s.hash_path(self.trait_ref.path);
state.write_u64(s.finish());
self.modifier.hash(state);
}
}

Expand All @@ -392,69 +411,41 @@ fn get_trait_info_from_bound<'a>(bound: &'a GenericBound<'_>) -> Option<(Res, &'
}
}

fn get_ty_res(ty: Ty<'_>) -> Option<Res> {
match ty.kind {
TyKind::Path(QPath::Resolved(_, path)) => Some(path.res),
TyKind::Path(QPath::TypeRelative(ty, _)) => get_ty_res(*ty),
_ => None,
}
}

// FIXME: ComparableTraitRef does not support nested bounds needed for associated_type_bounds
fn into_comparable_trait_ref(trait_ref: &TraitRef<'_>) -> ComparableTraitRef {
ComparableTraitRef(
trait_ref.path.res,
trait_ref
.path
.segments
.iter()
.filter_map(|segment| {
// get trait bound type arguments
Some(segment.args?.args.iter().filter_map(|arg| {
if let GenericArg::Type(ty) = arg {
return get_ty_res(**ty);
}
None
}))
})
.flatten()
.collect(),
)
}

fn rollup_traits(
cx: &LateContext<'_>,
bounds: &[GenericBound<'_>],
fn rollup_traits<'cx, 'tcx>(
cx: &'cx LateContext<'tcx>,
bounds: &'tcx [GenericBound<'tcx>],
msg: &'static str,
) -> Vec<(ComparableTraitRef, Span)> {
let mut map = FxHashMap::default();
) -> Vec<(ComparableTraitRef<'cx, 'tcx>, Span)> {
// Source order is needed for joining spans
let mut map = FxIndexMap::default();
let mut repeated_res = false;

let only_comparable_trait_refs = |bound: &GenericBound<'_>| {
if let GenericBound::Trait(t, _) = bound {
Some((into_comparable_trait_ref(&t.trait_ref), t.span))
let only_comparable_trait_refs = |bound: &'tcx GenericBound<'tcx>| {
if let GenericBound::Trait(t, modifier) = bound {
Some((
ComparableTraitRef {
cx,
trait_ref: &t.trait_ref,
modifier: *modifier,
},
t.span,
))
} else {
None
}
};

let mut i = 0usize;
for bound in bounds.iter().filter_map(only_comparable_trait_refs) {
let (comparable_bound, span_direct) = bound;
match map.entry(comparable_bound) {
Entry::Occupied(_) => repeated_res = true,
Entry::Vacant(e) => {
e.insert((span_direct, i));
i += 1;
IndexEntry::Occupied(_) => repeated_res = true,
IndexEntry::Vacant(e) => {
e.insert(span_direct);
},
}
}

// Put bounds in source order
let mut comparable_bounds = vec![Default::default(); map.len()];
for (k, (v, i)) in map {
comparable_bounds[i] = (k, v);
}
let comparable_bounds: Vec<_> = map.into_iter().collect();

if repeated_res && let [first_trait, .., last_trait] = bounds {
let all_trait_span = first_trait.span().to(last_trait.span());
Expand Down
119 changes: 106 additions & 13 deletions clippy_utils/src/hir_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::tokenize_with_text;
use rustc_ast::ast::InlineAsmTemplatePiece;
use rustc_data_structures::fx::FxHasher;
use rustc_hir::MatchSource::TryDesugar;
use rustc_hir::def::Res;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::{
ArrayLen, AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, Closure, ConstArg, ConstArgKind, Expr,
ExprField, ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime,
Expand All @@ -17,11 +17,33 @@ use rustc_middle::ty::TypeckResults;
use rustc_span::{BytePos, ExpnKind, MacroKind, Symbol, SyntaxContext, sym};
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::slice;

/// Callback that is called when two expressions are not equal in the sense of `SpanlessEq`, but
/// other conditions would make them equal.
type SpanlessEqCallback<'a> = dyn FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a;

/// Determines how paths are hashed and compared for equality.
#[derive(Copy, Clone, Debug, Default)]
pub enum PathCheck {
/// Paths must match exactly and are hashed by their exact HIR tree.
///
/// Thus, `std::iter::Iterator` and `Iterator` are not considered equal even though they refer
/// to the same item.
#[default]
Exact,
/// Paths are compared and hashed based on their resolution.
///
/// They can appear different in the HIR tree but are still considered equal
/// and have equal hashes as long as they refer to the same item.
///
/// Note that this is currently only partially implemented specifically for paths that are
/// resolved before type-checking, i.e. the final segment must have a non-error resolution.
/// If a path with an error resolution is encountered, it falls back to the default exact
/// matching behavior.
Resolution,
}

/// Type used to check whether two ast are the same. This is different from the
/// operator `==` on ast types as this operator would compare true equality with
/// ID and span.
Expand All @@ -33,6 +55,7 @@ pub struct SpanlessEq<'a, 'tcx> {
maybe_typeck_results: Option<(&'tcx TypeckResults<'tcx>, &'tcx TypeckResults<'tcx>)>,
allow_side_effects: bool,
expr_fallback: Option<Box<SpanlessEqCallback<'a>>>,
path_check: PathCheck,
}

impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
Expand All @@ -42,6 +65,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
maybe_typeck_results: cx.maybe_typeck_results().map(|x| (x, x)),
allow_side_effects: true,
expr_fallback: None,
path_check: PathCheck::default(),
}
}

Expand All @@ -54,6 +78,16 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
}
}

/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
/// details.
#[must_use]
pub fn paths_by_resolution(self) -> Self {
Self {
path_check: PathCheck::Resolution,
..self
}
}

#[must_use]
pub fn expr_fallback(self, expr_fallback: impl FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a) -> Self {
Self {
Expand Down Expand Up @@ -498,7 +532,7 @@ impl HirEqInterExpr<'_, '_, '_> {
match (left.res, right.res) {
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
(Res::Local(_), _) | (_, Res::Local(_)) => false,
_ => over(left.segments, right.segments, |l, r| self.eq_path_segment(l, r)),
_ => self.eq_path_segments(left.segments, right.segments),
}
}

Expand All @@ -511,17 +545,39 @@ impl HirEqInterExpr<'_, '_, '_> {
}
}

pub fn eq_path_segments(&mut self, left: &[PathSegment<'_>], right: &[PathSegment<'_>]) -> bool {
left.len() == right.len() && left.iter().zip(right).all(|(l, r)| self.eq_path_segment(l, r))
pub fn eq_path_segments<'tcx>(
&mut self,
mut left: &'tcx [PathSegment<'tcx>],
mut right: &'tcx [PathSegment<'tcx>],
) -> bool {
if let PathCheck::Resolution = self.inner.path_check
&& let Some(left_seg) = generic_path_segments(left)
&& let Some(right_seg) = generic_path_segments(right)
{
// If we compare by resolution, then only check the last segments that could possibly have generic
// arguments
left = left_seg;
right = right_seg;
}

over(left, right, |l, r| self.eq_path_segment(l, r))
}

pub fn eq_path_segment(&mut self, left: &PathSegment<'_>, right: &PathSegment<'_>) -> bool {
// The == of idents doesn't work with different contexts,
// we have to be explicit about hygiene
left.ident.name == right.ident.name
&& both(left.args.as_ref(), right.args.as_ref(), |l, r| {
self.eq_path_parameters(l, r)
})
if !self.eq_path_parameters(left.args(), right.args()) {
return false;
}

if let PathCheck::Resolution = self.inner.path_check
&& left.res != Res::Err
&& right.res != Res::Err
{
left.res == right.res
} else {
// The == of idents doesn't work with different contexts,
// we have to be explicit about hygiene
left.ident.name == right.ident.name
}
}

pub fn eq_ty(&mut self, left: &Ty<'_>, right: &Ty<'_>) -> bool {
Expand Down Expand Up @@ -684,6 +740,21 @@ pub fn eq_expr_value(cx: &LateContext<'_>, left: &Expr<'_>, right: &Expr<'_>) ->
SpanlessEq::new(cx).deny_side_effects().eq_expr(left, right)
}

/// Returns the segments of a path that might have generic parameters.
/// Usually just the last segment for free items, except for when the path resolves to an associated
/// item, in which case it is the last two
fn generic_path_segments<'tcx>(segments: &'tcx [PathSegment<'tcx>]) -> Option<&'tcx [PathSegment<'tcx>]> {
match segments.last()?.res {
Res::Def(DefKind::AssocConst | DefKind::AssocFn | DefKind::AssocTy, _) => {
// <Ty as module::Trait<T>>::assoc::<U>
// ^^^^^^^^^^^^^^^^ ^^^^^^^^^^ segments: [module, Trait<T>, assoc<U>]
Some(&segments[segments.len().checked_sub(2)?..])
},
Res::Err => None,
_ => Some(slice::from_ref(segments.last()?)),
}
}

/// Type used to hash an ast element. This is different from the `Hash` trait
/// on ast types as this
/// trait would consider IDs and spans.
Expand All @@ -694,17 +765,29 @@ pub struct SpanlessHash<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
maybe_typeck_results: Option<&'tcx TypeckResults<'tcx>>,
s: FxHasher,
path_check: PathCheck,
}

impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
pub fn new(cx: &'a LateContext<'tcx>) -> Self {
Self {
cx,
maybe_typeck_results: cx.maybe_typeck_results(),
path_check: PathCheck::default(),
s: FxHasher::default(),
}
}

/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
/// details.
#[must_use]
pub fn paths_by_resolution(self) -> Self {
Self {
path_check: PathCheck::Resolution,
..self
}
}

pub fn finish(self) -> u64 {
self.s.finish()
}
Expand Down Expand Up @@ -1042,9 +1125,19 @@ impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
// even though the binding names are different and they have different `HirId`s.
Res::Local(_) => 1_usize.hash(&mut self.s),
_ => {
for seg in path.segments {
self.hash_name(seg.ident.name);
self.hash_generic_args(seg.args().args);
if let PathCheck::Resolution = self.path_check
&& let [.., last] = path.segments
&& let Some(segments) = generic_path_segments(path.segments)
{
for seg in segments {
self.hash_generic_args(seg.args().args);
}
last.res.hash(&mut self.s);
} else {
for seg in path.segments {
self.hash_name(seg.ident.name);
self.hash_generic_args(seg.args().args);
}
}
},
}
Expand Down
Loading

0 comments on commit d6be597

Please sign in to comment.