From 658b550a6a39a6f9a858174be939dbe288a5dcc2 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 14 Aug 2023 18:29:59 -0400 Subject: [PATCH] Expand expressions to include parentheses in E712 --- .../test/fixtures/pycodestyle/E712.py | 6 ++ crates/ruff/src/rules/pycodestyle/helpers.rs | 18 ++++- .../pycodestyle/rules/literal_comparisons.rs | 9 ++- .../src/rules/pycodestyle/rules/not_tests.rs | 19 ++++- ...les__pycodestyle__tests__E712_E712.py.snap | 46 ++++++++++- crates/ruff_python_ast/src/lib.rs | 1 + crates/ruff_python_ast/src/parenthesize.rs | 66 +++++++++++++++ crates/ruff_python_ast/src/stmt_if.rs | 3 - crates/ruff_python_ast/tests/parenthesize.rs | 80 +++++++++++++++++++ 9 files changed, 232 insertions(+), 16 deletions(-) create mode 100644 crates/ruff_python_ast/src/parenthesize.rs create mode 100644 crates/ruff_python_ast/tests/parenthesize.rs diff --git a/crates/ruff/resources/test/fixtures/pycodestyle/E712.py b/crates/ruff/resources/test/fixtures/pycodestyle/E712.py index 818afddaefe95..c0be4d7aa1c47 100644 --- a/crates/ruff/resources/test/fixtures/pycodestyle/E712.py +++ b/crates/ruff/resources/test/fixtures/pycodestyle/E712.py @@ -25,6 +25,12 @@ if res == True != False: pass +if(True) == TrueElement or x == TrueElement: + pass + +if (yield i) == True: + print("even") + #: Okay if x not in y: pass diff --git a/crates/ruff/src/rules/pycodestyle/helpers.rs b/crates/ruff/src/rules/pycodestyle/helpers.rs index 76a752b8f92fe..ff7c837a22333 100644 --- a/crates/ruff/src/rules/pycodestyle/helpers.rs +++ b/crates/ruff/src/rules/pycodestyle/helpers.rs @@ -1,8 +1,10 @@ -use ruff_python_ast::{CmpOp, Expr, Ranged}; -use ruff_text_size::{TextLen, TextRange}; use unicode_width::UnicodeWidthStr; +use ruff_python_ast::node::AnyNodeRef; +use ruff_python_ast::parenthesize::ParenthesizedExpression; +use ruff_python_ast::{CmpOp, Expr, Ranged}; use ruff_source_file::{Line, Locator}; +use ruff_text_size::{TextLen, TextRange}; use crate::line_width::{LineLength, LineWidth, TabSize}; @@ -14,6 +16,7 @@ pub(super) fn generate_comparison( left: &Expr, ops: &[CmpOp], comparators: &[Expr], + parent: AnyNodeRef, locator: &Locator, ) -> String { let start = left.start(); @@ -21,7 +24,9 @@ pub(super) fn generate_comparison( let mut contents = String::with_capacity(usize::from(end - start)); // Add the left side of the comparison. - contents.push_str(locator.slice(left.range())); + contents.push_str(locator.slice( + ParenthesizedExpression::from_expr(left.into(), parent, locator.contents()).range(), + )); for (op, comparator) in ops.iter().zip(comparators) { // Add the operator. @@ -39,7 +44,12 @@ pub(super) fn generate_comparison( }); // Add the right side of the comparison. - contents.push_str(locator.slice(comparator.range())); + contents.push_str( + locator.slice( + ParenthesizedExpression::from_expr(comparator.into(), parent, locator.contents()) + .range(), + ), + ); } contents diff --git a/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs b/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs index d16fa579f74de..9c7f525b6ca72 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/literal_comparisons.rs @@ -279,8 +279,13 @@ pub(crate) fn literal_comparisons(checker: &mut Checker, compare: &ast::ExprComp .map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op)) .copied() .collect::>(); - let content = - generate_comparison(&compare.left, &ops, &compare.comparators, checker.locator()); + let content = generate_comparison( + &compare.left, + &ops, + &compare.comparators, + compare.into(), + checker.locator(), + ); for diagnostic in &mut diagnostics { diagnostic.set_fix(Fix::suggested(Edit::range_replacement( content.to_string(), diff --git a/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs b/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs index 0ee126d046933..d5a0c3d5898e7 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/not_tests.rs @@ -83,8 +83,7 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) { ops, comparators, range: _, - }) = unary_op.operand.as_ref() - else { + }) = unary_op.operand.as_ref() else { return; }; @@ -94,7 +93,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) { let mut diagnostic = Diagnostic::new(NotInTest, unary_op.operand.range()); if checker.patch(diagnostic.kind.rule()) { diagnostic.set_fix(Fix::automatic(Edit::range_replacement( - generate_comparison(left, &[CmpOp::NotIn], comparators, checker.locator()), + generate_comparison( + left, + &[CmpOp::NotIn], + comparators, + unary_op.into(), + checker.locator(), + ), unary_op.range(), ))); } @@ -106,7 +111,13 @@ pub(crate) fn not_tests(checker: &mut Checker, unary_op: &ast::ExprUnaryOp) { let mut diagnostic = Diagnostic::new(NotIsTest, unary_op.operand.range()); if checker.patch(diagnostic.kind.rule()) { diagnostic.set_fix(Fix::automatic(Edit::range_replacement( - generate_comparison(left, &[CmpOp::IsNot], comparators, checker.locator()), + generate_comparison( + left, + &[CmpOp::IsNot], + comparators, + unary_op.into(), + checker.locator(), + ), unary_op.range(), ))); } diff --git a/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap b/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap index e2a9be7b88be8..ba3f1143bd887 100644 --- a/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap +++ b/crates/ruff/src/rules/pycodestyle/snapshots/ruff__rules__pycodestyle__tests__E712_E712.py.snap @@ -181,7 +181,7 @@ E712.py:22:5: E712 [*] Comparison to `True` should be `cond is True` or `if cond 20 20 | var = 1 if cond == True else -1 if cond == False else cond 21 21 | #: E712 22 |-if (True) == TrueElement or x == TrueElement: - 22 |+if True is TrueElement or x == TrueElement: + 22 |+if (True) is TrueElement or x == TrueElement: 23 23 | pass 24 24 | 25 25 | if res == True != False: @@ -204,7 +204,7 @@ E712.py:25:11: E712 [*] Comparison to `True` should be `cond is True` or `if con 25 |+if res is True is not False: 26 26 | pass 27 27 | -28 28 | #: Okay +28 28 | if(True) == TrueElement or x == TrueElement: E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or `if cond:` | @@ -224,6 +224,46 @@ E712.py:25:19: E712 [*] Comparison to `False` should be `cond is not False` or ` 25 |+if res is True is not False: 26 26 | pass 27 27 | -28 28 | #: Okay +28 28 | if(True) == TrueElement or x == TrueElement: + +E712.py:28:4: E712 [*] Comparison to `True` should be `cond is True` or `if cond:` + | +26 | pass +27 | +28 | if(True) == TrueElement or x == TrueElement: + | ^^^^ E712 +29 | pass + | + = help: Replace with `cond is True` + +ℹ Suggested fix +25 25 | if res == True != False: +26 26 | pass +27 27 | +28 |-if(True) == TrueElement or x == TrueElement: + 28 |+if(True) is TrueElement or x == TrueElement: +29 29 | pass +30 30 | +31 31 | if (yield i) == True: + +E712.py:31:17: E712 [*] Comparison to `True` should be `cond is True` or `if cond:` + | +29 | pass +30 | +31 | if (yield i) == True: + | ^^^^ E712 +32 | print("even") + | + = help: Replace with `cond is True` + +ℹ Suggested fix +28 28 | if(True) == TrueElement or x == TrueElement: +29 29 | pass +30 30 | +31 |-if (yield i) == True: + 31 |+if (yield i) is True: +32 32 | print("even") +33 33 | +34 34 | #: Okay diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index ac615c12803aa..d28f459dd4af4 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -12,6 +12,7 @@ pub mod identifier; pub mod imports; pub mod node; mod nodes; +pub mod parenthesize; pub mod relocate; pub mod statement_visitor; pub mod stmt_if; diff --git a/crates/ruff_python_ast/src/parenthesize.rs b/crates/ruff_python_ast/src/parenthesize.rs new file mode 100644 index 0000000000000..9949c9c2e9d8a --- /dev/null +++ b/crates/ruff_python_ast/src/parenthesize.rs @@ -0,0 +1,66 @@ +use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer}; +use ruff_text_size::{TextRange, TextSize}; +use std::ops::Sub; + +use crate::node::AnyNodeRef; +use crate::Ranged; + +/// A wrapper around an expression that may be parenthesized. +#[derive(Debug)] +pub struct ParenthesizedExpression<'a> { + /// The underlying AST node. + expr: AnyNodeRef<'a>, + /// The range of the expression including parentheses, if the expression is parenthesized; + /// or `None`, if the expression is not parenthesized. + range: Option, +} + +impl<'a> ParenthesizedExpression<'a> { + /// Given an expression and its parent, returns a parenthesized expression. + pub fn from_expr(expr: AnyNodeRef<'a>, parent: AnyNodeRef<'a>, contents: &str) -> Self { + Self { + expr, + range: parenthesized_range(expr, parent, contents), + } + } + + /// Returns `true` if the expression is parenthesized. + pub fn is_parenthesized(&self) -> bool { + self.range.is_some() + } +} + +impl Ranged for ParenthesizedExpression<'_> { + fn range(&self) -> TextRange { + self.range.unwrap_or_else(|| self.expr.range()) + } +} + +/// Returns the [`TextRange`] of a given expression including parentheses, if the expression is +/// parenthesized; or `None`, if the expression is not parenthesized. +fn parenthesized_range(expr: AnyNodeRef, parent: AnyNodeRef, contents: &str) -> Option { + // If the parent is an `arguments` node, then the range of the expression includes the closing + // parenthesis, so exclude it from our test range. + let exclusive_parent_end = if parent.is_arguments() { + parent.end().sub(TextSize::new(1)) + } else { + parent.end() + }; + + // First, test if there's a closing parenthesis because it tends to be cheaper. + let tokenizer = + SimpleTokenizer::new(contents, TextRange::new(expr.end(), exclusive_parent_end)); + let right = tokenizer.skip_trivia().next()?; + + if right.kind == SimpleTokenKind::RParen { + // Next, test for the opening parenthesis. + let mut tokenizer = + SimpleTokenizer::up_to_without_back_comment(expr.start(), contents).skip_trivia(); + let left = tokenizer.next_back()?; + if left.kind == SimpleTokenKind::LParen { + return Some(TextRange::new(left.start(), right.end())); + } + } + + None +} diff --git a/crates/ruff_python_ast/src/stmt_if.rs b/crates/ruff_python_ast/src/stmt_if.rs index 0acc04a40535f..af2ae6bdb5674 100644 --- a/crates/ruff_python_ast/src/stmt_if.rs +++ b/crates/ruff_python_ast/src/stmt_if.rs @@ -40,6 +40,3 @@ pub fn if_elif_branches(stmt_if: &StmtIf) -> impl Iterator }) })) } - -#[cfg(test)] -mod test {} diff --git a/crates/ruff_python_ast/tests/parenthesize.rs b/crates/ruff_python_ast/tests/parenthesize.rs new file mode 100644 index 0000000000000..82f4489c102dd --- /dev/null +++ b/crates/ruff_python_ast/tests/parenthesize.rs @@ -0,0 +1,80 @@ +use ruff_python_ast::parenthesize::ParenthesizedExpression; +use ruff_python_parser::parse_expression; + +#[test] +fn test_parenthesized_name() { + let source_code = r#"(x) + 1"#; + let expr = parse_expression(source_code, "").unwrap(); + + let bin_op = expr.as_bin_op_expr().unwrap(); + let name = bin_op.left.as_ref(); + + let parenthesized = ParenthesizedExpression::from_expr(name.into(), bin_op.into(), source_code); + assert!(parenthesized.is_parenthesized()); +} + +#[test] +fn test_un_parenthesized_name() { + let source_code = r#"x + 1"#; + let expr = parse_expression(source_code, "").unwrap(); + + let bin_op = expr.as_bin_op_expr().unwrap(); + let name = bin_op.left.as_ref(); + + let parenthesized = ParenthesizedExpression::from_expr(name.into(), bin_op.into(), source_code); + assert!(!parenthesized.is_parenthesized()); +} + +#[test] +fn test_parenthesized_argument() { + let source_code = r#"f((a))"#; + let expr = parse_expression(source_code, "").unwrap(); + + let call = expr.as_call_expr().unwrap(); + let arguments = &call.arguments; + let argument = arguments.args.first().unwrap(); + + let parenthesized = + ParenthesizedExpression::from_expr(argument.into(), arguments.into(), source_code); + assert!(parenthesized.is_parenthesized()); +} + +#[test] +fn test_unparenthesized_argument() { + let source_code = r#"f(a)"#; + let expr = parse_expression(source_code, "").unwrap(); + + let call = expr.as_call_expr().unwrap(); + let arguments = &call.arguments; + let argument = arguments.args.first().unwrap(); + + let parenthesized = + ParenthesizedExpression::from_expr(argument.into(), arguments.into(), source_code); + assert!(!parenthesized.is_parenthesized()); +} + +#[test] +fn test_parenthesized_tuple_member() { + let source_code = r#"(a, (b))"#; + let expr = parse_expression(source_code, "").unwrap(); + + let tuple = expr.as_tuple_expr().unwrap(); + let member = tuple.elts.last().unwrap(); + + let parenthesized = + ParenthesizedExpression::from_expr(member.into(), tuple.into(), source_code); + assert!(parenthesized.is_parenthesized()); +} + +#[test] +fn test_unparenthesized_tuple_member() { + let source_code = r#"(a, b)"#; + let expr = parse_expression(source_code, "").unwrap(); + + let tuple = expr.as_tuple_expr().unwrap(); + let member = tuple.elts.last().unwrap(); + + let parenthesized = + ParenthesizedExpression::from_expr(member.into(), tuple.into(), source_code); + assert!(!parenthesized.is_parenthesized()); +}