Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Android: Allow re-downloading voice typing models on URL change and error #11557

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions packages/app-mobile/components/voiceTyping/VoiceTypingDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import { AppState } from '../../utils/types';
import { connect } from 'react-redux';
import { View, StyleSheet } from 'react-native';
import AccessibleView from '../accessibility/AccessibleView';
import Logger from '@joplin/utils/Logger';

const logger = Logger.create('VoiceTypingDialog');

interface Props {
locale: string;
Expand All @@ -34,10 +37,11 @@ interface UseVoiceTypingProps {
onText: OnTextCallback;
}

const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => {
const useVoiceTyping = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps) => {
const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null);
const [error, setError] = useState<Error>(null);
const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null);
const [modelIsOutdated, setModelIsOutdated] = useState(false);

const onTextRef = useRef(onText);
onTextRef.current = onText;
Expand All @@ -51,9 +55,20 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]);
}, [locale, provider]);

const [redownloadCounter, setRedownloadCounter] = useState(0);

useEffect(() => {
if (modelIsOutdated) {
logger.info('The downloaded version of the model is from an outdated URL.');
}
}, [modelIsOutdated]);

useAsyncEffect(async (event: AsyncEffectEvent) => {
try {
await voiceTypingRef.current?.stop();
onSetPreviewRef.current?.('');

setModelIsOutdated(await builder.isDownloadedFromOutdatedUrl());

if (!await builder.isDownloaded()) {
if (event.cancelled) return;
Expand All @@ -72,7 +87,7 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
} finally {
setMustDownloadModel(false);
}
}, [builder]);
}, [builder, redownloadCounter]);

useAsyncEffect(async (_event: AsyncEffectEvent) => {
setMustDownloadModel(!(await builder.isDownloaded()));
Expand All @@ -82,7 +97,16 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
void voiceTypingRef.current?.stop();
}, []);

return [error, mustDownloadModel, voiceTyping];
const onRequestRedownload = useCallback(async () => {
await voiceTypingRef.current?.stop();
await builder.clearDownloads();
setMustDownloadModel(true);
setRedownloadCounter(value => value + 1);
}, [builder]);

return {
error, mustDownloadModel, voiceTyping, onRequestRedownload, modelIsOutdated,
};
};

const styles = StyleSheet.create({
Expand Down Expand Up @@ -112,7 +136,13 @@ const styles = StyleSheet.create({
const VoiceTypingDialog: React.FC<Props> = props => {
const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading);
const [preview, setPreview] = useState<string>('');
const [modelError, mustDownloadModel, voiceTyping] = useWhisper({
const {
error: modelError,
mustDownloadModel,
voiceTyping,
onRequestRedownload,
modelIsOutdated,
} = useVoiceTyping({
locale: props.locale,
onSetPreview: setPreview,
onText: props.onText,
Expand Down Expand Up @@ -172,6 +202,11 @@ const VoiceTypingDialog: React.FC<Props> = props => {
return <Text variant='labelSmall'>{preview}</Text>;
};

const reDownloadButton = <Button onPress={onRequestRedownload}>
{modelIsOutdated ? _('Download updated model') : _('Re-download model')}
</Button>;
const allowReDownload = recorderState === RecorderState.Error || modelIsOutdated;

return (
<Surface>
<View style={styles.container}>
Expand Down Expand Up @@ -203,6 +238,7 @@ const VoiceTypingDialog: React.FC<Props> = props => {
</View>
</View>
<View style={styles.actionContainer}>
{allowReDownload ? reDownloadButton : null}
<Button
onPress={onDismiss}
accessibilityHint={_('Ends voice typing')}
Expand Down
39 changes: 32 additions & 7 deletions packages/app-mobile/services/voiceTyping/VoiceTyping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import shim from '@joplin/lib/shim';
import Logger from '@joplin/utils/Logger';
import { PermissionsAndroid, Platform } from 'react-native';
import unzip from './utils/unzip';
import { _ } from '@joplin/lib/locale';
const md5 = require('md5');

const logger = Logger.create('voiceTyping');
Expand Down Expand Up @@ -30,6 +31,7 @@ export interface VoiceTypingProvider {
modelName: string;
supported(): boolean;
modelLocalFilepath(locale: string): string;
deleteCachedModels(locale: string): Promise<void>;
getDownloadUrl(locale: string): string;
getUuidPath(locale: string): string;
build(options: BuildProviderOptions): Promise<VoiceTypingSession>;
Expand All @@ -39,9 +41,9 @@ export default class VoiceTyping {
private provider: VoiceTypingProvider|null = null;
public constructor(
private locale: string,
providers: VoiceTypingProvider[],
allProviders: VoiceTypingProvider[],
) {
this.provider = providers.find(p => p.supported()) ?? null;
this.provider = allProviders.find(p => p.supported()) ?? null;
}

public supported() {
Expand All @@ -67,10 +69,31 @@ export default class VoiceTyping {
);
}

public async isDownloadedFromOutdatedUrl() {
const uuidPath = this.getUuidPath();
if (!await shim.fsDriver().exists(uuidPath)) {
// Not downloaded at all
return false;
}

const modelUrl = this.provider.getDownloadUrl(this.locale);
const urlHash = await shim.fsDriver().readFile(uuidPath);
return urlHash.trim() !== md5(modelUrl);
}

public async isDownloaded() {
return await shim.fsDriver().exists(this.getUuidPath());
}

public async clearDownloads() {
const confirmed = await shim.showConfirmationDialog(
_('Delete model and re-download?\nThis cannot be undone.'),
);
if (confirmed) {
await this.provider.deleteCachedModels(this.locale);
}
}

public async download() {
const modelPath = this.getModelPath();
const modelUrl = this.provider.getDownloadUrl(this.locale);
Expand Down Expand Up @@ -104,16 +127,18 @@ export default class VoiceTyping {

logger.info(`Moving ${fullUnzipPath} => ${modelPath}`);
await shim.fsDriver().move(fullUnzipPath, modelPath);

await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
if (!await this.isDownloaded()) {
logger.warn('Model should be downloaded!');
}
} finally {
await shim.fsDriver().remove(unzipDir);
await shim.fsDriver().remove(downloadPath);
}
}

await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
if (!await this.isDownloaded()) {
logger.warn('Model should be downloaded!');
} else {
logger.info('Model stats', await shim.fsDriver().stat(modelPath));
}
}

public async build(callbacks: SpeechToTextCallbacks) {
Expand Down
4 changes: 4 additions & 0 deletions packages/app-mobile/services/voiceTyping/vosk.android.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSe
const vosk: VoiceTypingProvider = {
supported: () => true,
modelLocalFilepath: (locale: string) => getModelDir(locale),
deleteCachedModels: async (locale: string) => {
const path = getModelDir(locale);
await shim.fsDriver().remove(path, { recursive: true });
},
getDownloadUrl: (locale) => languageModelUrl(locale),
getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'),
build: async ({ callbacks, locale, modelPath }) => {
Expand Down
1 change: 1 addition & 0 deletions packages/app-mobile/services/voiceTyping/vosk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const vosk: VoiceTypingProvider = {
modelLocalFilepath: () => null,
getDownloadUrl: () => null,
getUuidPath: () => null,
deleteCachedModels: () => null,
build: async () => {
throw new Error('Unsupported!');
},
Expand Down
4 changes: 4 additions & 0 deletions packages/app-mobile/services/voiceTyping/whisper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ const whisper: VoiceTypingProvider = {

return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx');
},
deleteCachedModels: async (locale) => {
await shim.fsDriver().remove(modelLocalFilepath());
await shim.fsDriver().remove(whisper.getUuidPath(locale));
},
getUuidPath: () => {
return join(dirname(modelLocalFilepath()), 'uuid');
},
Expand Down
Loading