Skip to content

Commit

Permalink
added batched transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Dec 11, 2024
1 parent 26066cf commit 6d14085
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ https://github.com/user-attachments/assets/797e6552-27cd-41b1-a7f3-e5cbc72094f5

### Updates

Latest Version: v0.3.81
Latest Version: v0.3.9

See [release history](https://github.com/KoljaB/RealtimeSTT/releases).

Expand Down Expand Up @@ -450,6 +450,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
- **level** (int, default=logging.WARNING): Logging level.
- **batch_size** (int, default=16): Batch size for the main transcription. Set to 0 to deactivate.
- **init_logging** (bool, default=True): Whether to initialize the logging framework. Set to False to manage this yourself.
- **handle_buffer_overflow** (bool, default=True): If set, the system will log a warning when an input overflow occurs during recording and remove the data from the buffer.
Expand Down Expand Up @@ -489,6 +491,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
- **on_realtime_transcription_stabilized**: A callback function that is triggered whenever there's an update in the real-time transcription and returns a higher quality, stabilized text as its argument.
- **realtime_batch_size**: (int, default=16): Batch size for the real-time transcription model. Set to 0 to deactivate.
- **beam_size_realtime** (int, default=3): The beam size to use for real-time transcription beam search decoding.
#### Voice Activation Parameters
Expand Down
68 changes: 51 additions & 17 deletions RealtimeSTT/audio_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from scipy.signal import resample
from scipy import signal
import signal as system_signal
from faster_whisper import WhisperModel, BatchedInferencePipeline
import faster_whisper
import openwakeword
import collections
Expand Down Expand Up @@ -88,7 +89,7 @@

class TranscriptionWorker:
def __init__(self, conn, stdout_pipe, model_path, download_root, compute_type, gpu_device_index, device,
ready_event, shutdown_event, interrupt_stop_event, beam_size, initial_prompt, suppress_tokens):
ready_event, shutdown_event, interrupt_stop_event, beam_size, initial_prompt, suppress_tokens, batch_size):
self.conn = conn
self.stdout_pipe = stdout_pipe
self.model_path = model_path
Expand All @@ -102,6 +103,7 @@ def __init__(self, conn, stdout_pipe, model_path, download_root, compute_type, g
self.beam_size = beam_size
self.initial_prompt = initial_prompt
self.suppress_tokens = suppress_tokens
self.batch_size = batch_size
self.queue = queue.Queue()

def custom_print(self, *args, **kwargs):
Expand Down Expand Up @@ -137,6 +139,8 @@ def run(self):
device_index=self.gpu_device_index,
download_root=self.download_root,
)
if self.batch_size > 0:
model = BatchedInferencePipeline(model=model)
except Exception as e:
logging.exception(f"Error initializing main faster_whisper transcription model: {e}")
raise
Expand All @@ -153,13 +157,24 @@ def run(self):
try:
audio, language = self.queue.get(timeout=0.1)
try:
segments, info = model.transcribe(
audio,
language=language if language else None,
beam_size=self.beam_size,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens
)
if self.batch_size > 0:
segments, info = model.transcribe(
audio,
language=language if language else None,
beam_size=self.beam_size,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens,
batch_size=self.batch_size
)
else:
segments, info = model.transcribe(
audio,
language=language if language else None,
beam_size=self.beam_size,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens
)

transcription = " ".join(seg.text for seg in segments).strip()
logging.debug(f"Final text detected with main model: {transcription}")
self.conn.send(('success', (transcription, info)))
Expand Down Expand Up @@ -211,6 +226,7 @@ def __init__(self,
use_microphone=True,
spinner=True,
level=logging.WARNING,
batch_size: int = 16,

# Realtime transcription parameters
enable_realtime_transcription=False,
Expand All @@ -220,6 +236,7 @@ def __init__(self,
init_realtime_after_seconds=INIT_REALTIME_INITIAL_PAUSE,
on_realtime_transcription_update=None,
on_realtime_transcription_stabilized=None,
realtime_batch_size: int = 16,

# Voice activation parameters
silero_sensitivity: float = INIT_SILERO_SENSITIVITY,
Expand Down Expand Up @@ -319,6 +336,7 @@ def __init__(self,
- spinner (bool, default=True): Show spinner animation with current
state.
- level (int, default=logging.WARNING): Logging level.
- batch_size (int, default=16): Batch size for the main transcription
- enable_realtime_transcription (bool, default=False): Enables or
disables real-time transcription of audio. When set to True, the
audio will be transcribed continuously as it is being recorded.
Expand Down Expand Up @@ -350,6 +368,8 @@ def __init__(self,
triggered when the transcribed text stabilizes in quality. The
stabilized text is generally more accurate but may arrive with a
slight delay compared to the regular real-time updates.
- realtime_batch_size (int, default=16): Batch size for the real-time
transcription model.
- silero_sensitivity (float, default=SILERO_SENSITIVITY): Sensitivity
for the Silero Voice Activity Detection model ranging from 0
(least sensitive) to 1 (most sensitive). Default is 0.5.
Expand Down Expand Up @@ -528,6 +548,8 @@ def __init__(self,
self.beam_size = beam_size
self.beam_size_realtime = beam_size_realtime
self.allowed_latency_limit = allowed_latency_limit
self.batch_size = batch_size
self.realtime_batch_size = realtime_batch_size

self.level = level
self.audio_queue = mp.Queue()
Expand Down Expand Up @@ -646,7 +668,8 @@ def __init__(self,
self.interrupt_stop_event,
self.beam_size,
self.initial_prompt,
self.suppress_tokens
self.suppress_tokens,
self.batch_size
)
)

Expand Down Expand Up @@ -676,14 +699,15 @@ def __init__(self,
logging.info("Initializing faster_whisper realtime "
f"transcription model {self.realtime_model_type}"
)
print(self.download_root)
self.realtime_model_type = faster_whisper.WhisperModel(
model_size_or_path=self.realtime_model_type,
device=self.device,
compute_type=self.compute_type,
device_index=self.gpu_device_index,
download_root=self.download_root,
)
if self.realtime_batch_size > 0:
self.realtime_model_type = BatchedInferencePipeline(model=self.realtime_model_type)

except Exception as e:
logging.exception("Error initializing faster_whisper "
Expand Down Expand Up @@ -2075,13 +2099,23 @@ def _realtime_worker(self):
continue
else:
# Perform transcription and assemble the text
segments, info = self.realtime_model_type.transcribe(
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens,
)
if self.realtime_batch_size > 0:
segments, info = self.realtime_model_type.transcribe(
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens,
batch_size=self.realtime_batch_size
)
else:
segments, info = self.realtime_model_type.transcribe(
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens
)

self.detected_realtime_language = info.language if info.language_probability > 0 else None
self.detected_realtime_language_probability = info.language_probability
Expand Down
4 changes: 3 additions & 1 deletion tests/realtimestt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
parser.add_argument('-d', '--root', type=str, # no default=None,
help='Root directory where the Whisper models are downloaded to.')

from tests.install_packages import check_and_install_packages
from install_packages import check_and_install_packages
check_and_install_packages([
{
'import_name': 'rich',
Expand Down Expand Up @@ -169,6 +169,8 @@ def process_text(text):
'early_transcription_on_silence': 0,
'beam_size': 5,
'beam_size_realtime': 3,
# 'batch_size': 0,
# 'realtime_batch_size': 0,
'no_log_file': True,
'initial_prompt': (
"End incomplete sentences with ellipses.\n"
Expand Down

0 comments on commit 6d14085

Please sign in to comment.