Skip to content

Commit

Permalink
fix(task): fix support for "balance" option
Browse files Browse the repository at this point in the history
  • Loading branch information
FrenchKrab authored Sep 20, 2023
1 parent 1af9d44 commit 71f012b
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- fix(pipeline): fix support for IOBase audio
- fix(pipeline): fix corner case with no speaker
- fix(train): prevent metadata preparation to happen twice
- fix(task): fix support for "balance" option
- improve(task): shorten and improve structure of Tensorboard tags

### Dependencies
Expand Down
9 changes: 6 additions & 3 deletions pyannote/audio/tasks/segmentation/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def train__iter__helper(self, rng: random.Random, **filters):
# indices of training files that matches domain filters
training = self.metadata["subset"] == Subsets.index("train")
for key, value in filters.items():
training &= self.metadata[key] == value
training &= self.metadata[key] == self.metadata_unique_values[key].index(value)
file_ids = np.where(training)[0]

# turn annotated duration into a probability distribution
Expand Down Expand Up @@ -485,16 +485,19 @@ def train__iter__(self):
# create a subchunk generator for each combination of "balance" keys
subchunks = dict()
for product in itertools.product(
[self.metadata_unique_values[key] for key in balance]
*[self.metadata_unique_values[key] for key in balance]
):
# we iterate on the cartesian product of the values in metadata_unique_values
# eg: for balance=["database", "split"], with 2 databases and 2 splits:
# ("DIHARD", "A"), ("DIHARD", "B"), ("REPERE", "A"), ("REPERE", "B")
filters = {key: value for key, value in zip(balance, product)}
subchunks[product] = self.train__iter__helper(rng, **filters)

while True:
# select one subchunk generator at random (with uniform probability)
# so that it is balanced on average
if balance is not None:
chunks = subchunks[rng.choice(subchunks)]
chunks = subchunks[rng.choice(list(subchunks))]

# generate random chunk
yield next(chunks)
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ class MultiLabelSegmentation(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key to as frame-wise weight in loss function.
batch_size : int, optional
Expand All @@ -87,7 +87,7 @@ def __init__(
classes: Optional[List[str]] = None,
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/overlapped_speech_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
overlap: dict, optional
Controls how artificial chunks with overlapping speech are generated:
- "probability" key is the probability of artificial overlapping chunks. Setting
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
overlap: dict = OVERLAP_DEFAULTS,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ class SpeakerDiarization(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "database" will make sure that each database
will be equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key as frame-wise weight in loss function.
batch_size : int, optional
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
max_speakers_per_frame: int = None,
weigh_by_cardinality: bool = False,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/voice_activity_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task):
parts, only the remaining central part of each chunk is used for computing the
loss during training, and for aggregating scores during inference.
Defaults to 0. (i.e. no warm-up).
balance: str, optional
When provided, training samples are sampled uniformly with respect to that key.
For instance, setting `balance` to "uri" will make sure that each file will be
equally represented in the training samples.
balance: Sequence[Text], optional
When provided, training samples are sampled uniformly with respect to these keys.
For instance, setting `balance` to ["database","subset"] will make sure that each
database & subset combination will be equally represented in the training samples.
weight: str, optional
When provided, use this key to as frame-wise weight in loss function.
batch_size : int, optional
Expand All @@ -81,7 +81,7 @@ def __init__(
protocol: Protocol,
duration: float = 2.0,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
balance: Sequence[Text] = None,
weight: Text = None,
batch_size: int = 32,
num_workers: int = None,
Expand Down

0 comments on commit 71f012b

Please sign in to comment.