From b154bd00d663583481e17d4343d6fefd7cb77035 Mon Sep 17 00:00:00 2001 From: Qiao Wang Date: Sun, 15 Sep 2024 21:12:53 -0700 Subject: [PATCH] feat: Implement cached_content with generateContent methods PiperOrigin-RevId: 674994964 --- src/functions/generate_content.ts | 2 + src/functions/pre_fetch_processing.ts | 8 +- src/functions/util.ts | 32 +++ src/models/chat_session.ts | 2 +- src/models/generative_models.ts | 36 +++- src/models/test/generative_models_test.ts | 148 +++++++++---- src/models/test/models_test.ts | 6 + src/resources/cached_contents.ts | 240 ++++++++++++++++++++++ src/resources/index.ts | 19 ++ src/resources/shared/api_client.ts | 142 +++++++++++++ src/types/content.ts | 105 ++++++++++ src/vertex_ai.ts | 102 ++++++++- 12 files changed, 793 insertions(+), 49 deletions(-) create mode 100644 src/functions/util.ts create mode 100644 src/resources/cached_contents.ts create mode 100644 src/resources/index.ts create mode 100644 src/resources/shared/api_client.ts diff --git a/src/functions/generate_content.ts b/src/functions/generate_content.ts index 12a6405f..6b820f2d 100644 --- a/src/functions/generate_content.ts +++ b/src/functions/generate_content.ts @@ -72,6 +72,7 @@ export async function generateContent( const generateContentRequest: GenerateContentRequest = { contents: request.contents, systemInstruction: request.systemInstruction, + cachedContent: request.cachedContent, generationConfig: request.generationConfig ?? generationConfig, safetySettings: request.safetySettings ?? safetySettings, tools: request.tools ?? tools, @@ -126,6 +127,7 @@ export async function generateContentStream( const generateContentRequest: GenerateContentRequest = { contents: request.contents, systemInstruction: request.systemInstruction, + cachedContent: request.cachedContent, generationConfig: request.generationConfig ?? generationConfig, safetySettings: request.safetySettings ?? safetySettings, tools: request.tools ?? tools, diff --git a/src/functions/pre_fetch_processing.ts b/src/functions/pre_fetch_processing.ts index 9330729e..4ad85184 100644 --- a/src/functions/pre_fetch_processing.ts +++ b/src/functions/pre_fetch_processing.ts @@ -80,7 +80,9 @@ export function validateGenerationConfig( export function getApiVersion( request: GenerateContentRequest ): 'v1' | 'v1beta1' { - return hasVertexRagStore(request) ? 'v1beta1' : 'v1'; + return hasVertexRagStore(request) || hasCachedContent(request) + ? 'v1beta1' + : 'v1'; } export function hasVertexRagStore(request: GenerateContentRequest): boolean { @@ -94,6 +96,10 @@ export function hasVertexRagStore(request: GenerateContentRequest): boolean { return false; } +function hasCachedContent(request: GenerateContentRequest): boolean { + return !!request.cachedContent; +} + export function hasVertexAISearch(request: GenerateContentRequest): boolean { for (const tool of request?.tools ?? []) { const retrieval = (tool as RetrievalTool).retrieval; diff --git a/src/functions/util.ts b/src/functions/util.ts new file mode 100644 index 00000000..1e353b5a --- /dev/null +++ b/src/functions/util.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Content} from '../types/content'; +import {constants} from '../util'; + +export function formulateSystemInstructionIntoContent( + systemInstruction: string | Content +): Content { + if (typeof systemInstruction === 'string') { + return { + role: constants.SYSTEM_ROLE, + parts: [{text: systemInstruction}], + } as Content; + } + systemInstruction.role = constants.SYSTEM_ROLE; + return systemInstruction; +} diff --git a/src/models/chat_session.ts b/src/models/chat_session.ts index 6256ade8..0c015024 100644 --- a/src/models/chat_session.ts +++ b/src/models/chat_session.ts @@ -18,7 +18,7 @@ /* tslint:disable */ import {GoogleAuth} from 'google-auth-library'; -import {formulateSystemInstructionIntoContent} from './util'; +import {formulateSystemInstructionIntoContent} from '../functions/util'; import { generateContent, generateContentStream, diff --git a/src/models/generative_models.ts b/src/models/generative_models.ts index b803382c..17a75cf2 100644 --- a/src/models/generative_models.ts +++ b/src/models/generative_models.ts @@ -18,13 +18,14 @@ /* tslint:disable */ import {GoogleAuth} from 'google-auth-library'; -import {formulateSystemInstructionIntoContent} from './util'; +import {formulateSystemInstructionIntoContent} from '../functions/util'; import {countTokens} from '../functions/count_tokens'; import { generateContent, generateContentStream, } from '../functions/generate_content'; import { + CachedContent, Content, CountTokensRequest, CountTokensResponse, @@ -295,6 +296,7 @@ export class GenerativeModelPreview { private readonly publisherModelEndpoint: string; private readonly resourcePath: string; private readonly apiEndpoint?: string; + private readonly cachedContent?: CachedContent; /** * @constructor @@ -310,6 +312,7 @@ export class GenerativeModelPreview { this.safetySettings = getGenerativeModelParams.safetySettings; this.tools = getGenerativeModelParams.tools; this.toolConfig = getGenerativeModelParams.toolConfig; + this.cachedContent = getGenerativeModelParams.cachedContent; this.requestOptions = getGenerativeModelParams.requestOptions ?? {}; if (getGenerativeModelParams.systemInstruction) { this.systemInstruction = formulateSystemInstructionIntoContent( @@ -358,11 +361,13 @@ export class GenerativeModelPreview { request: GenerateContentRequest | string ): Promise { request = formulateRequestToGenerateContentRequest(request); - const formulatedRequest = - formulateSystemInstructionIntoGenerateContentRequest( + const formulatedRequest = { + ...formulateSystemInstructionIntoGenerateContentRequest( request, this.systemInstruction - ); + ), + cachedContent: this.cachedContent?.name, + }; return generateContent( this.location, this.resourcePath, @@ -405,11 +410,13 @@ export class GenerativeModelPreview { request: GenerateContentRequest | string ): Promise { request = formulateRequestToGenerateContentRequest(request); - const formulatedRequest = - formulateSystemInstructionIntoGenerateContentRequest( + const formulatedRequest = { + ...formulateSystemInstructionIntoGenerateContentRequest( request, this.systemInstruction - ); + ), + cachedContent: this.cachedContent?.name, + }; return generateContentStream( this.location, this.resourcePath, @@ -486,6 +493,7 @@ export class GenerativeModelPreview { resourcePath: this.resourcePath, tools: this.tools, systemInstruction: this.systemInstruction, + cachedContent: this.cachedContent?.name, }; if (request) { @@ -497,9 +505,23 @@ export class GenerativeModelPreview { startChatRequest.tools = request.tools ?? this.tools; startChatRequest.systemInstruction = request.systemInstruction ?? this.systemInstruction; + startChatRequest.cachedContent = + request.cachedContent ?? this.cachedContent?.name; } return new ChatSessionPreview(startChatRequest, this.requestOptions); } + + getModelName(): string { + return this.model; + } + + getCachedContent(): CachedContent | undefined { + return this.cachedContent; + } + + getSystemInstruction(): Content | undefined { + return this.systemInstruction; + } } function formulateResourcePathFromModel( diff --git a/src/models/test/generative_models_test.ts b/src/models/test/generative_models_test.ts index 7a524389..20edf0fc 100644 --- a/src/models/test/generative_models_test.ts +++ b/src/models/test/generative_models_test.ts @@ -107,16 +107,25 @@ const BASE_MODEL_PARAMS = { auth: FAKE_GOOGLE_AUTH, }; -describe('GenerativeModel', () => { +describe('', () => { const modelTestCases = [ - (param: GetGenerativeModelParams) => new GenerativeModel(param), - (param: GetGenerativeModelParams) => new GenerativeModelPreview(param), + { + createModel: (param: GetGenerativeModelParams) => + new GenerativeModel(param), + isPreviewModel: false, + }, + { + createModel: (param: GetGenerativeModelParams) => + new GenerativeModelPreview(param), + isPreviewModel: true, + }, ]; describe('generate method should call internal function', () => { const testCases = [ { name: 'when passed a string prompt', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -127,7 +136,7 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: TEST_USER_CONTENT_MESSAGE, + request: jasmine.objectContaining(TEST_USER_CONTENT_MESSAGE), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -138,6 +147,7 @@ describe('GenerativeModel', () => { }, { name: 'when passed a object prompt', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -148,7 +158,7 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: TEST_USER_CONTENT_MESSAGE, + request: jasmine.objectContaining(TEST_USER_CONTENT_MESSAGE), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -159,6 +169,7 @@ describe('GenerativeModel', () => { }, { name: 'when the model name has `models` prefix', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: 'models/model-name', @@ -169,7 +180,7 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: TEST_USER_CONTENT_MESSAGE, + request: jasmine.objectContaining(TEST_USER_CONTENT_MESSAGE), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -180,6 +191,7 @@ describe('GenerativeModel', () => { }, { name: 'when the model name has `project` prefix', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: @@ -192,7 +204,7 @@ describe('GenerativeModel', () => { resourcePath: 'projects/my-project/locations/my-location/models/my-tuned-model', token: jasmine.any(Promise), - request: TEST_USER_CONTENT_MESSAGE, + request: jasmine.objectContaining(TEST_USER_CONTENT_MESSAGE), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -203,6 +215,7 @@ describe('GenerativeModel', () => { }, { name: 'when pass params at model constructor level', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -220,10 +233,44 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: { + request: jasmine.objectContaining({ systemInstruction: TEST_SYSTEM_INSTRUCTION, ...TEST_USER_CONTENT_MESSAGE, - }, + }), + apiEndpoint: TEST_ENDPOINT_BASE_PATH, + generationConfig: TEST_GENERATION_CONFIG, + safetySettings: TEST_SAFETY_SETTINGS, + tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, + toolConfig: TEST_TOOLS_CONFIG, + requestOptions: TEST_REQUEST_OPTIONS, + }), + }, + { + name: 'when pass params at model constructor level', + previewOnly: true, + modelParams: { + ...BASE_MODEL_PARAMS, + model: MODEL_NAME, + googleAuth: FAKE_GOOGLE_AUTH, + apiEndpoint: TEST_ENDPOINT_BASE_PATH, + generationConfig: TEST_GENERATION_CONFIG, + systemInstruction: TEST_SYSTEM_INSTRUCTION, + tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, + toolConfig: TEST_TOOLS_CONFIG, + safetySettings: TEST_SAFETY_SETTINGS, + requestOptions: TEST_REQUEST_OPTIONS, + cachedContent: {name: 'cachedContentName'}, + }, + generateContentParams: TEST_CHAT_MESSSAGE_TEXT, + expectedParams: Object.values({ + location: LOCATION, + resourcePath: RESOURCE_PATH, + token: jasmine.any(Promise), + request: jasmine.objectContaining({ + systemInstruction: TEST_SYSTEM_INSTRUCTION, + ...TEST_USER_CONTENT_MESSAGE, + cachedContent: 'cachedContentName', + }), apiEndpoint: TEST_ENDPOINT_BASE_PATH, generationConfig: TEST_GENERATION_CONFIG, safetySettings: TEST_SAFETY_SETTINGS, @@ -234,6 +281,7 @@ describe('GenerativeModel', () => { }, { name: 'when pass params at model method level', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -258,14 +306,14 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: { + request: jasmine.objectContaining({ ...TEST_USER_CONTENT_MESSAGE, systemInstruction: TEST_SYSTEM_INSTRUCTION, generationConfig: TEST_GENERATION_CONFIG, tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, toolConfig: TEST_TOOLS_CONFIG, safetySettings: TEST_SAFETY_SETTINGS, - }, + }), apiEndpoint: TEST_ENDPOINT_BASE_PATH, generationConfig: undefined, tools: undefined, @@ -276,6 +324,7 @@ describe('GenerativeModel', () => { }, { name: 'when set system instruction(wrong role) in model constructor', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -287,10 +336,10 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: { + request: jasmine.objectContaining({ systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, ...TEST_USER_CONTENT_MESSAGE, - }, + }), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -301,6 +350,7 @@ describe('GenerativeModel', () => { }, { name: 'when set system instruction(wrong role) in model method level', + previewOnly: false, modelParams: { ...BASE_MODEL_PARAMS, model: MODEL_NAME, @@ -314,10 +364,10 @@ describe('GenerativeModel', () => { location: LOCATION, resourcePath: RESOURCE_PATH, token: jasmine.any(Promise), - request: { + request: jasmine.objectContaining({ ...TEST_USER_CONTENT_MESSAGE, systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE, - }, + }), apiEndpoint: undefined, generationConfig: undefined, safetySettings: undefined, @@ -326,15 +376,22 @@ describe('GenerativeModel', () => { requestOptions: {}, }), }, - ].flatMap(testCase => - modelTestCases.map(createModel => ({ - createModel, - ...testCase, - })) - ); + ] + .flatMap(testCase => + modelTestCases.map((modelTestCase: any) => ({ + createModel: modelTestCase.createModel, + isPreviewModel: modelTestCase.isPreviewModel, + ...testCase, + })) + ) + .filter( + (testCase: any) => + !testCase.previewOnly || + (testCase.previewOnly && testCase.isPreviewModel) + ); testCases.forEach((testCase: any) => { - it(`${testCase.name} when call generateContent`, async () => { + it(`${testCase.name} when call generateContent (isPreviewModel=${testCase.isPreviewModel})`, async () => { const model = testCase.createModel(testCase.modelParams); const generateContentSpy: jasmine.Spy = spyOn( GenerateContentFunctions, @@ -348,20 +405,26 @@ describe('GenerativeModel', () => { }); }); - testCases.forEach((testCase: any) => { - it(`${testCase.name} when call generateContentStream`, async () => { - const model = testCase.createModel(testCase.modelParams); - const generateContentSpy: jasmine.Spy = spyOn( - GenerateContentFunctions, - 'generateContentStream' - ); + testCases + .filter( + (testCase: any) => + !testCase.previewOnly || + (testCase.previewOnly && testCase.isPreviewModel) + ) + .forEach((testCase: any) => { + it(`${testCase.name} when call generateContentStream (isPreviewModel=${testCase.isPreviewModel})`, async () => { + const model = testCase.createModel(testCase.modelParams); + const generateContentSpy: jasmine.Spy = spyOn( + GenerateContentFunctions, + 'generateContentStream' + ); - await model.generateContentStream(testCase.generateContentParams); + await model.generateContentStream(testCase.generateContentParams); - const expectedParams = testCase.expectedParams; - expect(generateContentSpy).toHaveBeenCalledWith(...expectedParams); + const expectedParams = testCase.expectedParams; + expect(generateContentSpy).toHaveBeenCalledWith(...expectedParams); + }); }); - }); }); describe('countTokens method should call internal function', () => { @@ -401,12 +464,19 @@ describe('GenerativeModel', () => { requestOptions: TEST_REQUEST_OPTIONS, }), }, - ].flatMap(testCase => - modelTestCases.map(createModel => ({ - createModel, - ...testCase, - })) - ); + ] + .flatMap(testCase => + modelTestCases.map((modelTestCase: any) => ({ + createModel: modelTestCase.createModel, + isPreviewModel: modelTestCase.isPreviewModel, + ...testCase, + })) + ) + .filter( + (testCase: any) => + !testCase.previewOnly || + (testCase.previewOnly && testCase.isPreviewModel) + ); testCases.forEach((testCase: any) => { it(`${testCase.name} when call countTokens`, async () => { const model = testCase.createModel(testCase.modelParams); diff --git a/src/models/test/models_test.ts b/src/models/test/models_test.ts index 50616f8c..e8185b59 100644 --- a/src/models/test/models_test.ts +++ b/src/models/test/models_test.ts @@ -1347,6 +1347,7 @@ describe('GenerativeModelPreview generateContent', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContent(req); // @ts-ignore @@ -1386,6 +1387,7 @@ describe('GenerativeModelPreview generateContent', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContent(req); // @ts-ignore @@ -1425,6 +1427,7 @@ describe('GenerativeModelPreview generateContent', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContent(req); // @ts-ignore @@ -1464,6 +1467,7 @@ describe('GenerativeModelPreview generateContent', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContent(req); // @ts-ignore @@ -2117,6 +2121,7 @@ describe('GenerativeModelPreview generateContentStream', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContentStream(req); // @ts-ignore @@ -2156,6 +2161,7 @@ describe('GenerativeModelPreview generateContentStream', () => { }, ], }, + cachedContent: undefined, }; await modelWithSystemInstruction.generateContentStream(req); // @ts-ignore diff --git a/src/resources/cached_contents.ts b/src/resources/cached_contents.ts new file mode 100644 index 00000000..977241c3 --- /dev/null +++ b/src/resources/cached_contents.ts @@ -0,0 +1,240 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {formulateSystemInstructionIntoContent} from '../functions/util'; +import {ClientError} from '../types'; +import {CachedContent, ListCachedContentsResponse} from '../types'; +import {ApiClient} from './shared/api_client'; + +function camelToSnake(str: string): string { + return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); +} + +class CachedContentsClient { + constructor(readonly apiClient: ApiClient) {} + + create(cachedContent: CachedContent): Promise { + return this.apiClient.unaryApiCall( + new URL( + this.apiClient.getBaseUrl() + + '/' + + this.apiClient.getBaseResourePath() + + '/cachedContents' + ), + { + body: JSON.stringify(cachedContent), + }, + 'POST' + ); + } + + update( + cachedContent: CachedContent, + updateMask: string[] + ): Promise { + const url = new URL(this.apiClient.getBaseUrl() + '/' + cachedContent.name); + url.searchParams.append( + 'updateMask', + updateMask.map(e => camelToSnake(e)).join(',') + ); + return this.apiClient.unaryApiCall( + url, + { + body: JSON.stringify(cachedContent), + }, + 'PATCH' + ); + } + + delete(name: string): Promise { + return this.apiClient.unaryApiCall( + new URL(this.apiClient.getBaseUrl() + '/' + name), + {}, + 'DELETE' + ); + } + + list( + pageSize?: number, + pageToken?: string + ): Promise { + const url = new URL( + this.apiClient.getBaseUrl() + '/' + this.apiClient.getBaseResourePath() + ); + if (pageSize) url.searchParams.append('pageSize', String(pageSize)); + if (pageToken) url.searchParams.append('pageToken', pageToken); + return this.apiClient.unaryApiCall(url, {}, 'GET'); + } + + get(name: string): Promise { + return this.apiClient.unaryApiCall( + new URL(this.apiClient.getBaseUrl() + '/' + name), + {}, + 'GET' + ); + } +} + +export function inferFullResourceName( + project: string, + location: string, + cachedContentId: string +): string { + if (cachedContentId.startsWith('projects/')) { + return cachedContentId; + } + if (cachedContentId.startsWith('locations/')) { + return `projects/${project}/${cachedContentId}`; + } + if (cachedContentId.startsWith('cachedContents/')) { + return `projects/${project}/locations/${location}/${cachedContentId}`; + } + if (!cachedContentId.includes('/')) { + return `projects/${project}/locations/${location}/cachedContents/${cachedContentId}`; + } + throw new ClientError( + `Invalid CachedContent.name: ${cachedContentId}. CachedContent.name should start with 'projects/', 'locations/', 'cachedContents/' or is a number type.` + ); +} + +/** + * Infers the full model name based on the provided project, location, and model. + * + * @internal + */ +export function inferModelName( + project: string, + location: string, + model?: string +) { + if (!model) { + throw new ClientError('Model name is required.'); + } + if (model.startsWith('publishers/')) { + return `projects/${project}/locations/${location}/${model}`; + } + if (!model.startsWith('projects/')) { + return `projects/${project}/locations/${location}/publishers/google/models/${model}`; + } + return model; +} + +/** + * This class is for managing Vertex AI's CachedContent resource. + * @public + */ +export class CachedContents { + private readonly client: CachedContentsClient; + constructor(client: ApiClient) { + this.client = new CachedContentsClient(client); + } + + /** + * Creates cached content, this call will initialize the cached content in the data storage, and users need to pay for the cache data storage. + * @param cachedContent + * @param parent - Required. The parent resource where the cached content will be created. + */ + create(cachedContent: CachedContent): Promise { + const curatedCachedContent = { + ...cachedContent, + systemInstruction: cachedContent.systemInstruction + ? formulateSystemInstructionIntoContent(cachedContent.systemInstruction) + : undefined, + model: inferModelName( + this.client.apiClient.project, + this.client.apiClient.location, + cachedContent.model + ), + } as CachedContent; + return this.client.create(curatedCachedContent); + } + + /** + * Updates cached content configurations + * + * @param updateMask - Required. The list of fields to update. Format: google-fieldmask. See {@link https://cloud.google.com/docs/discovery/type-format} + * @param name - Immutable. Identifier. The server-generated resource name of the cached content Format: projects/{project}/locations/{location}/cachedContents/{cached_content}. + */ + update( + cachedContent: CachedContent, + updateMask: string[] + ): Promise { + if (!cachedContent.name) { + throw new ClientError('Cached content name is required for update.'); + } + if (!updateMask || updateMask.length === 0) { + throw new ClientError( + 'Update mask is required for update. Fields set in cachedContent but not in updateMask will be ignored. Examples: ["ttl"] or ["expireTime"].' + ); + } + const curatedCachedContent = { + ...cachedContent, + systemInstruction: cachedContent.systemInstruction + ? formulateSystemInstructionIntoContent(cachedContent.systemInstruction) + : undefined, + name: inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + cachedContent.name + ), + }; + return this.client.update(curatedCachedContent, updateMask); + } + + /** + * Deletes cached content. + * + * @param name - Required. The resource name referring to the cached content. + */ + delete(name: string): Promise { + return this.client.delete( + inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + name + ) + ); + } + + /** + * Lists cached contents in a project. + * + * @param pageSize - Optional. The maximum number of cached contents to return. The service may return fewer than this value. If unspecified, some default (under maximum) number of items will be returned. The maximum value is 1000; values above 1000 will be coerced to 1000. + * @param pageToken - Optional. A page token, received from a previous `ListCachedContents` call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to `ListCachedContents` must match the call that provided the page token. + */ + list( + pageSize?: number, + pageToken?: string + ): Promise { + return this.client.list(pageSize, pageToken); + } + + /** + * Gets cached content configurations. + * + * @param name - Required. The resource name referring to the cached content. + */ + get(name: string): Promise { + return this.client.get( + inferFullResourceName( + this.client.apiClient.project, + this.client.apiClient.location, + name + ) + ); + } +} diff --git a/src/resources/index.ts b/src/resources/index.ts new file mode 100644 index 00000000..13571ffb --- /dev/null +++ b/src/resources/index.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {CachedContents} from './cached_contents'; +export {ApiClient} from './shared/api_client'; diff --git a/src/resources/shared/api_client.ts b/src/resources/shared/api_client.ts new file mode 100644 index 00000000..59d48add --- /dev/null +++ b/src/resources/shared/api_client.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {GoogleAuth} from 'google-auth-library'; +import {constants} from '../../util'; +import { + ClientError, + GoogleApiError, + GoogleAuthError, + GoogleGenerativeAIError, +} from '../../types'; + +const AUTHORIZATION_HEADER = 'Authorization'; +const CONTENT_TYPE_HEADER = 'Content-Type'; +const USER_AGENT_HEADER = 'User-Agent'; + +export class ApiClient { + constructor( + readonly project: string, + readonly location: string, + readonly apiVersion: 'v1' | 'v1beta1', + private readonly googleAuth: GoogleAuth + ) {} + + /** + * Gets access token from GoogleAuth. Throws {@link GoogleAuthError} when + * fails. + * @returns Promise of token string. + */ + public fetchToken(): Promise { + const tokenPromise = this.googleAuth.getAccessToken().catch(e => { + throw new GoogleAuthError(constants.CREDENTIAL_ERROR_MESSAGE, e); + }); + return tokenPromise; + } + + getBaseUrl() { + return `https://${this.location}-aiplatform.googleapis.com/${this.apiVersion}`; + } + + getBaseResourePath() { + return `projects/${this.project}/locations/${this.location}`; + } + + async unaryApiCall( + url: URL, + requestInit: RequestInit, + httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE' + ): Promise { + const token = await this.getHeaders(); + return this.apiCall(url.toString(), { + ...requestInit, + method: httpMethod, + headers: token, + }); + } + + private async apiCall( + url: string, + requestInit: RequestInit + ): Promise { + const response = await fetch(url, requestInit).catch(e => { + throw new GoogleGenerativeAIError( + `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, + e + ); + }); + await throwErrorIfNotOK(response, url, requestInit).catch(e => { + throw e; + }); + try { + return await response.json(); + } catch (e) { + throw new GoogleGenerativeAIError(JSON.stringify(response), e as Error); + } + } + + private async getHeaders(): Promise { + const token = await this.fetchToken(); + return new Headers({ + [AUTHORIZATION_HEADER]: `Bearer ${token}`, + [CONTENT_TYPE_HEADER]: 'application/json', + [USER_AGENT_HEADER]: constants.USER_AGENT, + }); + } +} + +async function throwErrorIfNotOK( + response: Response | undefined, + url: string, + requestInit: RequestInit +) { + if (response === undefined) { + throw new GoogleGenerativeAIError('response is undefined'); + } + if (!response.ok) { + const status: number = response.status; + const statusText: string = response.statusText; + let errorBody; + if (response.headers.get('content-type')?.includes('application/json')) { + errorBody = await response.json(); + } else { + errorBody = { + error: { + message: `exception sending request to url: ${url} with requestInit: ${JSON.stringify(requestInit)}}`, + code: response.status, + status: response.statusText, + }, + }; + } + const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify( + errorBody + )}`; + if (status >= 400 && status < 500) { + const error = new ClientError( + errorMessage, + new GoogleApiError( + errorBody.error.message, + errorBody.error.code, + errorBody.error.status, + errorBody.error.details + ) + ); + throw error; + } + throw new GoogleGenerativeAIError(errorMessage); + } +} diff --git a/src/types/content.ts b/src/types/content.ts index 1babe1c7..e6ad741e 100644 --- a/src/types/content.ts +++ b/src/types/content.ts @@ -62,6 +62,12 @@ export declare interface GenerateContentRequest extends BaseModelParams { * Note: only text should be used in parts of {@link Content} */ systemInstruction?: string | Content; + + /** + * Optional. The name of the cached content used as context to serve the prediction. + * This is the name of a `CachedContent` and not the cache object itself. + */ + cachedContent?: string; } /** @@ -136,6 +142,13 @@ export declare interface ModelParams extends BaseModelParams { * @example "gemini-1.0-pro". */ model: string; + + /** + * Optional. The cached content used as context to serve the prediction. + * Note: only used in explicit caching, where users can have control over caching + * (e.g. what content to cache) and enjoy guaranteed cost savings. + */ + cachedContent?: CachedContent; } /** @@ -1021,6 +1034,11 @@ export declare interface StartChatParams { * Note: only text should be used in parts of {@link Content} */ systemInstruction?: string | Content; + /** + * Optional. The name of the cached content used as context to serve the prediction. + * This is the name of a `CachedContent` and not the cache object itself. + */ + cachedContent?: string; } /** @@ -1061,3 +1079,90 @@ export interface RequestOptions { */ customHeaders?: Headers; } + +/** + * A resource used in LLM queries for users to explicitly specify + * what to cache and how to cache. + */ +export interface CachedContent { + /** + * Immutable. Identifier. The server-generated resource name of the cached content. + * Format: projects/{project}/locations/{location}/cachedContents/{cached_content} + */ + name?: string; + + /** Optional. Immutable. The user-generated meaningful display name of the cached content. */ + displayName?: string; + + /** + * Immutable. The name of the publisher model to use for cached content. + * Format: projects/{project}/locations/{location}/publishers/{publisher}/models/{model} + */ + model?: string; + + /** Developer set system instruction. Currently, text only. */ + systemInstruction?: Content | string; + + /** Optional. Input only. Immutable. The content to cache. */ + contents?: Content[]; + + /** Optional. Input only. Immutable. A list of `Tools` the model may use to generate the next response. */ + tools?: Tool[]; + + /** Optional. Input only. Immutable. Tool config. This config is shared for all tools. */ + toolConfig?: ToolConfig; + + /** + * Output only. Creatation time of the cache entry. + * Format: google-datetime. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + createTime?: string; + + /** + * Output only. When the cache entry was last updated in UTC time. + * Format: google-datetime. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + updateTime?: string; + + /** Output only. Metadata on the usage of the cached content. */ + usageMetadata?: CachedContentUsageMetadata; + + /** + * Timestamp of when this resource is considered expired. + * This is *always* provided on output, regardless of what was sent on input. + */ + expireTime?: string; + + /** + * Input only. The TTL seconds for this resource. The expiration time + * is computed: now + TTL. + * Format: google-duration. See {@link https://cloud.google.com/docs/discovery/type-format} + */ + ttl?: string; +} + +/** Metadata on the usage of the cached content. */ +export interface CachedContentUsageMetadata { + /** Total number of tokens that the cached content consumes. */ + totalTokenCount?: number; + + /** Number of text characters. */ + textCount?: number; + + /** Number of images. */ + imageCount?: number; + + /** Duration of video in seconds. */ + videoDurationSeconds?: number; + + /** Duration of audio in seconds. */ + audioDurationSeconds?: number; +} + +/** Response with a list of CachedContents. */ +export interface ListCachedContentsResponse { + /** List of cached contents. */ + cachedContents?: CachedContent[]; + /** A token, which can be sent as `page_token` to retrieve the next page. If this field is omitted, there are no subsequent pages. */ + nextPageToken?: string; +} diff --git a/src/vertex_ai.ts b/src/vertex_ai.ts index 27194ee0..7d1858e7 100644 --- a/src/vertex_ai.ts +++ b/src/vertex_ai.ts @@ -20,12 +20,19 @@ import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library'; import {GenerativeModelPreview, GenerativeModel} from './models'; import { + CachedContent, GetGenerativeModelParams, ModelParams, RequestOptions, VertexInit, } from './types/content'; -import {GoogleAuthError, IllegalArgumentError} from './types/errors'; +import { + GoogleAuthError, + IllegalArgumentError, + ClientError, +} from './types/errors'; +import * as Resources from './resources'; +import {inferFullResourceName} from './resources/cached_contents'; /** * The `VertexAI` class is the base class for authenticating to Vertex AI. @@ -149,6 +156,9 @@ class VertexAIPreview { private readonly googleAuth: GoogleAuth; private readonly apiEndpoint?: string; + readonly apiClient: Resources.ApiClient; + readonly cachedContents: Resources.CachedContents; + /** * @constructor * @param project - The Google Cloud project to use for the request @@ -174,6 +184,14 @@ class VertexAIPreview { this.location = location; this.googleAuth = googleAuth; this.apiEndpoint = apiEndpoint; + + this.apiClient = new Resources.ApiClient( + this.project, + this.location, + 'v1beta1', + this.googleAuth + ); + this.cachedContents = new Resources.CachedContents(this.apiClient); } /** @@ -200,6 +218,88 @@ class VertexAIPreview { }; return new GenerativeModelPreview(getGenerativeModelParams); } + + getGenerativeModelFromCachedContent( + cachedContent: CachedContent, + modelParams?: Partial, + requestOptions?: RequestOptions + ) { + if (!cachedContent.name) { + throw new ClientError('Cached content must contain a `name` field.'); + } + if (!cachedContent.model) { + throw new ClientError('Cached content must contain a `model` field.'); + } + validateCachedContentModel(cachedContent.model); + /** + * Not checking tools and toolConfig for now as it would require a deep + * equality comparison and isn't likely to be a common case. + */ + const disallowedDuplicates: Array = + ['model', 'systemInstruction']; + + for (const key of disallowedDuplicates) { + if ( + modelParams?.[key] && + cachedContent[key] && + modelParams?.[key] !== cachedContent[key] + ) { + if (key === 'model') { + const modelParamsComp = parseModelName(modelParams[key]!); + const cachedContentComp = parseModelName(cachedContent[key]!); + if (modelParamsComp === cachedContentComp) { + continue; + } + } + throw new ClientError( + `Different value for "${key}" specified in modelParams` + + ` (${modelParams[key]}) and cachedContent (${cachedContent[key]})` + ); + } + } + + cachedContent.name = inferFullResourceName( + this.project, + this.location, + cachedContent.name + ); + const modelParamsFromCache: GetGenerativeModelParams = { + model: cachedContent.model, + project: this.project, + location: this.location, + googleAuth: this.googleAuth, + apiEndpoint: this.apiEndpoint, + safetySettings: modelParams?.safetySettings, + generationConfig: modelParams?.generationConfig, + tools: cachedContent.tools, + toolConfig: cachedContent.toolConfig, + requestOptions: requestOptions, + systemInstruction: cachedContent.systemInstruction, + cachedContent, + }; + return new GenerativeModelPreview(modelParamsFromCache); + } +} + +function validateCachedContentModel(modelName: string) { + if ( + modelName.startsWith('models/') || + (modelName.startsWith('projects/') && + modelName.includes('/publishers/google/models/')) || + !modelName.includes('/') + ) { + return; + } + throw new ClientError( + `Cached content model name must start with "models/" or match "projects/.*/publishers/google/models/.*" or is a model name listed at https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions. Received: ${modelName}` + ); +} + +function parseModelName(modelName: string): string { + if (!modelName.includes('/')) { + return modelName; + } + return modelName.split('/').pop()!; } function validateGoogleAuthOptions(