diff --git a/src/index.ts b/src/index.ts index e06a4e3c..1e54f412 100644 --- a/src/index.ts +++ b/src/index.ts @@ -23,19 +23,7 @@ import { processNonStream, processStream, } from './process_stream'; -import { - Content, - CountTokensRequest, - CountTokensResponse, - GenerateContentRequest, - GenerateContentResult, - GenerationConfig, - ModelParams, - Part, - SafetySetting, - StreamGenerateContentResult, - VertexInit, -} from './types/content'; +import {Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, Tool, VertexInit,} from './types/content'; import { ClientError, GoogleAuthError, @@ -126,11 +114,8 @@ export class VertexAI_Preview { } return new GenerativeModel( - this, - modelParams.model, - modelParams.generation_config, - modelParams.safety_settings - ); + this, modelParams.model, modelParams.generation_config, + modelParams.safety_settings, modelParams.tools); } } @@ -141,6 +126,7 @@ export declare interface StartChatParams { history?: Content[]; safety_settings?: SafetySetting[]; generation_config?: GenerationConfig; + tools?: Tool[]; stream?: boolean; } @@ -174,6 +160,7 @@ export class ChatSession { private _send_stream_promise: Promise = Promise.resolve(); generation_config?: GenerationConfig; safety_settings?: SafetySetting[]; + tools?: Tool[]; get history(): Content[] { return this.historyInternal; @@ -200,6 +187,7 @@ export class ChatSession { contents: this.historyInternal.concat([newContent]), safety_settings: this.safety_settings, generation_config: this.generation_config, + tools: this.tools, }; const generateContentResult: GenerateContentResult = @@ -262,6 +250,7 @@ export class ChatSession { contents: this.historyInternal.concat([newContent]), safety_settings: this.safety_settings, generation_config: this.generation_config, + tools: this.tools, }; const streamGenerateContentResultPromise = this._model_instance @@ -289,6 +278,7 @@ export class GenerativeModel { model: string; generation_config?: GenerationConfig; safety_settings?: SafetySetting[]; + tools?: Tool[]; private _vertex_instance: VertexAI_Preview; private _use_non_stream = false; private publisherModelEndpoint: string; @@ -302,15 +292,14 @@ export class GenerativeModel { * @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting} */ constructor( - vertex_instance: VertexAI_Preview, - model: string, - generation_config?: GenerationConfig, - safety_settings?: SafetySetting[] - ) { + vertex_instance: VertexAI_Preview, model: string, + generation_config?: GenerationConfig, safety_settings?: SafetySetting[], + tools?: Tool[]) { this._vertex_instance = vertex_instance; this.model = model; this.generation_config = generation_config; this.safety_settings = safety_settings; + this.tools = tools; if (model.startsWith('models/')) { this.publisherModelEndpoint = `publishers/google/${this.model}`; } else { @@ -349,6 +338,7 @@ export class GenerativeModel { contents: request.contents, generation_config: request.generation_config ?? this.generation_config, safety_settings: request.safety_settings ?? this.safety_settings, + tools: request.tools ?? [], }; const response: Response | undefined = await postRequest({ @@ -387,6 +377,7 @@ export class GenerativeModel { contents: request.contents, generation_config: request.generation_config ?? this.generation_config, safety_settings: request.safety_settings ?? this.safety_settings, + tools: request.tools ?? [], }; const response = await postRequest({ region: this._vertex_instance.location, @@ -444,6 +435,7 @@ export class GenerativeModel { request.generation_config ?? this.generation_config; startChatRequest.safety_settings = request.safety_settings ?? this.safety_settings; + startChatRequest.tools = request.tools ?? this.tools; } return new ChatSession(startChatRequest); } @@ -464,7 +456,16 @@ function formulateNewContent(request: string | Array): Content { } } - const newContent: Content = {role: constants.USER_ROLE, parts: newParts}; + let newContent = {} as Content; + + // TODO: this assumes all elements of newParts are of the same Part type + if ('functionResponse' in newParts[0]) { + newContent.role = constants.FUNCTION_ROLE; + } else { + newContent.role = constants.USER_ROLE; + } + + newContent.parts = newParts; return newContent; } diff --git a/src/process_stream.ts b/src/process_stream.ts index 1e62cf78..a762bfdd 100644 --- a/src/process_stream.ts +++ b/src/process_stream.ts @@ -193,6 +193,10 @@ function aggregateResponses( if (part.text) { aggregatedResponse.candidates[i].content.parts[0].text += part.text; } + if (part.functionCall) { + aggregatedResponse.candidates[i].content.parts[0].functionCall = + part.functionCall; + } } } } diff --git a/src/types/content.ts b/src/types/content.ts index 6adb61c2..18d5c659 100644 --- a/src/types/content.ts +++ b/src/types/content.ts @@ -64,6 +64,7 @@ export declare interface ModelParams extends BaseModelParams { export declare interface BaseModelParams { safety_settings?: SafetySetting[]; generation_config?: GenerationConfig; + tools?: Tool[]; } /** @@ -156,24 +157,38 @@ export interface BasePart {} export interface TextPart extends BasePart { text: string; inline_data?: never; + functionResponse?: never; + functionCall?: never; } export interface InlineDataPart extends BasePart { text?: never; inline_data: GenerativeContentBlob; + functionResponse?: never; + functionCall?: never; } export interface FileData { mime_type: string; file_uri: string; + functionCall?: never; } export interface FileDataPart extends BasePart { text?: never; file_data: FileData; + functionResponse?: never; + functionCall?: never; } -export declare type Part = TextPart | InlineDataPart | FileDataPart; +export interface FunctionResponsePart extends BasePart { + text?: never; + functionResponse?: FunctionResponse; + functionCall?: FunctionCall; +} + +export declare type Part = + TextPart | InlineDataPart | FileDataPart | FunctionResponsePart; /** * Raw media bytes sent directly in the request. Text should not be sent as @@ -269,6 +284,12 @@ export declare interface GenerateContentCandidate { finishMessage?: string; safetyRatings?: SafetyRating[]; citationMetadata?: CitationMetadata; + functionCall?: FunctionCall; +} + +export declare interface FunctionCall { + name: string; + args: object; } /** @@ -287,3 +308,22 @@ export declare interface CitationSource { uri?: string; license?: string; } + +export declare interface FunctionResponse { + name: string; + response: object; +} + +export declare interface FunctionParameters { + properties: object; +} + +export declare interface FunctionDeclaration { + name: string; + parameters: FunctionParameters; + description?: string; +} + +export declare interface Tool { + function_declarations: FunctionDeclaration[]; +} \ No newline at end of file diff --git a/src/util/constants.ts b/src/util/constants.ts index 9bff4fa1..4fae89e2 100644 --- a/src/util/constants.ts +++ b/src/util/constants.ts @@ -18,6 +18,7 @@ export const GENERATE_CONTENT_METHOD = 'generateContent'; export const STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent'; export const USER_ROLE = 'user'; export const MODEL_ROLE = 'model'; +export const FUNCTION_ROLE = 'function'; const USER_AGENT_PRODUCT = 'model-builder'; const CLIENT_LIBRARY_LANGUAGE = 'grpc-node/18.0.0'; const CLIENT_LIBRARY_VERSION = '0.1.3';