Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass tokenizer from LabelingAgent to llm #927

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3
- id: black
language_version: python3
14 changes: 7 additions & 7 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import io
import json
import logging
from tqdm import tqdm
import os
import pickle
from typing import Dict, List, Optional, Tuple, Union
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from rich.console import Console
from tqdm import tqdm
from transformers import AutoTokenizer

from autolabel.cache import (
Expand All @@ -19,24 +19,24 @@
SQLAlchemyGenerationCache,
SQLAlchemyTransformCache,
)
from autolabel.schema import TaskType
from autolabel.confidence import ConfidenceCalculator
from autolabel.configs import AutolabelConfig
from autolabel.dataset import AutolabelDataset
from autolabel.few_shot import (
PROVIDER_TO_MODEL,
DEFAULT_EMBEDDING_PROVIDER,
PROVIDER_TO_MODEL,
BaseExampleSelector,
BaseLabelSelector,
ExampleSelectorFactory,
LabelSelector,
BaseLabelSelector,
)
from autolabel.metrics import BaseMetric
from autolabel.models import BaseModel, ModelFactory
from autolabel.schema import (
AggregationFunction,
LLMAnnotation,
MetricResult,
TaskType,
)
from autolabel.tasks import TaskFactory
from autolabel.transforms import BaseTransform, TransformFactory
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
self.confidence_cache = confidence_cache
if not cache:
logger.warning(
f"cache parameter is deprecated and will be removed soon. Please use generation_cache, transform_cache and confidence_cache instead."
"cache parameter is deprecated and will be removed soon. Please use generation_cache, transform_cache and confidence_cache instead."
)
self.generation_cache = None
self.transform_cache = None
Expand All @@ -116,7 +116,7 @@ def __init__(
)
self.task = TaskFactory.from_config(self.config)
self.llm: BaseModel = ModelFactory.from_config(
self.config, cache=self.generation_cache
self.config, cache=self.generation_cache, tokenizer=confidence_tokenizer
)

if self.config.confidence_chunk_column():
Expand Down
22 changes: 15 additions & 7 deletions src/autolabel/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import logging
from .base import BaseModel

from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.schema import ModelProvider
from autolabel.cache import BaseCache

from .base import BaseModel

logger = logging.getLogger(__name__)

from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.anthropic import AnthropicLLM
from autolabel.models.cohere import CohereLLM
from autolabel.models.google import GoogleLLM
from autolabel.models.mistral import MistralLLM
from autolabel.models.hf_pipeline import HFPipelineLLM
from autolabel.models.hf_pipeline_vision import HFPipelineMultimodal
from autolabel.models.mistral import MistralLLM
from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.refuelV2 import RefuelLLMV2
from autolabel.models.vllm import VLLMModel

Expand All @@ -40,7 +44,11 @@ class ModelFactory:
"""The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration."""

@staticmethod
def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
def from_config(
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: AutoTokenizer = None,
) -> BaseModel:
"""
Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.
Args:
Expand All @@ -52,7 +60,7 @@ def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
provider = ModelProvider(config.provider())
try:
model_cls = MODEL_REGISTRY[provider]
model_obj = model_cls(config=config, cache=cache)
model_obj = model_cls(config=config, cache=cache, tokenizer=tokenizer)
# The below ensures that users should based off of the BaseModel
# when creating/registering custom models.
assert isinstance(
Expand Down
14 changes: 10 additions & 4 deletions src/autolabel/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from time import time
from typing import Dict, List, Optional

from langchain.schema import HumanMessage, Generation
from langchain.schema import Generation, HumanMessage
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import ErrorType, RefuelLLMResult, LabelingError
from autolabel.schema import ErrorType, LabelingError, RefuelLLMResult


class AnthropicLLM(BaseModel):
Expand Down Expand Up @@ -36,8 +37,13 @@ class AnthropicLLM(BaseModel):
"claude-3-5-sonnet-20240620": (15 / 1_000_000),
}

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)

try:
from anthropic._tokenizers import sync_get_tokenizer
Expand Down
8 changes: 5 additions & 3 deletions src/autolabel/models/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Base interface that all model providers will implement."""

from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
Expand All @@ -18,9 +17,12 @@ class BaseModel(ABC):
TTL_MS = 60 * 60 * 24 * 7 * 1000 # 1 week
DEFAULT_CONTEXT_LENGTH = None

def __init__(self, config: AutolabelConfig, cache: BaseCache) -> None:
def __init__(
self, config: AutolabelConfig, cache: BaseCache, tokenizer: AutoTokenizer
) -> None:
self.config = config
self.cache = cache
self.tokenizer = tokenizer
self.model_params = config.model_params()
self.max_context_length = config.max_context_length(
default=self.DEFAULT_CONTEXT_LENGTH
Expand Down
10 changes: 8 additions & 2 deletions src/autolabel/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
Expand All @@ -21,8 +22,13 @@ class CohereLLM(BaseModel):
# Reference: https://cohere.com/pricing
COST_PER_TOKEN = 15 / 1_000_000

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)
try:
import cohere
from langchain_community.llms import Cohere
Expand Down
8 changes: 5 additions & 3 deletions src/autolabel/models/google.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import logging
import os
from time import time
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
Expand Down Expand Up @@ -34,13 +35,14 @@ def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these models dont use the tokenizer at all

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they don't. But we need as its part of the interface

try:
import tiktoken
from langchain_google_vertexai import (
VertexAI,
HarmBlockThreshold,
HarmCategory,
VertexAI,
)
from vertexai import generative_models
except ImportError:
Expand All @@ -59,7 +61,7 @@ def __init__(
},
}

super().__init__(config, cache)
super().__init__(config, cache, tokenizer)

if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is None:
raise ValueError(
Expand Down
20 changes: 13 additions & 7 deletions src/autolabel/models/hf_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
import logging
from typing import List, Optional, Dict
from time import time
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import ErrorType, LabelingError, RefuelLLMResult


logger = logging.getLogger(__name__)


class HFPipelineLLM(BaseModel):
DEFAULT_MODEL = "google/flan-t5-xxl"
DEFAULT_PARAMS = {"temperature": 0.0, "quantize": 8}

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)

from langchain.llms import HuggingFacePipeline

try:
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
)
Expand Down
25 changes: 16 additions & 9 deletions src/autolabel/models/hf_pipeline_vision.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import json
from typing import List, Optional, Dict
from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.schema import RefuelLLMResult, Generation
import logging
from typing import Dict, List, Optional

from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import Generation, RefuelLLMResult

logger = logging.getLogger(__name__)

Expand All @@ -14,13 +16,18 @@ class HFPipelineMultimodal(BaseModel):
DEFAULT_MODEL = "HuggingFaceM4/idefics-9b-instruct"
DEFAULT_PARAMS = {"temperature": 0.0, "quantize": 8}

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)
try:
from transformers import (
AutoConfig,
AutoProcessor,
AutoModelForPreTraining,
AutoProcessor,
pipeline,
)
except ImportError:
Expand Down
31 changes: 17 additions & 14 deletions src/autolabel/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import asyncio
import logging
import os
import requests
from time import time
from typing import Dict, List, Optional, Tuple
import httpx

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import LabelingError, ErrorType, RefuelLLMResult
import json
import logging
from transformers import AutoTokenizer

import httpx
import requests
from langchain.schema import Generation
from tenacity import (
before_sleep_log,
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
retry_if_not_exception_type,
)
from transformers import AutoTokenizer

from langchain.schema import Generation
from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import ErrorType, LabelingError, RefuelLLMResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,8 +53,13 @@ class MistralLLM(BaseModel):
"mistral-large-latest": (24 / 1_000_000),
}

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)

if os.getenv("MISTRAL_API_KEY") is None:
raise ValueError("MISTRAL_API_KEY environment variable not set")
Expand Down
Loading
Loading