-
-
Notifications
You must be signed in to change notification settings - Fork 10
/
llm_gemini.py
100 lines (92 loc) · 3.18 KB
/
llm_gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import base64
import httpx
import ijson
import llm
import urllib.parse
# We disable all of these to avoid random unexpected errors
SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
@llm.hookimpl
def register_models(register):
register(GeminiPro("gemini-pro"))
register(GeminiPro("gemini-1.5-pro-latest"))
class GeminiPro(llm.Model):
can_stream = True
supports_images = True
def __init__(self, model_id):
self.model_id = model_id
def build_messages(self, prompt, conversation):
messages = []
if conversation:
for response in conversation.responses:
messages.append(
{"role": "user", "parts": [{"text": response.prompt.prompt}]}
)
messages.append({"role": "model", "parts": [{"text": response.text()}]})
if prompt.images:
for image in prompt.images:
messages.append(
{
"role": "user",
"parts": [
{
"inlineData": {
"mimeType": "image/jpeg",
"data": base64.b64encode(image.read()).decode(
"utf-8"
),
}
}
],
}
)
messages.append({"role": "user", "parts": [{"text": prompt.prompt}]})
return messages
def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "gemini", "LLM_GEMINI_KEY")
url = "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?".format(
self.model_id
) + urllib.parse.urlencode(
{"key": key}
)
gathered = []
with httpx.stream(
"POST",
url,
timeout=None,
json={
"contents": self.build_messages(prompt, conversation),
"safetySettings": SAFETY_SETTINGS,
},
) as http_response:
events = ijson.sendable_list()
coro = ijson.items_coro(events, "item")
for chunk in http_response.iter_bytes():
coro.send(chunk)
if events:
event = events[0]
if isinstance(event, dict) and "error" in event:
raise llm.ModelError(event["error"]["message"])
try:
yield event["candidates"][0]["content"]["parts"][0]["text"]
except KeyError:
yield ""
gathered.append(event)
events.clear()
response.response_json = gathered