Skip to content

Commit

Permalink
Make SIM401 catch ternary operations (#7415)
Browse files Browse the repository at this point in the history
## Summary

Make SIM401 rules to catch ternary operations when preview is enabled.

Fixes #7288.

## Test Plan

Tested against `SIM401.py` fixtures.
  • Loading branch information
Flowrey authored Oct 20, 2023
1 parent 860ffb9 commit fa556d1
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,16 @@
vars[idx] = a_dict[key]
else:
vars[idx] = "default"

###
# Positive cases (preview)
###

# SIM401
var = a_dict[key] if key in a_dict else "default3"

# SIM401
var = "default-1" if key not in a_dict else a_dict[key]

# OK (default contains effect)
var = a_dict[key] if key in a_dict else val1 + val2
5 changes: 5 additions & 0 deletions crates/ruff_linter/src/checkers/ast/analyze/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,11 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
orelse,
range: _,
}) => {
if checker.enabled(Rule::IfElseBlockInsteadOfDictGet) {
flake8_simplify::rules::if_exp_instead_of_dict_get(
checker, expr, test, body, orelse,
);
}
if checker.enabled(Rule::IfExprWithTrueFalse) {
flake8_simplify::rules::if_expr_with_true_false(checker, expr, test, body, orelse);
}
Expand Down
19 changes: 19 additions & 0 deletions crates/ruff_linter/src/rules/flake8_simplify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod tests {
use test_case::test_case;

use crate::registry::Rule;
use crate::settings::types::PreviewMode;
use crate::test::test_path;
use crate::{assert_messages, settings};

Expand Down Expand Up @@ -53,4 +54,22 @@ mod tests {
assert_messages!(snapshot, diagnostics);
Ok(())
}

#[test_case(Rule::IfElseBlockInsteadOfDictGet, Path::new("SIM401.py"))]
fn preview_rules(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!(
"preview__{}_{}",
rule_code.noqa_code(),
path.to_string_lossy()
);
let diagnostics = test_path(
Path::new("flake8_simplify").join(path).as_path(),
&settings::LinterSettings {
preview: PreviewMode::Enabled,
..settings::LinterSettings::for_rule(rule_code)
},
)?;
assert_messages!(snapshot, diagnostics);
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use crate::fix::edits::fits;
/// the key is not found. When possible, using `dict.get` is more concise and
/// more idiomatic.
///
/// Under [preview mode](https://docs.astral.sh/ruff/preview), this rule will
/// also suggest replacing `if`-`else` _expressions_ with `dict.get` calls.
///
/// ## Example
/// ```python
/// if "bar" in foo:
Expand All @@ -33,6 +36,16 @@ use crate::fix::edits::fits;
/// value = foo.get("bar", 0)
/// ```
///
/// If preview mode is enabled:
/// ```python
/// value = foo["bar"] if "bar" in foo else 0
/// ```
///
/// Use instead:
/// ```python
/// value = foo.get("bar", 0)
/// ```
///
/// ## References
/// - [Python documentation: Mapping Types](https://docs.python.org/3/library/stdtypes.html#mapping-types-dict)
#[violation]
Expand Down Expand Up @@ -202,3 +215,91 @@ pub(crate) fn if_else_block_instead_of_dict_get(checker: &mut Checker, stmt_if:
}
checker.diagnostics.push(diagnostic);
}

/// SIM401
pub(crate) fn if_exp_instead_of_dict_get(
checker: &mut Checker,
expr: &Expr,
test: &Expr,
body: &Expr,
orelse: &Expr,
) {
if checker.settings.preview.is_disabled() {
return;
}

let Expr::Compare(ast::ExprCompare {
left: test_key,
ops,
comparators: test_dict,
range: _,
}) = test
else {
return;
};
let [test_dict] = test_dict.as_slice() else {
return;
};

let (body, default_value) = match ops.as_slice() {
[CmpOp::In] => (body, orelse),
[CmpOp::NotIn] => (orelse, body),
_ => {
return;
}
};

let Expr::Subscript(ast::ExprSubscript {
value: expected_subscript,
slice: expected_slice,
..
}) = body
else {
return;
};

if ComparableExpr::from(expected_slice) != ComparableExpr::from(test_key)
|| ComparableExpr::from(test_dict) != ComparableExpr::from(expected_subscript)
{
return;
}

// Check that the default value is not "complex".
if contains_effect(default_value, |id| checker.semantic().is_builtin(id)) {
return;
}

let default_value_node = default_value.clone();
let dict_key_node = *test_key.clone();
let dict_get_node = ast::ExprAttribute {
value: expected_subscript.clone(),
attr: Identifier::new("get".to_string(), TextRange::default()),
ctx: ExprContext::Load,
range: TextRange::default(),
};
let fixed_node = ast::ExprCall {
func: Box::new(dict_get_node.into()),
arguments: Arguments {
args: vec![dict_key_node, default_value_node],
keywords: vec![],
range: TextRange::default(),
},
range: TextRange::default(),
};

let contents = checker.generator().expr(&fixed_node.into());

let mut diagnostic = Diagnostic::new(
IfElseBlockInsteadOfDictGet {
contents: contents.clone(),
},
expr.range(),
);
if !checker.indexer().has_comments(expr, checker.locator()) {
diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
contents,
expr.range(),
)));
}
checker.diagnostics.push(diagnostic);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
---
source: crates/ruff_linter/src/rules/flake8_simplify/mod.rs
---
SIM401.py:6:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block
|
5 | # SIM401 (pattern-1)
6 | / if key in a_dict:
7 | | var = a_dict[key]
8 | | else:
9 | | var = "default1"
| |____________________^ SIM401
10 |
11 | # SIM401 (pattern-2)
|
= help: Replace with `var = a_dict.get(key, "default1")`

Suggested fix
3 3 | ###
4 4 |
5 5 | # SIM401 (pattern-1)
6 |-if key in a_dict:
7 |- var = a_dict[key]
8 |-else:
9 |- var = "default1"
6 |+var = a_dict.get(key, "default1")
10 7 |
11 8 | # SIM401 (pattern-2)
12 9 | if key not in a_dict:

SIM401.py:12:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block
|
11 | # SIM401 (pattern-2)
12 | / if key not in a_dict:
13 | | var = "default2"
14 | | else:
15 | | var = a_dict[key]
| |_____________________^ SIM401
16 |
17 | # OK (default contains effect)
|
= help: Replace with `var = a_dict.get(key, "default2")`

Suggested fix
9 9 | var = "default1"
10 10 |
11 11 | # SIM401 (pattern-2)
12 |-if key not in a_dict:
13 |- var = "default2"
14 |-else:
15 |- var = a_dict[key]
12 |+var = a_dict.get(key, "default2")
16 13 |
17 14 | # OK (default contains effect)
18 15 | if key in a_dict:

SIM401.py:24:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block
|
23 | # SIM401 (complex expression in key)
24 | / if keys[idx] in a_dict:
25 | | var = a_dict[keys[idx]]
26 | | else:
27 | | var = "default"
| |___________________^ SIM401
28 |
29 | # SIM401 (complex expression in dict)
|
= help: Replace with `var = a_dict.get(keys[idx], "default")`

Suggested fix
21 21 | var = val1 + val2
22 22 |
23 23 | # SIM401 (complex expression in key)
24 |-if keys[idx] in a_dict:
25 |- var = a_dict[keys[idx]]
26 |-else:
27 |- var = "default"
24 |+var = a_dict.get(keys[idx], "default")
28 25 |
29 26 | # SIM401 (complex expression in dict)
30 27 | if key in dicts[idx]:

SIM401.py:30:1: SIM401 [*] Use `var = dicts[idx].get(key, "default")` instead of an `if` block
|
29 | # SIM401 (complex expression in dict)
30 | / if key in dicts[idx]:
31 | | var = dicts[idx][key]
32 | | else:
33 | | var = "default"
| |___________________^ SIM401
34 |
35 | # SIM401 (complex expression in var)
|
= help: Replace with `var = dicts[idx].get(key, "default")`

Suggested fix
27 27 | var = "default"
28 28 |
29 29 | # SIM401 (complex expression in dict)
30 |-if key in dicts[idx]:
31 |- var = dicts[idx][key]
32 |-else:
33 |- var = "default"
30 |+var = dicts[idx].get(key, "default")
34 31 |
35 32 | # SIM401 (complex expression in var)
36 33 | if key in a_dict:

SIM401.py:36:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789")` instead of an `if` block
|
35 | # SIM401 (complex expression in var)
36 | / if key in a_dict:
37 | | vars[idx] = a_dict[key]
38 | | else:
39 | | vars[idx] = "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789"
| |___________________________________________________________________________^ SIM401
40 |
41 | # SIM401
|
= help: Replace with `vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789")`

Suggested fix
33 33 | var = "default"
34 34 |
35 35 | # SIM401 (complex expression in var)
36 |-if key in a_dict:
37 |- vars[idx] = a_dict[key]
38 |-else:
39 |- vars[idx] = "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789"
36 |+vars[idx] = a_dict.get(key, "defaultß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789ß9💣2ℝ6789")
40 37 |
41 38 | # SIM401
42 39 | if foo():

SIM401.py:45:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block
|
43 | pass
44 | else:
45 | if key in a_dict:
| _____^
46 | | vars[idx] = a_dict[key]
47 | | else:
48 | | vars[idx] = "default"
| |_____________________________^ SIM401
49 |
50 | ###
|
= help: Replace with `vars[idx] = a_dict.get(key, "default")`

Suggested fix
42 42 | if foo():
43 43 | pass
44 44 | else:
45 |- if key in a_dict:
46 |- vars[idx] = a_dict[key]
47 |- else:
48 |- vars[idx] = "default"
45 |+ vars[idx] = a_dict.get(key, "default")
49 46 |
50 47 | ###
51 48 | # Negative cases

SIM401.py:123:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block
|
122 | # SIM401
123 | var = a_dict[key] if key in a_dict else "default3"
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401
124 |
125 | # SIM401
|
= help: Replace with `a_dict.get(key, "default3")`

Suggested fix
120 120 | ###
121 121 |
122 122 | # SIM401
123 |-var = a_dict[key] if key in a_dict else "default3"
123 |+var = a_dict.get(key, "default3")
124 124 |
125 125 | # SIM401
126 126 | var = "default-1" if key not in a_dict else a_dict[key]

SIM401.py:126:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block
|
125 | # SIM401
126 | var = "default-1" if key not in a_dict else a_dict[key]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401
127 |
128 | # OK (default contains effect)
|
= help: Replace with `a_dict.get(key, "default-1")`

Suggested fix
123 123 | var = a_dict[key] if key in a_dict else "default3"
124 124 |
125 125 | # SIM401
126 |-var = "default-1" if key not in a_dict else a_dict[key]
126 |+var = a_dict.get(key, "default-1")
127 127 |
128 128 | # OK (default contains effect)
129 129 | var = a_dict[key] if key in a_dict else val1 + val2


0 comments on commit fa556d1

Please sign in to comment.