Skip to content

Commit

Permalink
ENH: expand versioned_branches feature to Python 3 minor version comp…
Browse files Browse the repository at this point in the history
…arison (<, >, <=, >= with else)
  • Loading branch information
neutrinoceros authored and asottile committed Sep 12, 2021
1 parent f86dc65 commit 31546d2
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 8 deletions.
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,58 @@ def f():
yield (a, b)
```

### `if PY2` blocks
### Python2 and old Python3.x blocks

Availability:
- `--py3-plus` is passed on the commandline.

```python
# input
if six.PY2: # also understands `six.PY3` and `not` and `sys.version_info`
import sys
if sys.version_info < (3,): # also understands `six.PY2` (and `not`), `six.PY3` (and `not`)
print('py2')
else:
print('py3')
# output
import sys
print('py3')
```

Availability:
- `--py36-plus` will remove Python <= 3.5 only blocks
- `--py37-plus` will remove Python <= 3.6 only blocks
- so on and so forth

```python
# using --py36-plus for this example
# input
import sys
if sys.version_info < (3, 6):
print('py3.5')
else:
print('py3.6+')

if sys.version_info <= (3, 5):
print('py3.5')
else:
print('py3.6+')

if sys.version_info >= (3, 6):
print('py3.6+')
else:
print('py3.5')

# output
import sys
print('py3.6+')

print('py3.6+')

print('py3.6+')
```

Note that `if` blocks without an `else` will not be rewriten as it could introduce a syntax error.

### remove `six` compatibility code

Availability:
Expand Down
36 changes: 30 additions & 6 deletions pyupgrade/_plugins/versioned_branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._data import Version
from pyupgrade._token_helpers import Block


Expand Down Expand Up @@ -77,6 +78,7 @@ def _eq(test: ast.Compare, n: int) -> bool:
def _compare_to_3(
test: ast.Compare,
op: Union[Type[ast.cmpop], Tuple[Type[ast.cmpop], ...]],
minor: int = 0,
) -> bool:
if not (
isinstance(test.ops[0], op) and
Expand All @@ -87,9 +89,11 @@ def _compare_to_3(
return False

# checked above but mypy needs help
elts = cast('List[ast.Num]', test.comparators[0].elts)
ast_elts = cast('List[ast.Num]', test.comparators[0].elts)
# padding a 0 for compatibility with (3,) used as a spec
elts = tuple(e.n for e in ast_elts) + (0,)

return elts[0].n == 3 and all(n.n == 0 for n in elts[1:])
return elts[:2] == (3, minor) and all(n == 0 for n in elts[2:])


@register(ast.If)
Expand All @@ -98,8 +102,16 @@ def visit_If(
node: ast.If,
parent: ast.AST,
) -> Iterable[Tuple[Offset, TokenFunc]]:

min_version: Version
if state.settings.min_version == (3,):
min_version = (3, 0)
else:
min_version = state.settings.min_version
assert len(min_version) >= 2

if (
state.settings.min_version >= (3,) and (
min_version >= (3,) and (
# if six.PY2:
is_name_attr(node.test, state.from_imports, 'six', ('PY2',)) or
# if not six.PY3:
Expand All @@ -114,6 +126,7 @@ def visit_If(
)
) or
# sys.version_info == 2 or < (3,)
# or < (3, n) or <= (3, n) (with n<m)
(
isinstance(node.test, ast.Compare) and
is_name_attr(
Expand All @@ -124,15 +137,19 @@ def visit_If(
) and
len(node.test.ops) == 1 and (
_eq(node.test, 2) or
_compare_to_3(node.test, ast.Lt)
_compare_to_3(node.test, ast.Lt, min_version[1]) or
any(
_compare_to_3(node.test, (ast.Lt, ast.LtE), minor)
for minor in range(min_version[1])
)
)
)
)
):
if node.orelse and not isinstance(node.orelse[0], ast.If):
yield ast_to_offset(node), _fix_py2_block
elif (
state.settings.min_version >= (3,) and (
min_version >= (3,) and (
# if six.PY3:
is_name_attr(node.test, state.from_imports, 'six', ('PY3',)) or
# if not six.PY2:
Expand All @@ -147,6 +164,8 @@ def visit_If(
)
) or
# sys.version_info == 3 or >= (3,) or > (3,)
# sys.version_info >= (3, n) (with n<=m)
# or sys.version_info > (3, n) (with n<m)
(
isinstance(node.test, ast.Compare) and
is_name_attr(
Expand All @@ -157,7 +176,12 @@ def visit_If(
) and
len(node.test.ops) == 1 and (
_eq(node.test, 3) or
_compare_to_3(node.test, (ast.Gt, ast.GtE))
_compare_to_3(node.test, (ast.Gt, ast.GtE)) or
_compare_to_3(node.test, ast.GtE, min_version[1]) or
any(
_compare_to_3(node.test, (ast.Gt, ast.GtE), minor)
for minor in range(min_version[1])
)
)
)
)
Expand Down
152 changes: 152 additions & 0 deletions tests/features/versioned_branches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,155 @@ def test_fix_py2_blocks(s, expected):
def test_fix_py3_only_code(s, expected):
ret = _fix_plugins(s, settings=Settings(min_version=(3,)))
assert ret == expected


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'import sys\n'
'if sys.version_info > (3, 5):\n'
' 3+6\n'
'else:\n'
' 3-5\n',
'import sys\n'
'3+6\n',
id='sys.version_info > (3, 5)',
),
pytest.param(
'from sys import version_info\n'
'if version_info > (3, 5):\n'
' 3+6\n'
'else:\n'
' 3-5\n',
'from sys import version_info\n'
'3+6\n',
id='from sys import version_info, > (3, 5)',
),
pytest.param(
'import sys\n'
'if sys.version_info >= (3, 6):\n'
' 3+6\n'
'else:\n'
' 3-5\n',
'import sys\n'
'3+6\n',
id='sys.version_info >= (3, 6)',
),
pytest.param(
'from sys import version_info\n'
'if version_info >= (3, 6):\n'
' 3+6\n'
'else:\n'
' 3-5\n',
'from sys import version_info\n'
'3+6\n',
id='from sys import version_info, >= (3, 6)',
),
pytest.param(
'import sys\n'
'if sys.version_info < (3, 6):\n'
' 3-5\n'
'else:\n'
' 3+6\n',
'import sys\n'
'3+6\n',
id='sys.version_info < (3, 6)',
),
pytest.param(
'from sys import version_info\n'
'if version_info < (3, 6):\n'
' 3-5\n'
'else:\n'
' 3+6\n',
'from sys import version_info\n'
'3+6\n',
id='from sys import version_info, < (3, 6)',
),
pytest.param(
'import sys\n'
'if sys.version_info <= (3, 5):\n'
' 3-5\n'
'else:\n'
' 3+6\n',
'import sys\n'
'3+6\n',
id='sys.version_info <= (3, 5)',
),
pytest.param(
'from sys import version_info\n'
'if version_info <= (3, 5):\n'
' 3-5\n'
'else:\n'
' 3+6\n',
'from sys import version_info\n'
'3+6\n',
id='from sys import version_info, <= (3, 5)',
),
),
)
def test_fix_py3x_only_code(s, expected):
ret = _fix_plugins(s, settings=Settings(min_version=(3, 6)))
assert ret == expected


@pytest.mark.parametrize(
's',
(
# we timidly skip `if` without `else` as it could cause a SyntaxError
'import sys'

This comment has been minimized.

Copy link
@graingert

graingert Sep 13, 2021

Contributor

Missing a \n?

This comment has been minimized.

Copy link
@asottile

This comment has been minimized.

Copy link
@neutrinoceros

neutrinoceros Sep 13, 2021

Author Contributor

my bad, here's a fix: #533

'if sys.version_info >= (3, 6):\n'
' pass',
# here's the case where it causes a SyntaxError
'import sys'
'if True'
' if sys.version_info >= (3, 6):\n'
' pass\n',
# both branches are still relevant in the following cases
'import sys\n'
'if sys.version_info > (3, 7):\n'
' 3-6\n'
'else:\n'
' 3+7\n',
'import sys\n'
'if sys.version_info < (3, 7):\n'
' 3-6\n'
'else:\n'
' 3+7\n',
'import sys\n'
'if sys.version_info >= (3, 7):\n'
' 3+7\n'
'else:\n'
' 3-6\n',
'import sys\n'
'if sys.version_info <= (3, 7):\n'
' 3-7\n'
'else:\n'
' 3+8\n',
'import sys\n'
'if sys.version_info <= (3, 6):\n'
' 3-6\n'
'else:\n'
' 3+7\n',
'import sys\n'
'if sys.version_info > (3, 6):\n'
' 3+7\n'
'else:\n'
' 3-6\n',
),
)
def test_fix_py3x_only_noop(s):
assert _fix_plugins(s, settings=Settings(min_version=(3, 6))) == s

0 comments on commit 31546d2

Please sign in to comment.