Skip to content

Commit

Permalink
Avoid nested quotations in auto-quoting fix
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Dec 17, 2023
1 parent a336c1b commit e666f4d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ class C:
x: DataFrame[
int
] = 1

def func() -> DataFrame[[DataFrame[_P, _R]], DataFrame[_P, _R]]:
...
14 changes: 14 additions & 0 deletions crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,17 @@ pub(crate) fn quote_annotation(
expr.range(),
))
}

/// Filter out any [`Edit`]s that are completely contained by any other [`Edit`].
pub(crate) fn filter_contained(edits: Vec<Edit>) -> Vec<Edit> {
let mut filtered: Vec<Edit> = Vec::with_capacity(edits.len());
for edit in edits {
if filtered
.iter()
.all(|filtered_edit| !filtered_edit.range().contains_range(edit.range()))
{
filtered.push(edit);
}
}
filtered
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, Fix, FixAvailability, Violation};
Expand All @@ -13,7 +12,7 @@ use crate::checkers::ast::Checker;
use crate::codes::Rule;
use crate::fix;
use crate::importer::ImportedMembers;
use crate::rules::flake8_type_checking::helpers::quote_annotation;
use crate::rules::flake8_type_checking::helpers::{filter_contained, quote_annotation};
use crate::rules::flake8_type_checking::imports::ImportBinding;

/// ## What it does
Expand Down Expand Up @@ -263,27 +262,29 @@ pub(crate) fn runtime_import_in_type_checking_block(

/// Generate a [`Fix`] to quote runtime usages for imports in a type-checking block.
fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result<Fix> {
let quote_reference_edits = imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.context().is_runtime() {
Some(quote_annotation(
reference.expression_id()?,
checker.semantic(),
checker.locator(),
checker.stylist(),
checker.generator(),
))
} else {
None
}
let quote_reference_edits = filter_contained(
imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.context().is_runtime() {
Some(quote_annotation(
reference.expression_id()?,
checker.semantic(),
checker.locator(),
checker.stylist(),
checker.generator(),
))
} else {
None
}
})
})
})
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>>>()?,
);

let mut rest = quote_reference_edits.into_iter().unique();
let mut rest = quote_reference_edits.into_iter();
let head = rest.next().expect("Expected at least one reference");
Ok(Fix::unsafe_edits(head, rest).isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;

use anyhow::Result;
use itertools::Itertools;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{Diagnostic, DiagnosticKind, Fix, FixAvailability, Violation};
Expand All @@ -13,7 +12,9 @@ use crate::checkers::ast::Checker;
use crate::codes::Rule;
use crate::fix;
use crate::importer::ImportedMembers;
use crate::rules::flake8_type_checking::helpers::{is_typing_reference, quote_annotation};
use crate::rules::flake8_type_checking::helpers::{
filter_contained, is_typing_reference, quote_annotation,
};
use crate::rules::flake8_type_checking::imports::ImportBinding;
use crate::rules::isort::{categorize, ImportSection, ImportType};

Expand Down Expand Up @@ -483,32 +484,34 @@ fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) ->
)?;

// Step 3) Quote any runtime usages of the referenced symbol.
let quote_reference_edits = imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.context().is_runtime() {
Some(quote_annotation(
reference.expression_id()?,
checker.semantic(),
checker.locator(),
checker.stylist(),
checker.generator(),
))
} else {
None
}
let quote_reference_edits = filter_contained(
imports
.iter()
.flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.context().is_runtime() {
Some(quote_annotation(
reference.expression_id()?,
checker.semantic(),
checker.locator(),
checker.stylist(),
checker.generator(),
))
} else {
None
}
})
})
})
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>>>()?,
);

Ok(Fix::unsafe_edits(
remove_import_edit,
add_import_edit
.into_edits()
.into_iter()
.chain(quote_reference_edits.into_iter().unique()),
.chain(quote_reference_edits),
)
.isolate(Checker::isolation(
checker.semantic().parent_statement_id(node_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ quote.py:78:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a typ
88 |- int
89 |- ] = 1
89 |+ x: "DataFrame[int]" = 1
90 90 |
91 |- def func() -> DataFrame[[DataFrame[_P, _R]], DataFrame[_P, _R]]:
91 |+ def func() -> "DataFrame[[DataFrame[_P, _R]], DataFrame[_P, _R]]":
92 92 | ...

quote.py:78:35: TCH002 [*] Move third-party import `pandas.Series` into a type-checking block
|
Expand Down Expand Up @@ -329,5 +333,9 @@ quote.py:78:35: TCH002 [*] Move third-party import `pandas.Series` into a type-c
88 |- int
89 |- ] = 1
89 |+ x: "DataFrame[int]" = 1
90 90 |
91 |- def func() -> DataFrame[[DataFrame[_P, _R]], DataFrame[_P, _R]]:
91 |+ def func() -> "DataFrame[[DataFrame[_P, _R]], DataFrame[_P, _R]]":
92 92 | ...


0 comments on commit e666f4d

Please sign in to comment.