Skip to content

Commit

Permalink
Merge pull request #692 from maresb/migrate-pydantic
Browse files Browse the repository at this point in the history
Code migrations for Pydantic v2
  • Loading branch information
maresb authored Sep 13, 2024
2 parents 41d3451 + 345fbcb commit 0837772
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 32 deletions.
21 changes: 12 additions & 9 deletions conda_lock/lockfile/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
import pathlib
import typing

from collections import namedtuple
from typing import (
Expand All @@ -22,7 +21,7 @@
if TYPE_CHECKING:
from hashlib import _Hash

from pydantic import Field, validator
from pydantic import Field, ValidationInfo, field_validator
from typing_extensions import Literal

from conda_lock.common import ordered_union, relative_path
Expand Down Expand Up @@ -60,9 +59,10 @@ class BaseLockedDependency(StrictModel):
def key(self) -> LockKey:
return LockKey(self.manager, self.name, self.platform)

@validator("hash")
def validate_hash(cls, v: HashModel, values: Dict[str, typing.Any]) -> HashModel:
if (values["manager"] == "conda") and (v.md5 is None):
@field_validator("hash")
@classmethod
def validate_hash(cls, v: HashModel, info: ValidationInfo) -> HashModel:
if (info.data["manager"] == "conda") and (v.md5 is None):
raise ValueError("conda package hashes must use MD5")
return v

Expand Down Expand Up @@ -217,7 +217,7 @@ class LockMeta(StrictModel):
..., description="Hash of dependencies for each target platform"
)
channels: List[Channel] = Field(
..., description="Channels used to resolve dependencies"
..., description="Channels used to resolve dependencies", validate_default=True
)
platforms: List[str] = Field(..., description="Target platforms")
sources: List[str] = Field(
Expand Down Expand Up @@ -282,7 +282,8 @@ def __or__(self, other: "LockMeta") -> "LockMeta":
custom_metadata=new_custom_metadata,
)

@validator("channels", pre=True, always=True)
@field_validator("channels", mode="before")
@classmethod
def ensure_channels(cls, v: List[Union[str, Channel]]) -> List[Channel]:
res: List[Channel] = []
for e in v:
Expand All @@ -304,10 +305,12 @@ def dict_for_output(self) -> Dict[str, Any]:
return {
"version": Lockfile.version,
"metadata": json.loads(
self.metadata.json(by_alias=True, exclude_unset=True, exclude_none=True)
self.metadata.model_dump_json(
by_alias=True, exclude_unset=True, exclude_none=True
)
),
"package": [
package.dict(by_alias=True, exclude_unset=True, exclude_none=True)
package.model_dump(by_alias=True, exclude_unset=True, exclude_none=True)
for package in self.package
],
}
12 changes: 6 additions & 6 deletions conda_lock/models/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ class CondaUrl(BaseModel):
raw_url: str
env_var_url: str

token: Optional[str]
token_env_var: Optional[str]
token: Optional[str] = None
token_env_var: Optional[str] = None

user: Optional[str]
user_env_var: Optional[str]
user: Optional[str] = None
user_env_var: Optional[str] = None

password: Optional[str]
password_env_var: Optional[str]
password: Optional[str] = None
password_env_var: Optional[str] = None

@classmethod
def from_string(cls, value: str) -> "CondaUrl":
Expand Down
25 changes: 15 additions & 10 deletions conda_lock/models/lock_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Literal

from conda_lock.models import StrictModel
Expand All @@ -21,7 +21,8 @@ class _BaseDependency(StrictModel):
extras: List[str] = []
markers: Optional[str] = None

@validator("extras")
@field_validator("extras")
@classmethod
def sorted_extras(cls, v: List[str]) -> List[str]:
return sorted(v)

Expand Down Expand Up @@ -53,11 +54,11 @@ class Package(StrictModel):


class PoetryMappedDependencySpec(StrictModel):
url: Optional[str]
url: Optional[str] = None
manager: Literal["conda", "pip"]
extras: List
markers: Optional[str]
poetry_version_spec: Optional[str]
markers: Optional[str] = None
poetry_version_spec: Optional[str] = None


class LockSpecification(BaseModel):
Expand All @@ -84,16 +85,18 @@ def content_hash_for_platform(
self, platform: str, virtual_package_repo: Optional[FakeRepoData]
) -> str:
data = {
"channels": [c.json() for c in self.channels],
"channels": [c.model_dump_json() for c in self.channels],
"specs": [
p.dict()
p.model_dump()
for p in sorted(
self.dependencies[platform], key=lambda p: (p.manager, p.name)
)
],
}
if self.pip_repositories:
data["pip_repositories"] = [repo.json() for repo in self.pip_repositories]
data["pip_repositories"] = [
repo.model_dump_json() for repo in self.pip_repositories
]
if virtual_package_repo is not None:
vpr_data = virtual_package_repo.all_repodata
data["virtual_package_hash"] = {
Expand All @@ -104,7 +107,8 @@ def content_hash_for_platform(
env_spec = json.dumps(data, sort_keys=True)
return hashlib.sha256(env_spec.encode("utf-8")).hexdigest()

@validator("channels", pre=True)
@field_validator("channels", mode="before")
@classmethod
def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]:
for i, e in enumerate(v):
if isinstance(e, str):
Expand All @@ -114,7 +118,8 @@ def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]:
raise ValueError("nodefaults channel is not allowed, ref #418")
return typing.cast(List[Channel], v)

@validator("pip_repositories", pre=True)
@field_validator("pip_repositories", mode="before")
@classmethod
def validate_pip_repositories(
cls, value: List[Union[PipRepository, str]]
) -> List[PipRepository]:
Expand Down
10 changes: 5 additions & 5 deletions conda_lock/virtual_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from types import TracebackType
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from conda_lock.interfaces.vendored_conda import MatchSpec
from conda_lock.models.channel import Channel
Expand All @@ -23,8 +23,7 @@
class FakePackage(BaseModel):
"""A minimal representation of the required metadata for a conda package"""

class Config:
frozen = True
model_config = ConfigDict(frozen=True)

name: str
version: str = "1.0"
Expand All @@ -36,7 +35,7 @@ class Config:
package_type: Optional[str] = "virtual_system"

def to_repodata_entry(self) -> Tuple[str, Dict[str, Any]]:
out = self.dict()
out = self.model_dump()
if self.build_string:
build = f"{self.build_string}_{self.build_number}"
else:
Expand Down Expand Up @@ -236,7 +235,8 @@ def default_virtual_package_repodata(cuda_version: str = "11.4") -> FakeRepoData
class VirtualPackageSpecSubdir(BaseModel):
packages: Dict[str, str]

@validator("packages")
@field_validator("packages")
@classmethod
def validate_packages(cls, v: Dict[str, str]) -> Dict[str, str]:
for package_name in v:
if not package_name.startswith("__"):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"ensureconda >=1.4.4",
"gitpython >=3.1.30",
"jinja2",
"pydantic >=1.10",
"pydantic >=2",
"pyyaml >= 5.1",
# constraint on version comes from poetry
"requests >=2.26,<3.0",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,9 @@ def test_aggregate_lock_specs():
],
sources=[],
)
assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"})
assert actual.model_dump(exclude={"sources"}) == expected.model_dump(
exclude={"sources"}
)
assert actual.content_hash(None) == expected.content_hash(None)


Expand Down

0 comments on commit 0837772

Please sign in to comment.