Skip to content

Commit

Permalink
Expand expressions to include parentheses in E712
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Aug 17, 2023
1 parent d9bb51d commit 658b550
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 16 deletions.
6 changes: 6 additions & 0 deletions crates/ruff/resources/test/fixtures/pycodestyle/E712.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
if res == True != False:
pass

if(True) == TrueElement or x == TrueElement:
pass

if (yield i) == True:
print("even")

#: Okay
if x not in y:
pass
Expand Down
18 changes: 14 additions & 4 deletions crates/ruff/src/rules/pycodestyle/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use ruff_python_ast::{CmpOp, Expr, Ranged};
use ruff_text_size::{TextLen, TextRange};
use unicode_width::UnicodeWidthStr;

use ruff_python_ast::node::AnyNodeRef;
use ruff_python_ast::parenthesize::ParenthesizedExpression;
use ruff_python_ast::{CmpOp, Expr, Ranged};
use ruff_source_file::{Line, Locator};
use ruff_text_size::{TextLen, TextRange};

use crate::line_width::{LineLength, LineWidth, TabSize};

Expand All @@ -14,14 +16,17 @@ pub(super) fn generate_comparison(
left: &Expr,
ops: &[CmpOp],
comparators: &[Expr],
parent: AnyNodeRef,
locator: &Locator,
) -> String {
let start = left.start();
let end = comparators.last().map_or_else(|| left.end(), Ranged::end);
let mut contents = String::with_capacity(usize::from(end - start));

// Add the left side of the comparison.
contents.push_str(locator.slice(left.range()));
contents.push_str(locator.slice(
ParenthesizedExpression::from_expr(left.into(), parent, locator.contents()).range(),
));

for (op, comparator) in ops.iter().zip(comparators) {
// Add the operator.
Expand All @@ -39,7 +44,12 @@ pub(super) fn generate_comparison(
});

// Add the right side of the comparison.
contents.push_str(locator.slice(comparator.range()));
contents.push_str(
locator.slice(
ParenthesizedExpression::from_expr(comparator.into(), parent, locator.contents())
.range(),
),
);
}

contents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,13 @@ pub(crate) fn literal_comparisons(checker: &mut Checker, compare: &ast::ExprComp
.map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op))
.copied()
.collect::<Vec<_>>();
let content =
generate_comparison(&compare.left, &ops, &compare.comparators, checker.locator());
let content = generate_comparison(
&compare.left,
&ops,
&compare.comparators,
compare.into(),
checker.locator(),
);
for diagnostic in &mut diagnostics {
diagnostic.set_fix(Fix::suggested(Edit::range_replacement(
content.to_string(),
Expand Down
19 changes: 15 additions & 4 deletions crates/ruff/src/rules/pycodestyle/rules/not_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) {
ops,
comparators,
range: _,
}) = unary_op.operand.as_ref()
else {
}) = unary_op.operand.as_ref() else {
return;
};

Expand All @@ -94,7 +93,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) {
let mut diagnostic = Diagnostic::new(NotInTest, unary_op.operand.range());
if checker.patch(diagnostic.kind.rule()) {
diagnostic.set_fix(Fix::automatic(Edit::range_replacement(
generate_comparison(left, &[CmpOp::NotIn], comparators, checker.locator()),
generate_comparison(
left,
&[CmpOp::NotIn],
comparators,
unary_op.into(),
checker.locator(),
),
unary_op.range(),
)));
}
Expand All @@ -106,7 +111,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) {
let mut diagnostic = Diagnostic::new(NotIsTest, unary_op.operand.range());
if checker.patch(diagnostic.kind.rule()) {
diagnostic.set_fix(Fix::automatic(Edit::range_replacement(
generate_comparison(left, &[CmpOp::IsNot], comparators, checker.locator()),
generate_comparison(
left,
&[CmpOp::IsNot],
comparators,
unary_op.into(),
checker.locator(),
),
unary_op.range(),
)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ E712.py:22:5: E712 [*] Comparison to `True` should be `cond is True` or `if cond
20 20 | var = 1 if cond == True else -1 if cond == False else cond
21 21 | #: E712
22 |-if (True) == TrueElement or x == TrueElement:
22 |+if True is TrueElement or x == TrueElement:
22 |+if (True) is TrueElement or x == TrueElement:
23 23 | pass
24 24 |
25 25 | if res == True != False:
Expand All @@ -204,7 +204,7 @@ E712.py:25:11: E712 [*] Comparison to `True` should be `cond is True` or `if con
25 |+if res is True is not False:
26 26 | pass
27 27 |
28 28 | #: Okay
28 28 | if(True) == TrueElement or x == TrueElement:

E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or `if cond:`
|
Expand All @@ -224,6 +224,46 @@ E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or `
25 |+if res is True is not False:
26 26 | pass
27 27 |
28 28 | #: Okay
28 28 | if(True) == TrueElement or x == TrueElement:

E712.py:28:4: E712 [*] Comparison to `True` should be `cond is True` or `if cond:`
|
26 | pass
27 |
28 | if(True) == TrueElement or x == TrueElement:
| ^^^^ E712
29 | pass
|
= help: Replace with `cond is True`

Suggested fix
25 25 | if res == True != False:
26 26 | pass
27 27 |
28 |-if(True) == TrueElement or x == TrueElement:
28 |+if(True) is TrueElement or x == TrueElement:
29 29 | pass
30 30 |
31 31 | if (yield i) == True:

E712.py:31:17: E712 [*] Comparison to `True` should be `cond is True` or `if cond:`
|
29 | pass
30 |
31 | if (yield i) == True:
| ^^^^ E712
32 | print("even")
|
= help: Replace with `cond is True`

Suggested fix
28 28 | if(True) == TrueElement or x == TrueElement:
29 29 | pass
30 30 |
31 |-if (yield i) == True:
31 |+if (yield i) is True:
32 32 | print("even")
33 33 |
34 34 | #: Okay


1 change: 1 addition & 0 deletions crates/ruff_python_ast/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod identifier;
pub mod imports;
pub mod node;
mod nodes;
pub mod parenthesize;
pub mod relocate;
pub mod statement_visitor;
pub mod stmt_if;
Expand Down
66 changes: 66 additions & 0 deletions crates/ruff_python_ast/src/parenthesize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer};
use ruff_text_size::{TextRange, TextSize};
use std::ops::Sub;

use crate::node::AnyNodeRef;
use crate::Ranged;

/// A wrapper around an expression that may be parenthesized.
#[derive(Debug)]
pub struct ParenthesizedExpression<'a> {
/// The underlying AST node.
expr: AnyNodeRef<'a>,
/// The range of the expression including parentheses, if the expression is parenthesized;
/// or `None`, if the expression is not parenthesized.
range: Option<TextRange>,
}

impl<'a> ParenthesizedExpression<'a> {
/// Given an expression and its parent, returns a parenthesized expression.
pub fn from_expr(expr: AnyNodeRef<'a>, parent: AnyNodeRef<'a>, contents: &str) -> Self {
Self {
expr,
range: parenthesized_range(expr, parent, contents),
}
}

/// Returns `true` if the expression is parenthesized.
pub fn is_parenthesized(&self) -> bool {
self.range.is_some()
}
}

impl Ranged for ParenthesizedExpression<'_> {
fn range(&self) -> TextRange {
self.range.unwrap_or_else(|| self.expr.range())
}
}

/// Returns the [`TextRange`] of a given expression including parentheses, if the expression is
/// parenthesized; or `None`, if the expression is not parenthesized.
fn parenthesized_range(expr: AnyNodeRef, parent: AnyNodeRef, contents: &str) -> Option<TextRange> {
// If the parent is an `arguments` node, then the range of the expression includes the closing
// parenthesis, so exclude it from our test range.
let exclusive_parent_end = if parent.is_arguments() {
parent.end().sub(TextSize::new(1))
} else {
parent.end()
};

// First, test if there's a closing parenthesis because it tends to be cheaper.
let tokenizer =
SimpleTokenizer::new(contents, TextRange::new(expr.end(), exclusive_parent_end));
let right = tokenizer.skip_trivia().next()?;

if right.kind == SimpleTokenKind::RParen {
// Next, test for the opening parenthesis.
let mut tokenizer =
SimpleTokenizer::up_to_without_back_comment(expr.start(), contents).skip_trivia();
let left = tokenizer.next_back()?;
if left.kind == SimpleTokenKind::LParen {
return Some(TextRange::new(left.start(), right.end()));
}
}

None
}
3 changes: 0 additions & 3 deletions crates/ruff_python_ast/src/stmt_if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,3 @@ pub fn if_elif_branches(stmt_if: &StmtIf) -> impl Iterator<Item = IfElifBranch>
})
}))
}

#[cfg(test)]
mod test {}
80 changes: 80 additions & 0 deletions crates/ruff_python_ast/tests/parenthesize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use ruff_python_ast::parenthesize::ParenthesizedExpression;
use ruff_python_parser::parse_expression;

#[test]
fn test_parenthesized_name() {
let source_code = r#"(x) + 1"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let bin_op = expr.as_bin_op_expr().unwrap();
let name = bin_op.left.as_ref();

let parenthesized = ParenthesizedExpression::from_expr(name.into(), bin_op.into(), source_code);
assert!(parenthesized.is_parenthesized());
}

#[test]
fn test_un_parenthesized_name() {
let source_code = r#"x + 1"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let bin_op = expr.as_bin_op_expr().unwrap();
let name = bin_op.left.as_ref();

let parenthesized = ParenthesizedExpression::from_expr(name.into(), bin_op.into(), source_code);
assert!(!parenthesized.is_parenthesized());
}

#[test]
fn test_parenthesized_argument() {
let source_code = r#"f((a))"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let call = expr.as_call_expr().unwrap();
let arguments = &call.arguments;
let argument = arguments.args.first().unwrap();

let parenthesized =
ParenthesizedExpression::from_expr(argument.into(), arguments.into(), source_code);
assert!(parenthesized.is_parenthesized());
}

#[test]
fn test_unparenthesized_argument() {
let source_code = r#"f(a)"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let call = expr.as_call_expr().unwrap();
let arguments = &call.arguments;
let argument = arguments.args.first().unwrap();

let parenthesized =
ParenthesizedExpression::from_expr(argument.into(), arguments.into(), source_code);
assert!(!parenthesized.is_parenthesized());
}

#[test]
fn test_parenthesized_tuple_member() {
let source_code = r#"(a, (b))"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let tuple = expr.as_tuple_expr().unwrap();
let member = tuple.elts.last().unwrap();

let parenthesized =
ParenthesizedExpression::from_expr(member.into(), tuple.into(), source_code);
assert!(parenthesized.is_parenthesized());
}

#[test]
fn test_unparenthesized_tuple_member() {
let source_code = r#"(a, b)"#;
let expr = parse_expression(source_code, "<filename>").unwrap();

let tuple = expr.as_tuple_expr().unwrap();
let member = tuple.elts.last().unwrap();

let parenthesized =
ParenthesizedExpression::from_expr(member.into(), tuple.into(), source_code);
assert!(!parenthesized.is_parenthesized());
}

0 comments on commit 658b550

Please sign in to comment.