Skip to content

Commit

Permalink
Code health update on model server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Oct 11, 2024
1 parent 2488aa7 commit 9baac29
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions lit_nlp/examples/gcp/model_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.examples.gcp import model_server
from lit_nlp.examples.prompt_debugging import utils as pd_utils
import webtest


class TestWSGIApp(absltest.TestCase):
class TestWSGIApp(parameterized.TestCase):

@mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
def test_predict_endpoint(self, mock_get_models):
@classmethod
def setUpClass(cls):
test_model_name = 'lit_on_gcp_test_model'
sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)
test_model_config = f'{test_model_name}:test_model_path'
os.environ['MODEL_CONFIG'] = test_model_config

mock_model = mock.MagicMock()
mock_model.predict.side_effect = [[{'response': 'test output text'}]]
generation_model = mock.MagicMock()
generation_model.predict.side_effect = [[{'response': 'test output text'}]]

salience_model = mock.MagicMock()
salience_model.predict.side_effect = [[{
Expand All @@ -30,33 +32,42 @@ def test_predict_endpoint(self, mock_get_models):
[{'tokens': ['test', 'output', 'text']}]
]

sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)

mock_get_models.return_value = {
test_model_name: mock_model,
cls.mock_models = {
test_model_name: generation_model,
sal_name: salience_model,
tok_name: tokenize_model,
}
app = webtest.TestApp(model_server.get_wsgi_app())

response = app.post_json('/predict', {'inputs': 'test_input'})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, [{'response': 'test output text'}])

response = app.post_json('/salience', {'inputs': 'test_input'})
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json,
[{
@parameterized.named_parameters(
dict(
testcase_name='predict',
endpoint='/predict',
expected=[{'response': 'test output text'}],
),
dict(
testcase_name='salience',
endpoint='/salience',
expected=[{
'tokens': ['test', 'output', 'text'],
'grad_l2': [0.1234, 0.3456, 0.5678],
'grad_dot_input': [0.1234, -0.3456, 0.5678],
}],
)
),
dict(
testcase_name='tokenize',
endpoint='/tokenize',
expected=[{'tokens': ['test', 'output', 'text']}],
),
)
@mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
def test_endpoint(self, mock_get_models, endpoint, expected):
mock_get_models.return_value = self.mock_models
app = webtest.TestApp(model_server.get_wsgi_app())

response = app.post_json('/tokenize', {'inputs': 'test_input'})
response = app.post_json(endpoint, {'inputs': [{'prompt': 'test input'}]})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, [{'tokens': ['test', 'output', 'text']}])
self.assertEqual(response.json, expected)


if __name__ == '__main__':
Expand Down

0 comments on commit 9baac29

Please sign in to comment.