From 0896fca7a04fa4beee5c11fa0d9a4ffedd487342 Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 21 Jun 2024 00:01:53 -0700 Subject: [PATCH] Improve stats and cost warnings --- src/simulacrum.py | 11 ++--------- src/telegram/telegram_bot.py | 26 ++++++++++++++------------ 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/simulacrum.py b/src/simulacrum.py index 747ec4d..a5b0086 100644 --- a/src/simulacrum.py +++ b/src/simulacrum.py @@ -7,9 +7,7 @@ class Simulacrum: def __init__(self, context_file): self.context = Context(context_file) - self.last_cost = None - self.last_prompt_tokens = None - self.last_completion_tokens = None + self.last_completion = None self.warned_about_cost = False async def chat(self, user_input, user_name, image_url): @@ -19,7 +17,7 @@ async def chat(self, user_input, user_name, image_url): self.context.add_message("user", user_input, image_url) completion = await ChatExecutor(self.context).execute() content = completion.content.strip() - self._set_stats(completion) + self.last_completion = completion self.context.add_message("assistant", content) self.context.save() speech = self._filter_hidden(content) @@ -68,11 +66,6 @@ def _filter_hidden(self, response): response = re.sub(r".*?", "", response, flags=re.DOTALL) return response.strip() - def _set_stats(self, completion): - self.last_cost = completion.cost - self.last_prompt_tokens = completion.prompt_tokens - self.last_completion_tokens = completion.completion_tokens - def _apply_attribution(self, input, name): attribute_messages = self.context.vars.get("attribute_messages") if attribute_messages and name: diff --git a/src/telegram/telegram_bot.py b/src/telegram/telegram_bot.py index b3acce7..6f538e9 100644 --- a/src/telegram/telegram_bot.py +++ b/src/telegram/telegram_bot.py @@ -87,10 +87,11 @@ async def stats_command_handler(self, ctx): lines.append("*Conversation*") lines.append(f"`Cost: ${self.sim.get_conversation_cost():.2f}`") lines.append("\n*Last Message*") - if self.sim.last_cost: - lines.append(f"`Cost: ${self.sim.last_cost:.2f}`") - lines.append(f"`Prompt tokens: {self.sim.last_prompt_tokens}`") - lines.append(f"`Completion tokens: {self.sim.last_completion_tokens}`") + last_completion = self.sim.last_completion + if last_completion: + lines.append(f"`Cost: ${last_completion.cost:.2f}`") + lines.append(f"`Prompt tokens: {last_completion.prompt_tokens}`") + lines.append(f"`Completion tokens: {last_completion.completion_tokens}`") else: lines.append("`Not available`") await ctx.send_message("\n".join(lines)) @@ -150,16 +151,17 @@ async def _chat(self, ctx, user_message, image_url=None): response = await self.sim.chat(user_message, ctx.user_name, image_url) response = response.translate(str.maketrans("*_", "_*")) await ctx.send_message(response) - await self._warn_high_cost(ctx) + await self._warn_cost(ctx) - async def _warn_high_cost(self, ctx): - if not self.sim.last_cost: + async def _warn_cost(self, ctx, threshold_high=0.15, threshold_elevated=0.10): + cost = self.sim.last_completion.cost + if not cost: return - if self.sim.last_cost > 0.20: - await ctx.send_message("`🔴 Cost is high. Start a new conversation soon.`") - elif self.sim.last_cost > 0.15 and not self.sim.warned_about_cost: - self.sim.warned_about_cost = True + if cost > threshold_high: + await ctx.send_message("🔴 Cost is high. Start a new conversation soon.") + elif cost > threshold_elevated and not self.sim.cost_warning_sent: + self.sim.cost_warning_sent = True await ctx.send_message( - "`🟡 Cost is elevated. Start a new conversation when ready.`" + "🟡 Cost is elevated. Start a new conversation when ready." )