Skip to content

Commit

Permalink
Update OpenAI version to v4 and add JS tracing (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl authored Oct 17, 2023
1 parent d7a45ae commit 8379aa2
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 87 deletions.
47 changes: 31 additions & 16 deletions js/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import * as yaml from "js-yaml";
import mustache from "mustache";

import { Score, Scorer, ScorerArgs } from "./base.js";
import {
ChatCompletionFunctions,
ChatCompletionRequestMessage,
ChatCompletionResponseMessage,
} from "openai";
import { ChatCache, cachedChatCompletion } from "./oai.js";
import { templates } from "./templates.js";
import {
ChatCompletionCreateParams,
ChatCompletionMessage,
} from "openai/resources/index.mjs";
import { currentSpan } from "./util.js";

const NO_COT_SUFFIX =
"Answer the question by calling `select_choice` with a single choice from {{__choices}}.";
Expand Down Expand Up @@ -63,9 +63,9 @@ export function buildClassificationFunctions(useCoT: boolean) {
export type OpenAIClassifierArgs<RenderArgs> = {
name: string;
model: string;
messages: ChatCompletionRequestMessage[];
messages: ChatCompletionMessage[];
choiceScores: Record<string, number>;
classificationFunctions: ChatCompletionFunctions[];
classificationFunctions: ChatCompletionCreateParams.Function[];
cache?: ChatCache;
} & LLMArgs &
RenderArgs;
Expand All @@ -77,17 +77,21 @@ export async function OpenAIClassifier<RenderArgs, Output>(
name,
output,
expected,
openAiApiKey,
openAiOrganizationId,
...remaining
} = args;

const {
messages: messagesArg,
model,
choiceScores,
classificationFunctions,
maxTokens,
temperature,
cache,
openAiApiKey,
openAiOrganizationId,
...remainingRenderArgs
} = args;
} = remaining;

let found = false;
for (const m of SUPPORTED_MODELS) {
Expand All @@ -113,11 +117,13 @@ export async function OpenAIClassifier<RenderArgs, Output>(
...remainingRenderArgs,
};

const messages: ChatCompletionRequestMessage[] = messagesArg.map((m) => ({
const messages: ChatCompletionMessage[] = messagesArg.map((m) => ({
...m,
content: m.content && mustache.render(m.content, renderArgs),
}));

let ret = null;
let validityScore = 1;
try {
const resp = await cachedChatCompletion(
{
Expand All @@ -135,24 +141,27 @@ export async function OpenAIClassifier<RenderArgs, Output>(
);

if (resp.choices.length > 0) {
return {
ret = {
name,
...parseResponse(resp.choices[0].message!, choiceScores),
};
} else {
throw new Error("Empty response from OpenAI");
}
} catch (error) {
return {
validityScore = 0;
ret = {
name,
score: 0,
error: `${error}`,
};
}

return ret;
}

function parseResponse(
resp: ChatCompletionResponseMessage,
resp: ChatCompletionMessage,
choiceScores: Record<string, number>
): Omit<Score, "name"> {
let score = 0;
Expand Down Expand Up @@ -202,7 +211,7 @@ export function LLMClassifierFromTemplate<RenderArgs>({
temperature?: number;
}): Scorer<string, LLMClassifierArgs<RenderArgs>> {
const choiceStrings = Object.keys(choiceScores);
return async (
const ret = async (
runtimeArgs: ScorerArgs<string, LLMClassifierArgs<RenderArgs>>
) => {
const useCoT = runtimeArgs.useCoT ?? useCoTArg ?? true;
Expand All @@ -211,7 +220,7 @@ export function LLMClassifierFromTemplate<RenderArgs>({
promptTemplate + "\n" + (useCoT ? COT_SUFFIX : NO_COT_SUFFIX);

let maxTokens = 512;
const messages: ChatCompletionRequestMessage[] = [
const messages: ChatCompletionMessage[] = [
{
role: "user",
content: prompt,
Expand All @@ -234,6 +243,12 @@ export function LLMClassifierFromTemplate<RenderArgs>({
useCoT,
});
};
Object.defineProperty(ret, "name", {
value: name,
configurable: true,
});

return ret;
}

export interface ModelGradedSpec {
Expand Down
78 changes: 46 additions & 32 deletions js/oai.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
import {
ChatCompletionFunctions,
ChatCompletionRequestMessage,
Configuration,
CreateChatCompletionRequestFunctionCall,
CreateChatCompletionResponse,
OpenAIApi,
} from "openai";
ChatCompletion,
ChatCompletionCreateParams,
ChatCompletionMessage,
} from "openai/resources/index.mjs";
import { OpenAI } from "openai";

import { Env } from "./env.js";
import { currentSpan } from "./util.js";

export interface CachedLLMParams {
model: string;
messages: ChatCompletionRequestMessage[];
functions?: ChatCompletionFunctions[];
function_call?: CreateChatCompletionRequestFunctionCall;
messages: ChatCompletionMessage[];
functions?: ChatCompletionCreateParams.Function[];
function_call?: ChatCompletionCreateParams.FunctionCallOption;
temperature?: number;
max_tokens?: number;
}

export interface ChatCache {
get(params: CachedLLMParams): Promise<CreateChatCompletionResponse | null>;
set(
params: CachedLLMParams,
response: CreateChatCompletionResponse
): Promise<void>;
get(params: CachedLLMParams): Promise<ChatCompletion | null>;
set(params: CachedLLMParams, response: ChatCompletion): Promise<void>;
}

export interface OpenAIAuth {
Expand All @@ -33,28 +30,45 @@ export interface OpenAIAuth {
export async function cachedChatCompletion(
params: CachedLLMParams,
options: { cache?: ChatCache } & OpenAIAuth
): Promise<CreateChatCompletionResponse> {
): Promise<ChatCompletion> {
const { cache, openAiApiKey, openAiOrganizationId } = options;

const cached = await cache?.get(params);
if (cached) {
return cached;
}
return await currentSpan().traced("OpenAI Completion", async (span: any) => {
let cached = false;
let ret = await cache?.get(params);
if (ret) {
cached = true;
} else {
const openai = new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY,
organization: openAiOrganizationId,
});

const config = new Configuration({
apiKey: openAiApiKey || Env.OPENAI_API_KEY,
organization: openAiOrganizationId,
});
const openai = new OpenAIApi(config);
if (openai === null) {
throw new Error("OPENAI_API_KEY not set");
}

if (openai === null) {
throw new Error("OPENAI_API_KEY not set");
}
const completion = await openai.chat.completions.create(params);

const completion = await openai.createChatCompletion(params);
const data = completion.data;
await cache?.set(params, completion);
ret = completion;
}

await cache?.set(params, data);
const { messages, ...rest } = params;
span.log({
input: messages,
metadata: {
...rest,
cached,
},
output: ret.choices[0],
metrics: {
tokens: ret.usage?.total_tokens,
prompt_tokens: ret.usage?.prompt_tokens,
completion_tokens: ret.usage?.completion_tokens,
},
});

return data;
return ret;
});
}
41 changes: 41 additions & 0 deletions js/util.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* This is copy/pasted from braintrust-sdk*/
export class NoopSpan {
public id: string;
public span_id: string;
public root_span_id: string;
public kind: "span" = "span";

constructor() {
this.id = "";
this.span_id = "";
this.root_span_id = "";
}

public log(_: any) {}

public startSpan(_0: string, _1?: any) {
return this;
}

public traced<R>(_0: string, callback: (span: any) => R, _1: any): R {
return callback(this);
}

public end(args?: any): number {
return args?.endTime ?? new Date().getTime() / 1000;
}

public close(args?: any): number {
return this.end(args);
}
}
declare global {
var __inherited_braintrust_state: any;
}
export function currentSpan() {
if (globalThis.__inherited_braintrust_state) {
return globalThis.__inherited_braintrust_state.currentSpan.getStore();
} else {
return new NoopSpan();
}
}
4 changes: 2 additions & 2 deletions node/llm.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { ChatCompletionMessage } from "openai/resources/index.mjs";
import {
Battle,
LLMClassifierFromTemplate,
OpenAIClassifier,
buildClassificationFunctions,
} from "../js/llm";
import { ChatCompletionRequestMessage } from "openai";
import { ChatCache } from "../js/oai";

let cache: ChatCache | undefined;
Expand All @@ -18,7 +18,7 @@ test("openai", async () => {
return grade.match(/Winner: (\d+)/)![1];
};

const messages: ChatCompletionRequestMessage[] = [
const messages: ChatCompletionMessage[] = [
{
role: "system",
content: `You are a technical project manager who helps software engineers generate better titles for their GitHub issues.
Expand Down
Loading

0 comments on commit 8379aa2

Please sign in to comment.