diff --git a/embeddings.py b/embeddings.py index 4ad70f46..f7e3acbb 100644 --- a/embeddings.py +++ b/embeddings.py @@ -183,11 +183,11 @@ def analyzeFile(item): # Set number of threads if os.path.isdir(cfg.INPUT_PATH): - cfg.CPU_THREADS = int(args.threads) + cfg.CPU_THREADS = max(1, int(args.threads)) cfg.TFLITE_THREADS = 1 else: cfg.CPU_THREADS = 1 - cfg.TFLITE_THREADS = int(args.threads) + cfg.TFLITE_THREADS = max(1, int(args.threads)) # Set batch size cfg.BATCH_SIZE = max(1, int(args.batchsize)) diff --git a/segments.py b/segments.py index f20bf844..cfb8db73 100644 --- a/segments.py +++ b/segments.py @@ -4,6 +4,7 @@ """ import argparse import os +import multiprocessing from multiprocessing import Pool import numpy as np @@ -152,6 +153,7 @@ def findSegments(afile: str, rfile: str): for i, line in enumerate(lines): if rtype == "table" and i > 0: + # TODO: Use header columns to get the right indices d = line.split("\t") start = float(d[5]) end = float(d[6]) @@ -186,8 +188,8 @@ def findSegments(afile: str, rfile: str): species = d[3] confidence = float(d[4]) - # Check if confidence is high enough - if confidence >= cfg.MIN_CONFIDENCE: + # Check if confidence is high enough and label is not "nocall" + if confidence >= cfg.MIN_CONFIDENCE and species.lower() != "nocall": segments.append({"audio": afile, "start": start, "end": end, "species": species, "confidence": confidence}) return segments @@ -239,8 +241,8 @@ def extractSegments(item: tuple[tuple[str, list[dict]], float, dict[str]]): os.makedirs(outpath, exist_ok=True) # Save segment - seg_name = "{:.3f}_{}_{}.wav".format( - seg["confidence"], seg_cnt, seg["audio"].rsplit(os.sep, 1)[-1].rsplit(".", 1)[0] + seg_name = "{:.3f}_{}_{}_{:.1f}s_{:.1f}s.wav".format( + seg["confidence"], seg_cnt, seg["audio"].rsplit(os.sep, 1)[-1].rsplit(".", 1)[0], seg["start"], seg["end"] ) seg_path = os.path.join(outpath, seg_name) audio.saveSignal(seg_sig, seg_path) @@ -267,7 +269,7 @@ def extractSegments(item: tuple[tuple[str, list[dict]], float, dict[str]]): parser.add_argument( "--seg_length", type=float, default=3.0, help="Length of extracted segments in seconds. Defaults to 3.0." ) - parser.add_argument("--threads", type=int, default=4, help="Number of CPU threads.") + parser.add_argument("--threads", type=int, default=min(8, max(1, multiprocessing.cpu_count() // 2)), help="Number of CPU threads.") args = parser.parse_args() diff --git a/train.py b/train.py index d7ef0932..33e99d1b 100644 --- a/train.py +++ b/train.py @@ -398,7 +398,7 @@ def run_trial(self, trial, *args, **kwargs): cfg.TRAIN_CACHE_MODE = args.cache_mode.lower() cfg.TRAIN_CACHE_FILE = args.cache_file cfg.TFLITE_THREADS = 1 - cfg.CPU_THREADS = cfg.CPU_THREADS = max(1, int(args.threads)) + cfg.CPU_THREADS = max(1, int(args.threads)) cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(args.fmin))) cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(args.fmax)))