Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapAISDKModel method for Vercel's AI SDK #896

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions js/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ Chinook_Sqlite.sql
/wrappers/openai.js
/wrappers/openai.d.ts
/wrappers/openai.d.cts
/wrappers/vercel.cjs
/wrappers/vercel.js
/wrappers/vercel.d.ts
/wrappers/vercel.d.cts
/singletons/traceable.cjs
/singletons/traceable.js
/singletons/traceable.d.ts
Expand Down
20 changes: 18 additions & 2 deletions js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
"wrappers/openai.js",
"wrappers/openai.d.ts",
"wrappers/openai.d.cts",
"wrappers/vercel.cjs",
"wrappers/vercel.js",
"wrappers/vercel.d.ts",
"wrappers/vercel.d.cts",
"singletons/traceable.cjs",
"singletons/traceable.js",
"singletons/traceable.d.ts",
Expand Down Expand Up @@ -101,17 +105,18 @@
"uuid": "^9.0.0"
},
"devDependencies": {
"@ai-sdk/openai": "^0.0.40",
"@babel/preset-env": "^7.22.4",
"@faker-js/faker": "^8.4.1",
"@jest/globals": "^29.5.0",
"langchain": "^0.2.10",
"@langchain/core": "^0.2.17",
"@langchain/langgraph": "^0.0.29",
"@langchain/openai": "^0.2.5",
"@tsconfig/recommended": "^1.0.2",
"@types/jest": "^29.5.1",
"@typescript-eslint/eslint-plugin": "^5.59.8",
"@typescript-eslint/parser": "^5.59.8",
"ai": "^3.2.37",
"babel-jest": "^29.5.0",
"cross-env": "^7.0.3",
"dotenv": "^16.1.3",
Expand All @@ -121,11 +126,13 @@
"eslint-plugin-no-instanceof": "^1.0.1",
"eslint-plugin-prettier": "^4.2.1",
"jest": "^29.5.0",
"langchain": "^0.2.10",
"openai": "^4.38.5",
"prettier": "^2.8.8",
"ts-jest": "^29.1.0",
"ts-node": "^10.9.1",
"typescript": "^5.4.5"
"typescript": "^5.4.5",
"zod": "^3.23.8"
},
"peerDependencies": {
"@langchain/core": "*",
Expand Down Expand Up @@ -249,6 +256,15 @@
"import": "./wrappers/openai.js",
"require": "./wrappers/openai.cjs"
},
"./wrappers/vercel": {
"types": {
"import": "./wrappers/vercel.d.ts",
"require": "./wrappers/vercel.d.cts",
"default": "./wrappers/vercel.d.ts"
},
"import": "./wrappers/vercel.js",
"require": "./wrappers/vercel.cjs"
},
"./singletons/traceable": {
"types": {
"import": "./singletons/traceable.d.ts",
Expand Down
1 change: 1 addition & 0 deletions js/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const entrypoints = {
wrappers: "wrappers/index",
anonymizer: "anonymizer/index",
"wrappers/openai": "wrappers/openai",
"wrappers/vercel": "wrappers/vercel",
"singletons/traceable": "singletons/traceable",
};

Expand Down
77 changes: 77 additions & 0 deletions js/src/tests/wrapped_ai_sdk.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import { openai } from "@ai-sdk/openai";
import {
generateObject,
generateText,
streamObject,
streamText,
tool,
} from "ai";
import { z } from "zod";
import { wrapAISDKModel } from "../wrappers/vercel.js";

test("AI SDK generateText", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { text } = await generateText({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
});
console.log(text);
});

test("AI SDK generateText with a tool", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { text } = await generateText({
model: modelWithTracing,
prompt:
"Write a vegetarian lasagna recipe for 4 people. Get ingredients first.",
tools: {
getIngredients: tool({
description: "get a list of ingredients",
parameters: z.object({
ingredients: z.array(z.string()),
}),
execute: async () =>
JSON.stringify(["pasta", "tomato", "cheese", "onions"]),
}),
},
maxToolRoundtrips: 2,
});
console.log(text);
});

test("AI SDK generateObject", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { object } = await generateObject({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
schema: z.object({
ingredients: z.array(z.string()),
}),
});
console.log(object);
});

test("AI SDK streamText", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { textStream } = await streamText({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
});
for await (const chunk of textStream) {
console.log(chunk);
}
});

test("AI SDK streamObject", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { partialObjectStream } = await streamObject({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
schema: z.object({
ingredients: z.array(z.string()),
}),
});
for await (const chunk of partialObjectStream) {
console.log(chunk);
}
});
93 changes: 92 additions & 1 deletion js/src/traceable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ export function traceable<Func extends (...args: any[]) => any>(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
aggregator?: (args: any[]) => any;
argsConfigPath?: [number] | [number, string];
__finalTracedIteratorKey?: string;

/**
* Extract invocation parameters from the arguments of the traced function.
Expand All @@ -294,7 +295,12 @@ export function traceable<Func extends (...args: any[]) => any>(
}
) {
type Inputs = Parameters<Func>;
const { aggregator, argsConfigPath, ...runTreeConfig } = config ?? {};
const {
aggregator,
__finalTracedIteratorKey,
argsConfigPath,
...runTreeConfig
} = config ?? {};

const traceableFunc = (
...args: Inputs | [RunTree, ...Inputs] | [RunnableConfigLike, ...Inputs]
Expand Down Expand Up @@ -434,6 +440,47 @@ export function traceable<Func extends (...args: any[]) => any>(
return chunks;
}

function tapReadableStreamForTracing(
stream: ReadableStream<unknown>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
) {
const reader = stream.getReader();
let finished = false;
const chunks: unknown[] = [];

const tappedStream = new ReadableStream({
async start(controller) {
// eslint-disable-next-line no-constant-condition
while (true) {
const result = await (snapshot
? snapshot(() => reader.read())
: reader.read());
if (result.done) {
finished = true;
await currentRunTree?.end(
handleRunOutputs(await handleChunks(chunks))
);
await handleEnd();
controller.close();
break;
}
chunks.push(result.value);
controller.enqueue(result.value);
}
},
async cancel(reason) {
if (!finished) await currentRunTree?.end(undefined, "Cancelled");
await currentRunTree?.end(
handleRunOutputs(await handleChunks(chunks))
);
await handleEnd();
return reader.cancel(reason);
},
});

return tappedStream;
}

async function* wrapAsyncIteratorForTracing(
iterator: AsyncIterator<unknown, unknown, undefined>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
Expand Down Expand Up @@ -463,10 +510,14 @@ export function traceable<Func extends (...args: any[]) => any>(
await handleEnd();
}
}

function wrapAsyncGeneratorForTracing(
iterable: AsyncIterable<unknown>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
) {
if (isReadableStream(iterable)) {
return tapReadableStreamForTracing(iterable, snapshot);
}
const iterator = iterable[Symbol.asyncIterator]();
const wrappedIterator = wrapAsyncIteratorForTracing(iterator, snapshot);
iterable[Symbol.asyncIterator] = () => wrappedIterator;
Expand Down Expand Up @@ -512,6 +563,25 @@ export function traceable<Func extends (...args: any[]) => any>(
return wrapAsyncGeneratorForTracing(returnValue, snapshot);
}

if (
!Array.isArray(returnValue) &&
typeof returnValue === "object" &&
returnValue != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(
(returnValue as Record<string, any>)[__finalTracedIteratorKey]
)
) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...returnValue,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(
(returnValue as Record<string, any>)[__finalTracedIteratorKey],
snapshot
),
};
}

const tracedPromise = new Promise<unknown>((resolve, reject) => {
Promise.resolve(returnValue)
.then(
Expand All @@ -523,6 +593,27 @@ export function traceable<Func extends (...args: any[]) => any>(
);
}

if (
!Array.isArray(rawOutput) &&
typeof rawOutput === "object" &&
rawOutput != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(
(rawOutput as Record<string, any>)[__finalTracedIteratorKey]
)
) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...rawOutput,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(
(rawOutput as Record<string, any>)[
__finalTracedIteratorKey
],
snapshot
),
};
}

if (isGenerator(wrappedFunc) && isIteratorLike(rawOutput)) {
const chunks = gatherAll(rawOutput);

Expand Down
72 changes: 72 additions & 0 deletions js/src/wrappers/generic.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import type { RunTreeConfig } from "../index.js";
import { traceable } from "../traceable.js";

export const _wrapClient = <T extends object>(
sdk: T,
runName: string,
options?: Omit<RunTreeConfig, "name">
): T => {
return new Proxy(sdk, {
get(target, propKey, receiver) {
const originalValue = target[propKey as keyof T];
if (typeof originalValue === "function") {
return traceable(originalValue.bind(target), {
run_type: "llm",
...options,
name: [runName, propKey.toString()].join("."),
});
} else if (
originalValue != null &&
!Array.isArray(originalValue) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(originalValue instanceof Date) &&
typeof originalValue === "object"
) {
return _wrapClient(
originalValue,
[runName, propKey.toString()].join("."),
options
);
} else {
return Reflect.get(target, propKey, receiver);
}
},
});
};

type WrapSDKOptions = Partial<
RunTreeConfig & {
/**
* @deprecated Use `name` instead.
*/
runName: string;
}
>;

/**
* Wrap an arbitrary SDK, enabling automatic LangSmith tracing.
* Method signatures are unchanged.
*
* Note that this will wrap and trace ALL SDK methods, not just
* LLM completion methods. If the passed SDK contains other methods,
* we recommend using the wrapped instance for LLM calls only.
* @param sdk An arbitrary SDK instance.
* @param options LangSmith options.
* @returns
*/
export const wrapSDK = <T extends object>(
sdk: T,
options?: WrapSDKOptions
): T => {
const traceableOptions = options ? { ...options } : undefined;
if (traceableOptions != null) {
delete traceableOptions.runName;
delete traceableOptions.name;
}

return _wrapClient(
sdk,
options?.name ?? options?.runName ?? sdk.constructor?.name,
traceableOptions
);
};
1 change: 1 addition & 0 deletions js/src/wrappers/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from "./openai.js";
export { wrapSDK } from "./generic.js";
Loading
Loading