Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Change the typing for the LDAIConfig. #688

Merged
merged 6 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDGenerationConfig } from '../src/api/config';
import { LDAIDefaults } from '../src/api/config';
import { LDAIClientImpl } from '../src/LDAIClientImpl';
import { LDClientMin } from '../src/LDClientMin';

Expand Down Expand Up @@ -32,13 +32,14 @@ it('handles empty variables in template interpolation', () => {
it('returns model config with interpolated prompts', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
enabled: true,
};

const mockVariation = {
model: { modelId: 'example-provider', name: 'imagination' },
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
{ role: 'system', content: 'Hello {{name}}' },
{ role: 'user', content: 'Score: {{score}}' },
Expand All @@ -55,13 +56,11 @@ it('returns model config with interpolated prompts', async () => {
const result = await client.modelConfig(key, testContext, defaultValue, variables);

expect(result).toEqual({
config: {
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [
{ role: 'system', content: 'Hello John' },
{ role: 'user', content: 'Score: 42' },
],
},
model: { modelId: 'example-provider', name: 'imagination', temperature: 0.7, maxTokens: 4096 },
prompt: [
{ role: 'system', content: 'Hello John' },
{ role: 'user', content: 'Score: 42' },
],
tracker: expect.any(Object),
enabled: true,
});
Expand All @@ -70,7 +69,7 @@ it('returns model config with interpolated prompts', async () => {
it('includes context in variables for prompt interpolation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
};
Expand All @@ -84,13 +83,13 @@ it('includes context in variables for prompt interpolation', async () => {

const result = await client.modelConfig(key, testContext, defaultValue);

expect(result.config.prompt?.[0].content).toBe('User key: test-user');
expect(result.prompt?.[0].content).toBe('User key: test-user');
});

it('handles missing metadata in variation', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'test-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'test', name: 'test-model' },
prompt: [],
};
Expand All @@ -105,10 +104,8 @@ it('handles missing metadata in variation', async () => {
const result = await client.modelConfig(key, testContext, defaultValue);

expect(result).toEqual({
config: {
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
},
model: { modelId: 'example-provider', name: 'imagination' },
prompt: [{ role: 'system', content: 'Hello' }],
tracker: expect.any(Object),
enabled: false,
});
Expand All @@ -117,17 +114,19 @@ it('handles missing metadata in variation', async () => {
it('passes the default value to the underlying client', async () => {
const client = new LDAIClientImpl(mockLdClient);
const key = 'non-existent-flag';
const defaultValue: LDGenerationConfig = {
const defaultValue: LDAIDefaults = {
model: { modelId: 'default-model', name: 'default' },
prompt: [{ role: 'system', content: 'Default prompt' }],
enabled: true,
};

mockLdClient.variation.mockResolvedValue(defaultValue);

const result = await client.modelConfig(key, testContext, defaultValue);

expect(result).toEqual({
config: defaultValue,
model: defaultValue.model,
prompt: defaultValue.prompt,
tracker: expect.any(Object),
enabled: false,
});
Expand Down
8 changes: 6 additions & 2 deletions packages/sdk/server-ai/examples/bedrock/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ async function main() {
const completion = tracker.trackBedrockConverse(
await awsClient.send(
new ConverseCommand({
modelId: aiConfig.config.model?.modelId ?? 'no-model',
messages: mapPromptToConversation(aiConfig.config.prompt ?? []),
modelId: aiConfig.model?.modelId ?? 'no-model',
messages: mapPromptToConversation(aiConfig.prompt ?? []),
inferenceConfig: {
temperature: aiConfig.model?.temperature ?? 0.5,
maxTokens: aiConfig.model?.maxTokens ?? 4096,
},
}),
),
);
Expand Down
6 changes: 4 additions & 2 deletions packages/sdk/server-ai/examples/openai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ async function main(): Promise<void> {
const { tracker } = aiConfig;
const completion = await tracker.trackOpenAI(async () =>
client.chat.completions.create({
messages: aiConfig.config.prompt || [],
model: aiConfig.config.model?.modelId || 'gpt-4',
messages: aiConfig.prompt || [],
model: aiConfig.model?.modelId || 'gpt-4',
temperature: aiConfig.model?.temperature ?? 0.5,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added temperature and maxTokens usage to the bedrock and OpenAI examples.

max_tokens: aiConfig.model?.maxTokens ?? 4096,
}),
);

Expand Down
34 changes: 17 additions & 17 deletions packages/sdk/server-ai/src/LDAIClientImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as Mustache from 'mustache';

import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDAIConfig, LDGenerationConfig, LDMessage, LDModelConfig } from './api/config';
import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig } from './api/config';
import { LDAIClient } from './api/LDAIClient';
import { LDAIConfigTrackerImpl } from './LDAIConfigTrackerImpl';
import { LDClientMin } from './LDClientMin';
Expand Down Expand Up @@ -32,16 +32,28 @@ export class LDAIClientImpl implements LDAIClient {
return Mustache.render(template, variables, undefined, { escape: (item: any) => item });
}

async modelConfig<TDefault extends LDGenerationConfig>(
async modelConfig(
key: string,
context: LDContext,
defaultValue: TDefault,
defaultValue: LDAIDefaults,
variables?: Record<string, unknown>,
): Promise<LDAIConfig> {
const value: VariationContent = await this._ldClient.variation(key, context, defaultValue);
const tracker = new LDAIConfigTrackerImpl(
this._ldClient,
key,
// eslint-disable-next-line no-underscore-dangle
value._ldMeta?.versionKey ?? '',
context,
);
// eslint-disable-next-line no-underscore-dangle
const enabled = !!value._ldMeta?.enabled;
const config: LDAIConfig = {
tracker,
enabled,
};
// We are going to modify the contents before returning them, so we make a copy.
// This isn't a deep copy and the application developer should not modify the returned content.
const config: LDGenerationConfig = {};
if (value.model) {
config.model = { ...value.model };
}
Expand All @@ -54,18 +66,6 @@ export class LDAIClientImpl implements LDAIClient {
}));
}

return {
config,
// eslint-disable-next-line no-underscore-dangle
tracker: new LDAIConfigTrackerImpl(
this._ldClient,
key,
// eslint-disable-next-line no-underscore-dangle
value._ldMeta?.versionKey ?? '',
context,
),
// eslint-disable-next-line no-underscore-dangle
enabled: !!value._ldMeta?.enabled,
};
return config;
}
}
16 changes: 3 additions & 13 deletions packages/sdk/server-ai/src/api/LDAIClient.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import { LDContext } from '@launchdarkly/js-server-sdk-common';

import { LDAIConfig, LDGenerationConfig } from './config/LDAIConfig';

/**
* Interface for default model configuration.
*/
export interface LDAIDefaults extends LDGenerationConfig {
/**
* Whether the configuration is enabled.
*/
enabled?: boolean;
}
import { LDAIConfig, LDAIDefaults } from './config/LDAIConfig';

/**
* Interface for performing AI operations using LaunchDarkly.
Expand Down Expand Up @@ -77,10 +67,10 @@ export interface LDAIClient {
* }
* ```
*/
modelConfig<TDefault extends LDAIDefaults>(
modelConfig(
key: string,
context: LDContext,
defaultValue: TDefault,
defaultValue: LDAIDefaults,
variables?: Record<string, unknown>,
): Promise<LDAIConfig>;
}
29 changes: 16 additions & 13 deletions packages/sdk/server-ai/src/api/config/LDAIConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export interface LDModelConfig {
/**
* The ID of the model.
*/
modelId?: string;
modelId: string;

/**
* Tuning parameter for randomness versus determinism. Exact effect will be determined by the
Expand Down Expand Up @@ -41,9 +41,9 @@ export interface LDMessage {
}

/**
* Configuration which affects generation.
* AI configuration and tracker.
*/
export interface LDGenerationConfig {
export interface LDAIConfig {
/**
* Optional model configuration.
*/
Expand All @@ -52,16 +52,6 @@ export interface LDGenerationConfig {
* Optional prompt data.
*/
prompt?: LDMessage[];
}

/**
* AI Config value and tracker.
*/
export interface LDAIConfig {
/**
* The result of the AI Config customization.
*/
config: LDGenerationConfig;

/**
* A tracker which can be used to generate analytics.
Expand All @@ -73,3 +63,16 @@ export interface LDAIConfig {
*/
enabled: boolean;
}

/**
* Default value for a `modelConfig`. This is the same as the LDAIConfig, but it does not include
* a tracker and `enabled` is optional.
*/
export type LDAIDefaults = Omit<LDAIConfig, 'tracker' | 'enabled'> & {
/**
* Whether the configuration is enabled.
*
* defaults to false
*/
enabled?: boolean;
};