diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 0a61e2a6f..bdcb44cf9 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -642,6 +642,31 @@ def setup(self, stage=None): f"does not correspond to the cached one ({self.prepared_data['protocol']})" ) + # prepare annotations-segments into dict-like format, since it can't be stored in a cache .npy file like that + annotations = self.prepared_data['annotations-segments'] + annotations_dict = defaultdict(list) + file_ids = [] + for annotation in annotations: + file_id = annotation[0] + file_ids.append(file_id) + annotations_dict[file_id].append(annotation) + + segment_dtype = [ + ( + "file_id", + get_dtype(max(a[0] for a in annotations)), + ), + ("start", "f"), + ("end", "f"), + ("file_label_idx", get_dtype(max(a[3] for a in annotations))), + ("database_label_idx", get_dtype(max(a[4] for a in annotations))), + ("global_label_idx", get_dtype(max(a[5] for a in annotations))), + ] + + for file_id in file_ids: + annotations_dict[file_id] = np.array(annotations_dict[file_id], dtype=segment_dtype) + self.prepared_data['annotations-segments'] = annotations_dict + @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 9184121c4..6171535a1 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -281,10 +281,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 89d299a8d..68b35d66c 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -173,9 +173,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 8a091b1f7..b21710ca4 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -345,9 +345,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index e52613aeb..94d4da08a 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -154,9 +154,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[