Skip to content

Commit

Permalink
Linting cleanup in the model server
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Oct 11, 2024
1 parent 9baac29 commit 60bdc7c
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions lit_nlp/examples/gcp/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import os
from typing import Optional

from absl import app
from lit_nlp import dev_server
from lit_nlp.examples.prompt_debugging import models as pd_models
Expand All @@ -24,7 +25,7 @@ def get_wsgi_app() -> wsgi_app.App:

def wrap_handler(predict_fn):
@functools.wraps(predict_fn)
def _handler(app, request, environ):
def _handler(app: wsgi_app.App, request, unused_environ):
data = serialize.from_json(request.data) if len(request.data) else None
inputs = data['inputs']
outputs = predict_fn(inputs)
Expand All @@ -33,15 +34,19 @@ def _handler(app, request, environ):

return _handler

model_config = os.getenv('MODEL_CONFIG', DEFAULT_MODELS).split(',')
dl_framework = os.environ.get('DL_FRAMEWORK', DEFAULT_DL_FRAMEWORK)
dl_runtime = os.environ.get('DL_RUNTIME', DEFAULT_DL_RUNTIME)
precision = os.environ.get('PRECISION', DEFAULT_PRECISION)
batch_size = os.environ.get('BATCH_SIZE', DEFAULT_BATCH_SIZE)
sequence_length = os.environ.get('SEQUENCE_LENGTH', DEFAULT_SEQUENCE_LENGTH)
if not (model_config := os.getenv('MODEL_CONFIG', DEFAULT_MODELS).split(',')):
raise ValueError('No model configuration was provided')
elif (num_configs := len(model_config)) > 1:
raise ValueError(
f'Only 1 model configuration can be provided, got {num_configs}'
)

dl_framework = os.getenv('DL_FRAMEWORK', DEFAULT_DL_FRAMEWORK)
dl_runtime = os.getenv('DL_RUNTIME', DEFAULT_DL_RUNTIME)
precision = os.getenv('PRECISION', DEFAULT_PRECISION)
batch_size = int(os.getenv('BATCH_SIZE', DEFAULT_BATCH_SIZE))
sequence_length = int(os.getenv('SEQUENCE_LENGTH', DEFAULT_SEQUENCE_LENGTH))

# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
models = pd_models.get_models(
models_config=model_config,
dl_framework=dl_framework,
Expand All @@ -51,19 +56,13 @@ def _handler(app, request, environ):
sequence_length=sequence_length,
)

if len(DEFAULT_MODELS) < 1:
raise ValueError('No models specified in DEFAULT_MODELS')
model_name = model_config[0].split(':')[0]
sal_name, tok_name = pd_utils.generate_model_group_names(model_name)

generation_model = models[model_name]
salience_model = models[sal_name]
tokenize_model = models[tok_name]
gen_name = model_config[0].split(':')[0]
sal_name, tok_name = pd_utils.generate_model_group_names(gen_name)

handlers = {
'/predict': generation_model.predict,
'/salience': salience_model.predict,
'/tokenize': tokenize_model.predict,
'/predict': models[gen_name].predict,
'/salience': models[sal_name].predict,
'/tokenize': models[tok_name].predict,
}

wrapped_handlers = {
Expand Down

0 comments on commit 60bdc7c

Please sign in to comment.