Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle committed Apr 5, 2024
1 parent 4f42e1c commit 1d9df0f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
2 changes: 1 addition & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from functools import partial
from pathlib import Path
from typing import List
from spice import SpiceMessage

import tqdm
from openai import BadRequestError
from spice import SpiceMessage

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand Down
28 changes: 18 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from uuid import uuid4

import pytest
from spice import SpiceResponse
from spice.spice import SpiceCallArgs

from mentat import config
from mentat.agent_handler import AgentHandler
Expand Down Expand Up @@ -96,20 +98,26 @@ def mock_call_llm_api(mocker):
completion_mock = mocker.patch.object(LlmApiHandler, "call_llm_api")

def wrap_unstreamed_string(value):
mock_spice_response = MagicMock()
mock_spice_response.text = value

return mock_spice_response
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True)

def wrap_streamed_strings(values):
async def _async_generator():
for value in values:
yield value
class MockStreamingSpiceResponse:
def __init__(self):
self.cur_value = 0

def __aiter__(self):
return self

async def __anext__(self):
if self.cur_value >= len(values):
raise StopAsyncIteration
self.cur_value += 1
return values[self.cur_value - 1]

mock_spice_response = MagicMock()
mock_spice_response.stream = _async_generator
mock_spice_response.text = "".join(values)
def current_response(self):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True)

mock_spice_response = MockStreamingSpiceResponse()
return mock_spice_response

def set_streamed_values(values):
Expand Down

0 comments on commit 1d9df0f

Please sign in to comment.