Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix attrs on MVV within Annotated #393

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Fix accessing attributes on Unions nested within
Annotated (#393)
- Fix interaction of `register_error_code()` with new
configuration mechanism (#391)
- Check against invalid `Signature` objects and prepare
Expand Down
77 changes: 26 additions & 51 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
UNINITIALIZED_VALUE,
NO_RETURN_VALUE,
NoReturnConstraintExtension,
flatten_values,
is_union,
kv_pairs_from_mapping,
make_weak,
unannotate_value,
Expand Down Expand Up @@ -3959,7 +3961,9 @@ def composite_from_attribute(self, node: ast.Attribute) -> Composite:
self.reexport_tracker.record_attribute_accessed(
root_composite.value.val.__name__, node.attr, node, self
)
value = self._get_attribute_with_fallback(root_composite, node.attr, node)
value = self.get_attribute(
root_composite, node.attr, node, use_fallback=True
)
if self._should_use_varname_value(value):
varname_value = self.checker.maybe_get_variable_name_value(node.attr)
if varname_value is not None:
Expand All @@ -3981,7 +3985,9 @@ def get_attribute(
root_composite: Composite,
attr: str,
node: Optional[ast.AST] = None,
*,
ignore_none: bool = False,
use_fallback: bool = False,
) -> Value:
"""Get an attribute of this value.

Expand All @@ -3994,69 +4000,38 @@ def get_attribute(
varname=root_composite.varname,
node=root_composite.node,
)
if isinstance(root_composite.value, MultiValuedValue):
values = [
self.get_attribute(
Composite(subval, root_composite.varname, root_composite.node),
attr,
node,
ignore_none=ignore_none,
)
for subval in root_composite.value.vals
]
if any(value is UNINITIALIZED_VALUE for value in values):
return UNINITIALIZED_VALUE
return unite_values(*values)
return self._get_attribute_no_mvv(
root_composite, attr, node, ignore_none=ignore_none
)

def get_attribute_from_value(self, root_value: Value, attribute: str) -> Value:
return self.get_attribute(Composite(root_value), attribute)

def _get_attribute_no_mvv(
self,
root_composite: Composite,
attr: str,
node: Optional[ast.AST] = None,
ignore_none: bool = False,
) -> Value:
"""Get an attribute. root_value must not be a MultiValuedValue."""
ctx = _AttrContext(
root_composite, attr, self, node=node, ignore_none=ignore_none
)
return attributes.get_attribute(ctx)

def _get_attribute_with_fallback(
self, root_composite: Composite, attr: str, node: ast.AST
) -> Value:
ignore_none = self.options.get_value_for(IgnoreNoneAttributes)
if isinstance(root_composite.value, TypeVarValue):
root_composite = Composite(
value=root_composite.value.get_fallback_value(),
varname=root_composite.varname,
node=root_composite.node,
)
if isinstance(root_composite.value, MultiValuedValue):
if is_union(root_composite.value):
results = []
for subval in root_composite.value.vals:
for subval in flatten_values(root_composite.value):
composite = Composite(
subval, root_composite.varname, root_composite.node
)
subresult = self.get_attribute(
composite, attr, node, ignore_none=ignore_none
composite,
attr,
node,
ignore_none=ignore_none,
use_fallback=use_fallback,
)
if subresult is UNINITIALIZED_VALUE:
if (
subresult is UNINITIALIZED_VALUE
and use_fallback
and node is not None
):
subresult = self._get_attribute_fallback(subval, attr, node)
results.append(subresult)
return unite_values(*results)
result = self._get_attribute_no_mvv(
root_composite, attr, node, ignore_none=ignore_none
ctx = _AttrContext(
root_composite, attr, self, node=node, ignore_none=ignore_none
)
if result is UNINITIALIZED_VALUE:
result = attributes.get_attribute(ctx)
if result is UNINITIALIZED_VALUE and use_fallback and node is not None:
return self._get_attribute_fallback(root_composite.value, attr, node)
return result

def get_attribute_from_value(self, root_value: Value, attribute: str) -> Value:
return self.get_attribute(Composite(root_value), attribute)

def _get_attribute_fallback(
self, root_value: Value, attr: str, node: ast.AST
) -> Value:
Expand Down
7 changes: 7 additions & 0 deletions pyanalyze/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ def capybara():
annotated_global, MultiValuedValue([TypedValue(str), KnownValue(None)])
)

@assert_passes()
def test_unwrap_mvv(self):
def render_task(name: str):
if not (name or "").strip():
name = "x"
assert_is_value(name, TypedValue(str) | KnownValue("x"))


class TestHasAttrExtension(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
6 changes: 6 additions & 0 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,12 @@ def from_varname(
return None


def is_union(val: Value) -> bool:
return isinstance(val, MultiValuedValue) or (
isinstance(val, AnnotatedValue) and isinstance(val.value, MultiValuedValue)
)


def flatten_values(val: Value, *, unwrap_annotated: bool = False) -> Iterable[Value]:
"""Flatten a :class:`MultiValuedValue` into its constituent values.

Expand Down