Skip to content

Commit

Permalink
Merge branch 'main' into fix-embedding-similarity-crash
Browse files Browse the repository at this point in the history
  • Loading branch information
aphinx committed Oct 28, 2024
2 parents 2fdcd12 + f30aa1e commit 55722d9
Show file tree
Hide file tree
Showing 16 changed files with 7,022 additions and 42 deletions.
27 changes: 19 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,15 @@ Issue Description: {{input}}

const choiceScores = { 1: 1, 2: 0 };

const evaluator = LLMClassifierFromTemplate({
name: "TitleQuality",
promptTemplate,
choiceScores,
useCoT: true,
});
const evaluator =
LLMClassifierFromTemplate <
{ input: string } >
{
name: "TitleQuality",
promptTemplate,
choiceScores,
useCoT: true,
};

const input = `As suggested by Nicolo, we should standardize the error responses coming from GoTrue, postgres, and realtime (and any other/future APIs) so that it's better DX when writing a client,
We can make this change on the servers themselves, but since postgrest and gotrue are fully/partially external may be harder to change, it might be an option to transform the errors within the client libraries/supabase-js, could be messy?
Expand Down Expand Up @@ -316,7 +319,15 @@ print(f"Banana score: {result.score}")
```javascript
import { Score } from "autoevals";

const bananaScorer = ({ output, expected, input }): Score => {
const bananaScorer = ({
output,
expected,
input,
}: {
output: string;
expected: string;
input: string;
}): Score => {
return { name: "banana_scorer", score: output.includes("banana") ? 1 : 0 };
};

Expand All @@ -325,7 +336,7 @@ const bananaScorer = ({ output, expected, input }): Score => {
const output = "3";
const expected = "3 bananas";

const result = await bananaScorer({ output, expected, input });
const result = bananaScorer({ output, expected, input });
console.log(`Banana score: ${result.score}`);
})();
```
Expand Down
3 changes: 2 additions & 1 deletion js/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*
* const result = await Factuality({ output, expected, input });
* console.log(`Factuality score: ${result.score}`);
* console.log(`Factuality metadata: ${result.metadata.rationale}`);
* console.log(`Factuality metadata: ${result.metadata?.rationale}`);
* })();
* ```
*
Expand All @@ -36,5 +36,6 @@ export * from "./number";
export * from "./json";
export * from "./templates";
export * from "./ragas";
export * from "./value";
export { Evaluators } from "./manifest";
export { makePartial, ScorerWithPartial } from "./partial";
15 changes: 3 additions & 12 deletions js/llm.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import * as yaml from "js-yaml";
import mustache from "mustache";

import { Score, Scorer, ScorerArgs } from "@braintrust/core";
import { ChatCache, OpenAIAuth, cachedChatCompletion } from "./oai";
import { templates } from "./templates";
import { ModelGradedSpec, templates } from "./templates";
import {
ChatCompletionMessage,
ChatCompletionMessageParam,
Expand Down Expand Up @@ -217,7 +216,7 @@ export function LLMClassifierFromTemplate<RenderArgs>({
const prompt =
promptTemplate + "\n" + (useCoT ? COT_SUFFIX : NO_COT_SUFFIX);

let maxTokens = 512;
const maxTokens = 512;
const messages: ChatCompletionMessageParam[] = [
{
role: "user",
Expand Down Expand Up @@ -249,14 +248,6 @@ export function LLMClassifierFromTemplate<RenderArgs>({
return ret;
}

export interface ModelGradedSpec {
prompt: string;
choice_scores: Record<string, number>;
model?: string;
use_cot?: boolean;
temperature?: number;
}

export function LLMClassifierFromSpec<RenderArgs>(
name: string,
spec: ModelGradedSpec,
Expand All @@ -275,7 +266,7 @@ export function LLMClassifierFromSpecFile<RenderArgs>(
name: string,
templateName: keyof typeof templates,
): Scorer<any, LLMClassifierArgs<RenderArgs>> {
const doc = yaml.load(templates[templateName]) as ModelGradedSpec;
const doc = templates[templateName];
return LLMClassifierFromSpec(name, doc);
}

Expand Down
15 changes: 15 additions & 0 deletions js/manifest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import { ListContains } from "./list";
import { ScorerWithPartial } from "./partial";
import { Moderation } from "./moderation";
import { ExactMatch } from "./value";
import { ModelGradedSpec, templates } from "./templates";

interface AutoevalMethod {
method: ScorerWithPartial<any, any>;
description: string;
template?: ModelGradedSpec;
requiresExtraParams?: boolean;
}

export const Evaluators: {
Expand All @@ -42,20 +45,26 @@ export const Evaluators: {
method: Battle,
description:
"Test whether an output _better_ performs the `instructions` than the original (expected) value.",
template: templates.battle,
requiresExtraParams: true,
},
{
method: ClosedQA,
description:
"Test whether an output answers the `input` using knowledge built into the model. You can specify `criteria` to further constrain the answer.",
template: templates.closed_q_a,
requiresExtraParams: true,
},
{
method: Humor,
description: "Test whether an output is funny.",
template: templates.humor,
},
{
method: Factuality,
description:
"Test whether an output is factual, compared to an original (`expected`) value.",
template: templates.factuality,
},
{
method: Moderation,
Expand All @@ -66,25 +75,31 @@ export const Evaluators: {
method: Possible,
description:
"Test whether an output is a possible solution to the challenge posed in the input.",
template: templates.possible,
},
{
method: Security,
description: "Test whether an output is malicious.",
template: templates.security,
},
{
method: Sql,
description:
"Test whether a SQL query is semantically the same as a reference (output) query.",
template: templates.sql,
},
{
method: Summary,
description:
"Test whether an output is a better summary of the `input` than the original (`expected`) value.",
template: templates.summary,
},
{
method: Translation,
description:
"Test whether an `output` is as good of a translation of the `input` in the specified `language` as an expert (`expected`) value.",
template: templates.translation,
requiresExtraParams: true,
},
],
},
Expand Down
28 changes: 28 additions & 0 deletions js/oai.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { buildOpenAIClient } from "./oai";

describe.skip("OAI", () => {
test("should use Azure OpenAI", async () => {
/*
* You can plug in your own valid Azure OpenAI info
* to make sure it works.
*/
const client = buildOpenAIClient({
azureOpenAi: {
apiKey: "<some api key>",
endpoint: "https://<some resource>.openai.azure.com/",
apiVersion: "<some valid version>",
},
});
const {
choices: [
{
message: { content },
},
],
} = await client.chat.completions.create({
model: "<Azure OpenAI LLM deployment name>",
messages: [{ role: "system", content: "Hello" }],
});
expect(content).toBeTruthy();
});
});
60 changes: 51 additions & 9 deletions js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@ import {
ChatCompletionTool,
ChatCompletionToolChoiceOption,
} from "openai/resources";
import { OpenAI } from "openai";
import { AzureOpenAI, OpenAI } from "openai";

import { Env } from "./env";

export interface CachedLLMParams {
/**
Model to use for the completion.
Note: If using Azure OpenAI, this should be the deployment name..
*/
model: string;
messages: ChatCompletionMessageParam[];
tools?: ChatCompletionTool[];
tool_choice?: ChatCompletionToolChoiceOption;
temperature?: number;
max_tokens?: number;
span_info?: {
spanAttributes?: Record<string, string>;
};
}

export interface ChatCache {
Expand All @@ -28,6 +35,17 @@ export interface OpenAIAuth {
openAiBaseUrl?: string;
openAiDefaultHeaders?: Record<string, string>;
openAiDangerouslyAllowBrowser?: boolean;
/**
If present, use [Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/)
instead of OpenAI.
*/
azureOpenAi?: AzureOpenAiAuth;
}

export interface AzureOpenAiAuth {
apiKey: string;
endpoint: string;
apiVersion: string;
}

export function extractOpenAIArgs<T extends Record<string, unknown>>(
Expand All @@ -39,6 +57,7 @@ export function extractOpenAIArgs<T extends Record<string, unknown>>(
openAiBaseUrl: args.openAiBaseUrl,
openAiDefaultHeaders: args.openAiDefaultHeaders,
openAiDangerouslyAllowBrowser: args.openAiDangerouslyAllowBrowser,
azureOpenAi: args.azureOpenAi,
};
}

Expand All @@ -51,15 +70,24 @@ export function buildOpenAIClient(options: OpenAIAuth): OpenAI {
openAiBaseUrl,
openAiDefaultHeaders,
openAiDangerouslyAllowBrowser,
azureOpenAi,
} = options;

const client = new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY || Env.BRAINTRUST_API_KEY,
organization: openAiOrganizationId,
baseURL: openAiBaseUrl || Env.OPENAI_BASE_URL || PROXY_URL,
defaultHeaders: openAiDefaultHeaders,
dangerouslyAllowBrowser: openAiDangerouslyAllowBrowser,
});
const client = azureOpenAi
? new AzureOpenAI({
apiKey: azureOpenAi.apiKey,
endpoint: azureOpenAi.endpoint,
apiVersion: azureOpenAi.apiVersion,
defaultHeaders: openAiDefaultHeaders,
dangerouslyAllowBrowser: openAiDangerouslyAllowBrowser,
})
: new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY || Env.BRAINTRUST_API_KEY,
organization: openAiOrganizationId,
baseURL: openAiBaseUrl || Env.OPENAI_BASE_URL || PROXY_URL,
defaultHeaders: openAiDefaultHeaders,
dangerouslyAllowBrowser: openAiDangerouslyAllowBrowser,
});

if (globalThis.__inherited_braintrust_wrap_openai) {
return globalThis.__inherited_braintrust_wrap_openai(client);
Expand All @@ -69,6 +97,7 @@ export function buildOpenAIClient(options: OpenAIAuth): OpenAI {
}

declare global {
/* eslint-disable no-var */
var __inherited_braintrust_wrap_openai: ((openai: any) => any) | undefined;
}

Expand All @@ -77,5 +106,18 @@ export async function cachedChatCompletion(
options: { cache?: ChatCache } & OpenAIAuth,
): Promise<ChatCompletion> {
const openai = buildOpenAIClient(options);
return await openai.chat.completions.create(params);

const fullParams = globalThis.__inherited_braintrust_wrap_openai
? {
...params,
span_info: {
spanAttributes: {
...params.span_info?.spanAttributes,
purpose: "scorer",
},
},
}
: params;

return await openai.chat.completions.create(fullParams);
}
1 change: 1 addition & 0 deletions js/ragas.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ test("Ragas end-to-end test", async () => {

if (score === 1) {
expect(actualScore.score).toBeCloseTo(score, 4);
expect(actualScore.score).toBeLessThanOrEqual(1);
}
}
}, 600000);
11 changes: 10 additions & 1 deletion js/ragas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ type RagasArgs = {
model?: string;
} & LLMArgs;

interface RagasEmbeddingModelArgs extends Record<string, unknown> {
/**
@default
If not provided, the default model of {@link EmbeddingSimilarity} is used.
*/
embeddingModel?: string;
}

const ENTITY_PROMPT = `Given a text, extract unique entities without repetition. Ensure you consider different forms or mentions of the same entity as a single entity.
The output should be a well-formatted JSON instance that conforms to the JSON schema below.
Expand Down Expand Up @@ -603,7 +611,7 @@ export const AnswerRelevancy: ScorerWithPartial<
string,
RagasArgs & {
strictness?: number;
}
} & RagasEmbeddingModelArgs
> = makePartial(async (args) => {
const { chatArgs, client, ...inputs } = parseArgs(args);

Expand Down Expand Up @@ -656,6 +664,7 @@ export const AnswerRelevancy: ScorerWithPartial<
...extractOpenAIArgs(args),
output: question,
expected: input,
model: args.embeddingModel,
});
return { question, score };
}),
Expand Down
2 changes: 1 addition & 1 deletion js/string.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,5 @@ export const EmbeddingSimilarity: ScorerWithPartial<
}, "EmbeddingSimilarity");

function scaleScore(score: number, expectedMin: number): number {
return Math.max((score - expectedMin) / (1 - expectedMin), 0);
return Math.min(Math.max((score - expectedMin) / (1 - expectedMin), 0), 1);
}
Loading

0 comments on commit 55722d9

Please sign in to comment.