Skip to content

Commit

Permalink
Update pydantic to >= 2 (matrix-org#15858)
Browse files Browse the repository at this point in the history
Adapt codebase to use pydantic >= 2 models and functionalities.

Remove unneeded checks from `scripts-dev/check_pydantic_models.py`,
since pydantic can now be used in a strict mode which will prevent the
type coercion:
https://docs.pydantic.dev/2.0/usage/strict_mode/#type-coercions-in-strict-mode

Closes matrix-org#15858

Signed-off-by: David Runge <dave@sleepmap.de>
  • Loading branch information
dvzrv committed Jul 24, 2023
1 parent 641ff9e commit a465aa8
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 147 deletions.
1 change: 1 addition & 0 deletions changelog.d/15979.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update the `pydantic` dependency to `>= 2`. Contributed by @dvzrv.
188 changes: 141 additions & 47 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ matrix-common = "^1.3.0"
packaging = ">=16.1"
# This is the most recent version of Pydantic with available on common distros.
# We are currently incompatible with >=2.0.0: (https://github.com/matrix-org/synapse/issues/15858)
pydantic = "^1.7.4"
pydantic = ">=2.0.0"

# This is for building the rust components during "poetry install", which
# currently ignores the `build-system.requires` directive (c.f.
Expand Down
41 changes: 1 addition & 40 deletions scripts-dev/check_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@
import traceback
import unittest.mock
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar
from typing import Callable, Dict, Generator, List, Set, TypeVar

from parameterized import parameterized
from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
from pydantic.typing import get_args
from typing_extensions import ParamSpec

logger = logging.getLogger(__name__)
Expand All @@ -52,14 +51,6 @@
confloat,
]

TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
str,
bytes,
int,
float,
bool,
]


P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -97,42 +88,12 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper


def field_type_unwanted(type_: Any) -> bool:
"""Very rough attempt to detect if a type is unwanted as a Pydantic annotation.
At present, we exclude types which will coerce, or any generic type involving types
which will coerce."""
logger.debug("Is %s unwanted?")
if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
logger.debug("yes")
return True
logger.debug("Maybe. Subargs are %s", get_args(type_))
rv = any(field_type_unwanted(t) for t in get_args(type_))
logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
return rv


class PatchedBaseModel(PydanticBaseModel):
"""A patched version of BaseModel that inspects fields after models are defined.
We complain loudly if we see an unwanted type.
Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
"""

@classmethod
def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
for field in cls.__fields__.values():
# Note that field.type_ and field.outer_type are computed based on the
# annotation type, see pydantic.fields.ModelField._type_analysis
if field_type_unwanted(field.outer_type_):
# TODO: this only reports the first bad field. Can we find all bad ones
# and report them all?
raise FieldHasUnwantedTypeException(
f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
f"with unwanted type `{field.outer_type_}`"
)


@contextmanager
def monkeypatch_pydantic() -> Generator[None, None, None]:
Expand Down
5 changes: 3 additions & 2 deletions synapse/config/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, Type, TypeVar

import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as
from pydantic import BaseModel, TypeAdapter, ValidationError

from synapse.config._base import ConfigError
from synapse.types import JsonDict, StrSequence
Expand Down Expand Up @@ -86,7 +86,8 @@ def parse_and_validate_mapping(
try:
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
# `model_type` is a runtime variable. Pydantic is fine with this.
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
adapter = TypeAdapter(Dict[str, model_type]) # type: ignore[valid-type]
instances = adapter.validate_python(config, strict=True)
except ValidationError as e:
raise ConfigError(str(e)) from e
return instances
8 changes: 2 additions & 6 deletions synapse/config/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Dict, List, Optional, Union

import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from pydantic import BaseModel, ConfigDict, StrictBool, StrictInt, StrictStr

from synapse.config._base import (
Config,
Expand Down Expand Up @@ -87,11 +87,7 @@ class ConfigModel(BaseModel):
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""

class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
model_config = ConfigDict(extra="ignore", frozen=True)


class InstanceTcpLocationConfig(ConfigModel):
Expand Down
15 changes: 7 additions & 8 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
overload,
)

from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
from pydantic.error_wrappers import ErrorWrapper
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal

from twisted.web.server import Request
Expand Down Expand Up @@ -786,20 +785,20 @@ def validate_json_object(content: JsonDict, model_type: Type[Model]) -> Model:
if it wasn't a JSON object.
"""
try:
instance = model_type.parse_obj(content)
instance = model_type.model_validate(content, strict=True)
except ValidationError as e:
# Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
# more specific error if possible (which occasionally helps us to be spec-
# compliant) This is a bit awkward because the spec's error codes aren't very
# clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
errcode = Codes.BAD_JSON

raw_errors = e.raw_errors
if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
raw_error = raw_errors[0].exc
if isinstance(raw_error, MissingError):
raw_errors = e.errors()
if len(raw_errors) == 1:
error_type = raw_errors[0].get("type")
if error_type == "missing":
errcode = Codes.MISSING_PARAM
elif isinstance(raw_error, PydanticValueError):
else:
errcode = Codes.INVALID_PARAM

raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)
Expand Down
20 changes: 9 additions & 11 deletions synapse/rest/client/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple

from pydantic import Extra, StrictStr
from pydantic import ConfigDict, StrictStr

from synapse.api import errors
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData]
auth: Optional[AuthenticationData] = None
devices: List[StrictStr]

@interactive_auth_handler
Expand All @@ -105,7 +105,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# TODO: Can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = self.PostBody.parse_obj({})
body = self.PostBody.model_validate({}, strict=True)
else:
raise e

Expand Down Expand Up @@ -164,7 +164,7 @@ async def on_GET(
return 200, device

class DeleteBody(RequestBodyModel):
auth: Optional[AuthenticationData]
auth: Optional[AuthenticationData] = None

@interactive_auth_handler
async def on_DELETE(
Expand All @@ -183,7 +183,7 @@ async def on_DELETE(
# TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = self.DeleteBody.parse_obj({})
body = self.DeleteBody.model_validate({}, strict=True)
else:
raise

Expand All @@ -203,7 +203,7 @@ async def on_DELETE(
return 200, {}

class PutBody(RequestBodyModel):
display_name: Optional[StrictStr]
display_name: Optional[StrictStr] = None

async def on_PUT(
self, request: SynapseRequest, device_id: str
Expand All @@ -223,8 +223,7 @@ class DehydratedDeviceDataModel(RequestBodyModel):
Expects other freeform fields. Use .dict() to access them.
"""

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

algorithm: StrictStr

Expand Down Expand Up @@ -295,7 +294,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

class PutBody(RequestBodyModel):
device_data: DehydratedDeviceDataModel
initial_device_display_name: Optional[StrictStr]
initial_device_display_name: Optional[StrictStr] = None

async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_and_validate_json_object_from_request(request, self.PutBody)
Expand Down Expand Up @@ -530,8 +529,7 @@ class PutBody(RequestBodyModel):
device_id: StrictStr
initial_device_display_name: Optional[StrictStr]

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_and_validate_json_object_from_request(request, self.PutBody)
Expand Down
38 changes: 22 additions & 16 deletions synapse/rest/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from pydantic import Extra, StrictInt, StrictStr, constr, validator
from pydantic import (
ConfigDict,
StrictInt,
StrictStr,
constr,
field_validator,
model_validator,
)

from synapse.rest.models import RequestBodyModel
from synapse.util.threepids import validate_email
Expand All @@ -29,8 +36,7 @@ class AuthenticationData(RequestBodyModel):
`.dict(exclude_unset=True)` to access them.
"""

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

session: Optional[StrictStr] = None
type: Optional[StrictStr] = None
Expand All @@ -41,7 +47,7 @@ class Config:
else:
# See also assert_valid_client_secret()
ClientSecretStr = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
pattern="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=1,
max_length=255,
strict=True,
Expand All @@ -50,18 +56,18 @@ class Config:

class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretStr
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
id_server: Optional[StrictStr] = None
id_access_token: Optional[StrictStr] = None
next_link: Optional[StrictStr] = None
send_attempt: StrictInt

@validator("id_access_token", always=True)
def token_required_for_identity_server(
cls, token: Optional[str], values: Dict[str, object]
) -> Optional[str]:
if values.get("id_server") is not None and token is None:
@model_validator(mode="before")
def token_required_for_identity_server(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure that an access token is provided when a server is provided."""
if data.get("id_server") is not None and data.get("id_access_token") is None:
raise ValueError("id_access_token is required if an id_server is supplied.")
return token

return data


class EmailRequestTokenBody(ThreepidRequestTokenBody):
Expand All @@ -72,14 +78,14 @@ class EmailRequestTokenBody(ThreepidRequestTokenBody):
# know the exact spelling (eg. upper and lower case) of address in the database.
# Without this, an email stored in the database as "foo@bar.com" would cause
# user requests for "FOO@bar.com" to raise a Not Found error.
_email_validator = validator("email", allow_reuse=True)(validate_email)
email_validator = field_validator("email")(validate_email)


if TYPE_CHECKING:
ISO3116_1_Alpha_2 = StrictStr
else:
# Per spec: two-letter uppercase ISO-3166-1-alpha-2
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
ISO3116_1_Alpha_2 = constr(pattern="[A-Z]{2}", strict=True)


class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
Expand Down
8 changes: 2 additions & 6 deletions synapse/rest/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict


class RequestBodyModel(BaseModel):
Expand All @@ -16,8 +16,4 @@ class RequestBodyModel(BaseModel):
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""

class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
model_config = ConfigDict(extra="ignore", frozen=True)
4 changes: 3 additions & 1 deletion synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,9 @@ async def validate_constraint_and_delete_in_background(
# match the constraint.
# 3. We try re-validating the constraint.

parsed_progress = ValidateConstraintProgress.parse_obj(progress)
parsed_progress = ValidateConstraintProgress.model_validate(
progress, strict=True
)

if parsed_progress.state == ValidateConstraintProgress.State.check:
return_columns = ", ".join(unique_columns)
Expand Down
6 changes: 3 additions & 3 deletions tests/rest/client/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,21 +794,21 @@ def test_add_valid_email_second_time_canonicalise(self) -> None:
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
expected_errcode=Codes.BAD_JSON,
expected_errcode=Codes.INVALID_PARAM,
expected_error="Unable to parse email address",
)

def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
expected_errcode=Codes.BAD_JSON,
expected_errcode=Codes.INVALID_PARAM,
expected_error="Unable to parse email address",
)

def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
expected_errcode=Codes.BAD_JSON,
expected_errcode=Codes.INVALID_PARAM,
expected_error="Unable to parse email address",
)

Expand Down
Loading

0 comments on commit a465aa8

Please sign in to comment.