From ea0dcb717be8d22d98916252ccee352e9af4a09f Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Tue, 2 Jan 2024 10:19:41 -0800 Subject: [PATCH] fix: throw ClientError or GoogleGenerativeAIError according to response status so that users can catch them and handle them according to class name. PiperOrigin-RevId: 595149601 --- package.json | 4 +- src/index.ts | 160 +++++++++---------- src/process_stream.ts | 11 ++ src/types/errors.ts | 45 +++++- src/util/post_request.ts | 2 +- system_test/end_to_end_sample_test.ts | 27 +++- test/index_test.ts | 211 +++++++++++++++++++++++++- 7 files changed, 371 insertions(+), 89 deletions(-) diff --git a/package.json b/package.json index 1681a8c9..df4090c0 100644 --- a/package.json +++ b/package.json @@ -39,10 +39,10 @@ "@types/node": "^20.9.0", "gts": "^5.2.0", "jasmine": "^5.1.0", - "typescript": "~5.2.0", "jsdoc": "^4.0.0", "jsdoc-fresh": "^3.0.0", "jsdoc-region-tag": "^3.0.0", - "linkinator": "^4.0.0" + "linkinator": "^4.0.0", + "typescript": "~5.2.0" } } diff --git a/src/index.ts b/src/index.ts index 4be8f664..e06a4e3c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,7 +18,11 @@ /* tslint:disable */ import {GoogleAuth} from 'google-auth-library'; -import {processNonStream, processStream} from './process_stream'; +import { + processCountTokenResponse, + processNonStream, + processStream, +} from './process_stream'; import { Content, CountTokensRequest, @@ -32,7 +36,11 @@ import { StreamGenerateContentResult, VertexInit, } from './types/content'; -import {GoogleAuthError} from './types/errors'; +import { + ClientError, + GoogleAuthError, + GoogleGenerativeAIError, +} from './types/errors'; import {constants, postRequest} from './util'; export * from './types'; @@ -101,7 +109,7 @@ export class VertexAI_Preview { \n -`auth.authenticate_user()`\ \n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; const tokenPromise = this.googleAuth.getAccessToken().catch(e => { - throw new GoogleAuthError(`${credential_error_message}\n${e}`); + throw new GoogleAuthError(credential_error_message, e); }); return tokenPromise; } @@ -194,10 +202,13 @@ export class ChatSession { generation_config: this.generation_config, }; - const generateContentResult = await this._model_instance.generateContent( - generateContentrequest - ); - const generateContentResponse = await generateContentResult.response; + const generateContentResult: GenerateContentResult = + await this._model_instance + .generateContent(generateContentrequest) + .catch(e => { + throw e; + }); + const generateContentResponse = generateContentResult.response; // Only push the latest message to history if the response returned a result if (generateContentResponse.candidates.length !== 0) { this.historyInternal.push(newContent); @@ -253,13 +264,18 @@ export class ChatSession { generation_config: this.generation_config, }; - const streamGenerateContentResultPromise = - this._model_instance.generateContentStream(generateContentrequest); + const streamGenerateContentResultPromise = this._model_instance + .generateContentStream(generateContentrequest) + .catch(e => { + throw e; + }); this._send_stream_promise = this.appendHistory( streamGenerateContentResultPromise, newContent - ); + ).catch(e => { + throw new GoogleGenerativeAIError('exception appending chat history', e); + }); return streamGenerateContentResultPromise; } } @@ -320,7 +336,9 @@ export class GenerativeModel { if (!this._use_non_stream) { const streamGenerateContentResult: StreamGenerateContentResult = - await this.generateContentStream(request); + await this.generateContentStream(request).catch(e => { + throw e; + }); const result: GenerateContentResult = { response: await streamGenerateContentResult.response, }; @@ -333,27 +351,18 @@ export class GenerativeModel { safety_settings: request.safety_settings ?? this.safety_settings, }; - let response; - try { - response = await postRequest({ - region: this._vertex_instance.location, - project: this._vertex_instance.project, - resourcePath: this.publisherModelEndpoint, - resourceMethod: constants.GENERATE_CONTENT_METHOD, - token: await this._vertex_instance.token, - data: generateContentRequest, - apiEndpoint: this._vertex_instance.apiEndpoint, - }); - if (response === undefined) { - throw new Error('did not get a valid response.'); - } - if (!response.ok) { - throw new Error(`${response.status} ${response.statusText}`); - } - } catch (e) { - console.log(e); - } - + const response: Response | undefined = await postRequest({ + region: this._vertex_instance.location, + project: this._vertex_instance.project, + resourcePath: this.publisherModelEndpoint, + resourceMethod: constants.GENERATE_CONTENT_METHOD, + token: await this._vertex_instance.token, + data: generateContentRequest, + apiEndpoint: this._vertex_instance.apiEndpoint, + }).catch(e => { + throw new GoogleGenerativeAIError('exception posting request', e); + }); + throwErrorIfNotOK(response); const result: GenerateContentResult = processNonStream(response); return Promise.resolve(result); } @@ -379,27 +388,18 @@ export class GenerativeModel { generation_config: request.generation_config ?? this.generation_config, safety_settings: request.safety_settings ?? this.safety_settings, }; - let response; - try { - response = await postRequest({ - region: this._vertex_instance.location, - project: this._vertex_instance.project, - resourcePath: this.publisherModelEndpoint, - resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, - token: await this._vertex_instance.token, - data: generateContentRequest, - apiEndpoint: this._vertex_instance.apiEndpoint, - }); - if (response === undefined) { - throw new Error('did not get a valid response.'); - } - if (!response.ok) { - throw new Error(`${response.status} ${response.statusText}`); - } - } catch (e) { - console.log(e); - } - + const response = await postRequest({ + region: this._vertex_instance.location, + project: this._vertex_instance.project, + resourcePath: this.publisherModelEndpoint, + resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, + token: await this._vertex_instance.token, + data: generateContentRequest, + apiEndpoint: this._vertex_instance.apiEndpoint, + }).catch(e => { + throw new GoogleGenerativeAIError('exception posting request', e); + }); + throwErrorIfNotOK(response); const streamResult = processStream(response); return Promise.resolve(streamResult); } @@ -410,32 +410,19 @@ export class GenerativeModel { * @return The CountTokensResponse object with the token count. */ async countTokens(request: CountTokensRequest): Promise { - let response; - try { - response = await postRequest({ - region: this._vertex_instance.location, - project: this._vertex_instance.project, - resourcePath: this.publisherModelEndpoint, - resourceMethod: 'countTokens', - token: await this._vertex_instance.token, - data: request, - apiEndpoint: this._vertex_instance.apiEndpoint, - }); - if (response === undefined) { - throw new Error('did not get a valid response.'); - } - if (!response.ok) { - throw new Error(`${response.status} ${response.statusText}`); - } - } catch (e) { - console.log(e); - } - if (response) { - const responseJson = await response.json(); - return responseJson as CountTokensResponse; - } else { - throw new Error('did not get a valid response.'); - } + const response = await postRequest({ + region: this._vertex_instance.location, + project: this._vertex_instance.project, + resourcePath: this.publisherModelEndpoint, + resourceMethod: 'countTokens', + token: await this._vertex_instance.token, + data: request, + apiEndpoint: this._vertex_instance.apiEndpoint, + }).catch(e => { + throw new GoogleGenerativeAIError('exception posting request', e); + }); + throwErrorIfNotOK(response); + return processCountTokenResponse(response); } /** @@ -481,6 +468,21 @@ function formulateNewContent(request: string | Array): Content { return newContent; } +function throwErrorIfNotOK(response: Response | undefined) { + if (response === undefined) { + throw new GoogleGenerativeAIError('response is undefined'); + } + const status: number = response.status; + const statusText: string = response.statusText; + const errorMessage = `got status: ${status} ${statusText}`; + if (status >= 400 && status < 500) { + throw new ClientError(errorMessage); + } + if (!response.ok) { + throw new GoogleGenerativeAIError(errorMessage); + } +} + function validateGcsInput(contents: Content[]) { for (const content of contents) { for (const part of content.parts) { diff --git a/src/process_stream.ts b/src/process_stream.ts index 1f640303..1e62cf78 100644 --- a/src/process_stream.ts +++ b/src/process_stream.ts @@ -17,6 +17,7 @@ import { CitationSource, + CountTokensResponse, GenerateContentCandidate, GenerateContentResponse, GenerateContentResult, @@ -218,3 +219,13 @@ export function processNonStream(response: any): GenerateContentResult { response: {candidates: []}, }; } + +/** + * Process model responses from countTokens + * @ignore + */ +export function processCountTokenResponse(response: any): CountTokensResponse { + // ts-ignore + const responseJson = response.json(); + return responseJson as CountTokensResponse; +} diff --git a/src/types/errors.ts b/src/types/errors.ts index ca4a7658..ee8c959c 100644 --- a/src/types/errors.ts +++ b/src/types/errors.ts @@ -15,11 +15,52 @@ * limitations under the License. */ +/** + * GoogleAuthError is thrown when there is authentication issue with the request + */ class GoogleAuthError extends Error { - constructor(message: string) { + public readonly stack_trace: any = undefined; + constructor(message: string, stack_trace: any = undefined) { super(message); + this.message = constructErrorMessage('GoogleAuthError', message); this.name = 'GoogleAuthError'; + this.stack_trace = stack_trace; + } +} + +/** + * ClientError is thrown when http 4XX status is received. + * For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses + */ +class ClientError extends Error { + public readonly stack_trace: any = undefined; + constructor(message: string, stack_trace: any = undefined) { + super(message); + this.message = constructErrorMessage('ClientError', message); + this.name = 'ClientError'; + this.stack_trace = stack_trace; } } -export {GoogleAuthError}; +/** + * GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX + * For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status + */ +class GoogleGenerativeAIError extends Error { + public readonly stack_trace: any = undefined; + constructor(message: string, stack_trace: any = undefined) { + super(message); + this.message = constructErrorMessage('GoogleGenerativeAIError', message); + this.name = 'GoogleGenerativeAIError'; + this.stack_trace = stack_trace; + } +} + +function constructErrorMessage( + exceptionClass: string, + message: string +): string { + return `[VertexAI.${exceptionClass}]: ${message}`; +} + +export {ClientError, GoogleAuthError, GoogleGenerativeAIError}; diff --git a/src/util/post_request.ts b/src/util/post_request.ts index 5af1dbbb..4b1e6795 100644 --- a/src/util/post_request.ts +++ b/src/util/post_request.ts @@ -52,7 +52,7 @@ export async function postRequest({ vertexEndpoint += '?alt=sse'; } - return await fetch(vertexEndpoint, { + return fetch(vertexEndpoint, { method: 'POST', headers: { Authorization: `Bearer ${token}`, diff --git a/system_test/end_to_end_sample_test.ts b/system_test/end_to_end_sample_test.ts index bbfbb474..2852bd81 100644 --- a/system_test/end_to_end_sample_test.ts +++ b/system_test/end_to_end_sample_test.ts @@ -18,7 +18,7 @@ // @ts-ignore import * as assert from 'assert'; -import {VertexAI, TextPart} from '../src'; +import {ClientError, VertexAI, TextPart} from '../src'; // TODO: this env var isn't getting populated correctly const PROJECT = process.env.GCLOUD_PROJECT; @@ -129,7 +129,7 @@ describe('generateContentStream', () => { `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` ); }); - it('should should return a stream and aggregated response when passed multipart base64 content', async () => { + it('should return a stream and aggregated response when passed multipart base64 content', async () => { const streamingResp = await generativeVisionModel.generateContentStream( MULTI_PART_BASE64_REQUEST ); @@ -147,6 +147,29 @@ describe('generateContentStream', () => { `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` ); }); + it('should throw ClientError when having invalid input', async () => { + const badRequest = { + contents: [ + { + role: 'user', + parts: [ + {text: 'describe this image:'}, + {inline_data: {mime_type: 'image/png', data: 'invalid data'}}, + ], + }, + ], + }; + await generativeVisionModel.generateContentStream(badRequest).catch(e => { + assert( + e instanceof ClientError, + `sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}` + ); + assert( + e.message === '[VertexAI.ClientError]: got status: 400 Bad Request', + `sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}` + ); + }); + }); // TODO: this is returning a 500 on the system test project // it('should should return a stream and aggregated response when passed // multipart GCS content', diff --git a/test/index_test.ts b/test/index_test.ts index b587cbc4..d267c187 100644 --- a/test/index_test.ts +++ b/test/index_test.ts @@ -148,7 +148,9 @@ const TEST_MULTIPART_MESSAGE = [ const fetchResponseObj = { status: 200, statusText: 'OK', + ok: true, headers: {'Content-Type': 'application/json'}, + url: 'url', }; /** @@ -468,10 +470,11 @@ describe('countTokens', () => { const responseBody = { totalTokens: 1, }; - const response = Promise.resolve( - new Response(JSON.stringify(responseBody), fetchResponseObj) + const response = new Response( + JSON.stringify(responseBody), + fetchResponseObj ); - spyOn(global, 'fetch').and.returnValue(response); + spyOn(global, 'fetch').and.resolveTo(response); const resp = await model.countTokens(req); expect(resp).toEqual(responseBody); }); @@ -620,3 +623,205 @@ describe('ChatSession', () => { }); }); }); + +describe('when exception at fetch', () => { + const expectedErrorMessage = + '[VertexAI.GoogleGenerativeAIError]: exception posting request'; + const vertexai = new VertexAI({ + project: PROJECT, + location: LOCATION, + }); + const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); + const chatSession = model.startChat(); + const message = 'hi'; + const req: GenerateContentRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + const countTokenReq: CountTokensRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + beforeEach(() => { + spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN); + spyOn(global, 'fetch').and.throwError('error'); + }); + + it('generateContent should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.generateContent(req)).toBeRejected(); + }); + + it('generateContentStream should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.generateContentStream(req)).toBeRejected(); + }); + + it('sendMessage should throw GoogleGenerativeAI error', async () => { + await expectAsync(chatSession.sendMessage(message)).toBeRejected(); + }); + + it('countTokens should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); + }); +}); + +describe('when response is undefined', () => { + const expectedErrorMessage = + '[VertexAI.GoogleGenerativeAIError]: response is undefined'; + const vertexai = new VertexAI({ + project: PROJECT, + location: LOCATION, + }); + const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); + const req: GenerateContentRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + const message = 'hi'; + const chatSession = model.startChat(); + const countTokenReq: CountTokensRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + beforeEach(() => { + spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN); + spyOn(global, 'fetch').and.resolveTo(); + }); + + it('generateContent should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.generateContent(req)).toBeRejected(); + await model.generateContent(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('generateContentStream should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.generateContentStream(req)).toBeRejected(); + await model.generateContentStream(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('sendMessage should throw GoogleGenerativeAI error', async () => { + await expectAsync(chatSession.sendMessage(message)).toBeRejected(); + await chatSession.sendMessage(message).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('countTokens should throw GoogleGenerativeAI error', async () => { + await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); + await model.countTokens(countTokenReq).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); +}); + +describe('when response is 4XX', () => { + const expectedErrorMessage = + '[VertexAI.ClientError]: got status: 400 Bad Request'; + const req: GenerateContentRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + const vertexai = new VertexAI({ + project: PROJECT, + location: LOCATION, + }); + const fetch400Obj = { + status: 400, + statusText: 'Bad Request', + ok: false, + }; + const body = {}; + const response = new Response(JSON.stringify(body), fetch400Obj); + const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); + const message = 'hi'; + const chatSession = model.startChat(); + const countTokenReq: CountTokensRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + beforeEach(() => { + spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN); + spyOn(global, 'fetch').and.resolveTo(response); + }); + + it('generateContent should throw ClientError error', async () => { + await expectAsync(model.generateContent(req)).toBeRejected(); + await model.generateContent(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('generateContentStream should throw ClientError error', async () => { + await expectAsync(model.generateContentStream(req)).toBeRejected(); + await model.generateContentStream(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('sendMessage should throw ClientError error', async () => { + await expectAsync(chatSession.sendMessage(message)).toBeRejected(); + await chatSession.sendMessage(message).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('countTokens should throw ClientError error', async () => { + await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); + await model.countTokens(countTokenReq).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); +}); + +describe('when response is not OK and not 4XX', () => { + const expectedErrorMessage = + '[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error'; + const req: GenerateContentRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + const vertexai = new VertexAI({ + project: PROJECT, + location: LOCATION, + }); + const fetch500Obj = { + status: 500, + statusText: 'Internal Server Error', + ok: false, + }; + const body = {}; + const response = new Response(JSON.stringify(body), fetch500Obj); + const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); + const message = 'hi'; + const chatSession = model.startChat(); + const countTokenReq: CountTokensRequest = { + contents: TEST_USER_CHAT_MESSAGE, + }; + beforeEach(() => { + spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN); + spyOn(global, 'fetch').and.resolveTo(response); + }); + + it('generateContent should throws GoogleGenerativeAIError', async () => { + await expectAsync(model.generateContent(req)).toBeRejected(); + await model.generateContent(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('generateContentStream should throws GoogleGenerativeAIError', async () => { + await expectAsync(model.generateContentStream(req)).toBeRejected(); + await model.generateContentStream(req).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('sendMessage should throws GoogleGenerativeAIError', async () => { + await expectAsync(chatSession.sendMessage(message)).toBeRejected(); + await chatSession.sendMessage(message).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); + + it('countTokens should throws GoogleGenerativeAIError', async () => { + await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); + await model.countTokens(countTokenReq).catch(e => { + expect(e.message).toEqual(expectedErrorMessage); + }); + }); +});