Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 committed Mar 8, 2024
1 parent 269ea48 commit 7343708
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 25 deletions.
8 changes: 5 additions & 3 deletions packages/main/src/methods/chat-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export class ChatSession {
"Missing role for history item: " + JSON.stringify(content),
);
}
return formatNewContent(content.parts, content.role);
return formatNewContent(content.parts);
});
}
}
Expand All @@ -81,10 +81,11 @@ export class ChatSession {
request: string | Array<string | Part>,
): Promise<GenerateContentResult> {
await this._sendPromise;
const newContent = formatNewContent(request, "user");
const newContent = formatNewContent(request);
const generateContentRequest: GenerateContentRequest = {
safetySettings: this.params?.safetySettings,
generationConfig: this.params?.generationConfig,
tools: this.params?.tools,
contents: [...this._history, newContent],
};
let finalResult;
Expand Down Expand Up @@ -134,10 +135,11 @@ export class ChatSession {
request: string | Array<string | Part>,
): Promise<GenerateContentStreamResult> {
await this._sendPromise;
const newContent = formatNewContent(request, "user");
const newContent = formatNewContent(request);
const generateContentRequest: GenerateContentRequest = {
safetySettings: this.params?.safetySettings,
generationConfig: this.params?.generationConfig,
tools: this.params?.tools,
contents: [...this._history, newContent],
};
const streamPromise = generateContentStream(
Expand Down
53 changes: 49 additions & 4 deletions packages/main/src/requests/request-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import {
EmbedContentRequest,
GenerateContentRequest,
Part,
Role,
} from "../../types";
import {GoogleGenerativeAIError} from "../errors";

export function formatNewContent(
request: string | Array<string | Part>,
role: string,
): Content {
let newParts: Part[] = [];
if (typeof request === "string") {
Expand All @@ -38,7 +39,52 @@ export function formatNewContent(
}
}
}
return { role, parts: newParts };
// return { role, parts: newParts };
return assignRoleToPartsAndValidateSendMessageRequest(newParts);
}

/**
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are
* passed in a single Part array, we may need to assign different roles to each
* part. Currently only FunctionResponsePart requires a role other than 'user'.
* @private
* @param parts Array of parts to pass to the model
* @returns Array of content items
*/
function assignRoleToPartsAndValidateSendMessageRequest(
parts: Part[]
): Content {
const userContent: Content = {role: Role.USER, parts: []};
const functionContent: Content = {role: Role.FUNCTION, parts: []};
let hasUserContent = false;
let hasFunctionContent = false;
for (const part of parts) {
if ('functionResponse' in part) {
functionContent.parts.push(part);
hasFunctionContent = true;
} else {
userContent.parts.push(part);
hasUserContent = true;
}
}

if (hasUserContent && hasFunctionContent) {
throw new GoogleGenerativeAIError(
'Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'
);
}

if (!hasUserContent && !hasFunctionContent) {
throw new GoogleGenerativeAIError(
'Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'
);
}

if (hasUserContent) {
return userContent;
}

return functionContent;
}

export function formatGenerateContentInput(
Expand All @@ -49,7 +95,6 @@ export function formatGenerateContentInput(
} else {
const content = formatNewContent(
params as string | Array<string | Part>,
"user",
);
return { contents: [content] };
}
Expand All @@ -59,7 +104,7 @@ export function formatEmbedContentInput(
params: EmbedContentRequest | string | Array<string | Part>,
): EmbedContentRequest {
if (typeof params === "string" || Array.isArray(params)) {
const content = formatNewContent(params, "user");
const content = formatNewContent(params);
return { content };
}
return params;
Expand Down
39 changes: 33 additions & 6 deletions packages/main/src/requests/response-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,34 @@ import {

use(sinonChai);

const fakeResponse: GenerateContentResponse = {
const fakeResponseText: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: "model",
parts: [{ text: "Some text" }],
parts: [{ text: "Some text" }, { text: " and some more text" }],
},
},
],
};
const fakeResponseFunctionCall: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: "model",
parts: [
{
functionCall: {
name: "find_theaters",
args: {
location: "Mountain View, CA",
movie: "Barbie",
},
},
},
],
},
},
],
Expand All @@ -52,11 +73,17 @@ describe("response-helpers methods", () => {
restore();
});
describe("addHelpers", () => {
it("good response", async () => {
const enhancedResponse = addHelpers(fakeResponse);
expect(enhancedResponse.text()).to.equal("Some text");
it("good response text", async () => {
const enhancedResponse = addHelpers(fakeResponseText);
expect(enhancedResponse.text()).to.equal("Some text and some more text");
});
it("good response functionCall", async () => {
const enhancedResponse = addHelpers(fakeResponseFunctionCall);
expect(enhancedResponse.functionCall()).to.deep.equal(
fakeResponseFunctionCall.candidates[0].content.parts[0].functionCall,
);
});
it("bad response", async () => {
it("bad response safety", async () => {
const enhancedResponse = addHelpers(badFakeResponse);
expect(enhancedResponse.text).to.throw("SAFETY");
});
Expand Down
34 changes: 33 additions & 1 deletion packages/main/src/requests/response-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import {
EnhancedGenerateContentResponse,
FinishReason,
FunctionCall,
GenerateContentCandidate,
GenerateContentResponse,
} from "../../types";
Expand Down Expand Up @@ -54,6 +55,30 @@ export function addHelpers(
}
return "";
};
(response as EnhancedGenerateContentResponse).functionCall = () => {
if (response.candidates && response.candidates.length > 0) {
if (response.candidates.length > 1) {
console.warn(
`This response had ${response.candidates.length} ` +
`candidates. Returning text from the first candidate only. ` +
`Access response.candidates directly to use the other candidates.`,
);
}
if (hadBadFinishReason(response.candidates[0])) {
throw new GoogleGenerativeAIResponseError<GenerateContentResponse>(
`${formatBlockErrorMessage(response)}`,
response,
);
}
return getFunctionCall(response);
} else if (response.promptFeedback) {
throw new GoogleGenerativeAIResponseError<GenerateContentResponse>(
`Text not available. ${formatBlockErrorMessage(response)}`,
response,
);
}
return undefined;
};
return response as EnhancedGenerateContentResponse;
}

Expand All @@ -62,12 +87,19 @@ export function addHelpers(
*/
export function getText(response: GenerateContentResponse): string {
if (response.candidates?.[0].content?.parts?.[0]?.text) {
return response.candidates[0].content.parts[0].text;
return response.candidates[0].content.parts.map((({text})=>text)).join("");
} else {
return "";
}
}

/**
* Returns functionCall of first candidate.
*/
export function getFunctionCall(response: GenerateContentResponse): FunctionCall {
return response.candidates?.[0].content?.parts?.[0]?.functionCall;
}

const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY];

function hadBadFinishReason(candidate: GenerateContentCandidate): boolean {
Expand Down
25 changes: 22 additions & 3 deletions packages/main/src/requests/stream-reader.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ describe("processStream", () => {
expect(aggregatedResponse.text()).to.include("秋风瑟瑟,叶落纷纷");
expect(aggregatedResponse.text()).to.include("家人围坐在一起");
});
it("streaming response - functioncall", async () => {
const fakeResponse = getMockResponseStreaming(
"streaming-success-function-call-short.txt",
);
const result = processStream(fakeResponse as Response);
for await (const response of result.stream) {
expect(response.text()).to.be.empty;
expect(response.functionCall()).to.be.deep.equal({
name: "getTemperature",
args: { city: "San Jose" },
});
}
const aggregatedResponse = await result.response;
expect(aggregatedResponse.text()).to.be.empty;
expect(aggregatedResponse.functionCall()).to.be.deep.equal({
name: "getTemperature",
args: { city: "San Jose" },
});
});
it("candidate had finishReason", async () => {
const fakeResponse = getMockResponseStreaming(
"streaming-failure-finish-reason-safety.txt",
Expand Down Expand Up @@ -335,9 +354,9 @@ describe("aggregateResponses", () => {

it("aggregates text across responses", () => {
expect(response.candidates.length).to.equal(1);
expect(response.candidates[0].content.parts[0].text).to.equal(
"hello.angry stuff...more stuff",
);
expect(
response.candidates[0].content.parts.map(({ text }) => text),
).to.deep.equal(["hello.", "angry stuff", "...more stuff"]);
});

it("takes the last response's promptFeedback", () => {
Expand Down
15 changes: 12 additions & 3 deletions packages/main/src/requests/stream-reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
GenerateContentCandidate,
GenerateContentResponse,
GenerateContentStreamResult,
Part,
} from "../../types";
import { GoogleGenerativeAIError } from "../errors";
import { addHelpers } from "./response-helpers";
Expand Down Expand Up @@ -166,14 +167,22 @@ export function aggregateResponses(
if (!aggregatedResponse.candidates[i].content) {
aggregatedResponse.candidates[i].content = {
role: candidate.content.role || "user",
parts: [{ text: "" }],
parts: [],
};
}
const newPart: Partial<Part> = {};
for (const part of candidate.content.parts) {
if (part.text) {
aggregatedResponse.candidates[i].content.parts[0].text +=
part.text;
newPart.text = part.text;
}
if (part.functionCall) {
newPart.functionCall =
part.functionCall;
}
if (Object.keys(newPart).length === 0) {
newPart.text = "";
}
aggregatedResponse.candidates[i].content.parts.push(newPart as Part);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use(chaiAsPromised);
* Integration tests against live backend.
*/

describe("generateContent", function () {
describe("generateContent - multimodal", function () {
this.timeout(60e3);
this.slow(10e3);
it("non-streaming, image buffer provided", async () => {
Expand Down
Loading

0 comments on commit 7343708

Please sign in to comment.