Skip to content

Commit

Permalink
Support passing TypedDicts to WithAnnotations
Browse files Browse the repository at this point in the history
  • Loading branch information
syastrov committed Jul 11, 2021
1 parent 7108a0c commit 51f0448
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 65 deletions.
31 changes: 25 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,44 @@ def use_my_model():

### How do I annotate cases where I called QuerySet.annotate?

Django-stubs provides a special type, `django_stubs_ext.WithAnnotations`, which marks a `Model` as having been
annotated, meaning it allows getting/setting any attribute on the model instance.
Django-stubs provides a special type, `django_stubs_ext.WithAnnotations[Model]`, which indicates that the `Model` has
been annotated, meaning it allows getting/setting extra attributes on the model instance.

Currently, the mypy plugin is not smart enough to recognize that specific names were passed to `QuerySet.annotate` and
include them in the type.
Optionally, you can provide a `TypedDict` of these attributes,
e.g. `WithAnnotations[MyModel, MyTypedDict]`, to specify which annotated attributes are present.

Currently, the mypy plugin can recognize that specific names were passed to `QuerySet.annotate` and
include them in the type, but does not record the types of these attributes.

The knowledge of the specific annotated fields is not yet used in creating more specific types for `QuerySet`'s
`values`, `values_list`, or `filter` methods, however knowledge that the model was annotated _is_ used to create a
broader type result type for `values`/`values_list`, and to allow `filter`ing on any field.

```python
from typing import TypedDict
from django_stubs_ext import WithAnnotations
from django.db import models
from django.db.models.expressions import Value

class MyModel(models.Model):
username = models.CharField(max_length=100)


def func(m: WithAnnotations[MyModel]) -> str:
return m.asdf # OK, since the model is annotated
return m.asdf # OK, since the model is annotated as allowing any attribute

func(MyModel.objects.annotate(foo="").get(id=1)) # OK
func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.get(id=1)) # Error, since this model will not allow access to any attribute


class MyTypedDict(TypedDict):
foo: str

def func2(m: WithAnnotations[MyModel, MyTypedDict]) -> str:
return m.foo # OK, since we said field "foo" was allowed

func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.annotate(bar=Value("")).get(id=1)) # Error
```

## Related projects
Expand Down
4 changes: 2 additions & 2 deletions django-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Protocol

from .utils.version import get_version as get_version

Expand All @@ -8,6 +8,6 @@ __version__: str
def setup(set_prefix: bool = ...) -> None: ...

# Used internally by mypy_django_plugin.
class _AnyAttrAllowed:
class _AnyAttrAllowed(Protocol):
def __getattr__(self, item: str) -> Any: ...
def __setattr__(self, item: str, value: Any) -> None: ...
18 changes: 13 additions & 5 deletions django_stubs_ext/django_stubs_ext/annotations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import Any, Type, TypeVar
from typing import Any, Generic, Mapping, TypeVar

from django.db.models.base import Model
from typing_extensions import Annotated

# Really, we would like to use TypedDict as a bound, but it's not possible
_Annotations = TypeVar("_Annotations", covariant=True, bound=Mapping[str, Any])

class Annotations:
def __init__(self, **kwargs: Type[Any]):
pass

class Annotations(Generic[_Annotations]):
"""Use as `Annotations[MyTypedDict]`"""

pass


_T = TypeVar("_T", bound=Model)

WithAnnotations = Annotated[_T, Annotations()]
WithAnnotations = Annotated[_T, Annotations[_Annotations]]
"""Alias to make it easy to annotate the model `_T` as having annotations `_Annotations` (a `TypedDict` or `Any` if not provided).
Use as `WithAnnotations[MyModel]` or `WithAnnotations[MyModel, MyTypedDict]`.
"""
3 changes: 3 additions & 0 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
if fullname.startswith(annotated_prefix):
# For our "annotated models", extract the original model fullname
fullname = fullname[len(annotated_prefix) :].rstrip("]")
if "," in fullname:
# Remove second type arg, which might be present
fullname = fullname[: fullname.index(",")]

module, _, model_cls_name = fullname.rpartition(".")
for model_cls in self.model_modules.get(module, set()):
Expand Down
1 change: 1 addition & 0 deletions mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RELATED_MANAGER_CLASS = "django.db.models.manager.RelatedManager"

WITH_ANNOTATIONS_FULLNAME = "django_stubs_ext.WithAnnotations"
ANNOTATIONS_FULLNAME = "django_stubs_ext.annotations.Annotations"

BASEFORM_CLASS_FULLNAME = "django.forms.forms.BaseForm"
FORM_CLASS_FULLNAME = "django.forms.forms.Form"
Expand Down
77 changes: 47 additions & 30 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypedDictType, TypeOfAny

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import add_new_class_for_module
from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
Expand Down Expand Up @@ -402,11 +402,23 @@ def handle_annotated_type(ctx: AnalyzeTypeContext, django_context: DjangoContext
if not isinstance(type_arg, Instance):
return ctx.api.analyze_type(ctx.type)

return get_or_create_annotated_type(api, type_arg)
fields_dict = None
if len(args) > 1:
second_arg_type = ctx.api.analyze_type(args[1])
if isinstance(second_arg_type, TypedDictType):
fields_dict = second_arg_type
elif isinstance(second_arg_type, Instance) and second_arg_type.type.fullname == ANNOTATIONS_FULLNAME:
annotations_type_arg = second_arg_type.args[0]
if isinstance(annotations_type_arg, TypedDictType):
fields_dict = annotations_type_arg
elif not isinstance(annotations_type_arg, AnyType):
ctx.api.fail("Only TypedDicts are supported as type arguments to Annotations", ctx.context)

return get_or_create_annotated_type(api, type_arg, fields_dict=fields_dict)


def get_or_create_annotated_type(
api: Union[SemanticAnalyzer, CheckerPluginInterface], model_type: Instance
api: Union[SemanticAnalyzer, CheckerPluginInterface], model_type: Instance, fields_dict: Optional[TypedDictType]
) -> Instance:
"""
Expand All @@ -418,33 +430,38 @@ def get_or_create_annotated_type(
This is a bit of a hack to make a pretty type for error messages and which would make sense for users.
"""
model_module_name = "django_stubs_ext"
type_name = f"WithAnnotations[{model_type.type.fullname}]"

# If already existing annotated type for model exists, reuse it
if model_type.type.has_base(ANY_ATTR_ALLOWED_CLASS_FULLNAME):
annotated_type = model_type
if helpers.is_annotated_model_fullname(model_type.type.fullname):
# If it's already a generated class, we want to use the original model as a base
model_type = model_type.type.bases[0]

if fields_dict is not None:
type_name = f"WithAnnotations[{model_type.type.fullname}, {fields_dict}]"
else:
annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(
cast(TypeChecker, api), model_module_name + "." + type_name
type_name = f"WithAnnotations[{model_type.type.fullname}]"

annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(
cast(TypeChecker, api), model_module_name + "." + type_name
)
if annotated_typeinfo is None:
model_module_file = api.modules[model_module_name] # type: ignore

if isinstance(api, SemanticAnalyzer):
annotated_model_type = api.named_type_or_none(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])
assert annotated_model_type is not None
else:
annotated_model_type = api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])

annotated_typeinfo = add_new_class_for_module(
model_module_file,
type_name,
bases=[model_type] if fields_dict is not None else [model_type, annotated_model_type],
fields=fields_dict.items if fields_dict is not None else None,
)
if annotated_typeinfo is None:
model_module_file = api.modules[model_module_name] # type: ignore
if isinstance(api, SemanticAnalyzer):
annotated_model_type = api.named_type_or_none(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])
assert annotated_model_type is not None
else:
annotated_model_type = api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])

# Create a new class in the same module as the model, with the same name as the model but with a suffix
# The class inherits from the model and an internal class which allows get/set of any attribute.
# Essentially, this is a way of making an "intersection" type between the two types.
annotated_typeinfo = add_new_class_for_module(
model_module_file,
type_name,
bases=[
model_type,
annotated_model_type,
],
)
annotated_type = Instance(annotated_typeinfo, [])
if fields_dict is not None:
# To allow structural subtyping, make it a Protocol
annotated_typeinfo.is_protocol = True
# Save for later to easily find which field types were annotated
annotated_typeinfo.metadata["annotated_field_types"] = fields_dict.items
annotated_type = Instance(annotated_typeinfo, [])
return annotated_type
39 changes: 36 additions & 3 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import OrderedDict
from typing import List, Optional, Sequence, Type
from typing import Dict, List, Optional, Sequence, Type

from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.nodes import Expression, NameExpr
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression, NameExpr
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, TupleType
from mypy.types import Type as MypyType
Expand Down Expand Up @@ -183,6 +183,23 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
return helpers.reparametrize_instance(default_return_type, [model_type, row_type])


def gather_kwargs(ctx: MethodContext) -> Optional[Dict[str, MypyType]]:
num_args = len(ctx.arg_kinds)
kwargs = {}
named = (ARG_NAMED, ARG_NAMED_OPT)
for i in range(num_args):
if not ctx.arg_kinds[i]:
continue
if any(kind not in named for kind in ctx.arg_kinds[i]):
# Only named arguments supported
return None
for j in range(len(ctx.arg_names[i])):
name = ctx.arg_names[i][j]
assert name is not None
kwargs[name] = ctx.arg_types[i][j]
return kwargs


def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance)
Expand All @@ -195,7 +212,23 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj

api = ctx.api

annotated_type = get_or_create_annotated_type(api, model_type)
field_types = model_type.type.metadata.get("annotated_field_types")
kwargs = gather_kwargs(ctx)
if kwargs:
# For now, we don't try to resolve the output_field of the field would be, but use Any.
added_field_types = {name: AnyType(TypeOfAny.implementation_artifact) for name, typ in kwargs.items()}
if field_types is not None:
# Annotate was called more than once, so add/update existing field types
field_types.update(added_field_types)
else:
field_types = added_field_types

fields_dict = None
if field_types is not None:
fields_dict = helpers.make_typeddict(
api, fields=OrderedDict(field_types), required_keys=set(field_types.keys())
)
annotated_type = get_or_create_annotated_type(api, model_type, fields_dict=fields_dict)

row_type: MypyType
if len(default_return_type.args) > 1:
Expand Down
Loading

0 comments on commit 51f0448

Please sign in to comment.