Skip to content

Commit

Permalink
feat(adapters): add GCP VertexAI adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
  • Loading branch information
akihikokuroda committed Nov 21, 2024
1 parent 7ac23a4 commit ba5c8a3
Show file tree
Hide file tree
Showing 12 changed files with 610 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For Groq LLM Adapter
# GROQ_API_KEY=

# For GCP VertexAI Adapter
# GOOGLE_APPLICATION_CREDENTIALS=
# GCP_VERTEXAI_PROJECT=
# GCP_VERTEXAI_LOCATION=

# Tools
# CODE_INTERPRETER_URL=http://127.0.0.1:50051

Expand All @@ -30,3 +35,4 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For Elasticsearch Tool
# ELASTICSEARCH_NODE=
# ELASTICSEARCH_API_KEY=

1 change: 1 addition & 0 deletions docs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ To unify differences between various APIs, the framework defines a common interf
| `LangChain` | ⚠️ (depends on a provider) | ⚠️ (depends on a provider) ||
| `Groq` ||| ⚠️ (JSON object only) |
| `AWS Bedrock` ||| ⚠️ (JSON only) - model specific |
| `VertexAI` ||| ⚠️ (JSON only) |
| `BAM (Internal)` || ⚠️ (model specific template must be provided) ||
|[Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | | | |

Expand Down
9 changes: 9 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,14 @@
"peerDependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.687.0",
"@elastic/elasticsearch": "^8.0.0",
"@google-cloud/vertexai": "*",
"@googleapis/customsearch": "^3.2.0",
"@grpc/grpc-js": "^1.11.3",
"@grpc/proto-loader": "^0.7.13",
"@ibm-generative-ai/node-sdk": "~3.2.4",
"@langchain/community": ">=0.2.28",
"@langchain/core": ">=0.2.27",
"google-auth-library": "*",
"groq-sdk": "^0.7.0",
"ollama": "^0.5.8",
"openai": "^4.67.3",
Expand All @@ -214,6 +216,9 @@
"@elastic/elasticsearch": {
"optional": true
},
"@google-cloud/vertexai": {
"optional": true
},
"@googleapis/customsearch": {
"optional": true
},
Expand All @@ -232,6 +237,9 @@
"@langchain/core": {
"optional": true
},
"google-auth-library": {
"optional": true
},
"groq-sdk": {
"optional": true
},
Expand All @@ -255,6 +263,7 @@
"@elastic/elasticsearch": "^8.0.0",
"@eslint/js": "^9.13.0",
"@eslint/markdown": "^6.2.1",
"@google-cloud/vertexai": "^1.9.0",
"@googleapis/customsearch": "^3.2.0",
"@grpc/grpc-js": "^1.12.2",
"@grpc/proto-loader": "^0.7.13",
Expand Down
35 changes: 35 additions & 0 deletions src/adapters/vertexai/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { VertexAIChatLLM } from "@/adapters/vertexai/chat.js";

describe("VertexAI ChatLLM", () => {
const getInstance = () => {
return new VertexAIChatLLM({
modelId: "gemini-1.5-flash-001",
location: "us-central1",
project: "systemInstruction",
});
};

it("Serializes", async () => {
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = VertexAIChatLLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});
158 changes: 158 additions & 0 deletions src/adapters/vertexai/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {
AsyncStream,
BaseLLMTokenizeOutput,
ExecutionOptions,
GenerateCallbacks,
GenerateOptions,
LLMCache,
LLMMeta,
} from "@/llms/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import type { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { VertexAI } from "@google-cloud/vertexai";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { signalRace } from "@/internals/helpers/promise.js";
import { processContentResponse, registerVertexAI, createModel } from "./utils.js";

export class VertexAIChatLLMOutput extends ChatLLMOutput {
public readonly chunks: BaseMessage[] = [];

constructor(chunk: BaseMessage) {
super();
this.chunks.push(chunk);
}

get messages() {
return this.chunks;
}

merge(other: VertexAIChatLLMOutput): void {
this.chunks.push(...other.chunks);
}

getTextContent(): string {
return this.chunks.map((result) => result.text).join("");
}

toString() {
return this.getTextContent();
}

createSnapshot() {
return { chunks: shallowCopy(this.chunks) };
}

loadSnapshot(snapshot: typeof this.createSnapshot): void {
Object.assign(this, snapshot);
}
}

export interface VertexAIChatLLMInput {
modelId: string;
project: string;
location: string;
client?: VertexAI;
executionOptions?: ExecutionOptions;
cache?: LLMCache<VertexAIChatLLMOutput>;
parameters?: Record<string, any>;
}

export class VertexAIChatLLM extends ChatLLM<VertexAIChatLLMOutput> {
public readonly emitter: Emitter<GenerateCallbacks> = Emitter.root.child({
namespace: ["vertexai", "llm"],
creator: this,
});

protected client: VertexAI;

constructor(protected readonly input: VertexAIChatLLMInput) {
super(input.modelId, input.executionOptions, input.cache);
this.client = new VertexAI({ project: input.project, location: input.location });
}

static {
this.register();
registerVertexAI();
}

async meta(): Promise<LLMMeta> {
return { tokenLimit: Infinity };
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
const generativeModel = createModel(this.client, this.modelId);
const response = await generativeModel.countTokens({
contents: input.map((msg) => ({ parts: [{ text: msg.text }], role: msg.role })),
});
return {
tokensCount: response.totalTokens,
};
}

protected async _generate(
input: BaseMessage[],
options: GenerateOptions,
run: GetRunContext<this>,
): Promise<VertexAIChatLLMOutput> {
const generativeModel = createModel(this.client, this.modelId, options.guided?.json);
const response = await signalRace(
() =>
generativeModel.generateContent({
contents: input.map((msg) => ({ parts: [{ text: msg.text }], role: msg.role })),
}),
run.signal,
);
const result = BaseMessage.of({
role: Role.ASSISTANT,
text: processContentResponse(response.response),
});
return new VertexAIChatLLMOutput(result);
}

protected async *_stream(
input: BaseMessage[],
options: GenerateOptions | undefined,
run: GetRunContext<this>,
): AsyncStream<VertexAIChatLLMOutput, void> {
const generativeModel = createModel(this.client, this.modelId, options?.guided?.json);
const chat = generativeModel.startChat();
const response = await chat.sendMessageStream(input.map((msg) => msg.text));
for await (const chunk of await response.stream) {
if (options?.signal?.aborted) {
break;
}
const result = BaseMessage.of({
role: Role.ASSISTANT,
text: processContentResponse(chunk),
});
yield new VertexAIChatLLMOutput(result);
}
run.signal.throwIfAborted();
}

createSnapshot() {
return {
...super.createSnapshot(),
input: shallowCopy(this.input),
client: this.client,
};
}
}
35 changes: 35 additions & 0 deletions src/adapters/vertexai/llm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { verifyDeserialization } from "@tests/e2e/utils.js";
import { VertexAILLM } from "@/adapters/vertexai/llm.js";

describe("VertexAI LLM", () => {
const getInstance = () => {
return new VertexAILLM({
modelId: "gemini-1.5-flash-001",
location: "us-central1",
project: "systemInstruction",
});
};

it("Serializes", async () => {
const instance = getInstance();
const serialized = instance.serialize();
const deserialized = VertexAILLM.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});
Loading

0 comments on commit ba5c8a3

Please sign in to comment.