Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speaker identification APIs for HarmonyOS #1607

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,5 @@ sherpa-onnx-online-punct-en-2024-08-06
sherpa-onnx-pyannote-segmentation-3-0
sherpa-onnx-moonshine-tiny-en-int8
sherpa-onnx-moonshine-base-en-int8
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
6 changes: 6 additions & 0 deletions harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ export {
TtsOutput,
TtsInput,
} from './src/main/ets/components/NonStreamingTts';

export {
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingExtractor,
SpeakerEmbeddingManager,
} from './src/main/ets/components/SpeakerIdentification';
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
static Napi::External<SherpaOnnxSpeakerEmbeddingExtractor>
CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

#if __OHOS__
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return {};
}
#else
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Expand All @@ -19,6 +30,7 @@ CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {

return {};
}
#endif

if (!info[0].IsObject()) {
Napi::TypeError::New(env, "You should pass an object as the only argument.")
Expand Down Expand Up @@ -46,8 +58,18 @@ CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {

SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);

#if __OHOS__
std::unique_ptr<NativeResourceManager,
decltype(&OH_ResourceManager_ReleaseNativeResourceManager)>
mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]),
&OH_ResourceManager_ReleaseNativeResourceManager);

const SherpaOnnxSpeakerEmbeddingExtractor *extractor =
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(&c, mgr.get());
#else
const SherpaOnnxSpeakerEmbeddingExtractor *extractor =
SherpaOnnxCreateSpeakerEmbeddingExtractor(&c);
#endif

if (c.model) {
delete[] c.model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,18 @@ export type TtsOutput = {

export const offlineTtsGenerate: (handle: object, input: object) => TtsOutput;
export const offlineTtsGenerateAsync: (handle: object, input: object) => Promise<TtsOutput>;

export const createSpeakerEmbeddingExtractor: (config: object, mgr?: object) => object;
export const speakerEmbeddingExtractorDim: (handle: object) => number;
export const speakerEmbeddingExtractorCreateStream: (handle: object) => object;
export const speakerEmbeddingExtractorIsReady: (handle: object, stream: object) => boolean;
export const speakerEmbeddingExtractorComputeEmbedding: (handle: object, stream: object, enableExternalBuffer: boolean) => Float32Array;
export const createSpeakerEmbeddingManager: (dim: number) => object;
export const speakerEmbeddingManagerAdd: (handle: object, speaker: {name: string, v: Float32Array}) => boolean;
export const speakerEmbeddingManagerAddListFlattened: (handle: object, speaker: {name: string, vv: Float32Array, n: number}) => boolean;
export const speakerEmbeddingManagerRemove: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerSearch: (handle: object, obj: {v: Float32Array, threshold: number}) => string;
export const speakerEmbeddingManagerVerify: (handle: object, obj: {name: string, v: Float32Array, threshold: number}) => boolean;
export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number;
export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<string>;
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
getOfflineTtsSampleRate,
offlineTtsGenerate,
offlineTtsGenerateAsync,
} from "libsherpa_onnx.so";
} from 'libsherpa_onnx.so';

export class OfflineTtsVitsModelConfig {
public model: string = '';
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import {
createSpeakerEmbeddingExtractor,
createSpeakerEmbeddingManager,
speakerEmbeddingExtractorComputeEmbedding,
speakerEmbeddingExtractorCreateStream,
speakerEmbeddingExtractorDim,
speakerEmbeddingExtractorIsReady,
speakerEmbeddingManagerAdd,
speakerEmbeddingManagerAddListFlattened,
speakerEmbeddingManagerContains,
speakerEmbeddingManagerGetAllSpeakers,
speakerEmbeddingManagerNumSpeakers,
speakerEmbeddingManagerRemove,
speakerEmbeddingManagerSearch,
speakerEmbeddingManagerVerify
} from 'libsherpa_onnx.so';
import { OnlineStream } from './StreamingAsr';

export class SpeakerEmbeddingExtractorConfig {
public model: string = '';
public numThreads: number = 1;
public debug: boolean = false;
public provider: string = 'cpu';
}

export class SpeakerEmbeddingExtractor {
public config: SpeakerEmbeddingExtractorConfig = new SpeakerEmbeddingExtractorConfig();
public dim: number;
private handle: object;

constructor(config: SpeakerEmbeddingExtractorConfig, mgr?: object) {
this.handle = createSpeakerEmbeddingExtractor(config, mgr);
this.config = config;
this.dim = speakerEmbeddingExtractorDim(this.handle);
}

createStream(): OnlineStream {
return new OnlineStream(
speakerEmbeddingExtractorCreateStream(this.handle));
}

isReady(stream: OnlineStream): boolean {
return speakerEmbeddingExtractorIsReady(this.handle, stream.handle);
}

compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array {
return speakerEmbeddingExtractorComputeEmbedding(
this.handle, stream.handle, enableExternalBuffer);
}
}

function flatten(arrayList: Float32Array[]): Float32Array {
let n = 0;
for (let i = 0; i < arrayList.length; ++i) {
n += arrayList[i].length;
}
let ans = new Float32Array(n);

let offset = 0;
for (let i = 0; i < arrayList.length; ++i) {
ans.set(arrayList[i], offset);
offset += arrayList[i].length;
}
return ans;
}

interface SpeakerNameWithEmbedding {
name: string;
v: Float32Array;
}

interface SpeakerNameWithEmbeddingList {
name: string;
v: Float32Array[];
}

interface SpeakerNameWithEmbeddingN {
name: string;
vv: Float32Array;
n: number;
}

interface EmbeddingWithThreshold {
v: Float32Array;
threshold: number;
}

interface SpeakerNameEmbeddingThreshold {
name: string;
v: Float32Array;
threshold: number;
}

export class SpeakerEmbeddingManager {
public dim: number;
private handle: object;

constructor(dim: number) {
this.handle = createSpeakerEmbeddingManager(dim);
this.dim = dim;
}

add(speaker: SpeakerNameWithEmbedding): boolean {
return speakerEmbeddingManagerAdd(this.handle, speaker);
}

addMulti(speaker: SpeakerNameWithEmbeddingList): boolean {
const c: SpeakerNameWithEmbeddingN = {
name: speaker.name,
vv: flatten(speaker.v),
n: speaker.v.length,
};
return speakerEmbeddingManagerAddListFlattened(this.handle, c);
}

remove(name: string): boolean {
return speakerEmbeddingManagerRemove(this.handle, name);
}

search(obj: EmbeddingWithThreshold): string {
return speakerEmbeddingManagerSearch(this.handle, obj);
}

verify(obj: SpeakerNameEmbeddingThreshold): boolean {
return speakerEmbeddingManagerVerify(this.handle, obj);
}

contains(name: string): boolean {
return speakerEmbeddingManagerContains(this.handle, name);
}

getNumSpeakers(): number {
return speakerEmbeddingManagerNumSpeakers(this.handle);
}

getAllSpeakerNames(): string[] {
return speakerEmbeddingManagerGetAllSpeakers(this.handle);
}
}
33 changes: 31 additions & 2 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1328,8 +1328,8 @@ struct SherpaOnnxSpeakerEmbeddingExtractor {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingExtractor> impl;
};

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
static sherpa_onnx::SpeakerEmbeddingExtractorConfig
GetSpeakerEmbeddingExtractorConfig(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
sherpa_onnx::SpeakerEmbeddingExtractorConfig c;
c.model = SHERPA_ONNX_OR(config->model, "");
Expand All @@ -1342,9 +1342,21 @@ SherpaOnnxCreateSpeakerEmbeddingExtractor(
}

if (config->debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", c.ToString().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str());
#endif
}

return c;
}

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
auto c = GetSpeakerEmbeddingExtractorConfig(config);

if (!c.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
Expand Down Expand Up @@ -1983,6 +1995,23 @@ SherpaOnnxVoiceActivityDetector *SherpaOnnxCreateVoiceActivityDetectorOHOS(
return p;
}

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr) {
if (!mgr) {
return SherpaOnnxCreateSpeakerEmbeddingExtractor(config);
}

auto c = GetSpeakerEmbeddingExtractorConfig(config);

auto p = new SherpaOnnxSpeakerEmbeddingExtractor;

p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingExtractor>(mgr, c);

return p;
}

#if SHERPA_ONNX_ENABLE_TTS == 1
SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
const SherpaOnnxOfflineTtsConfig *config, NativeResourceManager *mgr) {
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,11 @@ SherpaOnnxCreateVoiceActivityDetectorOHOS(

SHERPA_ONNX_API SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
const SherpaOnnxOfflineTtsConfig *config, NativeResourceManager *mgr);

SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr);
#endif

#if defined(__GNUC__)
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#endif
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
Expand Down
16 changes: 13 additions & 3 deletions sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ class SpeakerEmbeddingExtractorGeneralImpl
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

#if __ANDROID_API__ >= 9
template <typename Manager>
SpeakerEmbeddingExtractorGeneralImpl(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
Manager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: model_(mgr, config) {}
#endif

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

Expand All @@ -46,9 +45,15 @@ class SpeakerEmbeddingExtractorGeneralImpl
std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %{public}d",
num_frames);
#else
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
#endif
return {};
}

Expand All @@ -64,8 +69,13 @@ class SpeakerEmbeddingExtractorGeneralImpl
if (meta_data.feature_normalize_type == "global-mean") {
SubtractGlobalMean(features.data(), num_frames, feat_dim);
} else {
#if __OHOS__
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %{public}s",
meta_data.feature_normalize_type.c_str());
#else
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
#endif
exit(-1);
}
}
Expand Down
Loading
Loading