Skip to content

Commit

Permalink
Refactor test suite for data models (#327)
Browse files Browse the repository at this point in the history
* refactor tests to care for invalids

* split original test_domain.py

* remove orignal test_domain.py

* tidy up test_constraints.py

* tidy-up of test_features

* tidy-up test_priors.py

* tidy-up test_kernels.py

* tidy-up test_molecular and test_molfeatures

* move test_samplers.py

* split filters from base

* remove duplicate test_util.py

* move test_domain_validators.py

* rename and move test_nchoosek_combinatorics.py

* update test pipeline
  • Loading branch information
jduerholt authored Dec 21, 2023
1 parent 7889ad9 commit 38d0bd1
Show file tree
Hide file tree
Showing 57 changed files with 3,676 additions and 3,673 deletions.
13 changes: 1 addition & 12 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,7 @@ jobs:
- name: Minimal Bofire, Python ${{ matrix.python-version }}
run: pip install "." pytest
- name: Run domain-only test, Python ${{ matrix.python-version }}
run: |
pytest tests/bofire/data_models/test_domain.py \
tests/bofire/data_models/test_base.py \
tests/bofire/data_models/test_constraint_fulfillment.py \
tests/bofire/data_models/test_constraints.py \
tests/bofire/data_models/test_domain_validators.py \
tests/bofire/data_models/test_features.py \
tests/bofire/data_models/test_filters.py \
tests/bofire/data_models/test_nchoosek_combinatorics.py \
tests/bofire/data_models/test_numeric.py \
tests/bofire/data_models/test_unions.py \
tests/bofire/data_models/test_util.py
run: pytest tests/bofire/data_models

testing:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion bofire/data_models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
AnyAcquisitionFunction,
)
from bofire.data_models.constraints.api import AnyConstraint, Constraint
from bofire.data_models.domain.api import Domain, Features, Inputs, Outputs
from bofire.data_models.domain.api import Constraints, Domain, Features, Inputs, Outputs
from bofire.data_models.features.api import (
AnyFeature,
AnyInput,
Expand Down
104 changes: 0 additions & 104 deletions bofire/data_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import collections.abc as collections
from typing import Any, Callable, List, Sequence, Type, Union, get_args, get_origin

import pandas as pd
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Extra
Expand All @@ -17,104 +14,3 @@ class Config:
pd.DataFrame: lambda x: x.to_dict(orient="list"),
pd.Series: lambda x: x.to_list(),
}


def filter_by_attribute(
data: Sequence,
attribute_getter: Callable[[Type], Any],
includes: Union[Type, Sequence[Type]] = None,
excludes: Union[Type, Sequence[Type]] = None,
exact: bool = False,
) -> List:
"""Returns those data elements where the attribute is of one of the include types.
Args:
data: to be filtered
attribute_getter: expects an item of the data list and returns the attribute to filter by
includes: attribute types that should be kept, sub-type are included by default, see exact
excludes: attribute types that will be excluded even if they are sub-types of or include types.
exact: true for not including subtypes
Returns:
list of data point with attributes as filtered for
"""
data_with_attr = []
for d in data:
try:
attribute_getter(d)
data_with_attr.append(d)
except AttributeError:
pass

filtered = filter_by_class(
data_with_attr,
includes=includes,
excludes=excludes,
exact=exact,
key=attribute_getter,
)
return filtered


def filter_by_class(
data: Sequence,
includes: Union[Type, Sequence[Type]] = None,
excludes: Union[Type, Sequence[Type]] = None,
exact: bool = False,
key: Callable[[Type], Any] = lambda x: x,
) -> List:
"""Returns those data elements where are one of the include types.
Args:
data: to be filtered
includes: attribute types that should be kept, sub-type are included by default, see exact
excludes: attribute types that will be excluded even if they are sub-types of or include types.
exact: true for not including subtypes
key: maps a data list item to something that is used for filtering, identity by default
Returns:
filtered list of data points
"""
if includes is None:
includes = []
if not isinstance(includes, collections.Sequence):
includes = [includes]
if excludes is None:
excludes = []
if not isinstance(excludes, collections.Sequence):
excludes = [excludes]

if len(includes) == len(excludes) == 0:
raise ValueError("no filter provided")

if len(includes) == 0:
includes = [object]

includes_ = []
for incl in includes:
if get_origin(incl) is Union:
includes_ += get_args(incl)
else:
includes_.append(incl)
includes = includes_
excludes_ = []
for excl in excludes:
if get_origin(excl) is Union:
excludes_ += get_args(excl)
else:
excludes_.append(excl)
excludes = excludes_

if len([x for x in includes if x in excludes]) > 0:
raise ValueError("includes and excludes overlap")

if exact:
return [
d for d in data if type(key(d)) in includes and type(key(d)) not in excludes
]
return [
d
for d in data
if isinstance(key(d), tuple(includes)) # type: ignore
and not isinstance(key(d), tuple(excludes)) # type: ignore
]
7 changes: 4 additions & 3 deletions bofire/data_models/domain/constraints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import collections.abc
from itertools import chain
from typing import List, Literal, Sequence, Type, Union
from typing import List, Literal, Optional, Sequence, Type, Union

import pandas as pd
from pydantic import Field

from bofire.data_models.base import BaseModel, filter_by_class
from bofire.data_models.base import BaseModel
from bofire.data_models.constraints.api import AnyConstraint, Constraint
from bofire.data_models.filters import filter_by_class


class Constraints(BaseModel):
Expand Down Expand Up @@ -78,7 +79,7 @@ def is_fulfilled(self, experiments: pd.DataFrame, tol: float = 1e-6) -> pd.Serie
def get(
self,
includes: Union[Type, List[Type]] = Constraint,
excludes: Union[Type, List[Type]] = None,
excludes: Optional[Union[Type, List[Type]]] = None,
exact: bool = False,
) -> "Constraints":
"""get constraints of the domain
Expand Down
4 changes: 2 additions & 2 deletions bofire/data_models/domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def get_features(
"""
assert isinstance(self.inputs, Inputs)
features = self.inputs + self.outputs
return features.get(includes, excludes, exact)
return features.get(includes, excludes, exact) # type: ignore

def get_feature_keys(
self,
includes: Union[Type, List[Type]] = Feature,
excludes: Union[Type, List[Type]] = None,
excludes: Optional[Union[Type, List[Type]]] = None,
exact: bool = False,
) -> List[str]:
"""Method to get feature keys of the domain
Expand Down
3 changes: 2 additions & 1 deletion bofire/data_models/domain/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import Field, validate_arguments
from scipy.stats.qmc import LatinHypercube, Sobol

from bofire.data_models.base import BaseModel, filter_by_attribute, filter_by_class
from bofire.data_models.base import BaseModel
from bofire.data_models.enum import CategoricalEncodingEnum, SamplingMethodEnum
from bofire.data_models.features.api import (
_CAT_SEP,
Expand All @@ -28,6 +28,7 @@
Output,
TInputTransformSpecs,
)
from bofire.data_models.filters import filter_by_attribute, filter_by_class
from bofire.data_models.molfeatures.api import MolFeatures
from bofire.data_models.objectives.api import AbstractObjective, Objective

Expand Down
113 changes: 113 additions & 0 deletions bofire/data_models/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import collections.abc as collections
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Type,
Union,
get_args,
get_origin,
)


def filter_by_attribute(
data: Sequence,
attribute_getter: Callable[[Type], Any],
includes: Optional[Union[Type, Sequence[Type]]] = None,
excludes: Optional[Union[Type, Sequence[Type]]] = None,
exact: bool = False,
) -> List:
"""Returns those data elements where the attribute is of one of the include types.
Args:
data: to be filtered
attribute_getter: expects an item of the data list and returns the attribute to filter by
includes: attribute types that should be kept, sub-type are included by default, see exact
excludes: attribute types that will be excluded even if they are sub-types of or include types.
exact: true for not including subtypes
Returns:
list of data point with attributes as filtered for
"""
data_with_attr = []
for d in data:
try:
attribute_getter(d)
data_with_attr.append(d)
except AttributeError:
pass

filtered = filter_by_class(
data_with_attr,
includes=includes,
excludes=excludes,
exact=exact,
key=attribute_getter,
)
return filtered


def filter_by_class(
data: Sequence,
includes: Optional[Union[Type, Sequence[Type]]] = None,
excludes: Optional[Union[Type, Sequence[Type]]] = None,
exact: bool = False,
key: Callable[[Type], Any] = lambda x: x,
) -> List:
"""Returns those data elements where are one of the include types.
Args:
data: to be filtered
includes: attribute types that should be kept, sub-type are included by default, see exact
excludes: attribute types that will be excluded even if they are sub-types of or include types.
exact: true for not including subtypes
key: maps a data list item to something that is used for filtering, identity by default
Returns:
filtered list of data points
"""
if includes is None:
includes = []
if not isinstance(includes, collections.Sequence):
includes = [includes]
if excludes is None:
excludes = []
if not isinstance(excludes, collections.Sequence):
excludes = [excludes]

if len(includes) == len(excludes) == 0:
raise ValueError("no filter provided")

if len(includes) == 0:
includes = [object]

includes_ = []
for incl in includes:
if get_origin(incl) is Union:
includes_ += get_args(incl)
else:
includes_.append(incl)
includes = includes_
excludes_ = []
for excl in excludes:
if get_origin(excl) is Union:
excludes_ += get_args(excl)
else:
excludes_.append(excl)
excludes = excludes_

if len([x for x in includes if x in excludes]) > 0:
raise ValueError("includes and excludes overlap")

if exact:
return [
d for d in data if type(key(d)) in includes and type(key(d)) not in excludes
]
return [
d
for d in data
if isinstance(key(d), tuple(includes)) # type: ignore
and not isinstance(key(d), tuple(excludes)) # type: ignore
]
8 changes: 8 additions & 0 deletions bofire/data_models/kernels/continuous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Literal, Optional

from pydantic import validator

from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.priors.api import AnyPrior

Expand All @@ -20,6 +22,12 @@ class MaternKernel(ContinuousKernel):
nu: float = 2.5
lengthscale_prior: Optional[AnyPrior] = None

@validator("nu")
def validate_nu(cls, v, values):
if v not in {0.5, 1.5, 2.5}:
raise ValueError("nu expected to be 0.5, 1.5, or 2.5")
return v


class LinearKernel(ContinuousKernel):
type: Literal["LinearKernel"] = "LinearKernel"
Expand Down
3 changes: 1 addition & 2 deletions bofire/data_models/molfeatures/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from bofire.data_models.molfeatures.molfeatures import ( # BagOfCharacters
from bofire.data_models.molfeatures.molfeatures import (
Fingerprints,
FingerprintsFragments,
Fragments,
Expand All @@ -14,6 +14,5 @@
Fingerprints,
Fragments,
FingerprintsFragments,
# BagOfCharacters,
MordredDescriptors,
]
15 changes: 15 additions & 0 deletions bofire/kernels/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gpytorch
import torch
from botorch.models.kernels.categorical import CategoricalKernel
from gpytorch.kernels import Kernel as GpytorchKernel

import bofire.data_models.kernels.api as data_models
Expand Down Expand Up @@ -143,6 +144,19 @@ def map_TanimotoKernel(
)


def map_HammondDistanceKernel(
data_model: data_models.TanimotoKernel,
batch_shape: torch.Size,
ard_num_dims: int,
active_dims: List[int],
) -> CategoricalKernel:
return CategoricalKernel(
batch_shape=batch_shape,
ard_num_dims=len(active_dims) if data_model.ard else None,
active_dims=active_dims, # type: ignore
)


KERNEL_MAP = {
data_models.RBFKernel: map_RBFKernel,
data_models.MaternKernel: map_MaternKernel,
Expand All @@ -152,6 +166,7 @@ def map_TanimotoKernel(
data_models.MultiplicativeKernel: map_MultiplicativeKernel,
data_models.ScaleKernel: map_ScaleKernel,
data_models.TanimotoKernel: map_TanimotoKernel,
data_models.HammondDistanceKernel: map_HammondDistanceKernel,
}


Expand Down
Loading

0 comments on commit 38d0bd1

Please sign in to comment.