Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local models to non-streaming accept list #14420

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ export class OpenAiFrontendApplicationContribution implements FrontendApplicatio
const newModels = createCustomModelDescriptionsFromPreferences(event.newValue);

const modelsToRemove = oldModels.filter(model => !newModels.some(newModel => newModel.id === model.id));
const modelsToAddOrUpdate = newModels.filter(newModel => !oldModels.some(model =>
model.id === newModel.id && model.model === newModel.model && model.url === newModel.url && model.apiKey === newModel.apiKey));
const modelsToAddOrUpdate = newModels.filter(newModel =>
!oldModels.some(model =>
model.id === newModel.id &&
model.model === newModel.model &&
model.url === newModel.url &&
model.apiKey === newModel.apiKey &&
model.enableStreaming === newModel.enableStreaming));

this.manager.removeLanguageModels(...modelsToRemove.map(model => model.id));
this.manager.createOrUpdateLanguageModels(...modelsToAddOrUpdate);
Expand All @@ -74,11 +79,14 @@ export class OpenAiFrontendApplicationContribution implements FrontendApplicatio
}
}

const openAIModelsWithDisabledStreaming = ['o1-preview'];

function createOpenAIModelDescription(modelId: string): OpenAiModelDescription {
return {
id: `openai/${modelId}`,
model: modelId,
apiKey: true
apiKey: true,
enableStreaming: !openAIModelsWithDisabledStreaming.includes(modelId)
};
}

Expand All @@ -93,7 +101,8 @@ function createCustomModelDescriptionsFromPreferences(preferences: Partial<OpenA
id: pref.id && typeof pref.id === 'string' ? pref.id : pref.model,
model: pref.model,
url: pref.url,
apiKey: typeof pref.apiKey === 'string' || pref.apiKey === true ? pref.apiKey : undefined
apiKey: typeof pref.apiKey === 'string' || pref.apiKey === true ? pref.apiKey : undefined,
enableStreaming: pref.enableStreaming ?? true
}
];
}, []);
Expand Down
6 changes: 6 additions & 0 deletions packages/ai-openai/src/browser/openai-preferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ export const OpenAiPreferencesSchema: PreferenceSchema = {
\n\
- provide an `apiKey` to access the API served at the given url. Use `true` to indicate the use of the global OpenAI API key.\
\n\
- specify `enableStreaming: false` to indicate that streaming shall not be used.\
\n\
Refer to [our documentation](https://theia-ide.org/docs/user_ai/#openai-compatible-models-eg-via-vllm) for more information.',
default: [],
items: {
Expand All @@ -71,6 +73,10 @@ export const OpenAiPreferencesSchema: PreferenceSchema = {
type: ['string', 'boolean'],
title: 'Either the key to access the API served at the given url or `true` to use the global OpenAI API key',
},
enableStreaming: {
type: 'boolean',
title: 'Indicates whether the streaming API shall be used. `true` by default.',
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ export interface OpenAiModelDescription {
* The key for the model. If 'true' is provided the global OpenAI API key will be used.
*/
apiKey: string | true | undefined;
/**
* Indicate whether the streaming API shall be used.
*/
enableStreaming: boolean;
}
export interface OpenAiLanguageModelsManager {
apiKey: string | undefined;
Expand Down
10 changes: 6 additions & 4 deletions packages/ai-openai/src/node/openai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ export class OpenAiModel implements LanguageModel {
/**
* @param id the unique id for this language model. It will be used to identify the model in the UI.
* @param model the model id as it is used by the OpenAI API
* @param openAIInitializer initializer for the OpenAI client, used for each request.
* @param enableStreaming whether the streaming API shall be used
* @param apiKey a function that returns the API key to use for this model, called on each request
* @param url the OpenAI API compatible endpoint where the model is hosted. If not provided the default OpenAI endpoint will be used.
*/
constructor(public readonly id: string, public model: string, public apiKey: () => string | undefined, public url: string | undefined) { }
constructor(public readonly id: string, public model: string, public enableStreaming: boolean, public apiKey: () => string | undefined, public url: string | undefined) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const openai = this.initializeOpenAi();
Expand Down Expand Up @@ -152,8 +154,8 @@ export class OpenAiModel implements LanguageModel {
};
}

protected isNonStreamingModel(model: string): boolean {
return ['o1-preview'].includes(model);
protected isNonStreamingModel(_model: string): boolean {
return !this.enableStreaming;
}

protected supportsStructuredOutput(): boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ export class OpenAiLanguageModelsManagerImpl implements OpenAiLanguageModelsMana
}
model.url = modelDescription.url;
model.model = modelDescription.model;
model.enableStreaming = modelDescription.enableStreaming;
model.apiKey = apiKeyProvider;
} else {
this.languageModelRegistry.addLanguageModels([new OpenAiModel(modelDescription.id, modelDescription.model, apiKeyProvider, modelDescription.url)]);
this.languageModelRegistry.addLanguageModels([
new OpenAiModel(modelDescription.id, modelDescription.model, modelDescription.enableStreaming, apiKeyProvider, modelDescription.url)
]);
}
}
}
Expand Down
Loading