Skip to content

Commit

Permalink
fixed text_output for word_timestamps=False
Browse files Browse the repository at this point in the history
-fixed: [word_level]=False prevented text output when word_timestamps=False (now warns user instead)
-fixed: use "\n" instead of "\n\n" between segments for ASS output
-changed: [strip] default to True for result_to_srt_vtt() to follow old behavior
  • Loading branch information
jianfch committed Mar 19, 2023
1 parent 6ccfa17 commit ce4c7b3
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions stable_whisper/text_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import warnings
from typing import List, Tuple
from itertools import chain
from .stabilization import valid_ts
Expand Down Expand Up @@ -87,13 +88,32 @@ def to_word_level(segments: List[dict]) -> List[dict]:
return [dict(text=w['word'], start=w['start'], end=w['end']) for s in segments for w in s['words']]


def _confirm_word_level(segments: List[dict]) -> bool:
is_missing_words = len(set(bool(s.get('words')) for s in segments) - {True}) == 1
if is_missing_words:
warnings.warn('Result is missing word timestamps. Word-level timing cannot be exported. '
'Use `word_level=False` to avoid this warning and not export word-level timing.')
return False
return True


def _preprocess_args(result: (dict, list),
segment_level: bool,
word_level: bool):
assert segment_level or word_level, '`segment_level` or `word_level` must be True'
segments = _get_segments(result)
if word_level:
word_level = _confirm_word_level(segments)
return segments, segment_level, word_level


def result_to_srt_vtt(result: (dict, list),
filepath: str = None,
segment_level=True,
word_level=True,
tag: Tuple[str, str] = None,
vtt: bool = None,
strip=False):
strip=True):
"""
Generate SRT/VTT from result to display segment-level and/or word-level timestamp.
Expand Down Expand Up @@ -122,7 +142,7 @@ def result_to_srt_vtt(result: (dict, list),
string of content if no [filepath] is provided, else None
"""
assert segment_level or word_level
segments, segment_level, word_level = _preprocess_args(result, segment_level, word_level)

is_srt = (filepath is None or not filepath.endswith('.vtt')) if vtt is None else vtt
if filepath:
Expand All @@ -134,7 +154,6 @@ def result_to_srt_vtt(result: (dict, list),

sub_str = '' if is_srt else 'WEBVTT\n\n'

segments = _get_segments(result)
if word_level and segment_level:
if tag is None:
tag = ('<font color="#00ff00">', '</font>') if is_srt else ('<u>', '</u>')
Expand Down Expand Up @@ -199,6 +218,7 @@ def result_to_ass(result: (dict, list),
string of content if no [filepath] is provided, else None
"""
segments, segment_level, word_level = _preprocess_args(result, segment_level, word_level)

fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
Expand All @@ -224,7 +244,6 @@ def result_to_ass(result: (dict, list),
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'

segments = _get_segments(result)
if word_level and segment_level:
if tag is None:
color = 'HFF00'
Expand All @@ -235,7 +254,7 @@ def result_to_ass(result: (dict, list),

valid_ts(segments)

sub_str += '\n\n'.join(segment2assblock(s, i, strip=strip) for i, s in enumerate(segments))
sub_str += '\n'.join(segment2assblock(s, i, strip=strip) for i, s in enumerate(segments))

if filepath:
if not filepath.lower().endswith('.ass'):
Expand Down

0 comments on commit ce4c7b3

Please sign in to comment.