Skip to content

Commit

Permalink
fix: throw ClientError or GoogleGenerativeAIError according to respon…
Browse files Browse the repository at this point in the history
…se status so that users can catch them and handle them according to class name.

PiperOrigin-RevId: 594135129
  • Loading branch information
yyyu-google authored and copybara-github committed Dec 29, 2023
1 parent 2a75efa commit cbd68db
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 88 deletions.
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@
"postpack": "if [ \"${CLEAN}\" ]; then npm run clean-after-pack; fi"
},
"dependencies": {
"@google-cloud/vertexai": "file:google-cloud-vertexai-0.1.3.tgz",
"google-auth-library": "^9.1.0"
},
"devDependencies": {
"@types/jasmine": "^5.1.2",
"@types/node": "^20.9.0",
"gts": "^5.2.0",
"jasmine": "^5.1.0",
"typescript": "~5.2.0",
"jsdoc": "^4.0.0",
"jsdoc-fresh": "^3.0.0",
"jsdoc-region-tag": "^3.0.0",
"linkinator": "^4.0.0"
"linkinator": "^4.0.0",
"typescript": "~5.2.0"
}
}
160 changes: 81 additions & 79 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
/* tslint:disable */
import {GoogleAuth} from 'google-auth-library';

import {processNonStream, processStream} from './process_stream';
import {
processCountTokenResponse,
processNonStream,
processStream,
} from './process_stream';
import {
Content,
CountTokensRequest,
Expand All @@ -32,7 +36,11 @@ import {
StreamGenerateContentResult,
VertexInit,
} from './types/content';
import {GoogleAuthError} from './types/errors';
import {
ClientError,
GoogleAuthError,
GoogleGenerativeAIError,
} from './types/errors';
import {constants, postRequest} from './util';
export * from './types';

Expand Down Expand Up @@ -101,7 +109,7 @@ export class VertexAI_Preview {
\n -`auth.authenticate_user()`\
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication';
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
throw new GoogleAuthError(`${credential_error_message}\n${e}`);
throw new GoogleAuthError(credential_error_message, e);
});
return tokenPromise;
}
Expand Down Expand Up @@ -194,10 +202,13 @@ export class ChatSession {
generation_config: this.generation_config,
};

const generateContentResult = await this._model_instance.generateContent(
generateContentrequest
);
const generateContentResponse = await generateContentResult.response;
const generateContentResult: GenerateContentResult =
await this._model_instance
.generateContent(generateContentrequest)
.catch(e => {
throw e;
});
const generateContentResponse = generateContentResult.response;
// Only push the latest message to history if the response returned a result
if (generateContentResponse.candidates.length !== 0) {
this.historyInternal.push(newContent);
Expand Down Expand Up @@ -253,13 +264,18 @@ export class ChatSession {
generation_config: this.generation_config,
};

const streamGenerateContentResultPromise =
this._model_instance.generateContentStream(generateContentrequest);
const streamGenerateContentResultPromise = this._model_instance
.generateContentStream(generateContentrequest)
.catch(e => {
throw e;
});

this._send_stream_promise = this.appendHistory(
streamGenerateContentResultPromise,
newContent
);
).catch(e => {
throw new GoogleGenerativeAIError('exception appending chat history', e);
});
return streamGenerateContentResultPromise;
}
}
Expand Down Expand Up @@ -320,7 +336,9 @@ export class GenerativeModel {

if (!this._use_non_stream) {
const streamGenerateContentResult: StreamGenerateContentResult =
await this.generateContentStream(request);
await this.generateContentStream(request).catch(e => {
throw e;
});
const result: GenerateContentResult = {
response: await streamGenerateContentResult.response,
};
Expand All @@ -333,27 +351,18 @@ export class GenerativeModel {
safety_settings: request.safety_settings ?? this.safety_settings,
};

let response;
try {
response = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: constants.GENERATE_CONTENT_METHOD,
token: await this._vertex_instance.token,
data: generateContentRequest,
apiEndpoint: this._vertex_instance.apiEndpoint,
});
if (response === undefined) {
throw new Error('did not get a valid response.');
}
if (!response.ok) {
throw new Error(`${response.status} ${response.statusText}`);
}
} catch (e) {
console.log(e);
}

const response: Response | undefined = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: constants.GENERATE_CONTENT_METHOD,
token: await this._vertex_instance.token,
data: generateContentRequest,
apiEndpoint: this._vertex_instance.apiEndpoint,
}).catch(e => {
throw new GoogleGenerativeAIError('exception posting request', e);
});
throwErrorIfNotOK(response);
const result: GenerateContentResult = processNonStream(response);
return Promise.resolve(result);
}
Expand All @@ -379,27 +388,18 @@ export class GenerativeModel {
generation_config: request.generation_config ?? this.generation_config,
safety_settings: request.safety_settings ?? this.safety_settings,
};
let response;
try {
response = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
token: await this._vertex_instance.token,
data: generateContentRequest,
apiEndpoint: this._vertex_instance.apiEndpoint,
});
if (response === undefined) {
throw new Error('did not get a valid response.');
}
if (!response.ok) {
throw new Error(`${response.status} ${response.statusText}`);
}
} catch (e) {
console.log(e);
}

const response = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
token: await this._vertex_instance.token,
data: generateContentRequest,
apiEndpoint: this._vertex_instance.apiEndpoint,
}).catch(e => {
throw new GoogleGenerativeAIError('exception posting request', e);
});
throwErrorIfNotOK(response);
const streamResult = processStream(response);
return Promise.resolve(streamResult);
}
Expand All @@ -410,32 +410,19 @@ export class GenerativeModel {
* @return The CountTokensResponse object with the token count.
*/
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> {
let response;
try {
response = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: 'countTokens',
token: await this._vertex_instance.token,
data: request,
apiEndpoint: this._vertex_instance.apiEndpoint,
});
if (response === undefined) {
throw new Error('did not get a valid response.');
}
if (!response.ok) {
throw new Error(`${response.status} ${response.statusText}`);
}
} catch (e) {
console.log(e);
}
if (response) {
const responseJson = await response.json();
return responseJson as CountTokensResponse;
} else {
throw new Error('did not get a valid response.');
}
const response = await postRequest({
region: this._vertex_instance.location,
project: this._vertex_instance.project,
resourcePath: this.publisherModelEndpoint,
resourceMethod: 'countTokens',
token: await this._vertex_instance.token,
data: request,
apiEndpoint: this._vertex_instance.apiEndpoint,
}).catch(e => {
throw new GoogleGenerativeAIError('exception posting request', e);
});
throwErrorIfNotOK(response);
return processCountTokenResponse(response);
}

/**
Expand Down Expand Up @@ -481,6 +468,21 @@ function formulateNewContent(request: string | Array<string | Part>): Content {
return newContent;
}

function throwErrorIfNotOK(response: Response | undefined) {
if (response === undefined) {
throw new GoogleGenerativeAIError('response is undefined');
}
const status: number = response.status;
const statusText: string = response.statusText;
const errorMessage = `got status: ${status} ${statusText}`;
if (status >= 400 && status < 500) {
throw new ClientError(errorMessage);
}
if (!response.ok) {
throw new GoogleGenerativeAIError(errorMessage);
}
}

function validateGcsInput(contents: Content[]) {
for (const content of contents) {
for (const part of content.parts) {
Expand Down
11 changes: 11 additions & 0 deletions src/process_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import {
CitationSource,
CountTokensResponse,
GenerateContentCandidate,
GenerateContentResponse,
GenerateContentResult,
Expand Down Expand Up @@ -218,3 +219,13 @@ export function processNonStream(response: any): GenerateContentResult {
response: {candidates: []},
};
}

/**
* Process model responses from countTokens
* @ignore
*/
export function processCountTokenResponse(response: any): CountTokensResponse {
// ts-ignore
const responseJson = response.json();
return responseJson as CountTokensResponse;
}
45 changes: 43 additions & 2 deletions src/types/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,52 @@
* limitations under the License.
*/

/**
* GoogleAuthError is thrown when there is authentication issue with the request
*/
class GoogleAuthError extends Error {
constructor(message: string) {
public readonly stack_trace: any = undefined;
constructor(message: string, stack_trace: any = undefined) {
super(message);
this.message = constructErrorMessage('GoogleAuthError', message);
this.name = 'GoogleAuthError';
this.stack_trace = stack_trace;
}
}

/**
* ClientError is thrown when http 4XX status is received.
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses
*/
class ClientError extends Error {
public readonly stack_trace: any = undefined;
constructor(message: string, stack_trace: any = undefined) {
super(message);
this.message = constructErrorMessage('ClientError', message);
this.name = 'ClientError';
this.stack_trace = stack_trace;
}
}

export {GoogleAuthError};
/**
* GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
*/
class GoogleGenerativeAIError extends Error {
public readonly stack_trace: any = undefined;
constructor(message: string, stack_trace: any = undefined) {
super(message);
this.message = constructErrorMessage('GoogleGenerativeAIError', message);
this.name = 'GoogleGenerativeAIError';
this.stack_trace = stack_trace;
}
}

function constructErrorMessage(
exceptionClass: string,
message: string
): string {
return `[VertexAI.${exceptionClass}]: ${message}`;
}

export {ClientError, GoogleAuthError, GoogleGenerativeAIError};
2 changes: 1 addition & 1 deletion src/util/post_request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export async function postRequest({
vertexEndpoint += '?alt=sse';
}

return await fetch(vertexEndpoint, {
return fetch(vertexEndpoint, {
method: 'POST',
headers: {
Authorization: `Bearer ${token}`,
Expand Down
22 changes: 21 additions & 1 deletion system_test/end_to_end_sample_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import * as assert from 'assert';

import {VertexAI, TextPart} from '../src';
import {ClientError} from '../src/';

// TODO: this env var isn't getting populated correctly
const PROJECT = process.env.GCLOUD_PROJECT;
Expand Down Expand Up @@ -129,7 +130,7 @@ describe('generateContentStream', () => {
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
);
});
it('should should return a stream and aggregated response when passed multipart base64 content', async () => {
it('should return a stream and aggregated response when passed multipart base64 content', async () => {
const streamingResp = await generativeVisionModel.generateContentStream(
MULTI_PART_BASE64_REQUEST
);
Expand All @@ -147,6 +148,25 @@ describe('generateContentStream', () => {
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
);
});
it('should throw ClientError when having invalid input', async () => {
const badRequest = {
'contents':[{
'role':'user','parts':[{'text':'describe this image:'},{'inline_data':{'mime_type':'image/png','data':'invalid data'}}
],
},
],
};
await generativeVisionModel.generateContentStream(badRequest).catch(e => {
assert(
e instanceof ClientError,
`sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}`
);
assert(
e.message === '[VertexAI.ClientError]: got status: 400 Bad Request',
`sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}`
);
});
});
// TODO: this is returning a 500 on the system test project
// it('should should return a stream and aggregated response when passed
// multipart GCS content',
Expand Down
Loading

0 comments on commit cbd68db

Please sign in to comment.