diff --git a/demos/whisper/README.md b/demos/whisper/README.md new file mode 100644 index 00000000000..455bca92b4a --- /dev/null +++ b/demos/whisper/README.md @@ -0,0 +1,89 @@ +([简体中文](./README_cn.md)|English) + +## Introduction +Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification. + +Whisper model trained by OpenAI whisper https://github.com/openai/whisper + +## Usage + ### 1. Installation + see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + + You can choose one way from easy, meduim and hard to install paddlespeech. + + ### 2. Prepare Input File + The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + + Here are sample files for this demo that can be downloaded: + ```bash + wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + ``` + + ### 3. Usage + - Command Line(Recommended) + ```bash + # to recognize text + paddlespeech whisper --task transcribe --input ./zh.wav + + # to recognize text and translate to English + paddlespeech whisper --task translate --input ./zh.wav + ``` + + Usage: + ```bash + paddlespeech whisper --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `model`: Model type of asr task. Default: `whisper-large`. + - `task`: Output type. Default: `transcribe`. + - `lang`: Model language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default. + - `sample_rate`: Sample rate of the model. Default: `16000`. Other sampling rates are not supported now. + - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + + - Python API + ```python + import paddle + from paddlespeech.cli.whisper import WhisperExecutor + + whisper_executor = WhisperExecutor() + + # to recognize text + text = whisper_executor( + model='whisper-large', + task='transcribe', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # to recognize text and translate to English + feature = whisper_executor( + model='whisper-large', + task='translate', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + Output: + ```bash + Transcribe Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康 + {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} + + Translate Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health. + {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} diff --git a/demos/whisper/README_cn.md b/demos/whisper/README_cn.md new file mode 100644 index 00000000000..784952761c8 --- /dev/null +++ b/demos/whisper/README_cn.md @@ -0,0 +1,91 @@ +(简体中文|[English](./README.md)) + +# Whisper模型 +## 介绍 +Whisper是一种通用的语音识别模型。它是在多种音频的大数据集上训练的,也是一个多任务模型,可以执行多语言语音识别以及语音翻译和语言识别。 + +Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper + +## 使用方法 +### 1. 安装 + 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + + 你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 + 这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + + 可以下载此 demo 的示例音频: + ```bash + wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + ``` + +### 3. 使用方法 + - 命令行 (推荐使用) + ```bash + + # 识别文本 + paddlespeech whisper --task transcribe --input ./zh.wav + + # 将语音翻译成英语 + paddlespeech whisper --task translate --input ./zh.wav + ``` + 使用方法: + ```bash + paddlespeech whisper --help + ``` + 参数: + - `input`(必须输入):用于识别的音频文件。 + - `model`:ASR 任务的模型,默认值:`whisper-large`。 + - `task`:输出类别,默认值:`transcribe`。 + - `lang`:模型语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。 + - `sample_rate`:音频采样率,默认值:`16000`,目前Whisper暂不支持其他采样率。 + - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载解码模型使用,默认值:`None`。 + - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + +- Python API + ```python + import paddle + from paddlespeech.cli.whisper import WhisperExecutor + + whisper_executor = WhisperExecutor() + + # 识别文本 + text = whisper_executor( + model='whisper-large', + task='transcribe', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # 将语音翻译成英语 + feature = whisper_executor( + model='whisper-large', + task='translate', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + + 输出: + ```bash + Transcribe Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康 + {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} + + Translate Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health. + {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} diff --git a/demos/whisper/run.sh b/demos/whisper/run.sh new file mode 100644 index 00000000000..1d758108d81 --- /dev/null +++ b/demos/whisper/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# audio download +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + +# to recognize text +paddlespeech whisper --task transcribe --input ./zh.wav + +# to recognize text and translate to English +paddlespeech whisper --task translate --input ./zh.wav \ No newline at end of file diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 7210091a98a..7551b6c02be 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -83,7 +83,8 @@ def execute(self, argv: List[str]) -> bool: 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', 'tts': 'Model-Language', - 'vector': 'Model-Sample Rate' + 'vector': 'Model-Sample Rate', + 'whisper': 'Model-Language-Sample Rate' } @@ -94,7 +95,9 @@ class StatsCommand: def __init__(self): self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] + self.task_choices = [ + 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper' + ] self.parser.add_argument( '--task', type=str, @@ -141,6 +144,10 @@ def execute(self, argv: List[str]) -> bool: 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], + 'whisper': [ + 'Whisper model for speech to text or translate speech to English.', + 'WhisperExecutor' + ] } for com, info in _commands.items(): diff --git a/paddlespeech/cli/whisper/__init__.py b/paddlespeech/cli/whisper/__init__.py new file mode 100644 index 00000000000..3bafc10d250 --- /dev/null +++ b/paddlespeech/cli/whisper/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import WhisperExecutor diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py new file mode 100644 index 00000000000..3b1771b2dd6 --- /dev/null +++ b/paddlespeech/cli/whisper/infer.py @@ -0,0 +1,468 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import os +import sys +import time +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import librosa +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from ..download import get_path_from_url +from ..executor import BaseExecutor +from ..log import logger +from ..utils import CLI_TIMER +from ..utils import stats_wrapper +from ..utils import timer_register +from paddlespeech.s2t.models.whisper import log_mel_spectrogram +from paddlespeech.s2t.models.whisper import ModelDimensions +from paddlespeech.s2t.models.whisper import Whisper +from paddlespeech.s2t.utils.utility import UpdateConfig + +__all__ = ['WhisperExecutor'] + + +@timer_register +class WhisperExecutor(BaseExecutor): + def __init__(self): + super().__init__('whisper') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.whisper', add_help=True) + self.parser.add_argument( + '--input', type=str, default=None, help='Audio file to recognize.') + self.parser.add_argument( + '--model', + type=str, + default='whisper', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of asr task.') + self.parser.add_argument( + '--lang', + type=str, + default='None', + help='Choose model decode language. Default is None, recognized by model.' + ) + self.parser.add_argument( + '--task', + type=str, + default='transcribe', + choices=["transcribe", "translate"], + help='Choose task tpye for transcribe or translate.') + self.parser.add_argument( + '--size', + type=str, + default='large', + help='Choose model size. now only support large, large:[whisper-large-16k]' + ) + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[16000], + help='Choose the audio sample rate of the model. only support 16000') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='ctc_prefix_beam_search', + choices=['ctc_greedy_search', 'ctc_prefix_beam_search'], + help='only support transformer and conformer model') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='whisper', + lang: str='None', + task: str='transcribe', + size: str='large', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='ctc_prefix_beam_search', + num_decoding_left_chunks: int=-1, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + logger.debug("start to init the model") + # default max_len: unit:second + self.max_len = 50 + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + + if cfg_path is None or ckpt_path is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + size + '-' + sample_rate_str + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.debug(self.res_path) + + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + logger.debug(self.cfg_path) + logger.debug(self.ckpt_path) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + with UpdateConfig(self.config): + if "whisper" in model_type: + resource_url = self.task_resource.res_dict['resuource_data'] + resource_md5 = self.task_resource.res_dict['resuource_data_md5'] + resuource_path = self.task_resource.res_dict['resuource_path'] + self.download_resource(resource_url, resuource_path, + resource_md5) + else: + raise Exception("wrong type") + + # load model + model_dict = paddle.load(self.ckpt_path) + dims = ModelDimensions(**model_dict["dims"]) + self.model = Whisper(dims) + self.model.load_dict(model_dict) + self.model.eval() + + #set task + if task is not None: + self.task = task + + #set language + if lang is not None: + self.language = lang + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + if isinstance(audio_file, (str, os.PathLike)): + logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + # Get the object for feature extraction + # whisper hard-coded audio hyperparameters, params in paddlespeech/s2t/models/whisper/whisper.py + logger.debug("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="float32", always_2d=True) + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample( + audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.debug(f"audio shape: {audio.shape}") + # fbank + audio = log_mel_spectrogram(audio) + + audio_len = paddle.to_tensor(audio.shape[0]) + #audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.debug(f"audio feat shape: {audio.shape}") + + logger.debug("audio feat process success") + + @paddle.no_grad() + def infer(self, model_type: str): + """ + Model inference and result stored in self.output. + """ + logger.debug("start to infer the model to get the output") + cfg = self.config + audio = self._inputs["audio"] + if cfg.temperature_increment_on_fallback is not None: + temperature = tuple( + np.arange(cfg.temperature, 1.0 + 1e-6, + cfg.temperature_increment_on_fallback)) + else: + temperature = [cfg.temperature] + + self._outputs["result"] = self.model.transcribe( + audio, + verbose=cfg.verbose, + task=self.task, + language=self.language, + temperature=temperature, + compression_ratio_threshold=cfg.compression_ratio_threshold, + logprob_threshold=cfg.logprob_threshold, + best_of=cfg.best_of, + beam_size=cfg.beam_size, + patience=cfg.patience, + length_penalty=cfg.length_penalty, + initial_prompt=cfg.initial_prompt, + condition_on_previous_text=cfg.condition_on_previous_text, + no_speech_threshold=cfg.no_speech_threshold) + + def postprocess(self) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._outputs["result"] + + def download_resource(self, url, lm_dir, md5sum): + download_path = get_path_from_url( + url=url, + root_dir=lm_dir, + md5sum=md5sum, + decompress=True, ) + + def _pcm16to32(self, audio): + assert (audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert (audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + logger.debug("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) + return False + except Exception as e: + logger.exception(e) + logger.error( + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + logger.debug("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.debug("The audio file format is right") + self.change_format = False + + return True + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model = parser_args.model + lang = parser_args.lang + task = parser_args.task + size = parser_args.size + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + decode_method = parser_args.decode_method + force_yes = parser_args.yes + rtf = parser_args.rtf + device = parser_args.device + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self( + audio_file=input_, + model=model, + lang=lang, + task=task, + size=size, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='whisper', + lang: str='None', + task: str='transcribe', + size: str='large', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, + force_yes: bool=False, + rtf: bool=False, + device=paddle.get_device()): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(audio_file) + paddle.set_device(device) + self._init_from_path(model, lang, task, size, sample_rate, config, + decode_method, num_decoding_left_chunks, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + + self.preprocess(model, audio_file) + self.infer(model) + res = self.postprocess() # Retrieve result of asr. + + if rtf: + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + + return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 8e9ecc4ba29..ce7fa662fe2 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -29,6 +29,11 @@ "transformer": ["paddlespeech.s2t.models.u2:U2Model"], "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], + # --------------------------------- + # ------------ Whisper ------------ + # --------------------------------- + "whisper": ["paddlespeech.s2t.models.whisper:Whisper"], + # --------------------------------- # -------------- CLS -------------- # --------------------------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index df50a6a9d52..b83d66f26be 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -25,6 +25,7 @@ 'tts_static_pretrained_models', 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', + 'whisper_dynamic_pretrained_models', ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". @@ -424,6 +425,31 @@ }, } +whisper_dynamic_pretrained_models = { + "whisper-large-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz', + 'md5': + '364c4d670835e5ca489045e1c29d75fe', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-large-model', + 'model': + 'whisper-large-model.pdparams', + 'params': + 'whisper-large-model.pdparams', + 'resuource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', + 'resuource_data_md5': + '37a0a8abdb3641a51194f79567a93b61', + 'resuource_path': + 'paddlespeech/s2t/models/whisper', + }, + }, +} + # --------------------------------- # -------------- CLS -------------- # --------------------------------- diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 8e9914b2e13..d3d89f4de8a 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -22,7 +22,7 @@ from ..utils.env import MODEL_HOME from .model_alias import model_alias -task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper'] model_format_supported = ['dynamic', 'static', 'onnx'] inference_mode_supported = ['online', 'offline'] diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py new file mode 100644 index 00000000000..63945b9eb94 --- /dev/null +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.∏ +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Whisper (https://github.com/openai/whisper/whisper/) +import os.path +import sys + +import distutils +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from paddlespeech.s2t.models.whisper import log_mel_spectrogram +from paddlespeech.s2t.models.whisper import ModelDimensions +from paddlespeech.s2t.models.whisper import transcribe +from paddlespeech.s2t.models.whisper import Whisper +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +class WhisperInfer(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + + paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + config.pop("ngpu") + + #load_model + model_dict = paddle.load(self.config.model_file) + config.pop("model_file") + dims = ModelDimensions(**model_dict["dims"]) + self.model = Whisper(dims) + self.model.load_dict(model_dict) + + def run(self): + check(args.audio_file) + + with paddle.no_grad(): + temperature = config.pop("temperature") + temperature_increment_on_fallback = config.pop( + "temperature_increment_on_fallback") + if temperature_increment_on_fallback is not None: + temperature = tuple( + np.arange(temperature, 1.0 + 1e-6, + temperature_increment_on_fallback)) + else: + temperature = [temperature] + + #load audio + mel = log_mel_spectrogram(args.audio) + + result = transcribe( + self.model, mel, temperature=temperature, **config) + if args.result_file is not None: + with open(args.result_file, 'w') as f: + f.write(str(result)) + return result + + +def check(audio_file: str): + if not os.path.isfile(audio_file): + print("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + _, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main(config, args): + WhisperInfer(config, args).run() + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + "--audio_file", type=str, help="path of the input audio file") + parser.add_argument( + "--debug", + type=distutils.util.strtobool, + default=False, + help="for debug.") + args = parser.parse_args() + + config = CfgNode(new_allowed=True) + + if args.config: + config.merge_from_file(args.config) + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + main(config, args) diff --git a/paddlespeech/s2t/models/whisper/__init__.py b/paddlespeech/s2t/models/whisper/__init__.py new file mode 100644 index 00000000000..98ab2361086 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/__init__.py @@ -0,0 +1,12 @@ +# MIT License, Copyright (c) 2022 OpenAI. +# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. +# +# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py) +from paddlespeech.s2t.models.whisper.whipser import decode +from paddlespeech.s2t.models.whisper.whipser import DecodingOptions +from paddlespeech.s2t.models.whisper.whipser import DecodingResult +from paddlespeech.s2t.models.whisper.whipser import detect_language +from paddlespeech.s2t.models.whisper.whipser import log_mel_spectrogram +from paddlespeech.s2t.models.whisper.whipser import ModelDimensions +from paddlespeech.s2t.models.whisper.whipser import transcribe +from paddlespeech.s2t.models.whisper.whipser import Whisper diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py new file mode 100644 index 00000000000..1c58c94c756 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -0,0 +1,360 @@ +# MIT License, Copyright (c) 2022 OpenAI. +# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. +# +# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py) +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import paddle +from paddlenlp.transformers import GPTTokenizer + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "iw": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + + +@dataclass(frozen=True) +class Tokenizer: + """A thin wrapper around `GPTTokenizer` providing quick access to special tokens""" + + tokenizer: "GPTTokenizer" + language: Optional[str] + sot_sequence: Tuple[int] + + def encode(self, text, **kwargs): + return self.tokenizer.encode(text, **kwargs) + + def decode(self, + token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], + **kwargs): + if len(token_ids) > 1: + ids_list = [] + for ids in token_ids: + if paddle.is_tensor(ids): + ids = ids.item() + if ids < len(self.tokenizer): + ids_list.append(ids) + token_ids = ids_list + + return self.tokenizer.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, tokens) -> str: + """ + Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. + This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + outputs = [[]] + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + outputs = [ + s if isinstance(s, str) else self.tokenizer.decode(s) + for s in outputs + ] + return "".join(outputs) + + @property + @lru_cache() + def eot(self) -> int: + return self.tokenizer.eos_token_id + + @property + @lru_cache() + def sot(self) -> int: + return self._get_single_token_id("<|startoftranscript|>") + + @property + @lru_cache() + def sot_lm(self) -> int: + return self._get_single_token_id("<|startoflm|>") + + @property + @lru_cache() + def sot_prev(self) -> int: + return self._get_single_token_id("<|startofprev|>") + + @property + @lru_cache() + def no_speech(self) -> int: + return self._get_single_token_id("<|nospeech|>") + + @property + @lru_cache() + def no_timestamps(self) -> int: + return self._get_single_token_id("<|notimestamps|>") + + @property + @lru_cache() + def timestamp_begin(self) -> int: + return self.tokenizer.all_special_ids[-1] + 1 + + @property + @lru_cache() + def language_token(self) -> int: + """Returns the token id corresponding to the value of the `language` field""" + if self.language is None: + raise ValueError( + "This tokenizer does not have language token configured") + + additional_tokens = dict( + zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, )) + candidate = f"<|{self.language}|>" + if candidate in additional_tokens: + return additional_tokens[candidate] + + raise KeyError(f"Language {self.language} not found in tokenizer.") + + @property + @lru_cache() + def all_language_tokens(self) -> Tuple[int]: + result = [] + for token, token_id in zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, ): + if token.strip("<|>") in LANGUAGES: + result.append(token_id) + return tuple(result) + + @property + @lru_cache() + def all_language_codes(self) -> Tuple[str]: + return tuple( + self.decode([l]).strip("<|>") for l in self.all_language_tokens) + + @property + @lru_cache() + def sot_sequence_including_notimestamps(self) -> Tuple[int]: + return tuple(list(self.sot_sequence) + [self.no_timestamps]) + + @property + @lru_cache() + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") + symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split( + ) + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = { + self.tokenizer.encode(" -").input_ids[0], + self.tokenizer.encode(" '").input_ids[0] + } + for symbol in symbols + list(miscellaneous): + for tokens in [ + self.tokenizer.encode(symbol).input_ids, + self.tokenizer.encode(" " + symbol).input_ids + ]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def _get_single_token_id(self, text) -> int: + tokens = self.tokenizer.encode(text).input_ids + assert len(tokens) == 1, f"{text} is not encoded as a single token" + return tokens[0] + + +@lru_cache(maxsize=None) +def build_tokenizer(name: str="gpt2"): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + path = os.path.join(os.path.dirname(__file__), "assets", name) + tokenizer = GPTTokenizer.from_pretrained(path) + + specials = [ + "<|startoftranscript|>", + * [f"<|{lang}|>" for lang in LANGUAGES.keys()], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + ] + + tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) + return tokenizer + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + task: Optional[str]=None, # Literal["transcribe", "translate", None] + language: Optional[str]=None, ) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + + if multilingual: + tokenizer_name = "multilingual" + task = task or "transcribe" + language = language or "en" + else: + tokenizer_name = "gpt2" + task = None + language = None + + tokenizer = build_tokenizer(name=tokenizer_name) + all_special_ids: List[int] = tokenizer.all_special_ids + sot: int = all_special_ids[1] + translate: int = all_special_ids[-6] + transcribe: int = all_special_ids[-5] + + langs = tuple(LANGUAGES.keys()) + sot_sequence = [sot] + if language is not None: + sot_sequence.append(sot + 1 + langs.index(language)) + if task is not None: + sot_sequence.append(transcribe if task == "transcribe" else translate) + + return Tokenizer( + tokenizer=tokenizer, + language=language, + sot_sequence=tuple(sot_sequence)) diff --git a/paddlespeech/s2t/models/whisper/utils.py b/paddlespeech/s2t/models/whisper/utils.py new file mode 100644 index 00000000000..d067af7d2b6 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/utils.py @@ -0,0 +1,92 @@ +# MIT License, Copyright (c) 2022 OpenAI. +# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. +# +# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/utils.py) +import zlib +from typing import Iterator +from typing import TextIO + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + return len(text) / len(zlib.compress(text.encode("utf-8"))) + + +def format_timestamp(seconds: float, + always_include_hours: bool=False, + decimal_marker: str='.'): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + + +def write_txt(transcript: Iterator[dict], file: TextIO): + for segment in transcript: + print(segment['text'].strip(), file=file, flush=True) + + +def write_vtt(transcript: Iterator[dict], file: TextIO): + print("WEBVTT\n", file=file) + for segment in transcript: + print( + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, ) + + +def write_srt(transcript: Iterator[dict], file: TextIO): + """ + Write a transcript to a file in SRT format. + + Example usage: + from pathlib import Path + from whisper.utils import write_srt + + result = transcribe(model, audio_path, temperature=temperature, **args) + + # save SRT + audio_basename = Path(audio_path).stem + with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: + write_srt(result["segments"], file=srt) + """ + for i, segment in enumerate(transcript, start=1): + # write srt lines + print( + f"{i}\n" + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, ) diff --git a/paddlespeech/s2t/models/whisper/whipser.py b/paddlespeech/s2t/models/whisper/whipser.py new file mode 100644 index 00000000000..7d696123cb0 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/whipser.py @@ -0,0 +1,1463 @@ +# MIT License, Copyright (c) 2022 OpenAI. +# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. +# +# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper) +import os +from dataclasses import dataclass +from dataclasses import field +from functools import lru_cache +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.nn.functional as F +import soundfile +import tqdm +from paddle import nn +from paddle.distribution import Categorical + +import paddlespeech.s2t.modules.align as paddlespeech_nn +from paddlespeech.s2t.models.whisper import utils +from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer +from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES +from paddlespeech.s2t.models.whisper.tokenizer import Tokenizer +from paddlespeech.s2t.utils.log import Log +logger = Log(__name__).getlog() + +_MODELS = ["large"] +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk +N_FRAMES = utils.exact_div( + N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(paddlespeech_nn.LayerNorm): + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return super().forward(x) + + +class Linear(paddlespeech_nn.Linear): + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return F.linear(x, self.weight, None + if self.bias is None else self.bias) + + +class Conv1d(paddlespeech_nn.Conv1D): + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return super().forward(x) + + +class MultiHeadAttention(nn.Layer): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state, bias_attr=True) + self.key = Linear(n_state, n_state, bias_attr=False) + self.value = Linear(n_state, n_state, bias_attr=True) + self.out = Linear(n_state, n_state, bias_attr=True) + + def forward( + self, + x: paddle.Tensor, + xa: Optional[paddle.Tensor]=None, + mask: Optional[paddle.Tensor]=None, + kv_cache: Optional[dict]=None, ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv = self.qkv_attention(q, k, v, mask) + return self.out(wv) + + def qkv_attention(self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + mask: Optional[paddle.Tensor]=None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head)**-0.25 + q = paddle.transpose( + q.view(*q.shape[:2], self.n_head, -1), (0, 2, 1, 3)) * scale + k = paddle.transpose( + k.view(*k.shape[:2], self.n_head, -1), (0, 2, 3, 1)) * scale + v = paddle.transpose( + v.view(*v.shape[:2], self.n_head, -1), (0, 2, 1, 3)) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + + w = F.softmax(qk.float(), axis=-1).to(q.dtype) + return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2) + + +class ResidualAttentionBlock(nn.Layer): + def __init__(self, n_state: int, n_head: int, cross_attention: bool=False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention( + n_state, n_head) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp, bias_attr=True), + nn.GELU(), Linear(n_mlp, n_state, bias_attr=True)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: paddle.Tensor, + xa: Optional[paddle.Tensor]=None, + mask: Optional[paddle.Tensor]=None, + kv_cache: Optional[dict]=None, ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) + if self.cross_attn: + x = x + self.cross_attn( + self.cross_attn_ln(x), xa, kv_cache=kv_cache) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = paddle.exp(-log_timescale_increment * paddle.arange( + channels // 2, dtype=paddle.float32)) + scaled_time = paddle.arange( + length, + dtype=paddle.float32)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return paddle.to_tensor( + paddle.concat( + [paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)) + + +class AudioEncoder(nn.Layer): + def __init__(self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int): + super().__init__() + self.conv1 = Conv1d( + n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True) + self.conv2 = Conv1d( + n_state, + n_state, + kernel_size=3, + stride=2, + padding=1, + bias_attr=True) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: paddle.Tensor): + """ + x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = paddle.transpose(x, (0, 2, 1)) + + assert x.shape[ + 1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Layer): + def __init__(self, + n_vocab: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = paddle.create_parameter( + shape=[n_ctx, n_state], dtype='float32') + + self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList([ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ]) + self.ln = LayerNorm(n_state) + + mask = fluid.layers.fill_constant( + shape=[n_ctx, n_state], value=-np.inf, dtype='float32') + mask = paddle.triu(mask, diagonal=1) + self.register_buffer("mask", mask, persistable=False) + + def forward(self, + x: paddle.Tensor, + xa: paddle.Tensor, + kv_cache: Optional[dict]=None): + """ + x : paddle.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(x) + self.positional_embedding[offset:offset + + x.shape[-1]] + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = (x @ paddle.transpose(self.token_embedding.weight, (1, 0))) + + return logits + + +@dataclass(frozen=True) +class DecodingOptions: + task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" + language: Optional[ + str] = None # language that the audio is in; uses detected language if None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = None # maximum number of tokens to sample + best_of: Optional[ + int] = None # number of independent samples to collect, when t > 0 + beam_size: Optional[ + int] = None # number of beams in beam search, when t == 0 + patience: Optional[ + float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) + + # options for ranking generations (either beams or best-of-N samples) + length_penalty: Optional[ + float] = None # "alpha" in Google NMT, None defaults to length norm + + # prompt, prefix, and token suppression + prompt: Optional[Union[str, List[ + int]]] = None # text or tokens for the previous context + prefix: Optional[Union[str, List[ + int]]] = None # text or tokens to prefix the current context + suppress_blank: bool = True # this will suppress blank outputs + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[ + float] = 1.0 # the initial timestamp cannot be later than this + + # implementation details + fp16: bool = False # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: paddle.Tensor + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + def logits(self, tokens: paddle.Tensor, + audio_features: paddle.Tensor) -> paddle.Tensor: + """Perform a forward pass on the decoder and return per-token logits""" + raise NotImplementedError + + def rearrange_kv_cache(self, source_indices) -> None: + """Update the key-value cache according to the updated beams""" + raise NotImplementedError + + def cleanup_caching(self) -> None: + """Clean up any resources or hooks after decoding is finished""" + pass + + +class WhisperInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + self.kv_cache = {} + self.hooks = [] + + def logits(self, tokens: paddle.Tensor, + audio_features: paddle.Tensor) -> paddle.Tensor: + if not self.kv_cache: + self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() + + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + return self.model.decoder( + tokens, audio_features, kv_cache=self.kv_cache) + + def cleanup_caching(self): + for hook in self.hooks: + hook.remove() + + self.kv_cache = {} + self.hooks = [] + + def rearrange_kv_cache(self, source_indices): + for module, tensor in self.kv_cache.items(): + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = tensor[source_indices].detach() + + +@paddle.no_grad() +def detect_language(model: "Whisper", + mel: paddle.Tensor, + tokenizer: Tokenizer=None + ) -> Tuple[paddle.Tensor, List[dict]]: + """ + Detect the spoken language in the audio, and return them as list of strings, along with the ids + of the most probable language tokens and the probability distribution over all language tokens. + This is performed outside the main decode loop in order to not interfere with kv-caching. + + Returns + ------- + language_tokens : Tensor, shape = (batch_size,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = batch_size + list of dictionaries containing the probability distribution over all languages. + """ + if tokenizer is None: + tokenizer = get_tokenizer(model.is_multilingual) + if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: + raise ValueError( + "This model doesn't have language tokens so it can't perform lang id" + ) + + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + batch_size = mel.shape[0] + x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1] + logits = model.logits(x, mel)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool) + mask[list(tokenizer.all_language_tokens)] = False + logits[:, mask] = -np.inf + language_tokens = paddle.argmax(logits, axis=-1) + language_token_probs = F.softmax(logits, axis=-1) + language_probs = [{ + c: language_token_probs[i, j].tolist() + for j, c in zip(tokenizer.all_language_tokens, + tokenizer.all_language_codes) + } for i in range(batch_size)] + + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + return language_tokens, language_probs + + +def transcribe( + model: "Whisper", + mel: paddle.Tensor, + *, + verbose: Optional[bool]=None, + temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, + 1.0), + compression_ratio_threshold: Optional[float]=2.4, + logprob_threshold: Optional[float]=-1.0, + no_speech_threshold: Optional[float]=0.6, + condition_on_previous_text: bool=True, + **decode_options, ): + """ + Transcribe an audio file using Whisper + + Parameters + ---------- + model: Whisper + The Whisper model instance + + mel: paddle.Tensor + The audio feature + + verbose: bool + Whether to display the text being decoded to the console. If True, displays all the details, + If False, displays minimal details. If None, does not display anything + + temperature: Union[float, Tuple[float, ...]] + Temperature for sampling. It can be a tuple of temperatures, which will be successfully used + upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. + + compression_ratio_threshold: float + If the gzip compression ratio is above this value, treat as failed + + logprob_threshold: float + If the average log probability over sampled tokens is below this value, treat as failed + + no_speech_threshold: float + If the no_speech probability is higher than this value AND the average log probability + over sampled tokens is below `logprob_threshold`, consider the segment as silent + + condition_on_previous_text: bool + if True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone to + getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + + decode_options: dict + Keyword arguments to construct `DecodingOptions` instances + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level details ("segments"), and + the spoken language ("language"), which is detected when `decode_options["language"]` is None. + """ + dtype = np.float32 #paddle only support float32 + + if dtype == np.float32: + decode_options["fp16"] = False + + if decode_options.get("language", None) is None: + if not model.is_multilingual: + decode_options["language"] = "en" + else: + if verbose: + print( + "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" + ) + segment = pad_or_trim(mel, N_FRAMES) + _, probs = model.detect_language(segment) + decode_options["language"] = max(probs, key=probs.get) + if verbose is not None: + print( + f"Detected language: {LANGUAGES[decode_options['language']].title()}" + ) + + language = decode_options["language"] + task = decode_options.get("task", "transcribe") + tokenizer = get_tokenizer( + model.is_multilingual, language=language, task=task) + + def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: + temperatures = [temperature] if isinstance(temperature, ( + int, float)) else temperature + decode_result = None + + for t in temperatures: + kwargs = {**decode_options} + if t > 0: + # disable beam_size and patience when t > 0 + kwargs.pop("beam_size", None) + kwargs.pop("patience", None) + else: + # disable best_of when t == 0 + kwargs.pop("best_of", None) + + options = DecodingOptions(**kwargs, temperature=t) + decode_result = model.decode(segment, options) + + needs_fallback = False + if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + needs_fallback = True # too repetitive + if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + needs_fallback = True # average log probability is too low + + if not needs_fallback: + break + + return decode_result + + seek = 0 + input_stride = utils.exact_div( + N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2 + time_precision = (input_stride * HOP_LENGTH / + SAMPLE_RATE) # time per output token: 0.02 (seconds) + all_tokens = [] + all_segments = [] + prompt_reset_since = 0 + + initial_prompt = decode_options.pop("initial_prompt", None) or [] + if initial_prompt: + initial_prompt = tokenizer.encode(" " + + initial_prompt.strip()).input_ids + all_tokens.extend(initial_prompt) + + def add_segment(*, + start: float, + end: float, + text_tokens: paddle.Tensor, + result: DecodingResult): + text = tokenizer.decode( + [token for token in text_tokens if token < tokenizer.eot]) + if len(text.strip()) == 0: # skip empty text output + return + + all_segments.append({ + "id": len(all_segments), + "seek": seek, + "start": start, + "end": end, + "text": text, + "tokens": result.tokens, + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + }) + if verbose: + print( + f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}" + ) + + # show the progress bar when verbose is False (otherwise the transcribed text will be printed) + num_frames = mel.shape[-1] + previous_seek_value = seek + + with tqdm.tqdm( + total=num_frames, unit='frames', + disable=verbose is not False) as pbar: + while seek < num_frames: + timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + segment = pad_or_trim(mel[:, seek:], N_FRAMES) + segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(segment) + tokens = paddle.to_tensor(result.tokens) + + if no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > no_speech_threshold + if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + seek += segment.shape[ + -1] # fast-forward to the next segment boundary + continue + + timestamp_tokens: paddle.Tensor = tokens.greater_equal( + paddle.to_tensor(tokenizer.timestamp_begin)) + + consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[ + 1:])[0] + if len( + consecutive + ) > 0: # if the output contains two consecutive timestamp tokens + consecutive = paddle.add(consecutive, paddle.to_tensor(1)) + last_slice = 0 + for current_slice in consecutive: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0].item() - tokenizer.timestamp_begin) + end_timestamp_position = ( + sliced_tokens[-1].item() - tokenizer.timestamp_begin) + add_segment( + start=timestamp_offset + start_timestamp_position * + time_precision, + end=timestamp_offset + end_timestamp_position * + time_precision, + text_tokens=sliced_tokens[1:-1], + result=result, ) + last_slice = current_slice + last_timestamp_position = ( + tokens[last_slice - 1].item() - tokenizer.timestamp_begin) + seek += last_timestamp_position * input_stride + all_tokens.extend(tokens[:last_slice + 1].tolist()) + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if len(timestamps) > 0 and timestamps[ + -1].item() != tokenizer.timestamp_begin: + # no consecutive timestamps but it has a timestamp; use the last one. + # single timestamp at the end means no speech after the last timestamp. + last_timestamp_position = timestamps[ + -1].item() - tokenizer.timestamp_begin + duration = last_timestamp_position * time_precision + + add_segment( + start=timestamp_offset, + end=timestamp_offset + duration, + text_tokens=tokens, + result=result, ) + + seek += segment.shape[-1] + all_tokens.extend(tokens.tolist()) + + if not condition_on_previous_text or result.temperature > 0.5: + # do not feed the prompt tokens if a high temperature was used + prompt_reset_since = len(all_tokens) + + # update progress bar + pbar.update(min(num_frames, seek) - previous_seek_value) + previous_seek_value = seek + + return dict( + text=tokenizer.decode(all_tokens[len(initial_prompt):]), + segments=all_segments, + language=language) + + +class SequenceRanker: + def rank(self, + tokens: List[List[paddle.Tensor]], + sum_logprobs: List[List[float]]) -> List[int]: + """ + Given a list of groups of samples and their cumulative log probabilities, + return the indices of the samples in each group to select as the final result + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """ + Select the sample with the highest log probabilities, penalized using either + a simple length normalization or Google NMT paper's length penalty + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, + tokens: List[List[paddle.Tensor]], + sum_logprobs: List[List[float]]): + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + penalty = length + else: + # from the Google NMT paper + penalty = ((5 + length) / 6)**self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + def reset(self): + """Initialize any stateful variables for decoding a new sequence""" + + def update(self, + tokens: paddle.Tensor, + logits: paddle.Tensor, + sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]: + """Specify how to select the next token, based on the current trace and logits + + Parameters + ---------- + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : Tensor, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Tensor, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + """ + raise NotImplementedError + + def finalize( + self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor + ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]: + """Finalize search and return the final candidate sequences + + Parameters + ---------- + tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : Tensor, shape = (batch_size, beam_size) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[Tensor]], length = batch_size + sequence of Tensors containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = batch_size + sequence of cumulative log probabilities corresponding to the above + + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update(self, + tokens: paddle.Tensor, + logits: paddle.Tensor, + sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]: + temperature = self.temperature + if temperature == 0: + next_tokens = paddle.argmax(logits, axis=-1) + else: + next_tokens = Categorical(logits=logits / temperature).sample( + shape=logits.shape) + + logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) + current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), + next_tokens] + sum_logprobs += current_logprobs * paddle.to_tensor( + (tokens[:, -1] != self.eot), dtype=paddle.float32) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1) + + completed = paddle.all((tokens[:, -1] == self.eot)) + return tokens, completed + + def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot, data_format="NCL") + return tokens, sum_logprobs.tolist() + + +class BeamSearchDecoder(TokenDecoder): + def __init__(self, + beam_size: int, + eot: int, + inference: Inference, + patience: Optional[float]=None): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update(self, + tokens: paddle.Tensor, + logits: paddle.Tensor, + sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + batch_size = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(batch_size)] + + logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) + next_tokens, source_indices, finished_sequences = [], [], [] + for i in range(batch_size): + scores, sources, finished = {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + logprob, token = paddle.topk( + logprobs[idx], k=self.beam_size + 1) + for logprob, token in zip(logprob, token): + new_logprob = (sum_logprobs[idx] + logprob).tolist()[0] + sequence = tuple(prefix + [token.tolist()[0]]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + tokens = paddle.to_tensor(next_tokens) + self.inference.rearrange_kv_cache(source_indices) + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip(self.finished_sequences, + finished_sequences): + for seq in sorted( + newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates + for sequences in self.finished_sequences) + return tokens, completed + + def finalize(self, + preceding_tokens: paddle.Tensor, + sum_logprobs: paddle.Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, sequences in enumerate(self.finished_sequences): + if len(sequences + ) < self.beam_size: # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + sequences[tuple(sequence)] = sum_logprobs[i][j].item() + if len(sequences) >= self.beam_size: + break + + tokens: List[List[paddle.Tensor]] = [ + [paddle.to_tensor(seq) for seq in sequences.keys()] + for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens, sum_logprobs + + +class LogitFilter: + def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None: + """Apply any filtering or masking to logits in-place + + Parameters + ---------- + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + def __init__(self, tokenizer: Tokenizer, sample_begin: int): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + + def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): + if tokens.shape[1] == self.sample_begin: + logits[:, self.tokenizer.encode(" ").input_ids + + [self.tokenizer.eot]] = -np.inf + + +class SuppressTokens(LogitFilter): + def __init__(self, suppress_tokens: Sequence[int]): + self.suppress_tokens = list(suppress_tokens) + + def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): + logits[:, self.suppress_tokens] = -np.inf + + +class ApplyTimestampRules(LogitFilter): + def __init__(self, + tokenizer: Tokenizer, + sample_begin: int, + max_initial_timestamp_index: Optional[int]): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + logits[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + seq = [t for t in tokens[k, self.sample_begin:].tolist()] + last_was_timestamp = len(seq) >= 1 and seq[ + -1] >= self.tokenizer.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[ + -2] >= self.tokenizer.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + logits[k, self.tokenizer.timestamp_begin:] = -np.inf + else: # cannot be normal text tokens + logits[k, :self.tokenizer.eot] = -np.inf + + # apply the `max_initial_timestamp` option + if tokens.shape[ + 1] == self.sample_begin and self.max_initial_timestamp_index is not None: + last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits[:, last_allowed + 1:] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) + for k in range(tokens.shape[0]): + timestamp_logprob = paddle.logsumexp( + logprobs[k, self.tokenizer.timestamp_begin:], axis=-1) + max_text_token_logprob = paddle.max( + logprobs[k, :self.tokenizer.timestamp_begin]) + if timestamp_logprob > max_text_token_logprob: + logits[k, :self.tokenizer.timestamp_begin] = -np.inf + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model: "Whisper", options: DecodingOptions): + self.model = model + + language = options.language or "en" + tokenizer = get_tokenizer( + model.is_multilingual, language=language, task=options.task) + self.tokenizer: Tokenizer = tokenizer + self.options: DecodingOptions = self._verify_options(options) + + self.beam_size: int = options.beam_size or options.best_of or 1 + self.n_ctx: int = model.dims.n_text_ctx + self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + + self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + if self.options.without_timestamps: + self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(tokenizer.sot) + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = WhisperInference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot, + self.inference, options.patience) + else: + self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append( + SuppressBlank(self.tokenizer, self.sample_begin)) + if self.options.suppress_tokens: + self.logit_filters.append( + SuppressTokens(self._get_suppress_tokens())) + if not options.without_timestamps: + precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + max_initial_timestamp_index = None + if options.max_initial_timestamp: + max_initial_timestamp_index = round( + self.options.max_initial_timestamp / precision) + self.logit_filters.append( + ApplyTimestampRules(tokenizer, self.sample_begin, + max_initial_timestamp_index)) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError( + "best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not ( + 0 <= options.length_penalty <= 1): + raise ValueError( + "length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + prefix = self.options.prefix + prompt = self.options.prompt + + if prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip().input_ids) + if isinstance(prefix, str) else prefix) + if self.sample_len is not None: + max_prefix_len = self.n_ctx // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip().input_ids) + if isinstance(prompt, str) else prompt) + tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 + - 1):] + tokens + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, + list), "suppress_tokens must be a list" + + suppress_tokens.extend([ + self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm + ]) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: paddle.Tensor): + #if self.options.fp16: + # mel = mel.half() + + if mel.shape[-2:] == (self.model.dims.n_audio_ctx, + self.model.dims.n_audio_state): + # encoded audio features are given; skip audio encoding + audio_features = mel + else: + audio_features = self.model.encoder(mel) + + #if audio_features.dtype != (np.float16 if self.options.fp16 else np.float32): + # return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + + return audio_features + + def _detect_language(self, + audio_features: paddle.Tensor, + tokens: paddle.Tensor): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language(audio_features, + self.tokenizer) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + tokens[:, self.sot_index + + 1] = lang_tokens # write language tokens + + return languages, lang_probs + + def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor): + assert audio_features.shape[0] == tokens.shape[0] + n_batch = tokens.shape[0] + sum_logprobs: paddle.Tensor = paddle.zeros( + paddle.to_tensor(n_batch), dtype=paddle.float32) + no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + + if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs + probs_at_sot = F.softmax( + logits[:, self.sot_index], + axis=-1, + dtype=paddle.float32) + no_speech_probs = probs_at_sot[:, self.tokenizer. + no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed = self.decoder.update(tokens, logits, + sum_logprobs) + if completed or tokens.shape[-1] > self.n_ctx: + break + finally: + self.inference.cleanup_caching() + + return tokens, sum_logprobs, no_speech_probs + + @paddle.no_grad() + def run(self, mel: paddle.Tensor) -> List[DecodingResult]: + self.decoder.reset() + tokenizer: Tokenizer = self.tokenizer + batch_size: int = mel.shape[0] + + audio_features: paddle.Tensor = self._get_audio_features( + mel) # encoder forward pass + + tokens: paddle.Tensor + if batch_size > 1: + for i in range(batch_size): + tokens = paddle.concat( + x=[ + paddle.to_tensor([self.initial_tokens]), + paddle.to_tensor([self.initial_tokens]) + ], + axis=0) + elif batch_size == 1: + tokens = paddle.to_tensor([self.initial_tokens]) + + # detect language if requested, overwriting the language token + languages, language_probs = self._detect_language( + paddle.to_tensor(audio_features), paddle.to_tensor(tokens)) + + if self.options.task == "lang_id": + return [ + DecodingResult( + audio_features=features, + language=language, + language_probs=probs) + for features, language, probs in zip(audio_features, languages, + language_probs) + ] + + # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling + + audio_features = paddle.repeat_interleave( + audio_features, self.beam_size, axis=0) + tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0) + + # call the main sampling loop + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, + tokens) + + # reshape the tensors to have (batch_size, beam_size) as the first two dimensions + audio_features = audio_features[::self.beam_size] + no_speech_probs = no_speech_probs[::self.beam_size] + assert audio_features.shape[0] == len(no_speech_probs) == batch_size + + tokens = tokens.reshape([batch_size, self.beam_size, -1]) + sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size]) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens: List[List[paddle.Tensor]] = [[ + t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s + ] for s in tokens] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[ + int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] + + sum_logprobs: List[ + float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[ + float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + + fields = (texts, languages, tokens, audio_features, avg_logprobs, + no_speech_probs) + if len(set(map(len, fields))) != 1: + raise RuntimeError( + f"inconsistent result lengths: {list(map(len, fields))}") + + return [ + DecodingResult( + audio_features=features, + language=language, + tokens=tokens, + text=text, + avg_logprob=avg_logprob, + no_speech_prob=no_speech_prob, + temperature=self.options.temperature, + compression_ratio=utils.compression_ratio(text), ) + for text, language, tokens, features, avg_logprob, no_speech_prob in + zip(*fields) + ] + + +@paddle.no_grad() +def decode(model: "Whisper", + mel: paddle.Tensor, + options: DecodingOptions=DecodingOptions() + ) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + result = DecodingTask(model, options).run(mel) + + if single: + result = result[0] + + return result + + +class Whisper(nn.Layer): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, ) + + def embed_audio(self, mel: paddle.Tensor): + return self.encoder.forward(mel) + + def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor): + return self.decoder.forward(tokens, audio_features) + + def forward(self, mel: paddle.Tensor, + tokens: paddle.Tensor) -> Dict[str, paddle.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return paddle.device.get_device() + + @property + def is_multilingual(self): + return self.dims.n_vocab == 51865 + + def install_kv_cache_hooks(self, cache: Optional[dict]=None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Layer, paddle.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[ + 1] > self.decoder.positional_embedding.shape[0]: + cache[ + module] = output # save as-is, for the first token or cross attention + else: + cache[module] = paddle.concat( + [cache[module], output], axis=1).detach() + return cache[module] + + def install_hooks(layer: nn.Layer): + if isinstance(layer, MultiHeadAttention): + hooks.append( + layer.key.register_forward_post_hook(save_to_cache)) + hooks.append( + layer.value.register_forward_post_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + detect_language = detect_language + transcribe = transcribe + decode = decode + + +def pad_or_trim(array, length: int=N_SAMPLES, *, axis: int=-1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if paddle.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(axis=axis, index=paddle.arange(length)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = paddle.transpose(array, (1, 0)) + array = F.pad( + array, [pad for sizes in pad_widths[::-1] for pad in sizes], + data_format='NLC') + array = paddle.transpose(array, (1, 0)) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = paddle.transpose(array, (1, 0)) + array = np.pad(array, pad_widths) + array = paddle.transpose(array, (1, 0)) + + return array + + +def hann_window(n_fft: int=N_FFT): + """ + hanning window + n_fft: The number of frequency components of the discrete Fourier transform. + """ + return paddle.to_tensor( + [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)], + dtype=paddle.float32) + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join( + os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: + return paddle.to_tensor(f[f"mel_{n_mels}"]) + + +def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], + n_mels: int=N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, paddle.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + paddle.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not paddle.is_tensor(audio): + if isinstance(audio, str): + audio, _ = soundfile.read(audio, dtype="float32", always_2d=True) + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + audio = paddle.to_tensor(audio) + + window = hann_window(N_FFT) + stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window) + + magnitudes = stft[:, :-1].abs()**2 + + filters = mel_filters(audio, n_mels) + mel_spec = filters @ magnitudes + mel_spec = paddle.to_tensor(mel_spec.numpy().tolist()) + + log_spec = paddle.clip(mel_spec, min=1e-10).log10() + log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/paddlespeech/s2t/models/whisper/whisper_LICENSE b/paddlespeech/s2t/models/whisper/whisper_LICENSE new file mode 100644 index 00000000000..49e465e19ee --- /dev/null +++ b/paddlespeech/s2t/models/whisper/whisper_LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index d571aa78ff3..644008d5220 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -93,5 +93,11 @@ paddlespeech stats --task text paddlespeech stats --task vector paddlespeech stats --task st +# whisper text recognize +paddlespeech whisper --task transcribe --input ./zh.wav + +# whisper recognize text and translate to English +paddlespeech whisper --task translate --input ./zh.wav + echo -e "\033[32mTest success !!!\033[0m"