Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: do not allow SequenceLocation to have negative start/end v… #443

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions src/ga4gh/vrs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ga4gh.core.pydantic import get_pydantic_root

from canonicaljson import encode_canonical_json
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict
from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict, model_validator

from ga4gh.core.pydantic import (
getattr_in
Expand Down Expand Up @@ -331,6 +331,23 @@ class Range(RootModel):
min_length=2,
)

@model_validator(mode="after")
def validate_range(self) -> "Range":
"""Validate range values

:raises ValueError: If ``root`` does not include at least one integer or if
the first element in ``root`` is greater than the second element in ``root``
"""
if self.root.count(None) == 2:
err_msg = "Must provide at least one integer."
raise ValueError(err_msg)

if self.root[0] is not None and self.root[1] is not None:
if self.root[0] > self.root[1]:
err_msg = "The first integer must be less than or equal to the second integer."
raise ValueError(err_msg)

return self

class Residue(RootModel):
"""A character representing a specific residue (i.e., molecular species) or
Expand Down Expand Up @@ -454,15 +471,55 @@ class SequenceLocation(_Ga4ghIdentifiableObject):
)
start: Optional[Union[Range, int]] = Field(
None,
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than the value of `end`.',
description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than or equal to the value of `end`.',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like that this needs to be >= 0. There should not be a constraint on start < end for VRS 2.0, but just noted that this wasn't updated in the VRS 2.0 field definitions. Will address presently.

)
end: Optional[Union[Range, int]] = Field(
None,
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than the value of `start`.',
description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than or equal to the value of `start`.',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like that this needs to be >= 0. There should not be a constraint on start < end for VRS 2.0, but just noted that this wasn't updated in the VRS 2.0 field definitions. Will address presently.


)
sequence: Optional[SequenceString] = Field(None, description="The literal sequence encoded by the `sequenceReference` at these coordinates.")

@model_validator(mode="after")
def validate_start_end(self) -> "SequenceLocation":
"""Validate ``start`` and ``end`` fields

:raises ValueError: If ``start`` or ``end`` has a value less than 0 or if
``start`` is greater than ``end``
:return: Sequence Location
"""
def _get_int_values(start_or_end: int | Range | None) -> list[int]:
"""Get list of integers from ``start`` or ``end`` fields

:param start_or_end: ``start`` or ``end`` field
:raises ValueError: If ``start_or_end`` has a value less than 0
:return: List of integer values
"""
int_values = []

if start_or_end is not None:
if isinstance(start_or_end, int):
int_values = [start_or_end]
else:
int_values = [val for val in start_or_end.root if val is not None]

if any(int_val < 0 for int_val in int_values):
err_msg = "The minimum value of `start` or `end` is 0."
raise ValueError(err_msg)

return int_values

start_values = _get_int_values(self.start)
end_values = _get_int_values(self.end)

if start_values and end_values:
for start_val in start_values:
if any(start_val > end_val for end_val in end_values):
err_msg = "`start` must be less than or equal to `end`."
raise ValueError(err_msg)

return self

def ga4gh_serialize_as_version(self, as_version: PrevVrsVersion):
"""This method will return a serialized string following the conventions for
SequenceLocation serialization as defined in the VRS version specified by
Expand Down
18 changes: 18 additions & 0 deletions tests/test_vrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@
cpb_431012 = models.CisPhasedBlock(**cpb_431012_dict)


@pytest.mark.parametrize(
"vrs_model, expected_err_msg",
[
(lambda: models.Range(root=[None, None]), "Must provide at least one integer."),
(lambda: models.Range(root=[2, 1]), "The first integer must be less than or equal to the second integer."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=-1), "The minimum value of `start` or `end` is 0."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, end=[-1, 0]), "The minimum value of `start` or `end` is 0."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=1, end=0), "`start` must be less than or equal to `end`."),
(lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=[3,4], end=[1,2]), "`start` must be less than or equal to `end`.")
]
)
def test_model_validation_errors(vrs_model, expected_err_msg):
"""Test that invalid VRS models raise errors"""
with pytest.raises(ValueError) as e:
vrs_model()
assert str(e.value.errors()[0]["ctx"]["error"]) == expected_err_msg


def test_vr():
assert a.model_dump(exclude_none=True) == allele_dict
assert is_pydantic_instance(a)
Expand Down