-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_context.py
75 lines (56 loc) · 1.97 KB
/
gen_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging
import os
import sys
from dotenv import load_dotenv
from llm import LLM
from openai import OpenAI
from redis import Redis
# isort: off
# Local imports
from de_wiki_context import get_context_ids, load_corpus
from data.read_data import questions
# Imports from the main project
load_dotenv()
sys.path.append(os.environ["MAIN_PROJECT"])
from access_redis import ( # noqa: E402
create_redis_client,
get_redis_context_ids,
put_redis_context_ids,
)
# isort: on
MAX_QUESTIONS = 25
if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.WARNING)
api_key = os.environ["PULZE_API_KEY"]
client = OpenAI(api_key=api_key, base_url="https://api.pulze.ai/v1")
corpus = load_corpus()
llm = LLM(client, "pulze", 1000)
redis_: Redis = create_redis_client()
existed = 0
saved = 0
saved_chunks = 0
errored = 0
try:
for _, question in zip(range(MAX_QUESTIONS), questions()):
print()
print(question)
chunk_ids = get_redis_context_ids(redis_, question.question_id)
if chunk_ids:
print("Already there, skipping getting context")
existed += 1
continue
chunk_ids, _ = get_context_ids(question.phrase, corpus, llm)
if not chunk_ids:
print("Did not get context")
errored += 1
continue
chunk_texts = [corpus.data[cid]["text"] for cid in chunk_ids]
put_redis_context_ids(redis_, question.question_id, chunk_ids, chunk_texts)
for cid in chunk_ids:
print(corpus.format_chunk(cid))
saved += 1
saved_chunks += len(chunk_ids)
finally:
print(f"Saved {saved_chunks} context chunks for {saved} questions")
print(f"There existed {existed} questions with context already provided")
print(f"Weren't able to retrieve context for {errored} questions")