Skip to content

Commit

Permalink
feat: Add tool config
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667737781
  • Loading branch information
happy-qiao authored and copybara-github committed Aug 26, 2024
1 parent 9cdbe94 commit f618132
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/functions/generate_content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
Tool,
} from '../types/content';
import {GoogleGenerativeAIError} from '../types/errors';
import {ToolConfig} from '../types/tool';
import * as constants from '../util/constants';

import {
Expand Down Expand Up @@ -55,6 +56,7 @@ export async function generateContent(
generationConfig?: GenerationConfig,
safetySettings?: SafetySetting[],
tools?: Tool[],
toolConfig?: ToolConfig,
requestOptions?: RequestOptions
): Promise<GenerateContentResult> {
request = formatContentRequest(request, generationConfig, safetySettings);
Expand All @@ -73,6 +75,7 @@ export async function generateContent(
generationConfig: request.generationConfig ?? generationConfig,
safetySettings: request.safetySettings ?? safetySettings,
tools: request.tools ?? tools,
toolConfig: request.toolConfig ?? toolConfig,
};
const response: Response | undefined = await postRequest({
region: location,
Expand Down Expand Up @@ -108,6 +111,7 @@ export async function generateContentStream(
generationConfig?: GenerationConfig,
safetySettings?: SafetySetting[],
tools?: Tool[],
toolConfig?: ToolConfig,
requestOptions?: RequestOptions
): Promise<StreamGenerateContentResult> {
request = formatContentRequest(request, generationConfig, safetySettings);
Expand All @@ -125,6 +129,7 @@ export async function generateContentStream(
generationConfig: request.generationConfig ?? generationConfig,
safetySettings: request.safetySettings ?? safetySettings,
tools: request.tools ?? tools,
toolConfig: request.toolConfig ?? toolConfig,
};
const response = await postRequest({
region: location,
Expand Down
5 changes: 5 additions & 0 deletions src/functions/test/functions_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
SafetySetting,
StreamGenerateContentResult,
Tool,
ToolConfig,
} from '../../types';
import {constants} from '../../util';
import {countTokens} from '../count_tokens';
Expand Down Expand Up @@ -193,6 +194,8 @@ const TEST_MULTIPART_MESSAGE_BASE64 = [

const TEST_EMPTY_TOOLS: Tool[] = [];

const TEST_EMPTY_TOOL_CONFIG: ToolConfig = {};

const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [
{
functionDeclarations: [
Expand Down Expand Up @@ -366,6 +369,7 @@ describe('generateContent', () => {
TEST_GENERATION_CONFIG,
TEST_SAFETY_SETTINGS,
TEST_EMPTY_TOOLS,
TEST_EMPTY_TOOL_CONFIG,
TEST_REQUEST_OPTIONS
)
).toBeRejected();
Expand Down Expand Up @@ -684,6 +688,7 @@ describe('generateContentStream', () => {
TEST_GENERATION_CONFIG,
TEST_SAFETY_SETTINGS,
TEST_EMPTY_TOOLS,
TEST_EMPTY_TOOL_CONFIG,
TEST_REQUEST_OPTIONS
)
).toBeRejected();
Expand Down
13 changes: 13 additions & 0 deletions src/models/chat_session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
StreamGenerateContentResult,
Tool,
} from '../types/content';
import {ToolConfig} from '../types';
import {ClientError, GoogleAuthError} from '../types/errors';
import {constants} from '../util';

Expand All @@ -57,6 +58,7 @@ export class ChatSession {
private readonly generationConfig?: GenerationConfig;
private readonly safetySettings?: SafetySetting[];
private readonly tools?: Tool[];
private readonly toolConfig?: ToolConfig;
private readonly apiEndpoint?: string;
private readonly systemInstruction?: Content;

Expand All @@ -80,6 +82,7 @@ export class ChatSession {
this.generationConfig = request.generationConfig;
this.safetySettings = request.safetySettings;
this.tools = request.tools;
this.toolConfig = request.toolConfig;
this.apiEndpoint = request.apiEndpoint;
this.requestOptions = requestOptions ?? {};
if (request.systemInstruction) {
Expand Down Expand Up @@ -130,6 +133,7 @@ export class ChatSession {
safetySettings: this.safetySettings,
generationConfig: this.generationConfig,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
};

Expand All @@ -142,6 +146,7 @@ export class ChatSession {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
).catch(e => {
throw e;
Expand Down Expand Up @@ -213,6 +218,7 @@ export class ChatSession {
safetySettings: this.safetySettings,
generationConfig: this.generationConfig,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
};

Expand All @@ -225,6 +231,7 @@ export class ChatSession {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
).catch(e => {
throw e;
Expand Down Expand Up @@ -261,6 +268,7 @@ export class ChatSessionPreview {
private readonly generationConfig?: GenerationConfig;
private readonly safetySettings?: SafetySetting[];
private readonly tools?: Tool[];
private readonly toolConfig?: ToolConfig;
private readonly apiEndpoint?: string;
private readonly systemInstruction?: Content;

Expand All @@ -284,6 +292,7 @@ export class ChatSessionPreview {
this.generationConfig = request.generationConfig;
this.safetySettings = request.safetySettings;
this.tools = request.tools;
this.toolConfig = request.toolConfig;
this.apiEndpoint = request.apiEndpoint;
this.requestOptions = requestOptions ?? {};
if (request.systemInstruction) {
Expand Down Expand Up @@ -333,6 +342,7 @@ export class ChatSessionPreview {
safetySettings: this.safetySettings,
generationConfig: this.generationConfig,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
};

Expand All @@ -345,6 +355,7 @@ export class ChatSessionPreview {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
).catch(e => {
throw e;
Expand Down Expand Up @@ -417,6 +428,7 @@ export class ChatSessionPreview {
safetySettings: this.safetySettings,
generationConfig: this.generationConfig,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
};

Expand All @@ -429,6 +441,7 @@ export class ChatSessionPreview {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
).catch(e => {
throw e;
Expand Down
10 changes: 10 additions & 0 deletions src/models/generative_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {
StreamGenerateContentResult,
Tool,
} from '../types/content';
import {ToolConfig} from '../types/tool';
import {ClientError, GoogleAuthError} from '../types/errors';
import {constants} from '../util';

Expand All @@ -55,6 +56,7 @@ export class GenerativeModel {
private readonly generationConfig?: GenerationConfig;
private readonly safetySettings?: SafetySetting[];
private readonly tools?: Tool[];
private readonly toolConfig?: ToolConfig;
private readonly requestOptions?: RequestOptions;
private readonly systemInstruction?: Content;
private readonly project: string;
Expand All @@ -77,6 +79,7 @@ export class GenerativeModel {
this.generationConfig = getGenerativeModelParams.generationConfig;
this.safetySettings = getGenerativeModelParams.safetySettings;
this.tools = getGenerativeModelParams.tools;
this.toolConfig = getGenerativeModelParams.toolConfig;
this.requestOptions = getGenerativeModelParams.requestOptions ?? {};
if (getGenerativeModelParams.systemInstruction) {
this.systemInstruction = formulateSystemInstructionIntoContent(
Expand Down Expand Up @@ -140,6 +143,7 @@ export class GenerativeModel {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
);
}
Expand Down Expand Up @@ -186,6 +190,7 @@ export class GenerativeModel {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
);
}
Expand Down Expand Up @@ -251,6 +256,7 @@ export class GenerativeModel {
publisherModelEndpoint: this.publisherModelEndpoint,
resourcePath: this.resourcePath,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
};

Expand Down Expand Up @@ -280,6 +286,7 @@ export class GenerativeModelPreview {
private readonly generationConfig?: GenerationConfig;
private readonly safetySettings?: SafetySetting[];
private readonly tools?: Tool[];
private readonly toolConfig?: ToolConfig;
private readonly requestOptions?: RequestOptions;
private readonly systemInstruction?: Content;
private readonly project: string;
Expand All @@ -302,6 +309,7 @@ export class GenerativeModelPreview {
this.generationConfig = getGenerativeModelParams.generationConfig;
this.safetySettings = getGenerativeModelParams.safetySettings;
this.tools = getGenerativeModelParams.tools;
this.toolConfig = getGenerativeModelParams.toolConfig;
this.requestOptions = getGenerativeModelParams.requestOptions ?? {};
if (getGenerativeModelParams.systemInstruction) {
this.systemInstruction = formulateSystemInstructionIntoContent(
Expand Down Expand Up @@ -364,6 +372,7 @@ export class GenerativeModelPreview {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
);
}
Expand Down Expand Up @@ -410,6 +419,7 @@ export class GenerativeModelPreview {
this.generationConfig,
this.safetySettings,
this.tools,
this.toolConfig,
this.requestOptions
);
}
Expand Down
16 changes: 8 additions & 8 deletions src/models/test/models_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ describe('GenerativeModel generateContent', () => {
);
await modelWithRequestOptions.generateContent(req);
// @ts-ignore
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});
it('set system instruction in constructor, should send system instruction to functions', async () => {
const modelWithSystemInstruction = new GenerativeModel({
Expand Down Expand Up @@ -1311,7 +1311,7 @@ describe('GenerativeModelPreview generateContent', () => {
);
await modelWithRequestOptions.generateContent(req);
// @ts-ignore
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});
it('set system instruction in constructor, should send system instruction to functions', async () => {
const modelWithSystemInstruction = new GenerativeModelPreview({
Expand Down Expand Up @@ -1766,7 +1766,7 @@ describe('GenerativeModel generateContentStream', () => {
);
await modelWithRequestOptions.generateContentStream(req);
// @ts-ignore
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});
it('set system instruction in generateContent, should send system instruction to functions', async () => {
const modelWithSystemInstruction = new GenerativeModel({
Expand Down Expand Up @@ -2080,7 +2080,7 @@ describe('GenerativeModelPreview generateContentStream', () => {
);
await modelWithRequestOptions.generateContentStream(req);
// @ts-ignore
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});

it('set system instruction in generateContent, should send system instruction to functions', async () => {
Expand Down Expand Up @@ -2425,7 +2425,7 @@ describe('ChatSession', () => {
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
TEST_REQUEST_OPTIONS
);
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});

it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => {
Expand Down Expand Up @@ -2674,7 +2674,7 @@ describe('ChatSession', () => {
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
TEST_REQUEST_OPTIONS
);
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});

it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
Expand Down Expand Up @@ -2894,7 +2894,7 @@ describe('ChatSessionPreview', () => {
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
TEST_REQUEST_OPTIONS
);
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});

it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => {
Expand Down Expand Up @@ -3133,7 +3133,7 @@ describe('ChatSessionPreview', () => {
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
TEST_REQUEST_OPTIONS
);
expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0);
expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0);
});

it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
Expand Down
7 changes: 7 additions & 0 deletions src/types/content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

// @ts-nocheck
import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library';
import {ToolConfig} from './tool';
import {SchemaType, Schema} from './common';

/**
Expand Down Expand Up @@ -114,6 +115,8 @@ export declare interface GetGenerativeModelParams extends ModelParams {
safetySettings?: SafetySetting[];
/** Optional. The tools to use for generation. */
tools?: Tool[];
/** Optional. This config is shared for all tools provided in the request. */
toolConfig?: ToolConfig;
/** Optional. The request options to use for generation. */
requestOptions?: RequestOptions;
/**
Expand Down Expand Up @@ -145,6 +148,8 @@ export declare interface BaseModelParams {
generationConfig?: GenerationConfig;
/** Optional. Array of {@link Tool}. */
tools?: Tool[];
/** Optional. This config is shared for all tools provided in the request. */
toolConfig?: ToolConfig;
/**
* Optional. The user provided system instructions for the model.
* Note: only text should be used in parts of {@link Content}
Expand Down Expand Up @@ -1007,6 +1012,8 @@ export declare interface StartChatParams {
generationConfig?: GenerationConfig;
/** Optional. Array of {@link Tool}. */
tools?: Tool[];
/** Optional. This config is shared for all tools provided in the request. */
toolConfig?: ToolConfig;
/** Optional. The base Vertex AI endpoint to use for the request. */
apiEndpoint?: string;
/**
Expand Down
1 change: 1 addition & 0 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@

export * from './content';
export * from './errors';
export * from './tool';
export * from './common';
export {GenerateContentResponseHandler} from './generate_content_response_handler';
Loading

0 comments on commit f618132

Please sign in to comment.