forked from run-llama/LlamaIndexTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OpenAIEmbedding.ts
146 lines (128 loc) · 4.09 KB
/
OpenAIEmbedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import { Tokenizers } from "@llamaindex/env";
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type { AzureOpenAIConfig } from "../llm/azure.js";
import {
getAzureConfigFromEnv,
getAzureModel,
shouldUseAzure,
} from "../llm/azure.js";
import type { OpenAISession } from "../llm/openai.js";
import { getOpenAISession } from "../llm/openai.js";
import { BaseEmbedding } from "./types.js";
export const ALL_OPENAI_EMBEDDING_MODELS = {
"text-embedding-ada-002": {
dimensions: 1536,
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
"text-embedding-3-small": {
dimensions: 1536,
dimensionOptions: [512, 1536],
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
"text-embedding-3-large": {
dimensions: 3072,
dimensionOptions: [256, 1024, 3072],
maxTokens: 8192,
tokenizer: Tokenizers.CL100K_BASE,
},
};
type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS;
export class OpenAIEmbedding extends BaseEmbedding {
/** embeddding model. defaults to "text-embedding-ada-002" */
model: string;
/** number of dimensions of the resulting vector, for models that support choosing fewer dimensions. undefined will default to model default */
dimensions: number | undefined;
// OpenAI session params
/** api key */
apiKey?: string = undefined;
/** maximum number of retries, default 10 */
maxRetries: number;
/** timeout in ms, default 60 seconds */
timeout?: number;
/** other session options for OpenAI */
additionalSessionOptions?: Omit<
Partial<OpenAIClientOptions>,
"apiKey" | "maxRetries" | "timeout"
>;
/** session object */
session: OpenAISession;
/**
* OpenAI Embedding
* @param init - initial parameters
*/
constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
super();
this.model = init?.model ?? "text-embedding-ada-002";
this.dimensions = init?.dimensions; // if no dimensions provided, will be undefined/not sent to OpenAI
this.embedBatchSize = init?.embedBatchSize ?? 10;
this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions;
// find metadata for model
const key = Object.keys(ALL_OPENAI_EMBEDDING_MODELS).find(
(key) => key === this.model,
) as ModelKeys | undefined;
if (key) {
this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key];
}
if (init?.azure || shouldUseAzure()) {
const azureConfig = {
...getAzureConfigFromEnv({
model: getAzureModel(this.model),
}),
...init?.azure,
};
this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
...azureConfig,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
}
}
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
private async getOpenAIEmbedding(input: string[]): Promise<number[][]> {
// TODO: ensure this for every sub class by calling it in the base class
input = this.truncateMaxTokens(input);
const { data } = await this.session.openai.embeddings.create({
model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user
input,
});
return data.map((d) => d.embedding);
}
/**
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
return await this.getOpenAIEmbedding(texts);
}
/**
* Get embeddings for a single text
* @param texts
*/
async getTextEmbedding(text: string): Promise<number[]> {
return (await this.getOpenAIEmbedding([text]))[0];
}
}