Skip to content

Commit

Permalink
SIM401: Implement check if-block -> dict.get(key, default) (#73)
Browse files Browse the repository at this point in the history
Closes #72
  • Loading branch information
MartinThoma authored Nov 27, 2021
1 parent ced1950 commit 31ea767
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 6 deletions.
25 changes: 21 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ Python-specific rules:
* [`SIM115`](https://github.com/MartinThoma/flake8-simplify/issues/17): Use context handler for opening files ([example](#SIM115))
* [`SIM116`](https://github.com/MartinThoma/flake8-simplify/issues/31): Use a dictionary instead of many if/else equality checks ([example](#SIM116))
* [`SIM117`](https://github.com/MartinThoma/flake8-simplify/issues/35): Merge with-statements that use the same scope ([example](#SIM117))
* [`SIM118`](https://github.com/MartinThoma/flake8-simplify/issues/40): Use 'key in dict' instead of 'key in dict.keys()' ([example](#SIM118))
* [`SIM119`](https://github.com/MartinThoma/flake8-simplify/issues/37) ![](https://shields.io/badge/-legacyfix-inactive): Use dataclasses for data containers ([example](#SIM119))
* `SIM120` ![](https://shields.io/badge/-legacyfix-inactive): Use 'class FooBar:' instead of 'class FooBar(object):' ([example](#SIM120))

Comparations:
Simplifying Comparations:

* `SIM201`: Use 'a != b' instead of 'not a == b' ([example](#SIM201))
* `SIM202`: Use 'a == b' instead of 'not a != b' ([example](#SIM202))
Expand All @@ -69,6 +68,11 @@ Comparations:
* [`SIM223`](https://github.com/MartinThoma/flake8-simplify/issues/6): Use 'False' instead of '... and False' ([example](#SIM223))
* [`SIM300`](https://github.com/MartinThoma/flake8-simplify/issues/16): Use 'age == 42' instead of '42 == age' ([example](#SIM300))

Simplifying usage of dictionaries:

* [`SIM401`](https://github.com/MartinThoma/flake8-simplify/issues/72): Use 'a_dict.get(key, "default_value")' instead of an if-block ([example](#SIM401))
* [`SIM118`](https://github.com/MartinThoma/flake8-simplify/issues/40): Use 'key in dict' instead of 'key in dict.keys()' ([example](#SIM118))

General Code Style:

* `SIM102`: Use a single if-statement instead of nested if-statements ([example](#SIM102))
Expand Down Expand Up @@ -279,10 +283,10 @@ Thank you for pointing this one out, [Aaron Gokaslan](https://github.com/Skylion

```python
# Bad
key in dict.keys()
key in a_dict.keys()

# Good
key in dict
key in a_dict
```

Thank you for pointing this one out, [Aaron Gokaslan](https://github.com/Skylion007)!
Expand Down Expand Up @@ -477,3 +481,16 @@ False
# Good
age == 42
```

### SIM401

```python
# Bad
if key in a_dict:
value = a_dict[key]
else:
value = "default_value"

# Good
thing = a_dict.get(key, "default_value")
```
157 changes: 155 additions & 2 deletions flake8_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __init__(self, orig: ast.Call) -> None:
"SIM300 Use '{right} == {left}' instead of "
"'{left} == {right}' (Yoda-conditions)"
)
SIM401 = (
"SIM401 Use '{value} = {dict}.get({key}, \"{default_value}\")' "
"instead of an if-block"
)

# ast.Constant in Python 3.8, ast.NameConstant in Python 3.6 and 3.7
BOOL_CONST_TYPES = (ast.Constant, ast.NameConstant)
Expand All @@ -129,7 +133,9 @@ def strip_triple_quotes(string: str) -> str:
return string


def to_source(node: Union[None, ast.expr, ast.Expr, ast.withitem]) -> str:
def to_source(
node: Union[None, ast.expr, ast.Expr, ast.withitem, ast.slice]
) -> str:
if node is None:
return "None"
source: str = astor.to_source(node).strip()
Expand Down Expand Up @@ -1597,6 +1603,152 @@ def _get_sim300(node: ast.Compare) -> List[Tuple[int, int, str]]:
return errors


def _get_sim401(node: ast.If) -> List[Tuple[int, int, str]]:
"""
Get all calls that should use default values for dictionary access.
Pattern 1
---------
if key in a_dict:
value = a_dict[key]
else:
value = "default"
which is
If(
test=Compare(
left=Name(id='key', ctx=Load()),
ops=[In()],
comparators=[Name(id='a_dict', ctx=Load())],
),
body=[
Assign(
targets=[Name(id='value', ctx=Store())],
value=Subscript(
value=Name(id='a_dict', ctx=Load()),
slice=Name(id='key', ctx=Load()),
ctx=Load(),
),
type_comment=None,
),
],
orelse=[
Assign(
targets=[Name(id='value', ctx=Store())],
value=Constant(value='default', kind=None),
type_comment=None,
),
],
),
Pattern 2
---------
if key not in a_dict:
value = 'default'
else:
value = a_dict[key]
which is
If(
test=Compare(
left=Name(id='key', ctx=Load()),
ops=[NotIn()],
comparators=[Name(id='a_dict', ctx=Load())],
),
body=[
Assign(
targets=[Name(id='value', ctx=Store())],
value=Constant(value='default', kind=None),
type_comment=None,
),
],
orelse=[
Assign(
targets=[Name(id='value', ctx=Store())],
value=Subscript(
value=Name(id='a_dict', ctx=Load()),
slice=Name(id='key', ctx=Load()),
ctx=Load(),
),
type_comment=None,
),
],
)
"""
errors: List[Tuple[int, int, str]] = []
is_pattern_1 = (
len(node.body) == 1
and isinstance(node.body[0], ast.Assign)
and isinstance(node.body[0].value, ast.Subscript)
and len(node.orelse) == 1
and isinstance(node.orelse[0], ast.Assign)
and isinstance(node.test, ast.Compare)
and len(node.test.ops) == 1
and isinstance(node.test.ops[0], ast.In)
)

# just like pattern_1, but using NotIn and reversing if/else
is_pattern_2 = (
len(node.body) == 1
and isinstance(node.body[0], ast.Assign)
and len(node.orelse) == 1
and isinstance(node.orelse[0], ast.Assign)
and isinstance(node.orelse[0].value, ast.Subscript)
and isinstance(node.test, ast.Compare)
and len(node.test.ops) == 1
and isinstance(node.test.ops[0], ast.NotIn)
)
if is_pattern_1:
assert isinstance(node.test, ast.Compare)
assert isinstance(node.body[0], ast.Assign)
assert isinstance(node.body[0].value, ast.Subscript)
assert isinstance(node.orelse[0], ast.Assign)
key = node.test.left
if to_source(key) != to_source(node.body[0].value.slice):
return errors # second part of pattern 1
dict_name = node.test.comparators[0]
default_value = node.orelse[0].value
value_node = node.body[0].targets[0]
key_str = to_source(key)
dict_str = to_source(dict_name)
default_str = to_source(default_value)
value_str = to_source(value_node)
elif is_pattern_2:
assert isinstance(node.test, ast.Compare)
assert isinstance(node.body[0], ast.Assign)
assert isinstance(node.orelse[0], ast.Assign)
assert isinstance(node.orelse[0].value, ast.Subscript)
key = node.test.left
if to_source(key) != to_source(node.orelse[0].value.slice):
return errors # second part of pattern 1
dict_name = node.test.comparators[0]
default_value = node.body[0].value
value_node = node.body[0].targets[0]
key_str = to_source(key)
dict_str = to_source(dict_name)
default_str = to_source(default_value)
value_str = to_source(value_node)
else:
return errors
errors.append(
(
node.lineno,
node.col_offset,
SIM401.format(
key=key_str,
dict=dict_str,
default_value=default_str,
value=value_str,
),
)
)
return errors


class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
self.errors: List[Tuple[int, int, str]] = []
Expand Down Expand Up @@ -1629,6 +1781,7 @@ def visit_If(self, node: ast.If) -> None:
self.errors += _get_sim108(node)
self.errors += _get_sim114(node)
self.errors += _get_sim116(node)
self.errors += _get_sim401(node)
self.generic_visit(node)

def visit_For(self, node: ast.For) -> None:
Expand Down Expand Up @@ -1673,7 +1826,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:

class Plugin:
name = __name__
version = importlib_metadata.version(__name__)
version = importlib_metadata.version(__name__) # type: ignore

def __init__(self, tree: ast.AST):
self._tree = tree
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ classifiers =
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Software Development
Framework :: Flake8

Expand Down Expand Up @@ -59,6 +60,7 @@ tests_dir = tests/
paths_to_mutate=flake8_simplify.py

[mypy]
exclude = build/lib/flake8_simplify.py
ignore_missing_imports = true
strict = true
check_untyped_defs = true
Expand All @@ -67,6 +69,7 @@ disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_unused_ignores = false
show_error_codes = true

[mypy-testing.*]
disallow_untyped_defs = false
Expand Down
42 changes: 42 additions & 0 deletions tests/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,45 @@ def test_sim300_int():
assert ret == {
"1:0 SIM300 Use 'age == 42' instead of '42 == age' (Yoda-conditions)"
}


def test_sim401_if_else():
ret = _results(
"""if key in a_dict:
value = a_dict[key]
else:
value = 'default'"""
)
assert ret == {
"""1:0 SIM401 Use 'value = a_dict.get(key, "default")' """
"""instead of an if-block"""
}


def test_sim401_negated_if_else():
ret = _results(
"""if key not in a_dict:
value = 'default'
else:
value = a_dict[key] """
)
assert (
"""1:0 SIM401 Use 'value = a_dict.get(key, "default")' """
"""instead of an if-block""" in ret
)


def test_sim401_prefix_negated_if_else():
ret = _results(
"""if not key in a_dict:
value = 'default'
else:
value = a_dict[key] """
)
assert (
"""1:3 SIM401 Use 'value = a_dict.get(key, "default")' """
"""instead of an if-block""" in ret
) or (
"1:3 SIM203 Use 'key not in a_dict' instead of 'not key in a_dict'"
in ret
)

0 comments on commit 31ea767

Please sign in to comment.