diff --git a/README.md b/README.md index a75e59e..a341edb 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ https://github.com/user-attachments/assets/797e6552-27cd-41b1-a7f3-e5cbc72094f5 ### Updates -Latest Version: v0.3.81 +Latest Version: v0.3.9 See [release history](https://github.com/KoljaB/RealtimeSTT/releases). @@ -450,6 +450,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to - **level** (int, default=logging.WARNING): Logging level. +- **batch_size** (int, default=16): Batch size for the main transcription. Set to 0 to deactivate. + - **init_logging** (bool, default=True): Whether to initialize the logging framework. Set to False to manage this yourself. - **handle_buffer_overflow** (bool, default=True): If set, the system will log a warning when an input overflow occurs during recording and remove the data from the buffer. @@ -489,6 +491,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to - **on_realtime_transcription_stabilized**: A callback function that is triggered whenever there's an update in the real-time transcription and returns a higher quality, stabilized text as its argument. +- **realtime_batch_size**: (int, default=16): Batch size for the real-time transcription model. Set to 0 to deactivate. + - **beam_size_realtime** (int, default=3): The beam size to use for real-time transcription beam search decoding. #### Voice Activation Parameters diff --git a/RealtimeSTT/audio_recorder.py b/RealtimeSTT/audio_recorder.py index 8f6c287..688bfc1 100644 --- a/RealtimeSTT/audio_recorder.py +++ b/RealtimeSTT/audio_recorder.py @@ -34,6 +34,7 @@ from scipy.signal import resample from scipy import signal import signal as system_signal +from faster_whisper import WhisperModel, BatchedInferencePipeline import faster_whisper import openwakeword import collections @@ -88,7 +89,7 @@ class TranscriptionWorker: def __init__(self, conn, stdout_pipe, model_path, download_root, compute_type, gpu_device_index, device, - ready_event, shutdown_event, interrupt_stop_event, beam_size, initial_prompt, suppress_tokens): + ready_event, shutdown_event, interrupt_stop_event, beam_size, initial_prompt, suppress_tokens, batch_size): self.conn = conn self.stdout_pipe = stdout_pipe self.model_path = model_path @@ -102,6 +103,7 @@ def __init__(self, conn, stdout_pipe, model_path, download_root, compute_type, g self.beam_size = beam_size self.initial_prompt = initial_prompt self.suppress_tokens = suppress_tokens + self.batch_size = batch_size self.queue = queue.Queue() def custom_print(self, *args, **kwargs): @@ -137,6 +139,8 @@ def run(self): device_index=self.gpu_device_index, download_root=self.download_root, ) + if self.batch_size > 0: + model = BatchedInferencePipeline(model=model) except Exception as e: logging.exception(f"Error initializing main faster_whisper transcription model: {e}") raise @@ -153,13 +157,24 @@ def run(self): try: audio, language = self.queue.get(timeout=0.1) try: - segments, info = model.transcribe( - audio, - language=language if language else None, - beam_size=self.beam_size, - initial_prompt=self.initial_prompt, - suppress_tokens=self.suppress_tokens - ) + if self.batch_size > 0: + segments, info = model.transcribe( + audio, + language=language if language else None, + beam_size=self.beam_size, + initial_prompt=self.initial_prompt, + suppress_tokens=self.suppress_tokens, + batch_size=self.batch_size + ) + else: + segments, info = model.transcribe( + audio, + language=language if language else None, + beam_size=self.beam_size, + initial_prompt=self.initial_prompt, + suppress_tokens=self.suppress_tokens + ) + transcription = " ".join(seg.text for seg in segments).strip() logging.debug(f"Final text detected with main model: {transcription}") self.conn.send(('success', (transcription, info))) @@ -211,6 +226,7 @@ def __init__(self, use_microphone=True, spinner=True, level=logging.WARNING, + batch_size: int = 16, # Realtime transcription parameters enable_realtime_transcription=False, @@ -220,6 +236,7 @@ def __init__(self, init_realtime_after_seconds=INIT_REALTIME_INITIAL_PAUSE, on_realtime_transcription_update=None, on_realtime_transcription_stabilized=None, + realtime_batch_size: int = 16, # Voice activation parameters silero_sensitivity: float = INIT_SILERO_SENSITIVITY, @@ -319,6 +336,7 @@ def __init__(self, - spinner (bool, default=True): Show spinner animation with current state. - level (int, default=logging.WARNING): Logging level. + - batch_size (int, default=16): Batch size for the main transcription - enable_realtime_transcription (bool, default=False): Enables or disables real-time transcription of audio. When set to True, the audio will be transcribed continuously as it is being recorded. @@ -350,6 +368,8 @@ def __init__(self, triggered when the transcribed text stabilizes in quality. The stabilized text is generally more accurate but may arrive with a slight delay compared to the regular real-time updates. + - realtime_batch_size (int, default=16): Batch size for the real-time + transcription model. - silero_sensitivity (float, default=SILERO_SENSITIVITY): Sensitivity for the Silero Voice Activity Detection model ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.5. @@ -528,6 +548,8 @@ def __init__(self, self.beam_size = beam_size self.beam_size_realtime = beam_size_realtime self.allowed_latency_limit = allowed_latency_limit + self.batch_size = batch_size + self.realtime_batch_size = realtime_batch_size self.level = level self.audio_queue = mp.Queue() @@ -646,7 +668,8 @@ def __init__(self, self.interrupt_stop_event, self.beam_size, self.initial_prompt, - self.suppress_tokens + self.suppress_tokens, + self.batch_size ) ) @@ -676,7 +699,6 @@ def __init__(self, logging.info("Initializing faster_whisper realtime " f"transcription model {self.realtime_model_type}" ) - print(self.download_root) self.realtime_model_type = faster_whisper.WhisperModel( model_size_or_path=self.realtime_model_type, device=self.device, @@ -684,6 +706,8 @@ def __init__(self, device_index=self.gpu_device_index, download_root=self.download_root, ) + if self.realtime_batch_size > 0: + self.realtime_model_type = BatchedInferencePipeline(model=self.realtime_model_type) except Exception as e: logging.exception("Error initializing faster_whisper " @@ -2075,13 +2099,23 @@ def _realtime_worker(self): continue else: # Perform transcription and assemble the text - segments, info = self.realtime_model_type.transcribe( - audio_array, - language=self.language if self.language else None, - beam_size=self.beam_size_realtime, - initial_prompt=self.initial_prompt, - suppress_tokens=self.suppress_tokens, - ) + if self.realtime_batch_size > 0: + segments, info = self.realtime_model_type.transcribe( + audio_array, + language=self.language if self.language else None, + beam_size=self.beam_size_realtime, + initial_prompt=self.initial_prompt, + suppress_tokens=self.suppress_tokens, + batch_size=self.realtime_batch_size + ) + else: + segments, info = self.realtime_model_type.transcribe( + audio_array, + language=self.language if self.language else None, + beam_size=self.beam_size_realtime, + initial_prompt=self.initial_prompt, + suppress_tokens=self.suppress_tokens + ) self.detected_realtime_language = info.language if info.language_probability > 0 else None self.detected_realtime_language_probability = info.language_probability diff --git a/tests/realtimestt_test.py b/tests/realtimestt_test.py index bdce35a..4cd0189 100644 --- a/tests/realtimestt_test.py +++ b/tests/realtimestt_test.py @@ -21,7 +21,7 @@ parser.add_argument('-d', '--root', type=str, # no default=None, help='Root directory where the Whisper models are downloaded to.') - from tests.install_packages import check_and_install_packages + from install_packages import check_and_install_packages check_and_install_packages([ { 'import_name': 'rich', @@ -169,6 +169,8 @@ def process_text(text): 'early_transcription_on_silence': 0, 'beam_size': 5, 'beam_size_realtime': 3, + # 'batch_size': 0, + # 'realtime_batch_size': 0, 'no_log_file': True, 'initial_prompt': ( "End incomplete sentences with ellipses.\n"