Skip to content

Commit

Permalink
feat: migrate to Pydantic v2 (#656)
Browse files Browse the repository at this point in the history
* Upgrade to pydantic v2
* Replace deprecated dict with model_dump
* Replace update_forward_refs with model_rebuild
* Fix error message in test
* Fix Packages and ChannelSearch rest models

- platforms is nullable: it should be optional
- Channel description is optional - so should be ChannelSearch description

* Fix Job rest model

item_spec can't be set to required in JobBase and optional in Job
(mypy complains).
Create a new class JobCreate where item_spec is required.
In job, item_spec is optional.

* Add UserOptionalRole

User role is nullable and should be optional in get (but not in set)

* Remove nullable=True in pydantic Field

This isn't used.
To mark a field as nullable, it should be set as Optional or "| None".

Fix deprecation warning: Extra keyword arguments on `Field` is deprecated and will be removed.

* Fix Channel model

keyword should be examples (not example)

* Replace deprecated from_orm with model_validate

* Replace deprecated parse_obj with model_validate

* Fix UserWarning

UserWarning: `pydantic.error_wrappers:ValidationError` has been moved to `pydantic:ValidationError`.

* Fix test_post_new_job_invalid_items_spec

* Fix starlette Deprecation Warning

DeprecationWarning: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.

* Increase wait time for running task tests

Test failing with postgres. Task still pending after 2.5 seconds.
  • Loading branch information
beenje authored Aug 23, 2023
1 parent 69cdb84 commit 46a31d4
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 125 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ dependencies:
- conda-content-trust
- pyinstrument
- pytest-asyncio
- pydantic <2
- pydantic >=2
- pip:
- git+https://github.com/jupyter-server/jupyter_releaser.git@v2
2 changes: 1 addition & 1 deletion quetz/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class User(Base):
'Profile', uselist=False, back_populates='user', cascade="all,delete-orphan"
)

role = Column(String)
role = Column(String, nullable=True)

@classmethod
def find(cls, db, name):
Expand Down
4 changes: 2 additions & 2 deletions quetz/jobs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from quetz.rest_models import PaginatedResponse

from .models import JobStatus, TaskStatus
from .rest_models import Job, JobBase, JobUpdateModel, Task
from .rest_models import Job, JobCreate, JobUpdateModel, Task

api_router = APIRouter()

Expand All @@ -44,7 +44,7 @@ def get_jobs(

@api_router.post("/api/jobs", tags=["Jobs"], status_code=201, response_model=Job)
def create_job(
job: JobBase,
job: JobCreate,
dao: Dao = Depends(get_dao),
auth: authorization.Rules = Depends(get_rules),
):
Expand Down
24 changes: 14 additions & 10 deletions quetz/jobs/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

from importlib_metadata import entry_points as get_entry_points
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from . import handlers
from .models import JobStatus, TaskStatus
Expand Down Expand Up @@ -83,7 +83,6 @@ def parse_job_name(v):
class JobBase(BaseModel):
"""New job spec"""

items_spec: str = Field(..., title='Item selector spec')
manifest: str = Field(None, title='Name of the function')

start_at: Optional[datetime] = Field(
Expand All @@ -97,7 +96,8 @@ class JobBase(BaseModel):
),
)

@validator("manifest", pre=True)
@field_validator("manifest", mode="before")
@classmethod
def validate_job_name(cls, function_name):
if isinstance(function_name, bytes):
return parse_job_name(function_name)
Expand All @@ -107,6 +107,12 @@ def validate_job_name(cls, function_name):
return function_name.encode('ascii')


class JobCreate(JobBase):
"""Create job spec"""

items_spec: str = Field(..., title='Item selector spec')


class JobUpdateModel(BaseModel):
"""Modify job spec items (status and items_spec)"""

Expand All @@ -123,10 +129,8 @@ class Job(JobBase):

status: JobStatus = Field(None, title='Status of the job (running, paused, ...)')

items_spec: str = Field(None, title='Item selector spec')

class Config:
orm_mode = True
items_spec: Optional[str] = Field(None, title='Item selector spec')
model_config = ConfigDict(from_attributes=True)


class Task(BaseModel):
Expand All @@ -136,12 +140,12 @@ class Task(BaseModel):
created: datetime = Field(None, title='Created at')
status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)')

@validator("package_version", pre=True)
@field_validator("package_version", mode="before")
@classmethod
def convert_package_version(cls, v):
if v:
return {'filename': v.filename, 'id': uuid.UUID(bytes=v.id).hex}
else:
return {}

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)
13 changes: 6 additions & 7 deletions quetz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Awaitable, Callable, List, Optional, Tuple, Type

import pydantic
import pydantic.error_wrappers
import requests
from fastapi import (
APIRouter,
Expand Down Expand Up @@ -477,7 +476,7 @@ def delete_user(

@api_router.get(
"/users/{username}/role",
response_model=rest_models.UserRole,
response_model=rest_models.UserOptionalRole,
tags=["users"],
)
def get_user_role(
Expand Down Expand Up @@ -732,7 +731,7 @@ def post_channel(
detail="Cannot use both `includelist` and `excludelist` together.",
)

user_attrs = new_channel.dict(exclude_unset=True)
user_attrs = new_channel.model_dump(exclude_unset=True)

if "size_limit" in user_attrs:
auth.assert_set_channel_size_limit()
Expand Down Expand Up @@ -789,7 +788,7 @@ def patch_channel(
):
auth.assert_update_channel_info(channel.name)

user_attrs = channel_data.dict(exclude_unset=True)
user_attrs = channel_data.model_dump(exclude_unset=True)

if "size_limit" in user_attrs:
auth.assert_set_channel_size_limit()
Expand Down Expand Up @@ -1064,7 +1063,7 @@ def get_package_versions(
version_list = []

for version, profile, api_key_profile in version_profile_list:
version_data = rest_models.PackageVersion.from_orm(version)
version_data = rest_models.PackageVersion.model_validate(version)
version_list.append(version_data)

return version_list
Expand All @@ -1089,7 +1088,7 @@ def get_paginated_package_versions(
version_list = []

for version, profile, api_key_profile in version_profile_list['result']:
version_data = rest_models.PackageVersion.from_orm(version)
version_data = rest_models.PackageVersion.model_validate(version)
version_list.append(version_data)

return {
Expand Down Expand Up @@ -1650,7 +1649,7 @@ def _delete_file(condainfo, filename):
summary=str(condainfo.about.get("summary", "n/a")),
description=str(condainfo.about.get("description", "n/a")),
)
except pydantic.error_wrappers.ValidationError as err:
except pydantic.ValidationError as err:
_delete_file(condainfo, file.filename)
raise errors.ValidationError(
"Validation Error for package: "
Expand Down
6 changes: 2 additions & 4 deletions quetz/metrics/rest_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from datetime import datetime
from typing import Dict, List

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from quetz.metrics.db_models import IntervalType


class PackageVersionMetricItem(BaseModel):
timestamp: datetime
count: int

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


class PackageVersionMetricSeries(BaseModel):
Expand Down
Loading

0 comments on commit 46a31d4

Please sign in to comment.