Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to disallowed_imports #646

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Unreleased

- Add `disallowed_imports` configuration option to disallow
imports of specific modules (#645)
imports of specific modules (#645, #646)
- Consider an annotated assignment without a value to be
an exported name (#644)
- Improve the location where `missing_parameter_annotation`
Expand Down
23 changes: 14 additions & 9 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,15 +2320,20 @@ def check_deprecation(self, node: ast.AST, value: Value) -> bool:

# Imports

def check_for_disallowed_import(self, node: ast.AST, name: str) -> None:
print("CHECK", name)
def check_for_disallowed_import(
self, node: ast.AST, name: str, *, check_parents: bool = True
) -> None:
disallowed = self.options.get_value_for(DisallowedImports)
if name in disallowed:
self._show_error_if_checking(
node,
f"Disallowed import of module {name!r}",
error_code=ErrorCode.disallowed_import,
)
parts = name.split(".") if check_parents else [name]
for i in range(len(parts)):
name_to_check = ".".join(parts[: i + 1])
if name_to_check in disallowed:
self._show_error_if_checking(
node,
f"Disallowed import of module {name!r}",
error_code=ErrorCode.disallowed_import,
)
break

def visit_Import(self, node: ast.Import) -> None:
self.generic_visit(node)
Expand Down Expand Up @@ -2434,7 +2439,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
else:
error_node = node
self.check_for_disallowed_import(
error_node, f"{node.module}.{alias.name}"
error_node, f"{node.module}.{alias.name}", check_parents=False
)

self._maybe_record_usages_from_import(node)
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ class_attribute_transformers = [
disallowed_imports = [
"getopt",
"email.quoprimime",
"xml",
]
21 changes: 14 additions & 7 deletions pyanalyze/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,45 @@ class TestDisallowedImport(TestNameCheckVisitorBase):
@assert_passes()
def test_top_level(self):
import getopt # E: disallowed_import
import xml.etree.ElementTree # E: disallowed_import

from getopt import GetoptError # E: disallowed_import

print(getopt, GetoptError) # shut up flake8
print(getopt, GetoptError, xml) # shut up flake8

def capybara():
import getopt # E: disallowed_import
from getopt import GetoptError # E: disallowed_import
import xml.etree.ElementTree # E: disallowed_import

print(getopt, GetoptError)
print(getopt, GetoptError, xml)

@assert_passes()
def test_nested(self):
import email.quoprimime # E: disallowed_import
import email.base64mime # ok
from email.quoprimime import unquote # E: disallowed_import
from xml.etree import ElementTree # E: disallowed_import

print(email, unquote)
print(email, unquote, ElementTree)

def capybara():
import email.quoprimime # E: disallowed_import
import email.base64mime # ok
from email.quoprimime import unquote # E: disallowed_import
from email import quoprimime # E: disallowed_import
from xml.etree import ElementTree # E: disallowed_import

print(email, unquote, quoprimime)
print(email, unquote, ElementTree)

@assert_passes()
def test_import_from(self):
from email import quoprimime # E: disallowed_import
from email import base64mime # ok

print(quoprimime)
print(quoprimime, base64mime)

def capybara():
from email import quoprimime # E: disallowed_import
from email import base64mime # ok

print(quoprimime)
print(quoprimime, base64mime)