Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support openai embedding for topic clustering #2729

Merged
merged 4 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fastchat/serve/monitor/summarize_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import argparse
import pickle

import pandas as pd

from fastchat.llm_judge.common import (
chat_compeletion_openai,
chat_compeletion_openai_azure,
Expand Down Expand Up @@ -74,3 +76,10 @@ def truncate_string(s, l):
print()
print(f"topics: {topics}")
print(f"percentages: {percentages}")

# save the informations
df = pd.DataFrame()
df["topic"] = topics
df["percentage"] = percentages

df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records")
45 changes: 35 additions & 10 deletions fastchat/serve/monitor/topic_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.cluster import KMeans, AgglomerativeClustering
import torch
from tqdm import tqdm
from openai import OpenAI

from fastchat.utils import detect_language

Expand Down Expand Up @@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only):
line_texts = [
x["content"] for x in l["conversation"] if x["role"] == "user"
]
elif "turns" in l:
line_texts = l["turns"]
Comment on lines +50 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry could you explain this a bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, input json file has different format. For examples, the 54K prompt json file you sent me uses "turns" to store the conversation prompts. I added support for json files that uses "turns".


for text in line_texts:
text = text.strip()
Expand Down Expand Up @@ -77,14 +80,26 @@ def read_texts(input_file, min_length, max_length, english_only):


def get_embeddings(texts, model_name, batch_size):
model = SentenceTransformer(model_name)
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
device="cuda",
convert_to_tensor=True,
)
if model_name == "text-embedding-ada-002":
client = OpenAI()
texts = texts.tolist()

embeddings = []
for i in tqdm(range(0, len(texts), batch_size)):
text = texts[i : i + batch_size]
responses = client.embeddings.create(input=text, model=model_name).data
embeddings.extend([data.embedding for data in responses])
embeddings = torch.tensor(embeddings)
else:
model = SentenceTransformer(model_name)
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
device="cuda",
convert_to_tensor=True,
)

embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.cpu()

Expand Down Expand Up @@ -218,6 +233,8 @@ def get_cluster_info(texts, labels, topk_indices):
)
parser.add_argument("--show-top-k", type=int, default=200)
parser.add_argument("--show-cut-off", type=int, default=512)
parser.add_argument("--save-embeddings", action="store_true")
parser.add_argument("--embeddings-file", type=str, default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this embedding cache is awesome!

args = parser.parse_args()

num_clusters = args.num_clusters
Expand All @@ -229,7 +246,15 @@ def get_cluster_info(texts, labels, topk_indices):
)
print(f"#text: {len(texts)}")

embeddings = get_embeddings(texts, args.model, args.batch_size)
if args.embeddings_file is None:
embeddings = get_embeddings(texts, args.model, args.batch_size)
if args.save_embeddings:
# allow saving embedding to save time and money
torch.save(embeddings, "embeddings.pt")
else:
embeddings = torch.load(args.embeddings_file)
print(f"embeddings shape: {embeddings.shape}")

if args.cluster_alg == "kmeans":
centers, labels = run_k_means(embeddings, num_clusters)
elif args.cluster_alg == "aggcls":
Expand All @@ -249,7 +274,7 @@ def get_cluster_info(texts, labels, topk_indices):
with open(filename_prefix + "_topk.txt", "w") as fout:
fout.write(topk_str)

with open(filename_prefix + "_all.txt", "w") as fout:
with open(filename_prefix + "_all.jsonl", "w") as fout:
for i in range(len(centers)):
tmp_indices = labels == i
tmp_embeddings = embeddings[tmp_indices]
Expand Down