Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface for setting project, crystal, and dataset names #157

Merged
merged 6 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions reciprocalspaceship/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,13 @@ def from_gemmi(cls, gemmiMtz):
"""
return cls(gemmiMtz)

def to_gemmi(self, skip_problem_mtztypes=False):
def to_gemmi(
self,
skip_problem_mtztypes=False,
project_name="reciprocalspaceship",
crystal_name="reciprocalspaceship",
dataset_name="reciprocalspaceship",
):
"""
Creates gemmi.Mtz object from DataSet object.
Expand All @@ -332,14 +338,22 @@ def to_gemmi(self, skip_problem_mtztypes=False):
skip_problem_mtztypes : bool
Whether to skip columns in DataSet that do not have specified
MTZ datatypes
project_name : str
Project name to assign to MTZ file
crystal_name : str
Crystal name to assign to MTZ file
dataset_name : str
Dataset name to assign to MTZ file
Returns
-------
gemmi.Mtz
"""
from reciprocalspaceship import io

return io.to_gemmi(self, skip_problem_mtztypes)
return io.to_gemmi(
self, skip_problem_mtztypes, project_name, crystal_name, dataset_name
)

def to_pickle(self, path, *args, **kwargs):
"""
Expand Down Expand Up @@ -471,7 +485,14 @@ def join(self, *args, check_isomorphous=True, **kwargs):
result = super().join(*args, **kwargs)
return result.__finalize__(self)

def write_mtz(self, mtzfile, skip_problem_mtztypes=False):
def write_mtz(
self,
mtzfile,
skip_problem_mtztypes=False,
project_name="reciprocalspaceship",
crystal_name="reciprocalspaceship",
dataset_name="reciprocalspaceship",
):
"""
Write DataSet to MTZ file.
Expand All @@ -489,10 +510,23 @@ def write_mtz(self, mtzfile, skip_problem_mtztypes=False):
skip_problem_mtztypes : bool
Whether to skip columns in DataSet that do not have specified
MTZ datatypes
project_name : str
Project name to assign to MTZ file
crystal_name : str
Crystal name to assign to MTZ file
dataset_name : str
Dataset name to assign to MTZ file
"""
from reciprocalspaceship import io

return io.write_mtz(self, mtzfile, skip_problem_mtztypes)
return io.write_mtz(
self,
mtzfile,
skip_problem_mtztypes,
project_name,
crystal_name,
dataset_name,
)

def select_mtzdtype(self, dtype):
"""
Expand Down
60 changes: 52 additions & 8 deletions reciprocalspaceship/io/mtz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import gemmi
import numpy as np
import pandas as pd

from reciprocalspaceship import DataSet
from reciprocalspaceship.dtypes.base import MTZDtype
Expand Down Expand Up @@ -58,7 +56,13 @@ def from_gemmi(gemmi_mtz):
return dataset


def to_gemmi(dataset, skip_problem_mtztypes=False):
def to_gemmi(
dataset,
skip_problem_mtztypes,
project_name,
crystal_name,
dataset_name,
):
"""
Construct gemmi.Mtz object from DataSet

Expand All @@ -76,6 +80,12 @@ def to_gemmi(dataset, skip_problem_mtztypes=False):
skip_problem_mtztypes : bool
Whether to skip columns in DataSet that do not have specified
mtz datatypes
project_name : str
Project name to assign to MTZ file
crystal_name : str
Crystal name to assign to MTZ file
dataset_name : str
Dataset name to assign to MTZ file

Returns
-------
Expand All @@ -91,6 +101,20 @@ def to_gemmi(dataset, skip_problem_mtztypes=False):
f"Instance of type {dataset.__class__.__name__} has no space group information"
)

# Check project_name, crystal_name, and dataset_name are str
if not isinstance(project_name, str):
raise ValueError(
f"project_name must be a string. Given type: {type(project_name)}"
)
if not isinstance(crystal_name, str):
raise ValueError(
f"crystal_name must be a string. Given type: {type(crystal_name)}"
)
if not isinstance(dataset_name, str):
raise ValueError(
f"dataset_name must be a string. Given type: {type(dataset_name)}"
)

# Build up a gemmi.Mtz object
mtz = gemmi.Mtz()
mtz.cell = dataset.cell
Expand All @@ -102,19 +126,24 @@ def to_gemmi(dataset, skip_problem_mtztypes=False):
if not all_in_asu:
dataset.hkl_to_asu(inplace=True)

# Construct data for Mtz object.
# Add Dataset with indicated names
mtz.add_dataset("reciprocalspaceship")
mtz.datasets[0].project_name = project_name
mtz.datasets[0].crystal_name = crystal_name
mtz.datasets[0].dataset_name = dataset_name

# Construct data for Mtz object
temp = dataset.reset_index()
columns = []
for c in temp.columns:
cseries = temp[c]
if isinstance(cseries.dtype, MTZDtype):
mtzcol = mtz.add_column(label=c, type=cseries.dtype.mtztype)
mtz.add_column(label=c, type=cseries.dtype.mtztype)
columns.append(c)
# Special case for CENTRIC and PARTIAL flags
elif cseries.dtype.name == "bool" and c in ["CENTRIC", "PARTIAL"]:
temp[c] = temp[c].astype("MTZInt")
mtzcol = mtz.add_column(label=c, type="I")
mtz.add_column(label=c, type="I")
columns.append(c)
elif skip_problem_mtztypes:
continue
Expand Down Expand Up @@ -159,7 +188,14 @@ def read_mtz(mtzfile):
return from_gemmi(gemmi_mtz)


def write_mtz(dataset, mtzfile, skip_problem_mtztypes=False):
def write_mtz(
dataset,
mtzfile,
skip_problem_mtztypes,
project_name,
crystal_name,
dataset_name,
):
"""
Write an MTZ reflection file from the reflection data in a DataSet.

Expand All @@ -179,7 +215,15 @@ def write_mtz(dataset, mtzfile, skip_problem_mtztypes=False):
skip_problem_mtztypes : bool
Whether to skip columns in DataSet that do not have specified
MTZ datatypes
project_name : str
Project name to assign to MTZ file
crystal_name : str
Crystal name to assign to MTZ file
dataset_name : str
Dataset name to assign to MTZ file
"""
mtz = to_gemmi(dataset, skip_problem_mtztypes)
mtz = to_gemmi(
dataset, skip_problem_mtztypes, project_name, crystal_name, dataset_name
)
mtz.write_to_file(mtzfile)
return
80 changes: 79 additions & 1 deletion tests/io/test_mtz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from os.path import exists

import gemmi
import numpy as np
import pytest
from pandas.testing import assert_frame_equal

Expand Down Expand Up @@ -145,3 +144,82 @@ def test_unmerged_after_write(data_unmerged, in_asu):
expected = data_unmerged.copy()
data_unmerged.write_mtz("/dev/null")
assert_frame_equal(data_unmerged, expected)


@pytest.mark.parametrize("project_name", [None, "project", "reciprocalspaceship", 1])
@pytest.mark.parametrize("crystal_name", [None, "crystal", "reciprocalspaceship", 1])
@pytest.mark.parametrize("dataset_name", [None, "dataset", "reciprocalspaceship", 1])
def test_to_gemmi_names(IOtest_mtz, project_name, crystal_name, dataset_name):
"""
Test that DataSet.to_gemmi() sets project/crystal/dataset names when given.
ValueError should be raised for anything other than a string.
"""
ds = rs.read_mtz(IOtest_mtz)

if (
not isinstance(project_name, str)
or not isinstance(crystal_name, str)
or not isinstance(dataset_name, str)
):
with pytest.raises(ValueError):
ds.to_gemmi(
project_name=project_name,
crystal_name=crystal_name,
dataset_name=dataset_name,
)
return

gemmimtz = ds.to_gemmi(
project_name=project_name,
crystal_name=crystal_name,
dataset_name=dataset_name,
)

assert gemmimtz.dataset(0).project_name == project_name
assert gemmimtz.dataset(0).crystal_name == crystal_name
assert gemmimtz.dataset(0).dataset_name == dataset_name


@pytest.mark.parametrize("project_name", [None, "project", "reciprocalspaceship", 1])
@pytest.mark.parametrize("crystal_name", [None, "crystal", "reciprocalspaceship", 1])
@pytest.mark.parametrize("dataset_name", [None, "dataset", "reciprocalspaceship", 1])
def test_write_mtz_names(IOtest_mtz, project_name, crystal_name, dataset_name):
"""
Test that DataSet.write_mtz() sets project/crystal/dataset names when given.
ValueError should be raised for anything other than a string.
"""
ds = rs.read_mtz(IOtest_mtz)

temp = tempfile.NamedTemporaryFile(suffix=".mtz")
if (
not isinstance(project_name, str)
or not isinstance(crystal_name, str)
or not isinstance(dataset_name, str)
):
with pytest.raises(ValueError):
ds.write_mtz(
temp.name,
project_name=project_name,
crystal_name=crystal_name,
dataset_name=dataset_name,
)
temp.close()
return
else:
ds.write_mtz(
temp.name,
project_name=project_name,
crystal_name=crystal_name,
dataset_name=dataset_name,
)

gemmimtz = gemmi.read_mtz_file(temp.name)

assert gemmimtz.dataset(0).project_name == project_name
assert gemmimtz.dataset(0).crystal_name == crystal_name
assert gemmimtz.dataset(0).dataset_name == dataset_name

# Clean up
temp.close()