Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gagik committed Dec 4, 2024
1 parent d0d59bc commit cd06aa0
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 32 deletions.
5 changes: 3 additions & 2 deletions src/participant/prompts/promptHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,13 @@ export class PromptHistory {
});
}
if (addedMessage) {
if (model && tokenLimit) {
totalUsedTokens += await model.countTokens(addedMessage);
if (tokenLimit) {
totalUsedTokens += (await model?.countTokens(addedMessage)) || 0;
if (totalUsedTokens > tokenLimit) {
break;
}
}

messages.push(addedMessage);
}
}
Expand Down
100 changes: 70 additions & 30 deletions src/test/suite/participant/participant.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ suite('Participant Controller Test Suite', function () {
button: sinon.SinonSpy;
};
let chatTokenStub;
let countTokensStub;
let countTokensStub: sinon.SinonStub;
let sendRequestStub: sinon.SinonStub;
let getCopilotModelStub: SinonStub<
[],
Promise<vscode.LanguageModelChat | undefined>
>;
let telemetryTrackStub: SinonSpy;

const invokeChatHandler = async (
Expand Down Expand Up @@ -231,18 +235,18 @@ suite('Participant Controller Test Suite', function () {
countTokensStub = sinon.stub();
// The model returned by vscode.lm.selectChatModels is always undefined in tests.
sendRequestStub = sinon.stub();
sinon.replace(model, 'getCopilotModel', () =>
Promise.resolve({
id: 'modelId',
vendor: 'copilot',
family: 'gpt-4o',
version: 'gpt-4o-date',
name: 'GPT 4o (date)',
maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK,
countTokens: countTokensStub,
sendRequest: sendRequestStub,
})
);
getCopilotModelStub = sinon.stub(model, 'getCopilotModel');

getCopilotModelStub.resolves({
id: 'modelId',
vendor: 'copilot',
family: 'gpt-4o',
version: 'gpt-4o-date',
name: 'GPT 4o (date)',
maxInputTokens: MAX_TOTAL_PROMPT_LENGTH_MOCK,
countTokens: countTokensStub,
sendRequest: sendRequestStub,
});

sinon.replace(testTelemetryService, 'track', telemetryTrackStub);
});
Expand Down Expand Up @@ -829,8 +833,7 @@ suite('Participant Controller Test Suite', function () {
});

test('includes 1 sample document as an object', async function () {
countTokensStub.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK);
sampleStub.resolves([
const sampleDocs = [
{
_id: new ObjectId('63ed1d522d8573fa5c203660'),
field: {
Expand All @@ -849,13 +852,27 @@ suite('Participant Controller Test Suite', function () {
],
},
},
]);
];

// This is the offset of the history token calculation calls
const callsOffset = 5;

// Called when including sample documents
countTokensStub
.onCall(callsOffset)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK);

sampleStub.resolves(sampleDocs);

const chatRequestMock = {
prompt: 'find all docs by a name example',
command: 'query',
references: [],
};
await invokeChatHandler(chatRequestMock);

expect(countTokensStub).callCount(callsOffset + 1);

const messages = sendRequestStub.secondCall
.args[0] as vscode.LanguageModelChatMessage[];
expect(getMessageContent(messages[1])).to.include(
Expand Down Expand Up @@ -890,11 +907,7 @@ suite('Participant Controller Test Suite', function () {
});

test('includes 1 sample documents when 3 make prompt too long', async function () {
countTokensStub
.onCall(0)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1);
countTokensStub.onCall(1).resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK);
sampleStub.resolves([
const sampleDocs = [
{
_id: new ObjectId('63ed1d522d8573fa5c203661'),
field: {
Expand All @@ -913,13 +926,30 @@ suite('Participant Controller Test Suite', function () {
stringField: 'Text 3',
},
},
]);
];

// This is the offset of the history token calculation calls
const callsOffset = 5;

// Called when including sample documents
countTokensStub
.onCall(callsOffset)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1);
countTokensStub
.onCall(callsOffset + 1)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK);

sampleStub.resolves(sampleDocs);

const chatRequestMock = {
prompt: 'find all docs by a name example',
command: 'query',
references: [],
};
await invokeChatHandler(chatRequestMock);

expect(countTokensStub).callCount(callsOffset + 1);

const messages = sendRequestStub.secondCall
.args[0] as vscode.LanguageModelChatMessage[];
expect(getMessageContent(messages[1])).to.include(
Expand Down Expand Up @@ -949,13 +979,18 @@ suite('Participant Controller Test Suite', function () {
});

test('does not include sample documents when even 1 makes prompt too long', async function () {
// This is the offset of the history token calculation calls
const callsOffset = 5;

// Called when including sample documents
countTokensStub
.onCall(0)
.onCall(callsOffset)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1);
countTokensStub
.onCall(1)
.onCall(callsOffset + 1)
.resolves(MAX_TOTAL_PROMPT_LENGTH_MOCK + 1);
sampleStub.resolves([

const sampleDocs = [
{
_id: new ObjectId('63ed1d522d8573fa5c203661'),
field: {
Expand All @@ -974,13 +1009,19 @@ suite('Participant Controller Test Suite', function () {
stringField: 'Text 3',
},
},
]);
];

sampleStub.resolves(sampleDocs);

const chatRequestMock = {
prompt: 'find all docs by a name example',
command: 'query',
references: [],
};
await invokeChatHandler(chatRequestMock);

expect(countTokensStub).callCount(callsOffset + 1);

const messages = sendRequestStub.secondCall
.args[0] as vscode.LanguageModelChatMessage[];
expect(getMessageContent(messages[1])).to.not.include(
Expand Down Expand Up @@ -2045,12 +2086,11 @@ Schema:
(_, index) => `Message ${index}`
);

sinon.stub(model, 'getCopilotModel').resolves({
getCopilotModelStub.resolves({
// Make each message count as 1 token for testing
countTokens: countTokensStub.resolves(1),
maxInputTokens: expectedMaxMessages,
// Make each message count as 1 token
countTokens: () => 1,
} as unknown as vscode.LanguageModelChat);

chatContextStub = {
history: mockedMessages.map((messageText) =>
createChatRequestTurn(undefined, messageText)
Expand Down

0 comments on commit cd06aa0

Please sign in to comment.