From fa556d1c74417bd2f43bd22a3a749e7cd018b442 Mon Sep 17 00:00:00 2001 From: Flowrey <43378459+Flowrey@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:47:00 +0200 Subject: [PATCH] Make SIM401 catch ternary operations (#7415) ## Summary Make SIM401 rules to catch ternary operations when preview is enabled. Fixes #7288. ## Test Plan Tested against `SIM401.py` fixtures. --- .../test/fixtures/flake8_simplify/SIM401.py | 13 ++ .../src/checkers/ast/analyze/expression.rs | 5 + .../src/rules/flake8_simplify/mod.rs | 19 ++ .../if_else_block_instead_of_dict_get.rs | 101 +++++++++ ...ify__tests__preview__SIM401_SIM401.py.snap | 202 ++++++++++++++++++ 5 files changed, 340 insertions(+) create mode 100644 crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__preview__SIM401_SIM401.py.snap diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py b/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py index 79b48a57dd1c5..26bc35ee63209 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py @@ -114,3 +114,16 @@ vars[idx] = a_dict[key] else: vars[idx] = "default" + +### +# Positive cases (preview) +### + +# SIM401 +var = a_dict[key] if key in a_dict else "default3" + +# SIM401 +var = "default-1" if key not in a_dict else a_dict[key] + +# OK (default contains effect) +var = a_dict[key] if key in a_dict else val1 + val2 diff --git a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs index 81cf642aee29a..13ad26bed7933 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs @@ -1273,6 +1273,11 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) { orelse, range: _, }) => { + if checker.enabled(Rule::IfElseBlockInsteadOfDictGet) { + flake8_simplify::rules::if_exp_instead_of_dict_get( + checker, expr, test, body, orelse, + ); + } if checker.enabled(Rule::IfExprWithTrueFalse) { flake8_simplify::rules::if_expr_with_true_false(checker, expr, test, body, orelse); } diff --git a/crates/ruff_linter/src/rules/flake8_simplify/mod.rs b/crates/ruff_linter/src/rules/flake8_simplify/mod.rs index 82a7d33c9b0c8..fe3eaecb8ddd3 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_simplify/mod.rs @@ -9,6 +9,7 @@ mod tests { use test_case::test_case; use crate::registry::Rule; + use crate::settings::types::PreviewMode; use crate::test::test_path; use crate::{assert_messages, settings}; @@ -53,4 +54,22 @@ mod tests { assert_messages!(snapshot, diagnostics); Ok(()) } + + #[test_case(Rule::IfElseBlockInsteadOfDictGet, Path::new("SIM401.py"))] + fn preview_rules(rule_code: Rule, path: &Path) -> Result<()> { + let snapshot = format!( + "preview__{}_{}", + rule_code.noqa_code(), + path.to_string_lossy() + ); + let diagnostics = test_path( + Path::new("flake8_simplify").join(path).as_path(), + &settings::LinterSettings { + preview: PreviewMode::Enabled, + ..settings::LinterSettings::for_rule(rule_code) + }, + )?; + assert_messages!(snapshot, diagnostics); + Ok(()) + } } diff --git a/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs b/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs index 8afe4fb52ce79..392156acab946 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs +++ b/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs @@ -20,6 +20,9 @@ use crate::fix::edits::fits; /// the key is not found. When possible, using `dict.get` is more concise and /// more idiomatic. /// +/// Under [preview mode](https://docs.astral.sh/ruff/preview), this rule will +/// also suggest replacing `if`-`else` _expressions_ with `dict.get` calls. +/// /// ## Example /// ```python /// if "bar" in foo: @@ -33,6 +36,16 @@ use crate::fix::edits::fits; /// value = foo.get("bar", 0) /// ``` /// +/// If preview mode is enabled: +/// ```python +/// value = foo["bar"] if "bar" in foo else 0 +/// ``` +/// +/// Use instead: +/// ```python +/// value = foo.get("bar", 0) +/// ``` +/// /// ## References /// - [Python documentation: Mapping Types](https://docs.python.org/3/library/stdtypes.html#mapping-types-dict) #[violation] @@ -202,3 +215,91 @@ pub(crate) fn if_else_block_instead_of_dict_get(checker: &mut Checker, stmt_if: } checker.diagnostics.push(diagnostic); } + +/// SIM401 +pub(crate) fn if_exp_instead_of_dict_get( + checker: &mut Checker, + expr: &Expr, + test: &Expr, + body: &Expr, + orelse: &Expr, +) { + if checker.settings.preview.is_disabled() { + return; + } + + let Expr::Compare(ast::ExprCompare { + left: test_key, + ops, + comparators: test_dict, + range: _, + }) = test + else { + return; + }; + let [test_dict] = test_dict.as_slice() else { + return; + }; + + let (body, default_value) = match ops.as_slice() { + [CmpOp::In] => (body, orelse), + [CmpOp::NotIn] => (orelse, body), + _ => { + return; + } + }; + + let Expr::Subscript(ast::ExprSubscript { + value: expected_subscript, + slice: expected_slice, + .. + }) = body + else { + return; + }; + + if ComparableExpr::from(expected_slice) != ComparableExpr::from(test_key) + || ComparableExpr::from(test_dict) != ComparableExpr::from(expected_subscript) + { + return; + } + + // Check that the default value is not "complex". + if contains_effect(default_value, |id| checker.semantic().is_builtin(id)) { + return; + } + + let default_value_node = default_value.clone(); + let dict_key_node = *test_key.clone(); + let dict_get_node = ast::ExprAttribute { + value: expected_subscript.clone(), + attr: Identifier::new("get".to_string(), TextRange::default()), + ctx: ExprContext::Load, + range: TextRange::default(), + }; + let fixed_node = ast::ExprCall { + func: Box::new(dict_get_node.into()), + arguments: Arguments { + args: vec![dict_key_node, default_value_node], + keywords: vec![], + range: TextRange::default(), + }, + range: TextRange::default(), + }; + + let contents = checker.generator().expr(&fixed_node.into()); + + let mut diagnostic = Diagnostic::new( + IfElseBlockInsteadOfDictGet { + contents: contents.clone(), + }, + expr.range(), + ); + if !checker.indexer().has_comments(expr, checker.locator()) { + diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( + contents, + expr.range(), + ))); + } + checker.diagnostics.push(diagnostic); +} diff --git a/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__preview__SIM401_SIM401.py.snap b/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__preview__SIM401_SIM401.py.snap new file mode 100644 index 0000000000000..6f807586d4c77 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__preview__SIM401_SIM401.py.snap @@ -0,0 +1,202 @@ +--- +source: crates/ruff_linter/src/rules/flake8_simplify/mod.rs +--- +SIM401.py:6:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block + | + 5 | # SIM401 (pattern-1) + 6 | / if key in a_dict: + 7 | | var = a_dict[key] + 8 | | else: + 9 | | var = "default1" + | |____________________^ SIM401 +10 | +11 | # SIM401 (pattern-2) + | + = help: Replace with `var = a_dict.get(key, "default1")` + +ℹ Suggested fix +3 3 | ### +4 4 | +5 5 | # SIM401 (pattern-1) +6 |-if key in a_dict: +7 |- var = a_dict[key] +8 |-else: +9 |- var = "default1" + 6 |+var = a_dict.get(key, "default1") +10 7 | +11 8 | # SIM401 (pattern-2) +12 9 | if key not in a_dict: + +SIM401.py:12:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block + | +11 | # SIM401 (pattern-2) +12 | / if key not in a_dict: +13 | | var = "default2" +14 | | else: +15 | | var = a_dict[key] + | |_____________________^ SIM401 +16 | +17 | # OK (default contains effect) + | + = help: Replace with `var = a_dict.get(key, "default2")` + +ℹ Suggested fix +9 9 | var = "default1" +10 10 | +11 11 | # SIM401 (pattern-2) +12 |-if key not in a_dict: +13 |- var = "default2" +14 |-else: +15 |- var = a_dict[key] + 12 |+var = a_dict.get(key, "default2") +16 13 | +17 14 | # OK (default contains effect) +18 15 | if key in a_dict: + +SIM401.py:24:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block + | +23 | # SIM401 (complex expression in key) +24 | / if keys[idx] in a_dict: +25 | | var = a_dict[keys[idx]] +26 | | else: +27 | | var = "default" + | |___________________^ SIM401 +28 | +29 | # SIM401 (complex expression in dict) + | + = help: Replace with `var = a_dict.get(keys[idx], "default")` + +ℹ Suggested fix +21 21 | var = val1 + val2 +22 22 | +23 23 | # SIM401 (complex expression in key) +24 |-if keys[idx] in a_dict: +25 |- var = a_dict[keys[idx]] +26 |-else: +27 |- var = "default" + 24 |+var = a_dict.get(keys[idx], "default") +28 25 | +29 26 | # SIM401 (complex expression in dict) +30 27 | if key in dicts[idx]: + +SIM401.py:30:1: SIM401 [*] Use `var = dicts[idx].get(key, "default")` instead of an `if` block + | +29 | # SIM401 (complex expression in dict) +30 | / if key in dicts[idx]: +31 | | var = dicts[idx][key] +32 | | else: +33 | | var = "default" + | |___________________^ SIM401 +34 | +35 | # SIM401 (complex expression in var) + | + = help: Replace with `var = dicts[idx].get(key, "default")` + +ℹ Suggested fix +27 27 | var = "default" +28 28 | +29 29 | # SIM401 (complex expression in dict) +30 |-if key in dicts[idx]: +31 |- var = dicts[idx][key] +32 |-else: +33 |- var = "default" + 30 |+var = dicts[idx].get(key, "default") +34 31 | +35 32 | # SIM401 (complex expression in var) +36 33 | if key in a_dict: + +SIM401.py:36:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789")` instead of an `if` block + | +35 | # SIM401 (complex expression in var) +36 | / if key in a_dict: +37 | | vars[idx] = a_dict[key] +38 | | else: +39 | | vars[idx] = "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789" + | |___________________________________________________________________________^ SIM401 +40 | +41 | # SIM401 + | + = help: Replace with `vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789")` + +ℹ Suggested fix +33 33 | var = "default" +34 34 | +35 35 | # SIM401 (complex expression in var) +36 |-if key in a_dict: +37 |- vars[idx] = a_dict[key] +38 |-else: +39 |- vars[idx] = "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789" + 36 |+vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789") +40 37 | +41 38 | # SIM401 +42 39 | if foo(): + +SIM401.py:45:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block + | +43 | pass +44 | else: +45 | if key in a_dict: + | _____^ +46 | | vars[idx] = a_dict[key] +47 | | else: +48 | | vars[idx] = "default" + | |_____________________________^ SIM401 +49 | +50 | ### + | + = help: Replace with `vars[idx] = a_dict.get(key, "default")` + +ℹ Suggested fix +42 42 | if foo(): +43 43 | pass +44 44 | else: +45 |- if key in a_dict: +46 |- vars[idx] = a_dict[key] +47 |- else: +48 |- vars[idx] = "default" + 45 |+ vars[idx] = a_dict.get(key, "default") +49 46 | +50 47 | ### +51 48 | # Negative cases + +SIM401.py:123:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block + | +122 | # SIM401 +123 | var = a_dict[key] if key in a_dict else "default3" + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401 +124 | +125 | # SIM401 + | + = help: Replace with `a_dict.get(key, "default3")` + +ℹ Suggested fix +120 120 | ### +121 121 | +122 122 | # SIM401 +123 |-var = a_dict[key] if key in a_dict else "default3" + 123 |+var = a_dict.get(key, "default3") +124 124 | +125 125 | # SIM401 +126 126 | var = "default-1" if key not in a_dict else a_dict[key] + +SIM401.py:126:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block + | +125 | # SIM401 +126 | var = "default-1" if key not in a_dict else a_dict[key] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401 +127 | +128 | # OK (default contains effect) + | + = help: Replace with `a_dict.get(key, "default-1")` + +ℹ Suggested fix +123 123 | var = a_dict[key] if key in a_dict else "default3" +124 124 | +125 125 | # SIM401 +126 |-var = "default-1" if key not in a_dict else a_dict[key] + 126 |+var = a_dict.get(key, "default-1") +127 127 | +128 128 | # OK (default contains effect) +129 129 | var = a_dict[key] if key in a_dict else val1 + val2 + +