Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an API example using server.cpp similar to OAI. #2009

Merged
merged 11 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,19 @@ Run with bash:
```sh
bash chat.sh
```

### API like OAI

API example using Python Flask: [api_like_OAI.py](api_like_OAI.py)
This example must be used with server.cpp

```sh
python api_like_OAI.py
```

After running the API server, you can use it in Python by setting the API base URL.
```python
openai.api_base = "http://<Your api-server IP>:port"
```

Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API
214 changes: 214 additions & 0 deletions examples/server/api_like_OAI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import argparse
from flask import Flask, jsonify, request, Response
import urllib.parse
import requests
import time
import json


app = Flask(__name__)

parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1')
parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081)

args = parser.parse_args()

def is_present(json, key):
try:
buf = json[key]
except KeyError:
return False
return True



#convert chat to prompt
def convert_chat(messages):
prompt = "" + args.chat_prompt.replace("\\n", "\n")

system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = args.stop.replace("\\n", "\n")


for line in messages:
if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()

return prompt

def make_postData(body, chat=False, stream=False):
postData = {}
if (chat):
postData["prompt"] = convert_chat(body["messages"])
else:
postData["prompt"] = body["prompt"]
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"]
if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"]
if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"]
if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"]
if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"]
if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"]
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
if(is_present(body, "seed")): postData["seed"] = body["seed"]
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
postData["stop"] = [args.stop]
if(is_present(body, "stop")): postData["stop"] += body["stop"]
postData["n_keep"] = -1
postData["stream"] = stream

return postData

def make_resData(data, chat=False, promptToken=[]):
resData = {
"id": "chatcmpl" if (chat) else "cmpl",
"object": "chat.completion" if (chat) else "text_completion",
"created": int(time.time()),
"model": "LLaMA_CPP",
"usage": {
"prompt_tokens": data["tokens_evaluated"],
"completion_tokens": data["tokens_predicted"],
"total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
}
}
if (len(promptToken) != 0):
resData["promptToken"] = promptToken
if (chat):
#only one choice is supported
resData["choices"] = [{
"index": 0,
"message": {
"role": "assistant",
"content": data["content"],
},
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
}]
else:
#only one choice is supported
resData["choices"] = [{
"text": data["content"],
"index": 0,
"logprobs": None,
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
}]
return resData

def make_resData_stream(data, chat=False, time_now = 0, start=False):
resData = {
"id": "chatcmpl" if (chat) else "cmpl",
"object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
"created": time_now,
"model": "LLaMA_CPP",
"choices": [
{
"finish_reason": None,
"index": 0
}
]
}
if (chat):
if (start):
resData["choices"][0]["delta"] = {
"role": "assistant"
}
else:
resData["choices"][0]["delta"] = {
"content": data["content"]
}
if (data["stop"]):
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
else:
resData["choices"][0]["text"] = data["content"]
if (data["stop"]):
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"

return resData


@app.route('/chat/completions', methods=['POST'])
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403)
body = request.get_json()
stream = False
tokenize = False
if(is_present(body, "stream")): stream = body["stream"]
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
postData = make_postData(body, chat=True, stream=stream)

promptToken = []
if (tokenize):
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
promptToken = tokenData["tokens"]

if (not stream):
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
resData = make_resData(data.json(), chat=True, promptToken=promptToken)
return jsonify(resData)
else:
def generate():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
yield 'data: {}\n'.format(json.dumps(resData))
for line in data.iter_lines():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')


@app.route('/completions', methods=['POST'])
@app.route('/v1/completions', methods=['POST'])
def completion():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403)
body = request.get_json()
stream = False
tokenize = False
if(is_present(body, "stream")): stream = body["stream"]
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
postData = make_postData(body, chat=False, stream=stream)

promptToken = []
if (tokenize):
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
promptToken = tokenData["tokens"]

if (not stream):
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
resData = make_resData(data.json(), chat=False, promptToken=promptToken)
return jsonify(resData)
else:
def generate():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time())
for line in data.iter_lines():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')


if __name__ == '__main__':
app.run(args.host, port=args.port)
14 changes: 9 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct llama_server_context {
bool has_next_token = false;
std::string generated_text;

size_t num_prompt_tokens = 0;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
size_t n_remain = 0;
Expand Down Expand Up @@ -139,6 +140,7 @@ struct llama_server_context {

void rewind() {
params.antiprompt.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
Expand Down Expand Up @@ -169,17 +171,18 @@ struct llama_server_context {
void loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
num_prompt_tokens = prompt_tokens.size();

if (params.n_keep < 0) {
params.n_keep = (int)prompt_tokens.size();
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);

// if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
if (num_prompt_tokens>= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());

Expand All @@ -193,15 +196,15 @@ struct llama_server_context {
truncated = true;
prompt_tokens = new_tokens;
} else {
const size_t ps = prompt_tokens.size();
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}

// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == prompt_tokens.size()) {
if (n_past == num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
Expand Down Expand Up @@ -684,6 +687,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "tokens_evaluated", llama.num_prompt_tokens },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "truncated", llama.truncated },
Expand Down