From 918d7a2e5f5d831601d4970e7cd93bf8919215ed Mon Sep 17 00:00:00 2001 From: Kori Kuzma Date: Tue, 17 Sep 2024 03:13:00 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20do=20not=20allow=20`SequenceLocation`=20?= =?UTF-8?q?to=20have=20negative=20`start`/`end`=20v=E2=80=A6=20(#443)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …alues close #442 * Add model_validators to `Range` and `SequenceLocation` --- src/ga4gh/vrs/models.py | 45 ++++++++++++++++++++++++++++++++++++++--- tests/test_vrs.py | 16 +++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/src/ga4gh/vrs/models.py b/src/ga4gh/vrs/models.py index bfe6279f..70a85c80 100644 --- a/src/ga4gh/vrs/models.py +++ b/src/ga4gh/vrs/models.py @@ -27,7 +27,7 @@ from ga4gh.core.pydantic import get_pydantic_root from canonicaljson import encode_canonical_json -from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict +from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict, ValidationInfo, field_validator from ga4gh.core.pydantic import ( getattr_in @@ -331,6 +331,25 @@ class Range(RootModel): min_length=2, ) + @field_validator("root", mode="after") + def validate_range(cls, v: List[Optional[int]]) -> List[Optional[int]]: + """Validate range values + + :param v: Root value + :raises ValueError: If ``root`` does not include at least one integer or if + the first element in ``root`` is greater than the second element in ``root`` + :return: Inclusive range + """ + if v.count(None) == 2: + err_msg = "Must provide at least one integer." + raise ValueError(err_msg) + + if v[0] is not None and v[1] is not None: + if v[0] > v[1]: + err_msg = "The first integer must be less than or equal to the second integer." + raise ValueError(err_msg) + + return v class Residue(RootModel): """A character representing a specific residue (i.e., molecular species) or @@ -454,15 +473,35 @@ class SequenceLocation(_Ga4ghIdentifiableObject): ) start: Optional[Union[Range, int]] = Field( None, - description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range less than the value of `end`.', + description='The start coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0.', ) end: Optional[Union[Range, int]] = Field( None, - description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0. MUST represent a coordinate or range greater than the value of `start`.', + description='The end coordinate or range of the SequenceLocation. The minimum value of this coordinate or range is 0.', ) sequence: Optional[SequenceString] = Field(None, description="The literal sequence encoded by the `sequenceReference` at these coordinates.") + @field_validator("start", "end", mode="after") + def validate_start_end(cls, v: Optional[Union[Range, int]], info: ValidationInfo) -> Optional[Union[Range, int]]: + """Validate ``start`` and ``end`` fields + + :param v: ``start`` or ``end`` value + :param info: Validation info + :raises ValueError: If ``start`` or ``end`` has a value less than 0 + :return: Sequence Location + """ + if v is not None: + if isinstance(v, int): + int_values = [v] + else: + int_values = [val for val in v.root if val is not None] + + if any(int_val < 0 for int_val in int_values): + err_msg = f"The minimum value of `{info.field_name}` is 0." + raise ValueError(err_msg) + return v + def ga4gh_serialize_as_version(self, as_version: PrevVrsVersion): """This method will return a serialized string following the conventions for SequenceLocation serialization as defined in the VRS version specified by diff --git a/tests/test_vrs.py b/tests/test_vrs.py index fc3cc90e..0ee47f16 100644 --- a/tests/test_vrs.py +++ b/tests/test_vrs.py @@ -96,6 +96,22 @@ cpb_431012 = models.CisPhasedBlock(**cpb_431012_dict) +@pytest.mark.parametrize( + "vrs_model, expected_err_msg", + [ + (lambda: models.Range(root=[None, None]), "Must provide at least one integer."), + (lambda: models.Range(root=[2, 1]), "The first integer must be less than or equal to the second integer."), + (lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, start=-1), "The minimum value of `start` is 0."), + (lambda: models.SequenceLocation(sequenceReference=allele_280320.location.sequenceReference, end=[-1, 0]), "The minimum value of `end` is 0."), + ] +) +def test_model_validation_errors(vrs_model, expected_err_msg): + """Test that invalid VRS models raise errors""" + with pytest.raises(ValueError) as e: + vrs_model() + assert str(e.value.errors()[0]["ctx"]["error"]) == expected_err_msg + + def test_vr(): assert a.model_dump(exclude_none=True) == allele_dict assert is_pydantic_instance(a)