diff --git a/reciprocalspaceship/utils/rfree.py b/reciprocalspaceship/utils/rfree.py index 3516b150..9aa6e1b3 100644 --- a/reciprocalspaceship/utils/rfree.py +++ b/reciprocalspaceship/utils/rfree.py @@ -3,7 +3,7 @@ from reciprocalspaceship.dtypes import MTZIntDtype -def add_rfree(dataset, fraction=0.05, ccp4_convention=False, inplace=False): +def add_rfree(dataset, fraction=0.05, ccp4_convention=False, inplace=False, seed=None): """ Add an r-free flag to the dataset object for refinement. R-free flags are used to identify reflections which are not used in automated refinement routines. @@ -23,6 +23,10 @@ def add_rfree(dataset, fraction=0.05, ccp4_convention=False, inplace=False): 1 is test set, 0 is working set, and key is "R-free-flags". See https://www.ccp4.ac.uk/html/freerflag.html#description for convention details. inplace : bool, optional + seed : int, optional + Seed to be passed to numpy.random.default_rng random number generator + for reproducible r-free flags. If None (default), r-free flags will + different each time. Returns ------- @@ -31,7 +35,9 @@ def add_rfree(dataset, fraction=0.05, ccp4_convention=False, inplace=False): """ if not inplace: dataset = dataset.copy() - test_set = np.random.random(len(dataset)) <= fraction + + rng = np.random.default_rng(seed) + test_set = rng.random(len(dataset)) <= fraction if not ccp4_convention: rfree_key = "R-free-flags" @@ -46,7 +52,7 @@ def add_rfree(dataset, fraction=0.05, ccp4_convention=False, inplace=False): return dataset -def copy_rfree(dataset, dataset_with_rfree, inplace=False): +def copy_rfree(dataset, dataset_with_rfree, inplace=False, rfree_key=None): """ Copy the rfree flag from one dataset object to another. @@ -58,6 +64,10 @@ def copy_rfree(dataset, dataset_with_rfree, inplace=False): A dataset with desired r-free flags. inplace : bool, optional Whether to operate in place or return a copy + rfree_key : str, optional + Name of the column containing rfree flags in dataset_with_rfree. + If None, dataset_with_rfree will be checked for column "R-free-flags" + (phenix convention) then column "FreeR_flag" (ccp4 convention) Returns ------- @@ -66,8 +76,22 @@ def copy_rfree(dataset, dataset_with_rfree, inplace=False): if not inplace: dataset = dataset.copy() - dataset["R-free-flags"] = 0 - dataset["R-free-flags"] = dataset["R-free-flags"].astype(MTZIntDtype()) + if rfree_key is not None: + if rfree_key not in dataset_with_rfree.columns: + raise ValueError( + f"""Supplied dataset_with_rfree contains no column {rfree_key}""" + ) + elif "R-free-flags" in dataset_with_rfree.columns: + rfree_key = "R-free-flags" + elif "FreeR_flag" in dataset_with_rfree.columns: + rfree_key = "FreeR_flag" + else: + raise ValueError( + """Failed to automatically find r-free flags in dataset_with_rfree. Please supply an rfree_key""" + ) + + dataset[rfree_key] = 0 + dataset[rfree_key] = dataset[rfree_key].astype(MTZIntDtype()) idx = dataset.index.intersection(dataset_with_rfree.index) - dataset.loc[idx, "R-free-flags"] = dataset_with_rfree.loc[idx, "R-free-flags"] + dataset.loc[idx, rfree_key] = dataset_with_rfree.loc[idx, rfree_key] return dataset diff --git a/tests/utils/test_rfree.py b/tests/utils/test_rfree.py index df12d7f6..6bdce312 100644 --- a/tests/utils/test_rfree.py +++ b/tests/utils/test_rfree.py @@ -1,6 +1,3 @@ -import unittest -from os.path import abspath, dirname, join - import numpy as np import pytest @@ -10,13 +7,18 @@ @pytest.mark.parametrize("fraction", [0.05, 0.10, 0.15]) @pytest.mark.parametrize("ccp4_convention", [False, True]) @pytest.mark.parametrize("inplace", [False, True]) -def test_add_rfree(data_fmodel, fraction, ccp4_convention, inplace): +@pytest.mark.parametrize("seed", [None, 2022]) +def test_add_rfree(data_fmodel, fraction, ccp4_convention, inplace, seed): """ - Test rs.utils.add_rfee + Test rs.utils.add_rfree """ data_copy = data_fmodel.copy() rfree = rs.utils.add_rfree( - data_fmodel, fraction=fraction, ccp4_convention=ccp4_convention, inplace=inplace + data_fmodel, + fraction=fraction, + ccp4_convention=ccp4_convention, + inplace=inplace, + seed=seed, ) if ccp4_convention: @@ -38,34 +40,78 @@ def test_add_rfree(data_fmodel, fraction, ccp4_convention, inplace): assert np.all(data_fmodel == data_copy) assert np.all(data_fmodel == rfree.loc[:, rfree.columns != label_name]) + repeat_rfree = rs.utils.add_rfree( + data_fmodel, + fraction=fraction, + ccp4_convention=ccp4_convention, + inplace=False, + seed=seed, + ) + if seed is not None: + assert np.all(rfree == repeat_rfree) + else: + assert not np.all(rfree == repeat_rfree) + + +@pytest.mark.parametrize("ccp4_convention", [False, True]) +@pytest.mark.parametrize("inplace", [False, True]) +@pytest.mark.parametrize("rfree_key", [None, "custom-rfree-key"]) +def test_copy_rfree(data_fmodel, ccp4_convention, inplace, rfree_key): + """ + Test rs.utils.copy_rfree + """ + data_copy = data_fmodel.copy() -class TestRfree(unittest.TestCase): - def test_copy_rfree(self): + # create dataset with rfree flags from which to copy + data_with_rfree = rs.utils.add_rfree( + data_fmodel, inplace=False, ccp4_convention=ccp4_convention + ) - datadir = join(abspath(dirname(__file__)), "../data/fmodel") - data = rs.read_mtz(join(datadir, "9LYZ.mtz")) - data_rfree = rs.utils.add_rfree(data, inplace=False) + # handle different possible column names for rfree flags + if rfree_key is not None: + if ccp4_convention: + rename_dict = {"FreeR_flag": rfree_key} + else: + rename_dict = {"R-free-flags": rfree_key} - # Test copy of R-free to copy of data - rfree = rs.utils.copy_rfree(data, data_rfree, inplace=False) - self.assertFalse(id(data) == id(rfree)) - self.assertFalse("R-free-flags" in data.columns) - self.assertTrue("R-free-flags" in rfree.columns) - self.assertTrue( - np.array_equal( - rfree["R-free-flags"].values, data_rfree["R-free-flags"].values - ) - ) + data_with_rfree.rename(columns=rename_dict, inplace=True) + else: + if ccp4_convention: + rfree_key = "FreeR_flag" + else: + rfree_key = "R-free-flags" + + data_with_copied_rfree = rs.utils.copy_rfree( + data_fmodel, data_with_rfree, inplace=inplace, rfree_key=rfree_key + ) - # Test copy of R-free inplace - rfree = rs.utils.copy_rfree(data, data_rfree, inplace=True) - self.assertTrue(id(data) == id(rfree)) - self.assertTrue("R-free-flags" in data.columns) - self.assertTrue("R-free-flags" in rfree.columns) - self.assertTrue( - np.array_equal( - rfree["R-free-flags"].values, data_rfree["R-free-flags"].values - ) + if inplace: + assert id(data_with_copied_rfree) == id(data_fmodel) + assert rfree_key in data_fmodel.columns + assert np.array_equal( + data_fmodel[rfree_key].values, data_with_rfree[rfree_key].values + ) + else: + assert id(data_with_copied_rfree) != id(data_fmodel) + assert rfree_key not in data_fmodel.columns + assert np.array_equal( + data_with_copied_rfree[rfree_key].values, data_with_rfree[rfree_key].values ) + assert np.all(data_fmodel == data_copy) + + +def test_copy_rfree_errors(data_fmodel): + """ + Test expected ValueErrors for rs.utils.copy_rfree + """ + # Raise ValueError because "R-free-flags" and "FreeR_flag" are missing + with pytest.raises(ValueError): + rs.utils.copy_rfree(data_fmodel, data_fmodel) - return + # Raise ValueError because "missing key" is missing, + # even though "R-free-flags" exists + data_with_standard_rfree = rs.utils.add_rfree(data_fmodel, inplace=False) + with pytest.raises(ValueError): + rs.utils.copy_rfree( + data_fmodel, data_with_standard_rfree, rfree_key="missing key" + )