Skip to content

Commit

Permalink
feat: Implement cached_content with generateContent methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674994964
  • Loading branch information
happy-qiao authored and copybara-github committed Sep 16, 2024
1 parent 3e5e1bf commit b154bd0
Show file tree
Hide file tree
Showing 12 changed files with 793 additions and 49 deletions.
2 changes: 2 additions & 0 deletions src/functions/generate_content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export async function generateContent(
const generateContentRequest: GenerateContentRequest = {
contents: request.contents,
systemInstruction: request.systemInstruction,
cachedContent: request.cachedContent,
generationConfig: request.generationConfig ?? generationConfig,
safetySettings: request.safetySettings ?? safetySettings,
tools: request.tools ?? tools,
Expand Down Expand Up @@ -126,6 +127,7 @@ export async function generateContentStream(
const generateContentRequest: GenerateContentRequest = {
contents: request.contents,
systemInstruction: request.systemInstruction,
cachedContent: request.cachedContent,
generationConfig: request.generationConfig ?? generationConfig,
safetySettings: request.safetySettings ?? safetySettings,
tools: request.tools ?? tools,
Expand Down
8 changes: 7 additions & 1 deletion src/functions/pre_fetch_processing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ export function validateGenerationConfig(
export function getApiVersion(
request: GenerateContentRequest
): 'v1' | 'v1beta1' {
return hasVertexRagStore(request) ? 'v1beta1' : 'v1';
return hasVertexRagStore(request) || hasCachedContent(request)
? 'v1beta1'
: 'v1';
}

export function hasVertexRagStore(request: GenerateContentRequest): boolean {
Expand All @@ -94,6 +96,10 @@ export function hasVertexRagStore(request: GenerateContentRequest): boolean {
return false;
}

function hasCachedContent(request: GenerateContentRequest): boolean {
return !!request.cachedContent;
}

export function hasVertexAISearch(request: GenerateContentRequest): boolean {
for (const tool of request?.tools ?? []) {
const retrieval = (tool as RetrievalTool).retrieval;
Expand Down
32 changes: 32 additions & 0 deletions src/functions/util.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {Content} from '../types/content';
import {constants} from '../util';

export function formulateSystemInstructionIntoContent(
systemInstruction: string | Content
): Content {
if (typeof systemInstruction === 'string') {
return {
role: constants.SYSTEM_ROLE,
parts: [{text: systemInstruction}],
} as Content;
}
systemInstruction.role = constants.SYSTEM_ROLE;
return systemInstruction;
}
2 changes: 1 addition & 1 deletion src/models/chat_session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
/* tslint:disable */
import {GoogleAuth} from 'google-auth-library';

import {formulateSystemInstructionIntoContent} from './util';
import {formulateSystemInstructionIntoContent} from '../functions/util';
import {
generateContent,
generateContentStream,
Expand Down
36 changes: 29 additions & 7 deletions src/models/generative_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
/* tslint:disable */
import {GoogleAuth} from 'google-auth-library';

import {formulateSystemInstructionIntoContent} from './util';
import {formulateSystemInstructionIntoContent} from '../functions/util';
import {countTokens} from '../functions/count_tokens';
import {
generateContent,
generateContentStream,
} from '../functions/generate_content';
import {
CachedContent,
Content,
CountTokensRequest,
CountTokensResponse,
Expand Down Expand Up @@ -295,6 +296,7 @@ export class GenerativeModelPreview {
private readonly publisherModelEndpoint: string;
private readonly resourcePath: string;
private readonly apiEndpoint?: string;
private readonly cachedContent?: CachedContent;

/**
* @constructor
Expand All @@ -310,6 +312,7 @@ export class GenerativeModelPreview {
this.safetySettings = getGenerativeModelParams.safetySettings;
this.tools = getGenerativeModelParams.tools;
this.toolConfig = getGenerativeModelParams.toolConfig;
this.cachedContent = getGenerativeModelParams.cachedContent;
this.requestOptions = getGenerativeModelParams.requestOptions ?? {};
if (getGenerativeModelParams.systemInstruction) {
this.systemInstruction = formulateSystemInstructionIntoContent(
Expand Down Expand Up @@ -358,11 +361,13 @@ export class GenerativeModelPreview {
request: GenerateContentRequest | string
): Promise<GenerateContentResult> {
request = formulateRequestToGenerateContentRequest(request);
const formulatedRequest =
formulateSystemInstructionIntoGenerateContentRequest(
const formulatedRequest = {
...formulateSystemInstructionIntoGenerateContentRequest(
request,
this.systemInstruction
);
),
cachedContent: this.cachedContent?.name,
};
return generateContent(
this.location,
this.resourcePath,
Expand Down Expand Up @@ -405,11 +410,13 @@ export class GenerativeModelPreview {
request: GenerateContentRequest | string
): Promise<StreamGenerateContentResult> {
request = formulateRequestToGenerateContentRequest(request);
const formulatedRequest =
formulateSystemInstructionIntoGenerateContentRequest(
const formulatedRequest = {
...formulateSystemInstructionIntoGenerateContentRequest(
request,
this.systemInstruction
);
),
cachedContent: this.cachedContent?.name,
};
return generateContentStream(
this.location,
this.resourcePath,
Expand Down Expand Up @@ -486,6 +493,7 @@ export class GenerativeModelPreview {
resourcePath: this.resourcePath,
tools: this.tools,
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent?.name,
};

if (request) {
Expand All @@ -497,9 +505,23 @@ export class GenerativeModelPreview {
startChatRequest.tools = request.tools ?? this.tools;
startChatRequest.systemInstruction =
request.systemInstruction ?? this.systemInstruction;
startChatRequest.cachedContent =
request.cachedContent ?? this.cachedContent?.name;
}
return new ChatSessionPreview(startChatRequest, this.requestOptions);
}

getModelName(): string {
return this.model;
}

getCachedContent(): CachedContent | undefined {
return this.cachedContent;
}

getSystemInstruction(): Content | undefined {
return this.systemInstruction;
}
}

function formulateResourcePathFromModel(
Expand Down
Loading

0 comments on commit b154bd0

Please sign in to comment.