Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Csse pyd2 model convert classes #352

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ New Features

Enhancements
++++++++++++
* ``v2.FailedOperation`` field `id` is becoming `Optional[str]` instead of plain `str` so that the default validates.
* v1.ProtoModel learned `model_copy`, `model_dump`, `model_dump_json` methods (all w/o warnings) so downstream can unify on newer syntax. Levi's work alternately/additionally taught v2 `copy`, `dict`, `json` (all w/warning) but dict has an alternate use in Pydantic v2.
* ``AtomicInput`` and ``AtomicResult`` ``OptimizationInput``, ``OptimizationResult``, ``TorsionDriveInput``, ``TorsionDriveResult``, ``FailedOperation`` (both versions) learned a ``.convert_v(ver)`` function that returns self or the other version.
* The ``models.v2`` ``AtomicInput``, ``AtomicResult``, ``AtomicResultProperties`` ``OptimizationInput``, ``OptimizationResult``, ``TorsionDriveInput``, ``TorsionDriveResult`` had their `schema_version` changed to a Literal[2] and validated so new instances will be 2, even if another value passed in.
* The ``models.v1`` ``AtomicInput``, ``AtomicResult``, ``OptimizationInput``, ``OptimizationResult``, ``TorsionDriveInput``, ``TorsionDriveResult`` had their `schema_version` changed to a Literal[1] and validated so new instances will be 1, even if another value passed in.
* The ``models.v1`` and ``models.v2`` ``OptimizationResult`` given schema_version for the first time.
* The ``models.v2`` have had their `schema_version` bumped for ``BasisSet``, ``AtomicInput``, ``OptimizationInput`` (implicit for ``AtomicResult`` and ``OptimizationResult``), ``TorsionDriveInput`` , and ``TorsionDriveResult``.
* The ``models.v2`` ``AtomicResultProperties`` has been given a ``schema_name`` and ``schema_version`` (2) for the first time.
* Note that ``models.v2`` ``QCInputSpecification`` and ``OptimizationSpecification`` have *not* had schema_version bumped.
Expand Down
2 changes: 1 addition & 1 deletion qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def dict(self, *args, **kwargs):

def json(self, *args, **kwargs):
"""
Passthrough to model_dump_sjon without deprecation warning
Passthrough to model_dump_json without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump_json(*args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions qcelemental/models/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import qcelemental

from .common_models import _qcsk_v2_default_v1_importpathschange

_nonapi_file = "align"
_shim_classes_removed_version = "0.40.0"

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
5 changes: 3 additions & 2 deletions qcelemental/models/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import qcelemental

from .common_models import _qcsk_v2_default_v1_importpathschange

_nonapi_file = "basemodels"
_shim_classes_removed_version = "0.40.0"

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
5 changes: 3 additions & 2 deletions qcelemental/models/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import qcelemental

from .common_models import _qcsk_v2_default_v1_importpathschange

_nonapi_file = "basis"
_shim_classes_removed_version = "0.40.0"

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
7 changes: 5 additions & 2 deletions qcelemental/models/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import qcelemental

_qcsk_v2_available_v1_nochange = "0.50.0"
_qcsk_v2_default_v1_importpathschange = "0.70.0"
_qcsk_v2_nochange_v1_dropped = "1.0.0"

_nonapi_file = "common_models"
_shim_classes_removed_version = "0.40.0"

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import qcelemental

_nonapi_file = "molecule"
_shim_classes_removed_version = "0.40.0"
from .common_models import _qcsk_v2_default_v1_importpathschange

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import qcelemental

_nonapi_file = "procedures"
_shim_classes_removed_version = "0.40.0"
from .common_models import _qcsk_v2_default_v1_importpathschange

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
5 changes: 3 additions & 2 deletions qcelemental/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import qcelemental

from .common_models import _qcsk_v2_default_v1_importpathschange

_nonapi_file = "results"
_shim_classes_removed_version = "0.40.0"

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import qcelemental

_nonapi_file = "types"
_shim_classes_removed_version = "0.40.0"
from .common_models import _qcsk_v2_default_v1_importpathschange

warn(
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_shim_classes_removed_version}",
f"qcelemental.models.{_nonapi_file} should be accessed through qcelemental.models (or qcelemental.models.v1 or .v2 for fixed QCSchema version). The 'models.{_nonapi_file}' route will be removed as soon as v{_qcsk_v2_default_v1_importpathschange}.",
DeprecationWarning,
)

Expand Down
12 changes: 12 additions & 0 deletions qcelemental/models/v1/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def dict(self, **kwargs) -> Dict[str, Any]:
else:
raise KeyError(f"Unknown encoding type '{encoding}', valid encoding types: 'json'.")

def model_dump(self, **kwargs):
# forwarding pydantic v2 API function to pydantic v1 API so downstream can unify on new syntax
return self.dict(**kwargs)

def model_dump_json(self, **kwargs):
# forwarding pydantic v2 API function to pydantic v1 API so downstream can unify on new syntax
return self.json(**kwargs)

def model_copy(self, **kwargs):
# forwarding pydantic v2 API function to pydantic v1 API so downstream can unify on new syntax
return self.copy(**kwargs)

def serialize(
self,
encoding: str,
Expand Down
24 changes: 24 additions & 0 deletions qcelemental/models/v1/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@ class FailedOperation(ProtoModel):
def __repr_args__(self) -> "ReprArgs":
return [("error", self.error)]

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.FailedOperation", "qcelemental.models.v2.FailedOperation"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="FailedOperation") == "self":
return self

dself = self.dict()
if version == 2:
self_vN = qcel.models.v2.FailedOperation(**dself)

return self_vN


def check_convertible_version(ver: int, error: str):
if ver == 1:
return "self"
elif ver == 2:
return True
else:
raise ValueError(f"QCSchema {error} version={version} does not exist for conversion.")


qcschema_input_default = "qcschema_input"
qcschema_output_default = "qcschema_output"
Expand Down
96 changes: 90 additions & 6 deletions qcelemental/models/v1/procedures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

try:
from typing import Literal
except ImportError:
# remove when minimum py38
from typing_extensions import Literal

from pydantic.v1 import Field, conlist, constr, validator

Expand All @@ -10,6 +16,7 @@
DriverEnum,
Model,
Provenance,
check_convertible_version,
qcschema_input_default,
qcschema_optimization_input_default,
qcschema_optimization_output_default,
Expand Down Expand Up @@ -53,7 +60,7 @@ class QCInputSpecification(ProtoModel):
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_input_default) = qcschema_input_default # type: ignore
schema_version: int = 1
schema_version: int = 1 # TODO

driver: DriverEnum = Field(DriverEnum.gradient, description=str(DriverEnum.__doc__))
model: Model = Field(..., description=str(Model.__doc__))
Expand All @@ -71,7 +78,7 @@ class OptimizationInput(ProtoModel):
schema_name: constr( # type: ignore
strip_whitespace=True, regex=qcschema_optimization_input_default
) = qcschema_optimization_input_default
schema_version: int = 1
schema_version: Literal[1] = 1

keywords: Dict[str, Any] = Field({}, description="The optimization specific keywords to be used.")
extras: Dict[str, Any] = Field({}, description="Extra fields that are not part of the schema.")
Expand All @@ -88,11 +95,31 @@ def __repr_args__(self) -> "ReprArgs":
("molecule_hash", self.initial_molecule.get_hash()[:7]),
]

@validator("schema_version", pre=True)
def _version_stamp(cls, v):
return 1

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.OptimizationInput", "qcelemental.models.v2.OptimizationInput"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="OptimizationInput") == "self":
return self

dself = self.dict()
if version == 2:
self_vN = qcel.models.v2.OptimizationInput(**dself)

return self_vN


class OptimizationResult(OptimizationInput):
schema_name: constr( # type: ignore
strip_whitespace=True, regex=qcschema_optimization_output_default
) = qcschema_optimization_output_default
schema_version: Literal[1] = 1

final_molecule: Optional[Molecule] = Field(..., description="The final molecule of the geometry optimization.")
trajectory: List[AtomicResult] = Field(
Expand Down Expand Up @@ -131,6 +158,25 @@ def _trajectory_protocol(cls, v, values):

return v

@validator("schema_version", pre=True)
def _version_stamp(cls, v):
return 1

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.OptimizationResult", "qcelemental.models.v2.OptimizationResult"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="OptimizationResult") == "self":
return self

dself = self.dict()
if version == 2:
self_vN = qcel.models.v2.OptimizationResult(**dself)

return self_vN


class OptimizationSpecification(ProtoModel):
"""
Expand All @@ -143,7 +189,7 @@ class OptimizationSpecification(ProtoModel):
"""

schema_name: constr(strip_whitespace=True, regex="qcschema_optimization_specification") = "qcschema_optimization_specification" # type: ignore
schema_version: int = 1
schema_version: int = 1 # TODO

procedure: str = Field(..., description="Optimization procedure to run the optimization with.")
keywords: Dict[str, Any] = Field({}, description="The optimization specific keywords to be used.")
Expand Down Expand Up @@ -200,7 +246,7 @@ class TorsionDriveInput(ProtoModel):
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_torsion_drive_input_default) = qcschema_torsion_drive_input_default # type: ignore
schema_version: int = 1
schema_version: Literal[1] = 1

keywords: TDKeywords = Field(..., description="The torsion drive specific keywords to be used.")
extras: Dict[str, Any] = Field({}, description="Extra fields that are not part of the schema.")
Expand All @@ -221,6 +267,25 @@ def _check_input_specification(cls, value):
assert value.driver == DriverEnum.gradient, "driver must be set to gradient"
return value

@validator("schema_version", pre=True)
def _version_stamp(cls, v):
return 1

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.TorsionDriveInput", "qcelemental.models.v2.TorsionDriveInput"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="TorsionDriveInput") == "self":
return self

dself = self.dict()
if version == 2:
self_vN = qcel.models.v2.TorsionDriveInput(**dself)

return self_vN


class TorsionDriveResult(TorsionDriveInput):
"""Results from running a torsion drive.
Expand All @@ -231,7 +296,7 @@ class TorsionDriveResult(TorsionDriveInput):
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_torsion_drive_output_default) = qcschema_torsion_drive_output_default # type: ignore
schema_version: int = 1
schema_version: Literal[1] = 1

final_energies: Dict[str, float] = Field(
..., description="The final energy at each angle of the TorsionDrive scan."
Expand All @@ -254,6 +319,25 @@ class TorsionDriveResult(TorsionDriveInput):
error: Optional[ComputeError] = Field(None, description=str(ComputeError.__doc__))
provenance: Provenance = Field(..., description=str(Provenance.__doc__))

@validator("schema_version", pre=True)
def _version_stamp(cls, v):
return 1

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.TorsionDriveResult", "qcelemental.models.v2.TorsionDriveResult"]:
"""Convert to instance of particular QCSchema version."""
import qcelemental as qcel

if check_convertible_version(version, error="TorsionDriveResult") == "self":
return self

dself = self.dict()
if version == 2:
self_vN = qcel.models.v2.TorsionDriveResult(**dself)

return self_vN


def Optimization(*args, **kwargs):
"""QC Optimization Results Schema.
Expand Down
Loading
Loading