diff --git a/src/functions/generate_content.ts b/src/functions/generate_content.ts index 42a21200..12a6405f 100644 --- a/src/functions/generate_content.ts +++ b/src/functions/generate_content.ts @@ -25,6 +25,7 @@ import { Tool, } from '../types/content'; import {GoogleGenerativeAIError} from '../types/errors'; +import {ToolConfig} from '../types/tool'; import * as constants from '../util/constants'; import { @@ -55,6 +56,7 @@ export async function generateContent( generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], + toolConfig?: ToolConfig, requestOptions?: RequestOptions ): Promise { request = formatContentRequest(request, generationConfig, safetySettings); @@ -73,6 +75,7 @@ export async function generateContent( generationConfig: request.generationConfig ?? generationConfig, safetySettings: request.safetySettings ?? safetySettings, tools: request.tools ?? tools, + toolConfig: request.toolConfig ?? toolConfig, }; const response: Response | undefined = await postRequest({ region: location, @@ -108,6 +111,7 @@ export async function generateContentStream( generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], + toolConfig?: ToolConfig, requestOptions?: RequestOptions ): Promise { request = formatContentRequest(request, generationConfig, safetySettings); @@ -125,6 +129,7 @@ export async function generateContentStream( generationConfig: request.generationConfig ?? generationConfig, safetySettings: request.safetySettings ?? safetySettings, tools: request.tools ?? tools, + toolConfig: request.toolConfig ?? toolConfig, }; const response = await postRequest({ region: location, diff --git a/src/functions/test/functions_test.ts b/src/functions/test/functions_test.ts index 11803c38..8b9afb58 100644 --- a/src/functions/test/functions_test.ts +++ b/src/functions/test/functions_test.ts @@ -31,6 +31,7 @@ import { SafetySetting, StreamGenerateContentResult, Tool, + ToolConfig, } from '../../types'; import {constants} from '../../util'; import {countTokens} from '../count_tokens'; @@ -193,6 +194,8 @@ const TEST_MULTIPART_MESSAGE_BASE64 = [ const TEST_EMPTY_TOOLS: Tool[] = []; +const TEST_EMPTY_TOOL_CONFIG: ToolConfig = {}; + const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [ { functionDeclarations: [ @@ -366,6 +369,7 @@ describe('generateContent', () => { TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, + TEST_EMPTY_TOOL_CONFIG, TEST_REQUEST_OPTIONS ) ).toBeRejected(); @@ -684,6 +688,7 @@ describe('generateContentStream', () => { TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, + TEST_EMPTY_TOOL_CONFIG, TEST_REQUEST_OPTIONS ) ).toBeRejected(); diff --git a/src/models/chat_session.ts b/src/models/chat_session.ts index db2374df..6256ade8 100644 --- a/src/models/chat_session.ts +++ b/src/models/chat_session.ts @@ -35,6 +35,7 @@ import { StreamGenerateContentResult, Tool, } from '../types/content'; +import {ToolConfig} from '../types'; import {ClientError, GoogleAuthError} from '../types/errors'; import {constants} from '../util'; @@ -57,6 +58,7 @@ export class ChatSession { private readonly generationConfig?: GenerationConfig; private readonly safetySettings?: SafetySetting[]; private readonly tools?: Tool[]; + private readonly toolConfig?: ToolConfig; private readonly apiEndpoint?: string; private readonly systemInstruction?: Content; @@ -80,6 +82,7 @@ export class ChatSession { this.generationConfig = request.generationConfig; this.safetySettings = request.safetySettings; this.tools = request.tools; + this.toolConfig = request.toolConfig; this.apiEndpoint = request.apiEndpoint; this.requestOptions = requestOptions ?? {}; if (request.systemInstruction) { @@ -130,6 +133,7 @@ export class ChatSession { safetySettings: this.safetySettings, generationConfig: this.generationConfig, tools: this.tools, + toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, }; @@ -142,6 +146,7 @@ export class ChatSession { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ).catch(e => { throw e; @@ -213,6 +218,7 @@ export class ChatSession { safetySettings: this.safetySettings, generationConfig: this.generationConfig, tools: this.tools, + toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, }; @@ -225,6 +231,7 @@ export class ChatSession { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ).catch(e => { throw e; @@ -261,6 +268,7 @@ export class ChatSessionPreview { private readonly generationConfig?: GenerationConfig; private readonly safetySettings?: SafetySetting[]; private readonly tools?: Tool[]; + private readonly toolConfig?: ToolConfig; private readonly apiEndpoint?: string; private readonly systemInstruction?: Content; @@ -284,6 +292,7 @@ export class ChatSessionPreview { this.generationConfig = request.generationConfig; this.safetySettings = request.safetySettings; this.tools = request.tools; + this.toolConfig = request.toolConfig; this.apiEndpoint = request.apiEndpoint; this.requestOptions = requestOptions ?? {}; if (request.systemInstruction) { @@ -333,6 +342,7 @@ export class ChatSessionPreview { safetySettings: this.safetySettings, generationConfig: this.generationConfig, tools: this.tools, + toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, }; @@ -345,6 +355,7 @@ export class ChatSessionPreview { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ).catch(e => { throw e; @@ -417,6 +428,7 @@ export class ChatSessionPreview { safetySettings: this.safetySettings, generationConfig: this.generationConfig, tools: this.tools, + toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, }; @@ -429,6 +441,7 @@ export class ChatSessionPreview { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ).catch(e => { throw e; diff --git a/src/models/generative_models.ts b/src/models/generative_models.ts index 0ca081ea..b803382c 100644 --- a/src/models/generative_models.ts +++ b/src/models/generative_models.ts @@ -39,6 +39,7 @@ import { StreamGenerateContentResult, Tool, } from '../types/content'; +import {ToolConfig} from '../types/tool'; import {ClientError, GoogleAuthError} from '../types/errors'; import {constants} from '../util'; @@ -55,6 +56,7 @@ export class GenerativeModel { private readonly generationConfig?: GenerationConfig; private readonly safetySettings?: SafetySetting[]; private readonly tools?: Tool[]; + private readonly toolConfig?: ToolConfig; private readonly requestOptions?: RequestOptions; private readonly systemInstruction?: Content; private readonly project: string; @@ -77,6 +79,7 @@ export class GenerativeModel { this.generationConfig = getGenerativeModelParams.generationConfig; this.safetySettings = getGenerativeModelParams.safetySettings; this.tools = getGenerativeModelParams.tools; + this.toolConfig = getGenerativeModelParams.toolConfig; this.requestOptions = getGenerativeModelParams.requestOptions ?? {}; if (getGenerativeModelParams.systemInstruction) { this.systemInstruction = formulateSystemInstructionIntoContent( @@ -140,6 +143,7 @@ export class GenerativeModel { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ); } @@ -186,6 +190,7 @@ export class GenerativeModel { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ); } @@ -251,6 +256,7 @@ export class GenerativeModel { publisherModelEndpoint: this.publisherModelEndpoint, resourcePath: this.resourcePath, tools: this.tools, + toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, }; @@ -280,6 +286,7 @@ export class GenerativeModelPreview { private readonly generationConfig?: GenerationConfig; private readonly safetySettings?: SafetySetting[]; private readonly tools?: Tool[]; + private readonly toolConfig?: ToolConfig; private readonly requestOptions?: RequestOptions; private readonly systemInstruction?: Content; private readonly project: string; @@ -302,6 +309,7 @@ export class GenerativeModelPreview { this.generationConfig = getGenerativeModelParams.generationConfig; this.safetySettings = getGenerativeModelParams.safetySettings; this.tools = getGenerativeModelParams.tools; + this.toolConfig = getGenerativeModelParams.toolConfig; this.requestOptions = getGenerativeModelParams.requestOptions ?? {}; if (getGenerativeModelParams.systemInstruction) { this.systemInstruction = formulateSystemInstructionIntoContent( @@ -364,6 +372,7 @@ export class GenerativeModelPreview { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ); } @@ -410,6 +419,7 @@ export class GenerativeModelPreview { this.generationConfig, this.safetySettings, this.tools, + this.toolConfig, this.requestOptions ); } diff --git a/src/models/test/models_test.ts b/src/models/test/models_test.ts index ab45531a..50616f8c 100644 --- a/src/models/test/models_test.ts +++ b/src/models/test/models_test.ts @@ -837,7 +837,7 @@ describe('GenerativeModel generateContent', () => { ); await modelWithRequestOptions.generateContent(req); // @ts-ignore - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('set system instruction in constructor, should send system instruction to functions', async () => { const modelWithSystemInstruction = new GenerativeModel({ @@ -1311,7 +1311,7 @@ describe('GenerativeModelPreview generateContent', () => { ); await modelWithRequestOptions.generateContent(req); // @ts-ignore - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('set system instruction in constructor, should send system instruction to functions', async () => { const modelWithSystemInstruction = new GenerativeModelPreview({ @@ -1766,7 +1766,7 @@ describe('GenerativeModel generateContentStream', () => { ); await modelWithRequestOptions.generateContentStream(req); // @ts-ignore - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('set system instruction in generateContent, should send system instruction to functions', async () => { const modelWithSystemInstruction = new GenerativeModel({ @@ -2080,7 +2080,7 @@ describe('GenerativeModelPreview generateContentStream', () => { ); await modelWithRequestOptions.generateContentStream(req); // @ts-ignore - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('set system instruction in generateContent, should send system instruction to functions', async () => { @@ -2425,7 +2425,7 @@ describe('ChatSession', () => { expect(chatSessionWithRequestOptions.requestOptions).toEqual( TEST_REQUEST_OPTIONS ); - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => { @@ -2674,7 +2674,7 @@ describe('ChatSession', () => { expect(chatSessionWithRequestOptions.requestOptions).toEqual( TEST_REQUEST_OPTIONS ); - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { @@ -2894,7 +2894,7 @@ describe('ChatSessionPreview', () => { expect(chatSessionWithRequestOptions.requestOptions).toEqual( TEST_REQUEST_OPTIONS ); - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => { @@ -3133,7 +3133,7 @@ describe('ChatSessionPreview', () => { expect(chatSessionWithRequestOptions.requestOptions).toEqual( TEST_REQUEST_OPTIONS ); - expect(generateContentSpy.calls.allArgs()[0][8].timeout).toEqual(0); + expect(generateContentSpy.calls.allArgs()[0][9].timeout).toEqual(0); }); it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { diff --git a/src/types/content.ts b/src/types/content.ts index 32e881ae..1babe1c7 100644 --- a/src/types/content.ts +++ b/src/types/content.ts @@ -17,6 +17,7 @@ // @ts-nocheck import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library'; +import {ToolConfig} from './tool'; import {SchemaType, Schema} from './common'; /** @@ -114,6 +115,8 @@ export declare interface GetGenerativeModelParams extends ModelParams { safetySettings?: SafetySetting[]; /** Optional. The tools to use for generation. */ tools?: Tool[]; + /** Optional. This config is shared for all tools provided in the request. */ + toolConfig?: ToolConfig; /** Optional. The request options to use for generation. */ requestOptions?: RequestOptions; /** @@ -145,6 +148,8 @@ export declare interface BaseModelParams { generationConfig?: GenerationConfig; /** Optional. Array of {@link Tool}. */ tools?: Tool[]; + /** Optional. This config is shared for all tools provided in the request. */ + toolConfig?: ToolConfig; /** * Optional. The user provided system instructions for the model. * Note: only text should be used in parts of {@link Content} @@ -1007,6 +1012,8 @@ export declare interface StartChatParams { generationConfig?: GenerationConfig; /** Optional. Array of {@link Tool}. */ tools?: Tool[]; + /** Optional. This config is shared for all tools provided in the request. */ + toolConfig?: ToolConfig; /** Optional. The base Vertex AI endpoint to use for the request. */ apiEndpoint?: string; /** diff --git a/src/types/index.ts b/src/types/index.ts index 6ab23a39..86440337 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -17,5 +17,6 @@ export * from './content'; export * from './errors'; +export * from './tool'; export * from './common'; export {GenerateContentResponseHandler} from './generate_content_response_handler'; diff --git a/src/types/tool.ts b/src/types/tool.ts new file mode 100644 index 00000000..4bc1e678 --- /dev/null +++ b/src/types/tool.ts @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** This config is shared for all tools provided in the request. */ +export interface ToolConfig { + /** Function calling config. */ + functionCallingConfig?: FunctionCallingConfig; +} + +/** Function calling mode. */ +export enum FunctionCallingMode { + /** Unspecified function calling mode. This value should not be used. */ + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED', + /** + * Default model behavior, model decides to predict either function calls + * or natural language response. + */ + AUTO = 'AUTO', + /** + * Model is constrained to always predicting function calls only. + * If "allowedFunctionNames" are set, the predicted function calls will be + * limited to any one of "allowedFunctionNames", else the predicted + * function calls will be any one of the provided "function_declarations". + */ + ANY = 'ANY', + /** + * Model will not predict any function calls. Model behavior is same as when + * not passing any function declarations. + */ + NONE = 'NONE', +} + +export interface FunctionCallingConfig { + /** Optional. Function calling mode. */ + mode?: FunctionCallingMode; + + /** + * Optional. Function names to call. Only set when the Mode is ANY. Function + * names should match [FunctionDeclaration.name]. With mode set to ANY, model + * will predict a function call from the set of function names provided. + */ + allowedFunctionNames?: string[]; +} diff --git a/src/vertex_ai.ts b/src/vertex_ai.ts index 1d25f472..27194ee0 100644 --- a/src/vertex_ai.ts +++ b/src/vertex_ai.ts @@ -123,6 +123,7 @@ export class VertexAI { safetySettings: modelParams.safetySettings, generationConfig: modelParams.generationConfig, tools: modelParams.tools, + toolConfig: modelParams.toolConfig, requestOptions: requestOptions, systemInstruction: modelParams.systemInstruction, }; @@ -193,6 +194,7 @@ class VertexAIPreview { safetySettings: modelParams.safetySettings, generationConfig: modelParams.generationConfig, tools: modelParams.tools, + toolConfig: modelParams.toolConfig, requestOptions: requestOptions, systemInstruction: modelParams.systemInstruction, };