Skip to content

Commit

Permalink
fix: do not allow SequenceLocation to have negative start/end v… (
Browse files Browse the repository at this point in the history
#443)

…alues

close #442

* Add model_validators to `Range` and `SequenceLocation`
  • Loading branch information
korikuzma authored Sep 17, 2024
1 parent 10157c3 commit 918d7a2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
45 changes: 42 additions & 3 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ga4gh.core.pydantic import get_pydantic_root

from canonicaljson import encode_canonical_json
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict, ValidationInfo, field_validator

from ga4gh.core.pydantic import (
getattr_in
Expand Down Expand Up @@ -331,6 +331,25 @@ class Range(RootModel):
min_length=2,
)

@field_validator("root", mode="after")
def validate_range(cls, v: List[Optional[int]]) -> List[Optional[int]]:
"""Validate range values
:param v: Root value
:raises ValueError: If ``root`` does not include at least one integer or if
the first element in ``root`` is greater than the second element in ``root``
:return: Inclusive range
"""
if v.count(None) == 2:
err_msg = "Must provide at least one integer."
raise ValueError(err_msg)

if v[0] is not None and v[1] is not None:
if v[0] > v[1]:
err_msg = "The first integer must be less than or equal to the second integer."
raise ValueError(err_msg)

return v

class Residue(RootModel):
"""A character representing a specific residue (i.e., molecular species) or
Expand Down Expand Up @@ -454,15 +473,35 @@ class SequenceLocation(_Ga4ghIdentifiableObject):
)
start: Optional[Union[Range, int]] = Field(
None,
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than the value of `end`.',
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0.',
)
end: Optional[Union[Range, int]] = Field(
None,
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than the value of `start`.',
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0.',

)
sequence: Optional[SequenceString] = Field(None, description="The literal sequence encoded by the `sequenceReference` at these coordinates.")

@field_validator("start", "end", mode="after")
def validate_start_end(cls, v: Optional[Union[Range, int]], info: ValidationInfo) -> Optional[Union[Range, int]]:
"""Validate ``start`` and ``end`` fields
:param v: ``start`` or ``end`` value
:param info: Validation info
:raises ValueError: If ``start`` or ``end`` has a value less than 0
:return: Sequence Location
"""
if v is not None:
if isinstance(v, int):
int_values = [v]
else:
int_values = [val for val in v.root if val is not None]

if any(int_val < 0 for int_val in int_values):
err_msg = f"The minimum value of `{info.field_name}` is 0."
raise ValueError(err_msg)
return v

def ga4gh_serialize_as_version(self, as_version: PrevVrsVersion):
"""This method will return a serialized string following the conventions for
SequenceLocation serialization as defined in the VRS version specified by
Expand Down
16 changes: 16 additions & 0 deletions tests/test_vrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@
cpb_431012 = models.CisPhasedBlock(**cpb_431012_dict)


@pytest.mark.parametrize(
"vrs_model, expected_err_msg",
[
(lambda: models.Range(root=[None, None]), "Must provide at least one integer."),
(lambda: models.Range(root=[2, 1]), "The first integer must be less than or equal to the second integer."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=-1), "The minimum value of `start` is 0."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, end=[-1, 0]), "The minimum value of `end` is 0."),
]
)
def test_model_validation_errors(vrs_model, expected_err_msg):
"""Test that invalid VRS models raise errors"""
with pytest.raises(ValueError) as e:
vrs_model()
assert str(e.value.errors()[0]["ctx"]["error"]) == expected_err_msg


def test_vr():
assert a.model_dump(exclude_none=True) == allele_dict
assert is_pydantic_instance(a)
Expand Down

0 comments on commit 918d7a2

Please sign in to comment.