Skip to content

Commit

Permalink
feat(adapters): add GCP VertexAI adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
  • Loading branch information
akihikokuroda committed Nov 15, 2024
1 parent b333594 commit 98ca4a9
Show file tree
Hide file tree
Showing 11 changed files with 610 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For Groq LLM Adapter
# GROQ_API_KEY=

# For GCP VertexAI Adapter
# GOOGLE_APPLICATION_CREDENTIALS=
# GCP_VERTEXAI_PROJECT=
# GCP_VERTEXAI_LOCATION=

# Tools
# CODE_INTERPRETER_URL=http://127.0.0.1:50051

Expand All @@ -30,3 +35,4 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For Elasticsearch Tool
# ELASTICSEARCH_NODE=
# ELASTICSEARCH_API_KEY=

1 change: 1 addition & 0 deletions docs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ To unify differences between various APIs, the framework defines a common interf
| `OpenAI` ||| ⚠️ (JSON schema only) |
| `LangChain` | ⚠️ (depends on a provider) | ⚠️ (depends on a provider) ||
| `Groq` ||| ⚠️ (JSON object only) |
| `VertexAI` ||| ⚠️ (JSON only) |
| `BAM (Internal)` || ⚠️ (model specific template must be provided) ||
|[Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | | | |

Expand Down
12 changes: 12 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,14 @@
},
"peerDependencies": {
"@elastic/elasticsearch": "^8.0.0",
"@google-cloud/vertexai": "*",
"@googleapis/customsearch": "^3.2.0",
"@grpc/grpc-js": "^1.11.3",
"@grpc/proto-loader": "^0.7.13",
"@ibm-generative-ai/node-sdk": "~3.2.4",
"@langchain/community": ">=0.2.28",
"@langchain/core": ">=0.2.27",
"google-auth-library": "*",
"groq-sdk": "^0.7.0",
"ollama": "^0.5.8",
"openai": "^4.67.3",
Expand All @@ -210,6 +212,9 @@
"@elastic/elasticsearch": {
"optional": true
},
"@google-cloud/vertexai": {
"optional": true
},
"@googleapis/customsearch": {
"optional": true
},
Expand All @@ -228,6 +233,9 @@
"@langchain/core": {
"optional": true
},
"google-auth-library": {
"optional": true
},
"groq-sdk": {
"optional": true
},
Expand Down Expand Up @@ -309,5 +317,9 @@
"vite-tsconfig-paths": "^5.0.1",
"vitest": "^2.1.3",
"yaml": "^2.6.0"
},
"optionalDependencies": {
"@google-cloud/vertexai": "^1.9.0",
"google-auth-library": "^9.15.0"
}
}
162 changes: 162 additions & 0 deletions src/adapters/vertexai/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {
AsyncStream,
BaseLLMTokenizeOutput,
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMMeta,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { VertexAI, GenerativeModel } from "@google-cloud/vertexai";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { signalRace } from "@/internals/helpers/promise.js";
import { processContentResponse, registerGenerativeModel } from "./utils.js";

export class VertexAIChatLLMOutput extends ChatLLMOutput {
public readonly chunks: BaseMessage[] = [];

constructor(chunk: BaseMessage) {
super();
this.chunks.push(chunk);
}

get messages() {
return this.chunks;
}

merge(other: VertexAIChatLLMOutput): void {
this.chunks.push(...other.chunks);
}

getTextContent(): string {
return this.chunks.map((result) => result.text).join("");
}

toString() {
return this.getTextContent();
}

createSnapshot() {
return { chunks: shallowCopy(this.chunks) };
}

loadSnapshot(snapshot: typeof this.createSnapshot): void {
Object.assign(this, snapshot);
}
}

type VertexAIGenerateOptions = GenerateOptions;

export interface VertexAIChatLLMInput {
modelId: string;
project: string;
location: string;
client?: GenerativeModel;
executionOptions?: ExecutionOptions;
cache?: LLMCache<VertexAIChatLLMOutput>;
parameters?: Record<string, any>;
}

export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
public readonly emitter: Emitter<GenerateCallbacks> = Emitter.root.child({
namespace: ["vertexai", "llm"],
creator: this,
});

protected client: GenerativeModel;

constructor(protected readonly input: VertexAIChatLLMInput) {
super(input.modelId, input.executionOptions, input.cache);
const vertexAI = new VertexAI({ project: input.project, location: input.location });
this.client =
input.client ??
vertexAI.getGenerativeModel({
model: input.modelId,
});
}

static {
this.register();
registerGenerativeModel();
}

async meta(): Promise<LLMMeta> {
return { tokenLimit: Infinity };
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
const response = await this.client.countTokens({
contents: input.map((msg) => ({ parts: [{ text: msg.text }], role: msg.role })),
});
return {
tokensCount: response.totalTokens,
};
}

protected async _generate(
input: BaseMessage[],
options: VertexAIGenerateOptions,
run: GetRunContext<this>,
): Promise<VertexAIChatLLMOutput> {
const response = await signalRace(
() =>
this.client.generateContent({
contents: input.map((msg) => ({ parts: [{ text: msg.text }], role: msg.role })),
}),
run.signal,
);
const result = BaseMessage.of({
role: Role.ASSISTANT,
text: processContentResponse(response.response),
});
return new VertexAIChatLLMOutput(result);
}

protected async *_stream(
input: BaseMessage[],
options: VertexAIGenerateOptions | undefined,
run: GetRunContext<this>,
): AsyncStream<VertexAIChatLLMOutput, void> {
const chat = this.client.startChat();
const response = await chat.sendMessageStream(input.map((msg) => msg.text));
for await (const chunk of await response.stream) {
if (options?.signal?.aborted) {
break;
}
const result = BaseMessage.of({
role: Role.ASSISTANT,
text: processContentResponse(chunk),
});
yield new VertexAIChatLLMOutput(result);
}
run.signal.throwIfAborted();
}

createSnapshot() {
return {
...super.createSnapshot(),
input: shallowCopy(this.input),
client: this.client,
};
}
}
53 changes: 53 additions & 0 deletions src/adapters/vertexai/llm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { VertexAILLM } from "@/adapters/vertexai/llm.js";
import { VertexAIChatLLM } from "@/adapters/vertexai/chat.js";

describe("VertexAI LLM", () => {
const getInstance = () => {
return new VertexAILLM({
modelId: "gemini-1.5-flash-001",
location: "us-central1",
project: "systemInstruction",
});
};

it("Serializes", async () => {
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = VertexAILLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});

describe("VertexAI ChatLLM", () => {
const getInstance = () => {
return new VertexAIChatLLM({
modelId: "gemini-1.5-flash-001",
location: "us-central1",
project: "systemInstruction",
});
};

it("Serializes", async () => {
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = VertexAIChatLLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});
Loading

0 comments on commit 98ca4a9

Please sign in to comment.