Skip to content

Commit

Permalink
Ensure root and group objects required are written (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve authored Jun 19, 2023
1 parent e1fe2ae commit e98b9eb
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,89 @@ def test_specifying_invalid_version():
error_message = str(exception.value)

assert "4712,4713" in error_message


def test_root_object_added():
""" When not explicitly included, a root object should be added
"""
group = GroupObject("group")
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))

output_file = BytesIO()
with TdmsWriter(output_file) as tdms_writer:
tdms_writer.write_segment([group, channel])
tdms_writer.write_segment([group, channel])

output_file.seek(0)

tdms_file = TdmsFile(output_file)
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
second_segment_objects = tdms_file._reader._segments[1].ordered_objects

assert first_segment_objects[0].path == "/"
assert not any(obj.path == "/" for obj in second_segment_objects)


def test_group_object_added():
""" When not explicitly included, a group object should be added
"""
root = RootObject()
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))

output_file = BytesIO()
with TdmsWriter(output_file) as tdms_writer:
tdms_writer.write_segment([root, channel])
tdms_writer.write_segment([root, channel])

output_file.seek(0)

tdms_file = TdmsFile(output_file)
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
second_segment_objects = tdms_file._reader._segments[1].ordered_objects

assert first_segment_objects[1].path == "/'group'"
assert not any(obj.path == "/'group'" for obj in second_segment_objects)


def test_group_not_duplicated():
root = RootObject()
group = GroupObject("group")
channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))

output_file = BytesIO()
with TdmsWriter(output_file) as tdms_writer:
tdms_writer.write_segment([root, group, channel])
tdms_writer.write_segment([channel])

output_file.seek(0)

tdms_file = TdmsFile(output_file)
first_segment_objects = tdms_file._reader._segments[0].ordered_objects
second_segment_objects = tdms_file._reader._segments[1].ordered_objects

assert len(first_segment_objects) == 3
assert len(second_segment_objects) == 1


def test_root_and_groups_ordered_first():
"""
The root and group objects should always come first
"""
root = RootObject()
group = GroupObject("group")
channel_0 = ChannelObject("group", "b", np.linspace(0.0, 1.0, 10))
channel_1 = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10))

output_file = BytesIO()
with TdmsWriter(output_file) as tdms_writer:
tdms_writer.write_segment([channel_0, group, channel_1, root])

output_file.seek(0)

tdms_file = TdmsFile(output_file)
first_segment_objects = tdms_file._reader._segments[0].ordered_objects

assert first_segment_objects[0].path == "/"
assert first_segment_objects[1].path == "/'group'"
assert first_segment_objects[2].path == "/'group'/'b'"
assert first_segment_objects[3].path == "/'group'/'a'"
43 changes: 39 additions & 4 deletions nptdms/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def __init__(self, file, mode='w', version=4712, index_file=False):
It's important that if you are appending segments to an
existing TDMS file, this matches the existing file version (this can be queried with the
:py:attr:`~nptdms.TdmsFile.tdms_version` property).
:param index_file: Whether or not to write a index file besides the data file. Index files
:param index_file: Whether to write an index file besides the data file. Index files
can be used to accelerate reading speeds for faster channel extraction and data positions inside
the data files. If ``file```variable is a path ``index_file`` can be ``True`` to store a ``.tdms_index``
file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file`` variable
is a readable object ``index_file`` can either be a readable object to write into or ``False`` to omit.
the data files. If ``file```variable is a path, ``index_file`` can be ``True`` to store a ``.tdms_index``
file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file``
is a readable object, ``index_file`` can either be a readable object to write into or ``False`` to omit.
"""
valid_versions = (4712, 4713)
if version not in valid_versions:
Expand All @@ -79,6 +79,8 @@ def __init__(self, file, mode='w', version=4712, index_file=False):
self._index_file_path = None
self._file_mode = mode
self._tdms_version = version
self._root_written = False
self._groups_written = set()

if hasattr(file, "read"):
# Is a file
Expand Down Expand Up @@ -123,13 +125,37 @@ def write_segment(self, objects):
:param objects: A list of TdmsObject instances to write
"""
path_object_pairs = [(ObjectPath.from_string(o.path), o) for o in objects]

# Make sure a root object is included if this is the first segment,
# and any groups used by channels have associated group objects
add_root = (not self._root_written) and (not any(p[0].is_root for p in path_object_pairs))
groups_included = set(p[0].group for p in path_object_pairs if p[0].is_group)
groups_required = set(p[0].group for p in path_object_pairs if p[0].is_channel)
groups_to_add = sorted(groups_required - groups_included - self._groups_written)

if add_root:
path_object_pairs.append((ObjectPath(), RootObject()))
if groups_to_add:
path_object_pairs.extend((ObjectPath(g), GroupObject(g)) for g in groups_to_add)

# Ensure objects are ordered with root first, then groups, in case any readers depend
# on parent objects being defined before their children.
# Channel ordering will be unchanged as sorts are stable.
path_object_pairs.sort(key=lambda p: _path_ordering_key(p[0]))

objects = [p[1] for p in path_object_pairs]
segment = TdmsSegment(objects, version=self._tdms_version)
segment.write(self._file)

if self._index_file is not None:
segment = TdmsSegment(objects, is_index_file=True, version=self._tdms_version)
segment.write(self._index_file)

self._root_written = True
self._groups_written.update(groups_included)
self._groups_written.update(groups_to_add)

def __enter__(self):
self.open()
return self
Expand Down Expand Up @@ -450,3 +476,12 @@ def _infer_dtype(data):
else:
return np.dtype('int8')
return None


def _path_ordering_key(path):
if path.is_root:
return 0
if path.is_group:
return 1
if path.is_channel:
return 2

0 comments on commit e98b9eb

Please sign in to comment.