Skip to content

Commit

Permalink
feat: add function calling support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 594254845
  • Loading branch information
sararob authored and copybara-github committed Jan 2, 2024
1 parent ea0dcb7 commit 3dab6bb
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 25 deletions.
49 changes: 25 additions & 24 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,7 @@ import {
processNonStream,
processStream,
} from './process_stream';
import {
Content,
CountTokensRequest,
CountTokensResponse,
GenerateContentRequest,
GenerateContentResult,
GenerationConfig,
ModelParams,
Part,
SafetySetting,
StreamGenerateContentResult,
VertexInit,
} from './types/content';
import {Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, Tool, VertexInit,} from './types/content';
import {
ClientError,
GoogleAuthError,
Expand Down Expand Up @@ -126,11 +114,8 @@ export class VertexAI_Preview {
}

return new GenerativeModel(
this,
modelParams.model,
modelParams.generation_config,
modelParams.safety_settings
);
this, modelParams.model, modelParams.generation_config,
modelParams.safety_settings, modelParams.tools);
}
}

Expand All @@ -141,6 +126,7 @@ export declare interface StartChatParams {
history?: Content[];
safety_settings?: SafetySetting[];
generation_config?: GenerationConfig;
tools?: Tool[];
stream?: boolean;
}

Expand Down Expand Up @@ -174,6 +160,7 @@ export class ChatSession {
private _send_stream_promise: Promise<void> = Promise.resolve();
generation_config?: GenerationConfig;
safety_settings?: SafetySetting[];
tools?: Tool[];

get history(): Content[] {
return this.historyInternal;
Expand All @@ -200,6 +187,7 @@ export class ChatSession {
contents: this.historyInternal.concat([newContent]),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

const generateContentResult: GenerateContentResult =
Expand Down Expand Up @@ -262,6 +250,7 @@ export class ChatSession {
contents: this.historyInternal.concat([newContent]),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

const streamGenerateContentResultPromise = this._model_instance
Expand Down Expand Up @@ -289,6 +278,7 @@ export class GenerativeModel {
model: string;
generation_config?: GenerationConfig;
safety_settings?: SafetySetting[];
tools?: Tool[];
private _vertex_instance: VertexAI_Preview;
private _use_non_stream = false;
private publisherModelEndpoint: string;
Expand All @@ -302,15 +292,14 @@ export class GenerativeModel {
* @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting}
*/
constructor(
vertex_instance: VertexAI_Preview,
model: string,
generation_config?: GenerationConfig,
safety_settings?: SafetySetting[]
) {
vertex_instance: VertexAI_Preview, model: string,
generation_config?: GenerationConfig, safety_settings?: SafetySetting[],
tools?: Tool[]) {
this._vertex_instance = vertex_instance;
this.model = model;
this.generation_config = generation_config;
this.safety_settings = safety_settings;
this.tools = tools;
if (model.startsWith('models/')) {
this.publisherModelEndpoint = `publishers/google/${this.model}`;
} else {
Expand Down Expand Up @@ -349,6 +338,7 @@ export class GenerativeModel {
contents: request.contents,
generation_config: request.generation_config ?? this.generation_config,
safety_settings: request.safety_settings ?? this.safety_settings,
tools: request.tools ?? [],
};

const response: Response | undefined = await postRequest({
Expand Down Expand Up @@ -387,6 +377,7 @@ export class GenerativeModel {
contents: request.contents,
generation_config: request.generation_config ?? this.generation_config,
safety_settings: request.safety_settings ?? this.safety_settings,
tools: request.tools ?? [],
};
const response = await postRequest({
region: this._vertex_instance.location,
Expand Down Expand Up @@ -444,6 +435,7 @@ export class GenerativeModel {
request.generation_config ?? this.generation_config;
startChatRequest.safety_settings =
request.safety_settings ?? this.safety_settings;
startChatRequest.tools = request.tools ?? this.tools;
}
return new ChatSession(startChatRequest);
}
Expand All @@ -464,7 +456,16 @@ function formulateNewContent(request: string | Array<string | Part>): Content {
}
}

const newContent: Content = {role: constants.USER_ROLE, parts: newParts};
let newContent = {} as Content;

// TODO: this assumes all elements of newParts are of the same Part type
if ('functionResponse' in newParts[0]) {
newContent.role = constants.FUNCTION_ROLE;
} else {
newContent.role = constants.USER_ROLE;
}

newContent.parts = newParts;
return newContent;
}

Expand Down
4 changes: 4 additions & 0 deletions src/process_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ function aggregateResponses(
if (part.text) {
aggregatedResponse.candidates[i].content.parts[0].text += part.text;
}
if (part.functionCall) {
aggregatedResponse.candidates[i].content.parts[0].functionCall =
part.functionCall;
}
}
}
}
Expand Down
42 changes: 41 additions & 1 deletion src/types/content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export declare interface ModelParams extends BaseModelParams {
export declare interface BaseModelParams {
safety_settings?: SafetySetting[];
generation_config?: GenerationConfig;
tools?: Tool[];
}

/**
Expand Down Expand Up @@ -156,24 +157,38 @@ export interface BasePart {}
export interface TextPart extends BasePart {
text: string;
inline_data?: never;
functionResponse?: never;
functionCall?: never;
}

export interface InlineDataPart extends BasePart {
text?: never;
inline_data: GenerativeContentBlob;
functionResponse?: never;
functionCall?: never;
}

export interface FileData {
mime_type: string;
file_uri: string;
functionCall?: never;
}

export interface FileDataPart extends BasePart {
text?: never;
file_data: FileData;
functionResponse?: never;
functionCall?: never;
}

export declare type Part = TextPart | InlineDataPart | FileDataPart;
export interface FunctionResponsePart extends BasePart {
text?: never;
functionResponse?: FunctionResponse;
functionCall?: FunctionCall;
}

export declare type Part =
TextPart | InlineDataPart | FileDataPart | FunctionResponsePart;

/**
* Raw media bytes sent directly in the request. Text should not be sent as
Expand Down Expand Up @@ -269,6 +284,12 @@ export declare interface GenerateContentCandidate {
finishMessage?: string;
safetyRatings?: SafetyRating[];
citationMetadata?: CitationMetadata;
functionCall?: FunctionCall;
}

export declare interface FunctionCall {
name: string;
args: object;
}

/**
Expand All @@ -287,3 +308,22 @@ export declare interface CitationSource {
uri?: string;
license?: string;
}

export declare interface FunctionResponse {
name: string;
response: object;
}

export declare interface FunctionParameters {
properties: object;
}

export declare interface FunctionDeclaration {
name: string;
parameters: FunctionParameters;
description?: string;
}

export declare interface Tool {
function_declarations: FunctionDeclaration[];
}
1 change: 1 addition & 0 deletions src/util/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export const GENERATE_CONTENT_METHOD = 'generateContent';
export const STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent';
export const USER_ROLE = 'user';
export const MODEL_ROLE = 'model';
export const FUNCTION_ROLE = 'function';
const USER_AGENT_PRODUCT = 'model-builder';
const CLIENT_LIBRARY_LANGUAGE = 'grpc-node/18.0.0';
const CLIENT_LIBRARY_VERSION = '0.1.3';
Expand Down

0 comments on commit 3dab6bb

Please sign in to comment.