Skip to content

Commit

Permalink
Anthropic bugfix - upgrading sdk, tokenizer usage (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 authored Oct 12, 2023
1 parent 374dc58 commit 86a1e56
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"numpy >= 1.23.0",
"requests >= 2.27.0",
"datasets >= 2.7.0",
"langchain == 0.0.210",
"langchain == 0.0.226",
"nervaluate >= 0.1.8",
"pandas >= 1.3.0",
"scikit-learn >= 1.0.0",
Expand Down Expand Up @@ -66,7 +66,7 @@ openai = [
"tiktoken >= 0.3.3"
]
anthropic = [
"anthropic == 0.2.6"
"anthropic == 0.3.0"
]
huggingface = [
"transformers >= 4.25.0",
Expand All @@ -88,7 +88,7 @@ all = [
"pre-commit",
"openai >= 0.27.4",
"tiktoken >= 0.3.3",
"anthropic == 0.2.6",
"anthropic == 0.3.0",
"transformers >= 4.25.0",
"google-cloud-aiplatform>=1.25.0",
"google-search-results>=2.4.2",
Expand Down
8 changes: 4 additions & 4 deletions src/autolabel/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:

try:
from langchain.chat_models import ChatAnthropic
from anthropic import tokenizer
from anthropic._tokenizers import sync_get_tokenizer
except ImportError:
raise ImportError(
"anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[anthropic]'"
Expand All @@ -45,7 +45,7 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
# initialize LLM
self.llm = ChatAnthropic(model=self.model_name, **self.model_params)

self.tokenizer = tokenizer
self.tokenizer = sync_get_tokenizer()

def _label(self, prompts: List[str]) -> RefuelLLMResult:
prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
Expand All @@ -58,9 +58,9 @@ def _label(self, prompts: List[str]) -> RefuelLLMResult:
return self._label_individually(prompts)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
num_prompt_toks = self.tokenizer.count_tokens(prompt)
num_prompt_toks = len(self.tokenizer.encode(prompt).ids)
if label:
num_label_toks = self.tokenizer.count_tokens(label)
num_label_toks = len(self.tokenizer.encode(label).ids)
else:
# get an upper bound
num_label_toks = self.model_params["max_tokens_to_sample"]
Expand Down

0 comments on commit 86a1e56

Please sign in to comment.