diff --git a/nptdms/test/writer/test_acceptance_tests.py b/nptdms/test/writer/test_writer.py similarity index 81% rename from nptdms/test/writer/test_acceptance_tests.py rename to nptdms/test/writer/test_writer.py index 63b6486..27474bb 100644 --- a/nptdms/test/writer/test_acceptance_tests.py +++ b/nptdms/test/writer/test_writer.py @@ -403,3 +403,89 @@ def test_specifying_invalid_version(): error_message = str(exception.value) assert "4712,4713" in error_message + + +def test_root_object_added(): + """ When not explicitly included, a root object should be added + """ + group = GroupObject("group") + channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10)) + + output_file = BytesIO() + with TdmsWriter(output_file) as tdms_writer: + tdms_writer.write_segment([group, channel]) + tdms_writer.write_segment([group, channel]) + + output_file.seek(0) + + tdms_file = TdmsFile(output_file) + first_segment_objects = tdms_file._reader._segments[0].ordered_objects + second_segment_objects = tdms_file._reader._segments[1].ordered_objects + + assert first_segment_objects[0].path == "/" + assert not any(obj.path == "/" for obj in second_segment_objects) + + +def test_group_object_added(): + """ When not explicitly included, a group object should be added + """ + root = RootObject() + channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10)) + + output_file = BytesIO() + with TdmsWriter(output_file) as tdms_writer: + tdms_writer.write_segment([root, channel]) + tdms_writer.write_segment([root, channel]) + + output_file.seek(0) + + tdms_file = TdmsFile(output_file) + first_segment_objects = tdms_file._reader._segments[0].ordered_objects + second_segment_objects = tdms_file._reader._segments[1].ordered_objects + + assert first_segment_objects[1].path == "/'group'" + assert not any(obj.path == "/'group'" for obj in second_segment_objects) + + +def test_group_not_duplicated(): + root = RootObject() + group = GroupObject("group") + channel = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10)) + + output_file = BytesIO() + with TdmsWriter(output_file) as tdms_writer: + tdms_writer.write_segment([root, group, channel]) + tdms_writer.write_segment([channel]) + + output_file.seek(0) + + tdms_file = TdmsFile(output_file) + first_segment_objects = tdms_file._reader._segments[0].ordered_objects + second_segment_objects = tdms_file._reader._segments[1].ordered_objects + + assert len(first_segment_objects) == 3 + assert len(second_segment_objects) == 1 + + +def test_root_and_groups_ordered_first(): + """ + The root and group objects should always come first + """ + root = RootObject() + group = GroupObject("group") + channel_0 = ChannelObject("group", "b", np.linspace(0.0, 1.0, 10)) + channel_1 = ChannelObject("group", "a", np.linspace(0.0, 1.0, 10)) + + output_file = BytesIO() + with TdmsWriter(output_file) as tdms_writer: + tdms_writer.write_segment([channel_0, group, channel_1, root]) + + output_file.seek(0) + + tdms_file = TdmsFile(output_file) + first_segment_objects = tdms_file._reader._segments[0].ordered_objects + + assert first_segment_objects[0].path == "/" + assert first_segment_objects[1].path == "/'group'" + assert first_segment_objects[2].path == "/'group'/'b'" + assert first_segment_objects[3].path == "/'group'/'a'" diff --git a/nptdms/writer.py b/nptdms/writer.py index dd2b5fd..94865d6 100644 --- a/nptdms/writer.py +++ b/nptdms/writer.py @@ -64,11 +64,11 @@ def __init__(self, file, mode='w', version=4712, index_file=False): It's important that if you are appending segments to an existing TDMS file, this matches the existing file version (this can be queried with the :py:attr:`~nptdms.TdmsFile.tdms_version` property). - :param index_file: Whether or not to write a index file besides the data file. Index files + :param index_file: Whether to write an index file besides the data file. Index files can be used to accelerate reading speeds for faster channel extraction and data positions inside - the data files. If ``file```variable is a path ``index_file`` can be ``True`` to store a ``.tdms_index`` - file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file`` variable - is a readable object ``index_file`` can either be a readable object to write into or ``False`` to omit. + the data files. If ``file```variable is a path, ``index_file`` can be ``True`` to store a ``.tdms_index`` + file at the same folder location or ``False`` to only write the data ``.tdms`` file. If ``file`` + is a readable object, ``index_file`` can either be a readable object to write into or ``False`` to omit. """ valid_versions = (4712, 4713) if version not in valid_versions: @@ -79,6 +79,8 @@ def __init__(self, file, mode='w', version=4712, index_file=False): self._index_file_path = None self._file_mode = mode self._tdms_version = version + self._root_written = False + self._groups_written = set() if hasattr(file, "read"): # Is a file @@ -123,6 +125,26 @@ def write_segment(self, objects): :param objects: A list of TdmsObject instances to write """ + path_object_pairs = [(ObjectPath.from_string(o.path), o) for o in objects] + + # Make sure a root object is included if this is the first segment, + # and any groups used by channels have associated group objects + add_root = (not self._root_written) and (not any(p[0].is_root for p in path_object_pairs)) + groups_included = set(p[0].group for p in path_object_pairs if p[0].is_group) + groups_required = set(p[0].group for p in path_object_pairs if p[0].is_channel) + groups_to_add = sorted(groups_required - groups_included - self._groups_written) + + if add_root: + path_object_pairs.append((ObjectPath(), RootObject())) + if groups_to_add: + path_object_pairs.extend((ObjectPath(g), GroupObject(g)) for g in groups_to_add) + + # Ensure objects are ordered with root first, then groups, in case any readers depend + # on parent objects being defined before their children. + # Channel ordering will be unchanged as sorts are stable. + path_object_pairs.sort(key=lambda p: _path_ordering_key(p[0])) + + objects = [p[1] for p in path_object_pairs] segment = TdmsSegment(objects, version=self._tdms_version) segment.write(self._file) @@ -130,6 +152,10 @@ def write_segment(self, objects): segment = TdmsSegment(objects, is_index_file=True, version=self._tdms_version) segment.write(self._index_file) + self._root_written = True + self._groups_written.update(groups_included) + self._groups_written.update(groups_to_add) + def __enter__(self): self.open() return self @@ -450,3 +476,12 @@ def _infer_dtype(data): else: return np.dtype('int8') return None + + +def _path_ordering_key(path): + if path.is_root: + return 0 + if path.is_group: + return 1 + if path.is_channel: + return 2