Skip to content

Commit

Permalink
Add support outlines
Browse files Browse the repository at this point in the history
Signed-off-by: GitHub <noreply@github.com>
  • Loading branch information
Aisuko committed Nov 15, 2023
1 parent 991ecce commit 0423e4c
Show file tree
Hide file tree
Showing 9 changed files with 649 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
"LIBRARY_PATH": "${workspaceFolder}/go-llama:${workspaceFolder}/go-stable-diffusion/:${workspaceFolder}/gpt4all/gpt4all-bindings/golang/:${workspaceFolder}/go-gpt2:${workspaceFolder}/go-rwkv:${workspaceFolder}/whisper.cpp:${workspaceFolder}/go-bert:${workspaceFolder}/bloomz",
"DEBUG": "true"
}
},
{
"name":"Launch outlines",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/backend/python/backend_outlines/backend_outlines.py",
"console": "integratedTerminal",
"justMyCode": true,
"env": {}
}
]
}
11 changes: 11 additions & 0 deletions backend/python/backend_outlines/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.PONY: outlines
outlines:
@echo "Creating virtual environment..."
@conda env create --name outlines --file outlines.yml
@echo "Virtual environment created."

.PONY: run
run:
@echo "Running outlines..."
bash run.sh
@echo "outlines run."
5 changes: 5 additions & 0 deletions backend/python/backend_outlines/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Creating a separate environment for the outlines project

```
make outlines
```
80 changes: 80 additions & 0 deletions backend/python/backend_outlines/backend_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This is the extra gRPC server for outlines of LocalAI
"""
from concurrent import futures
import argparse
import os
import signal
import sys
import time

import backend_pb2
import backend_pb2_grpc

import grpc

import outlines.text.generate as generate
import outlines.models as models

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
BackendServicer is the class that implements the gRPC service
"""
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
try:
# model should be name of the model, e.g. gpt2
if request.Model == "":
return backend_pb2.Result(success=False, message="Model name is empty")
# It includes cache of the model, we do not need to add cache here.
self.model = models.transformers(request.Model)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)

def Predict(self, request, context):
try:
output=generate.continuation(self.model, stop=[str(request.StopPrompts)])(str(request.Prompt))
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message=bytes(output, encoding='utf-8'))

def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)

# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()

serve(args.addr)
61 changes: 61 additions & 0 deletions backend/python/backend_outlines/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0423e4c

Please sign in to comment.