From 46a31d47ae1456d54dedc9d3123e97c7a2e8424e Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand Date: Wed, 23 Aug 2023 09:30:15 +0200 Subject: [PATCH] feat: migrate to Pydantic v2 (#656) * Upgrade to pydantic v2 * Replace deprecated dict with model_dump * Replace update_forward_refs with model_rebuild * Fix error message in test * Fix Packages and ChannelSearch rest models - platforms is nullable: it should be optional - Channel description is optional - so should be ChannelSearch description * Fix Job rest model item_spec can't be set to required in JobBase and optional in Job (mypy complains). Create a new class JobCreate where item_spec is required. In job, item_spec is optional. * Add UserOptionalRole User role is nullable and should be optional in get (but not in set) * Remove nullable=True in pydantic Field This isn't used. To mark a field as nullable, it should be set as Optional or "| None". Fix deprecation warning: Extra keyword arguments on `Field` is deprecated and will be removed. * Fix Channel model keyword should be examples (not example) * Replace deprecated from_orm with model_validate * Replace deprecated parse_obj with model_validate * Fix UserWarning UserWarning: `pydantic.error_wrappers:ValidationError` has been moved to `pydantic:ValidationError`. * Fix test_post_new_job_invalid_items_spec * Fix starlette Deprecation Warning DeprecationWarning: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead. * Increase wait time for running task tests Test failing with postgres. Task still pending after 2.5 seconds. --- environment.yml | 2 +- quetz/db_models.py | 2 +- quetz/jobs/api.py | 4 +- quetz/jobs/rest_models.py | 24 +++-- quetz/main.py | 13 ++- quetz/metrics/rest_models.py | 6 +- quetz/rest_models.py | 139 ++++++++++---------------- quetz/tests/api/test_api_keys.py | 6 +- quetz/tests/api/test_channels.py | 2 +- quetz/tests/api/test_main_packages.py | 4 +- quetz/tests/test_jobs.py | 12 ++- quetz/tests/test_mirror.py | 12 +-- setup.cfg | 1 + 13 files changed, 102 insertions(+), 125 deletions(-) diff --git a/environment.yml b/environment.yml index 083b5a56..9c4dccfe 100644 --- a/environment.yml +++ b/environment.yml @@ -50,6 +50,6 @@ dependencies: - conda-content-trust - pyinstrument - pytest-asyncio - - pydantic <2 + - pydantic >=2 - pip: - git+https://github.com/jupyter-server/jupyter_releaser.git@v2 diff --git a/quetz/db_models.py b/quetz/db_models.py index 91c5db49..10ece686 100644 --- a/quetz/db_models.py +++ b/quetz/db_models.py @@ -55,7 +55,7 @@ class User(Base): 'Profile', uselist=False, back_populates='user', cascade="all,delete-orphan" ) - role = Column(String) + role = Column(String, nullable=True) @classmethod def find(cls, db, name): diff --git a/quetz/jobs/api.py b/quetz/jobs/api.py index 79cf087c..38579b2f 100644 --- a/quetz/jobs/api.py +++ b/quetz/jobs/api.py @@ -17,7 +17,7 @@ from quetz.rest_models import PaginatedResponse from .models import JobStatus, TaskStatus -from .rest_models import Job, JobBase, JobUpdateModel, Task +from .rest_models import Job, JobCreate, JobUpdateModel, Task api_router = APIRouter() @@ -44,7 +44,7 @@ def get_jobs( @api_router.post("/api/jobs", tags=["Jobs"], status_code=201, response_model=Job) def create_job( - job: JobBase, + job: JobCreate, dao: Dao = Depends(get_dao), auth: authorization.Rules = Depends(get_rules), ): diff --git a/quetz/jobs/rest_models.py b/quetz/jobs/rest_models.py index 64607b2d..c518a587 100644 --- a/quetz/jobs/rest_models.py +++ b/quetz/jobs/rest_models.py @@ -5,7 +5,7 @@ from typing import Optional from importlib_metadata import entry_points as get_entry_points -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from . import handlers from .models import JobStatus, TaskStatus @@ -83,7 +83,6 @@ def parse_job_name(v): class JobBase(BaseModel): """New job spec""" - items_spec: str = Field(..., title='Item selector spec') manifest: str = Field(None, title='Name of the function') start_at: Optional[datetime] = Field( @@ -97,7 +96,8 @@ class JobBase(BaseModel): ), ) - @validator("manifest", pre=True) + @field_validator("manifest", mode="before") + @classmethod def validate_job_name(cls, function_name): if isinstance(function_name, bytes): return parse_job_name(function_name) @@ -107,6 +107,12 @@ def validate_job_name(cls, function_name): return function_name.encode('ascii') +class JobCreate(JobBase): + """Create job spec""" + + items_spec: str = Field(..., title='Item selector spec') + + class JobUpdateModel(BaseModel): """Modify job spec items (status and items_spec)""" @@ -123,10 +129,8 @@ class Job(JobBase): status: JobStatus = Field(None, title='Status of the job (running, paused, ...)') - items_spec: str = Field(None, title='Item selector spec') - - class Config: - orm_mode = True + items_spec: Optional[str] = Field(None, title='Item selector spec') + model_config = ConfigDict(from_attributes=True) class Task(BaseModel): @@ -136,12 +140,12 @@ class Task(BaseModel): created: datetime = Field(None, title='Created at') status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)') - @validator("package_version", pre=True) + @field_validator("package_version", mode="before") + @classmethod def convert_package_version(cls, v): if v: return {'filename': v.filename, 'id': uuid.UUID(bytes=v.id).hex} else: return {} - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) diff --git a/quetz/main.py b/quetz/main.py index 88bd8fe6..fe7d0c53 100644 --- a/quetz/main.py +++ b/quetz/main.py @@ -17,7 +17,6 @@ from typing import Awaitable, Callable, List, Optional, Tuple, Type import pydantic -import pydantic.error_wrappers import requests from fastapi import ( APIRouter, @@ -477,7 +476,7 @@ def delete_user( @api_router.get( "/users/{username}/role", - response_model=rest_models.UserRole, + response_model=rest_models.UserOptionalRole, tags=["users"], ) def get_user_role( @@ -732,7 +731,7 @@ def post_channel( detail="Cannot use both `includelist` and `excludelist` together.", ) - user_attrs = new_channel.dict(exclude_unset=True) + user_attrs = new_channel.model_dump(exclude_unset=True) if "size_limit" in user_attrs: auth.assert_set_channel_size_limit() @@ -789,7 +788,7 @@ def patch_channel( ): auth.assert_update_channel_info(channel.name) - user_attrs = channel_data.dict(exclude_unset=True) + user_attrs = channel_data.model_dump(exclude_unset=True) if "size_limit" in user_attrs: auth.assert_set_channel_size_limit() @@ -1064,7 +1063,7 @@ def get_package_versions( version_list = [] for version, profile, api_key_profile in version_profile_list: - version_data = rest_models.PackageVersion.from_orm(version) + version_data = rest_models.PackageVersion.model_validate(version) version_list.append(version_data) return version_list @@ -1089,7 +1088,7 @@ def get_paginated_package_versions( version_list = [] for version, profile, api_key_profile in version_profile_list['result']: - version_data = rest_models.PackageVersion.from_orm(version) + version_data = rest_models.PackageVersion.model_validate(version) version_list.append(version_data) return { @@ -1650,7 +1649,7 @@ def _delete_file(condainfo, filename): summary=str(condainfo.about.get("summary", "n/a")), description=str(condainfo.about.get("description", "n/a")), ) - except pydantic.error_wrappers.ValidationError as err: + except pydantic.ValidationError as err: _delete_file(condainfo, file.filename) raise errors.ValidationError( "Validation Error for package: " diff --git a/quetz/metrics/rest_models.py b/quetz/metrics/rest_models.py index bce4b063..c6df59ce 100644 --- a/quetz/metrics/rest_models.py +++ b/quetz/metrics/rest_models.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Dict, List -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from quetz.metrics.db_models import IntervalType @@ -9,9 +9,7 @@ class PackageVersionMetricItem(BaseModel): timestamp: datetime count: int - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class PackageVersionMetricSeries(BaseModel): diff --git a/quetz/rest_models.py b/quetz/rest_models.py index a9b4700d..ac6d8e0c 100644 --- a/quetz/rest_models.py +++ b/quetz/rest_models.py @@ -9,18 +9,15 @@ from enum import Enum from typing import Generic, List, Optional, TypeVar -from pydantic import BaseModel, Field, root_validator, validator -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator T = TypeVar('T') class BaseProfile(BaseModel): - name: Optional[str] = Field(None, nullable=True) + name: Optional[str] = Field(None) avatar_url: str - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class Profile(BaseProfile): @@ -30,27 +27,23 @@ class Profile(BaseProfile): class BaseUser(BaseModel): id: uuid.UUID username: str - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class User(BaseUser): profile: BaseProfile -Profile.update_forward_refs() +Profile.model_rebuild() -Role = Field(None, regex='owner|maintainer|member') +Role = Field(None, pattern='owner|maintainer|member') class Member(BaseModel): role: str = Role user: User - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class Pagination(BaseModel): @@ -67,26 +60,22 @@ class MirrorMode(str, Enum): class ChannelBase(BaseModel): name: str = Field(None, title='The name of the channel', max_length=50) description: Optional[str] = Field( - None, title='The description of the channel', max_length=300, nullable=True + None, title='The description of the channel', max_length=300 ) private: bool = Field(True, title="channel should be private") - size_limit: Optional[int] = Field( - None, title="size limit of the channel", nullable=True - ) + size_limit: Optional[int] = Field(None, title="size limit of the channel") ttl: int = Field(36000, title="ttl of the channel") - mirror_channel_url: Optional[str] = Field( - None, regex="^(http|https)://.+", nullable=True - ) - mirror_mode: Optional[MirrorMode] = Field(None, nullable=True) + mirror_channel_url: Optional[str] = Field(None, pattern="^(http|https)://.+") + mirror_mode: Optional[MirrorMode] = Field(None) - @validator("size_limit") + @field_validator("size_limit") + @classmethod def check_positive(cls, v): if v is not None and v < 0: return ValueError("must be positive value") return v - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class ChannelExtra(ChannelBase): @@ -97,9 +86,7 @@ class ChannelExtra(ChannelBase): class ChannelRole(BaseModel): name: str = Field(title="channel name") role: str = Field(title="user role") - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class ChannelActionEnum(str, Enum): @@ -133,37 +120,33 @@ class ChannelMetadata(BaseModel): includelist: Optional[List[str]] = Field( None, title="list of packages to include while creating a channel", - nullable=True, ) excludelist: Optional[List[str]] = Field( None, title="list of packages to exclude while creating a channel", - nullable=True, ) proxylist: Optional[List[str]] = Field( None, title="list of packages that should only be proxied (not copied, " "stored and redistributed)", - nullable=True, ) class Channel(ChannelBase): metadata: ChannelMetadata = Field( - default_factory=ChannelMetadata, title="channel metadata", example={} + default_factory=ChannelMetadata, title="channel metadata", examples={} ) actions: Optional[List[ChannelActionEnum]] = Field( None, title="list of actions to run after channel creation " "(see /channels/{}/actions for description)", - nullable=True, ) - @root_validator - def check_mirror_params(cls, values): - mirror_url = values.get("mirror_channel_url") - mirror_mode = values.get("mirror_mode") + @model_validator(mode='after') + def check_mirror_params(self) -> "Channel": + mirror_url = self.mirror_channel_url + mirror_mode = self.mirror_mode if mirror_url and not mirror_mode: raise ValueError( @@ -174,18 +157,14 @@ def check_mirror_params(cls, values): "'mirror_mode' provided but 'mirror_channel_url' is undefined" ) - return values + return self class ChannelMirrorBase(BaseModel): - url: str = Field(None, regex="^(http|https)://.+") - api_endpoint: Optional[str] = Field(None, regex="^(http|https)://.+", nullable=True) - metrics_endpoint: Optional[str] = Field( - None, regex="^(http|https)://.+", nullable=True - ) - - class Config: - orm_mode = True + url: str = Field(None, pattern="^(http|https)://.+") + api_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+") + metrics_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+") + model_config = ConfigDict(from_attributes=True) class ChannelMirror(ChannelMirrorBase): @@ -194,41 +173,31 @@ class ChannelMirror(ChannelMirrorBase): class Package(BaseModel): name: str = Field( - None, title='The name of package', max_length=1500, regex=r'^[a-z0-9-_\.]*$' + None, title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$' ) - summary: Optional[str] = Field( - None, title='The summary of the package', nullable=True - ) - description: Optional[str] = Field( - None, title='The description of the package', nullable=True - ) - url: Optional[str] = Field(None, title="project url", nullable=True) - platforms: List[str] = Field(None, title="supported platforms", nullable=True) - current_version: Optional[str] = Field( - None, title="latest version of any platform", nullable=True - ) - latest_change: Optional[datetime] = Field( - None, title="date of latest change", nullable=True - ) - - @validator("platforms", pre=True) + summary: Optional[str] = Field(None, title='The summary of the package') + description: Optional[str] = Field(None, title='The description of the package') + url: Optional[str] = Field(None, title="project url") + platforms: Optional[List[str]] = Field(None, title="supported platforms") + current_version: Optional[str] = Field(None, title="latest version of any platform") + latest_change: Optional[datetime] = Field(None, title="date of latest change") + + @field_validator("platforms", mode="before") + @classmethod def parse_list_of_platforms(cls, v): if isinstance(v, str): return v.split(":") else: return v - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class PackageRole(BaseModel): name: str = Field(title='The name of package') channel_name: str = Field(title='The channel this package belongs to') role: str = Field(title="user role for this package") - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class PackageSearch(Package): @@ -237,14 +206,12 @@ class PackageSearch(Package): class ChannelSearch(BaseModel): name: str = Field(None, title='The name of the channel', max_length=1500) - description: str = Field(None, title='The description of the channel') + description: Optional[str] = Field(None, title='The description of the channel') private: bool = Field(None, title='The visibility of the channel') + model_config = ConfigDict(from_attributes=True) - class Config: - orm_mode = True - -class PaginatedResponse(GenericModel, Generic[T]): +class PaginatedResponse(BaseModel, Generic[T]): pagination: Pagination = Field(None, title="Pagination object") result: List[T] = Field([], title="Result objects") @@ -254,21 +221,25 @@ class PostMember(BaseModel): role: str = Role +class UserOptionalRole(BaseModel): + role: Optional[str] = Role + + class UserRole(BaseModel): role: str = Role class CPRole(BaseModel): channel: str - package: Optional[str] = Field(None, nullable=True) + package: Optional[str] = Field(None) role: str = Role class BaseApiKey(BaseModel): description: str - time_created: Optional[date] = Field(None, nullable=True) - expire_at: Optional[date] = Field(None, nullable=True) - roles: Optional[List[CPRole]] = Field(None, nullable=True) + time_created: Optional[date] = Field(None) + expire_at: Optional[date] = Field(None) + roles: Optional[List[CPRole]] = Field(None) class ApiKey(BaseApiKey): @@ -289,18 +260,18 @@ class PackageVersion(BaseModel): uploader: BaseProfile time_created: datetime download_count: int + model_config = ConfigDict(from_attributes=True) - class Config: - orm_mode = True - - @validator("uploader", pre=True) + @field_validator("uploader", mode="before") + @classmethod def convert_uploader(cls, v): if hasattr(v, "profile"): return v.profile else: return v - @validator("info", pre=True) + @field_validator("info", mode="before") + @classmethod def load_json(cls, v): if isinstance(v, str): return json.loads(v) @@ -310,5 +281,5 @@ def load_json(cls, v): class ChannelAction(BaseModel): action: ChannelActionEnum - start_at: Optional[datetime] = Field(None, nullable=True) - repeat_every_seconds: Optional[int] = Field(None, nullable=True) + start_at: Optional[datetime] = Field(None) + repeat_every_seconds: Optional[int] = Field(None) diff --git a/quetz/tests/api/test_api_keys.py b/quetz/tests/api/test_api_keys.py index a78a86a6..85b4e83b 100644 --- a/quetz/tests/api/test_api_keys.py +++ b/quetz/tests/api/test_api_keys.py @@ -11,7 +11,7 @@ def api_keys(other_user, user, db, dao: Dao): def key_factory(key_user, descr, expire_at, roles): return dao.create_api_key( key_user.id, - BaseApiKey.parse_obj( + BaseApiKey.model_validate( dict(description=descr, expire_at=expire_at, roles=roles) ), descr, @@ -111,7 +111,7 @@ def test_list_keys_with_package_roles( def test_list_keys_subrole(auth_client, dao, user, private_channel): dao.create_api_key( user.id, - BaseApiKey.parse_obj( + BaseApiKey.model_validate( dict( description="user-key", roles=[ @@ -134,7 +134,7 @@ def test_list_keys_subrole(auth_client, dao, user, private_channel): def test_list_keys_without_roles(auth_client, dao, user): dao.create_api_key( user.id, - BaseApiKey.parse_obj(dict(description="user-key", roles=[])), + BaseApiKey.model_validate(dict(description="user-key", roles=[])), "user-key", ) diff --git a/quetz/tests/api/test_channels.py b/quetz/tests/api/test_channels.py index 534352c4..3976feab 100644 --- a/quetz/tests/api/test_channels.py +++ b/quetz/tests/api/test_channels.py @@ -689,7 +689,7 @@ def test_url_with_slash(auth_client, public_channel, db, remote_session): response = auth_client.post( f"/api/channels/{public_channel.name}/mirrors/", json={"url": mirror_url}, - allow_redirects=False, + follow_redirects=False, ) assert response.status_code == 307 diff --git a/quetz/tests/api/test_main_packages.py b/quetz/tests/api/test_main_packages.py index 3678c80b..cc8f308f 100644 --- a/quetz/tests/api/test_main_packages.py +++ b/quetz/tests/api/test_main_packages.py @@ -609,7 +609,7 @@ def test_validate_package_names(auth_client, public_channel, remove_package_vers @pytest.mark.parametrize( "package_name,msg", [ - ("TestPackage", "string does not match"), + ("TestPackage", "String should match"), ("test-package", None), ], ) @@ -835,7 +835,7 @@ def api_key(db, dao: Dao, owner, private_channel): # create an api key with restriction key = dao.create_api_key( owner.id, - BaseApiKey.parse_obj( + BaseApiKey.model_validate( dict( description="test api key", expire_at="2099-12-31", diff --git a/quetz/tests/test_jobs.py b/quetz/tests/test_jobs.py index 32163a0c..5680c4dc 100644 --- a/quetz/tests/test_jobs.py +++ b/quetz/tests/test_jobs.py @@ -244,7 +244,7 @@ async def test_running_task(db, user, package_version, supervisor): assert task.status == TaskStatus.pending # wait for task status to change - for i in range(50): + for i in range(100): time.sleep(0.05) db.refresh(task) @@ -283,7 +283,7 @@ async def test_restart_worker_process( assert task.status == TaskStatus.pending # wait for task status to change - for i in range(50): + for i in range(100): time.sleep(0.05) db.refresh(task) @@ -598,16 +598,20 @@ def test_post_new_job_manifest_validation( @pytest.mark.parametrize("user_role", ["owner"]) -def test_post_new_job_invalid_items_spec(auth_client, user, db, dummy_job_plugin): +def test_post_new_job_invalid_items_spec( + auth_client, user, db, dummy_job_plugin, mocker +): # items_spec=None is not allowed for jobs # (but it works with actions) manifest = "quetz-dummyplugin:dummy_func" + dummy_func = mocker.Mock() + mocker.patch("quetz_dummyplugin.jobs.dummy_func", dummy_func, create=True) response = auth_client.post( "/api/jobs", json={"items_spec": None, "manifest": manifest} ) assert response.status_code == 422 msg = response.json()['detail'] - assert "not an allowed value" in msg[0]['msg'] + assert "Input should be a valid string" in msg[0]['msg'] @pytest.mark.parametrize("user_role", ["owner"]) diff --git a/quetz/tests/test_mirror.py b/quetz/tests/test_mirror.py index ee59efb9..e715a253 100644 --- a/quetz/tests/test_mirror.py +++ b/quetz/tests/test_mirror.py @@ -860,11 +860,11 @@ def test_wrong_package_format(client, dummy_repo, owner, job_supervisor): [ ("proxy", None, "'mirror_channel_url' is undefined"), (None, "http://my-host", "'mirror_mode' is undefined"), - ("undefined", "http://my-host", "not a valid enumeration member"), - ("proxy", "my-host", "does not match"), - ("proxy", "http://", "does not match"), - ("proxy", "http:my-host", "does not match"), - ("proxy", "hosthttp://my-host", "does not match"), + ("undefined", "http://my-host", "Input should be 'proxy' or 'mirror'"), + ("proxy", "my-host", "String should match pattern"), + ("proxy", "http://", "String should match pattern"), + ("proxy", "http:my-host", "String should match pattern"), + ("proxy", "hosthttp://my-host", "String should match pattern"), (None, None, None), # non-mirror channel ("proxy", "http://my-host", None), ("proxy", "https://my-host", None), @@ -1070,7 +1070,7 @@ def test_proxylist_mirror_channel(owner, client, mirror_mode): response = client.get( "/get/mirror-channel-btel/linux-64/nrnpython-0.1-0.tar.bz2", - allow_redirects=False, + follow_redirects=False, ) assert response.status_code == 307 assert ( diff --git a/setup.cfg b/setup.cfg index 91215c78..f5b2737d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ install_requires = pluggy prometheus_client python-multipart + pydantic>=2.0.0 pyyaml requests sqlalchemy