Skip to content

Commit

Permalink
initial function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 committed Mar 6, 2024
1 parent 932e1be commit abb89ae
Show file tree
Hide file tree
Showing 10 changed files with 888 additions and 396 deletions.
10 changes: 9 additions & 1 deletion packages/main/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
RequestOptions,
SafetySetting,
StartChatParams,
Tool,
} from "../../types";
import { ChatSession } from "../methods/chat-session";
import { countTokens } from "../methods/count-tokens";
Expand All @@ -53,6 +54,7 @@ export class GenerativeModel {
generationConfig: GenerationConfig;
safetySettings: SafetySetting[];
requestOptions: RequestOptions;
tools?: Tool[];

constructor(
public apiKey: string,
Expand All @@ -68,6 +70,7 @@ export class GenerativeModel {
}
this.generationConfig = modelParams.generationConfig || {};
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
this.requestOptions = requestOptions || {};
}

Expand All @@ -85,6 +88,7 @@ export class GenerativeModel {
{
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
...formattedParams,
},
this.requestOptions,
Expand All @@ -107,6 +111,7 @@ export class GenerativeModel {
{
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
...formattedParams,
},
this.requestOptions,
Expand All @@ -121,7 +126,10 @@ export class GenerativeModel {
return new ChatSession(
this.apiKey,
this.model,
startChatParams,
{
tools: this.tools,
...startChatParams,
},
this.requestOptions,
);
}
Expand Down
53 changes: 53 additions & 0 deletions packages/main/test-integration/node/count-tokens.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { expect, use } from "chai";
import * as chaiAsPromised from "chai-as-promised";
import {
GoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
} from "../..";

use(chaiAsPromised);

/**
* Integration tests against live backend.
*/

describe("countTokens", function () {
this.timeout(60e3);
this.slow(10e3);
it("counts tokens right", async () => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
const model = genAI.getGenerativeModel({
model: "gemini-pro",
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
],
});
const response1 = await model.countTokens("count me");
const response2 = await model.countTokens({
contents: [{ role: "user", parts: [{ text: "count me" }] }],
});
expect(response1.totalTokens).to.equal(3);
expect(response2.totalTokens).to.equal(3);
});
});
45 changes: 45 additions & 0 deletions packages/main/test-integration/node/embed-content.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { expect, use } from "chai";
import * as chaiAsPromised from "chai-as-promised";
import {
GoogleGenerativeAI,
} from "../..";

use(chaiAsPromised);

/**
* Integration tests against live backend.
*/

describe("embedContent", function () {
this.timeout(60e3);
this.slow(10e3);
it("embeds a single Content object", async () => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
const model = genAI.getGenerativeModel({
model: "embedding-001",
});
const response1 = await model.embedContent("embed me");
const response2 = await model.embedContent({
content: { role: "user", parts: [{ text: "embed me" }] },
});
expect(response1.embedding).to.not.be.empty;
expect(response1).to.eql(response2);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import * as fs from "fs";
import { join } from "path";

import { expect, use } from "chai";
import * as chaiAsPromised from "chai-as-promised";
import {
GoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
} from "../..";

use(chaiAsPromised);

/**
* Integration tests against live backend.
*/

describe("generateContent", function () {
this.timeout(60e3);
this.slow(10e3);
it("non-streaming, image buffer provided", async () => {
const imageBuffer = fs.readFileSync(
join(__dirname, "../../test-utils/cat.png"),
);
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
const base64Image = imageBuffer.toString("base64");
const model = genAI.getGenerativeModel({
model: "gemini-pro-vision",
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
],
});
const result = await model.generateContent({
contents: [
{
role: "user",
parts: [
{ text: "Is it a cat?" },
{
inlineData: {
mimeType: "image/png",
data: base64Image,
},
},
],
},
],
});
const response = result.response;
expect(response.text()).to.not.be.empty;
});
});
Loading

0 comments on commit abb89ae

Please sign in to comment.