-
Notifications
You must be signed in to change notification settings - Fork 15.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New LLM integration: Ctranslate2 (#10400)
## Description: I've integrated CTranslate2 with LangChain. CTranlate2 is a recently popular library for efficient inference with Transformer models that compares favorably to alternatives such as HF Text Generation Inference and vLLM in [benchmarks](https://hamel.dev/notes/llm/inference/03_inference.html).
- Loading branch information
Showing
3 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# CTranslate2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**CTranslate2** is a C++ and Python library for efficient inference with Transformer models.\n", | ||
"\n", | ||
"The project implements a custom runtime that applies many performance optimization techniques such as weights quantization, layers fusion, batch reordering, etc., to accelerate and reduce the memory usage of Transformer models on CPU and GPU.\n", | ||
"\n", | ||
"Full list of features and supported models is included in the [project's repository](https://opennmt.net/CTranslate2/guides/transformers.html). To start, please check out the official [quickstart guide](https://opennmt.net/CTranslate2/quickstart.html).\n", | ||
"\n", | ||
"To use, you should have `ctranslate2` python package installed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#!pip install ctranslate2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use a Hugging Face model with CTranslate2, it has to be first converted to CTranslate2 format using the `ct2-transformers-converter` command. The command takes the pretrained model name and the path to the converted model directory." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading checkpoint shards: 100%|██████████████████| 2/2 [00:01<00:00, 1.81it/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# converstion can take several minutes\n", | ||
"!ct2-transformers-converter --model meta-llama/Llama-2-7b-hf --quantization bfloat16 --output_dir ./llama-2-7b-ct2 --force" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.llms import CTranslate2\n", | ||
"\n", | ||
"llm = CTranslate2(\n", | ||
" # output_dir from above:\n", | ||
" model_path=\"./llama-2-7b-ct2\",\n", | ||
" tokenizer_name=\"meta-llama/Llama-2-7b-hf\",\n", | ||
" device=\"cuda\",\n", | ||
" # device_index can be either single int or list or ints,\n", | ||
" # indicating the ids of GPUs to use for inference:\n", | ||
" device_index=[0,1], \n", | ||
" compute_type=\"bfloat16\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Single call" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"He presented me with plausible evidence for the existence of unicorns: 1) they are mentioned in ancient texts; and, more importantly to him (and not so much as a matter that would convince most people), he had seen one.\n", | ||
"I was skeptical but I didn't want my friend upset by his belief being dismissed outright without any consideration or argument on its behalf whatsoever - which is why we were having this conversation at all! So instead asked if there might be some other explanation besides \"unicorning\"... maybe it could have been an ostrich? Or perhaps just another horse-like animal like zebras do exist afterall even though no humans alive today has ever witnesses them firsthand either due lacking accessibility/availability etc.. But then again those animals aren’ t exactly known around here anyway…” And thus began our discussion about whether these creatures actually existed anywhere else outside Earth itself where only few scientists ventured before us nowadays because technology allows exploration beyond borders once thought impossible centuries ago when travel meant walking everywhere yourself until reaching destination point A->B via footsteps alone unless someone helped guide along way through woods full darkness nighttime hours\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\n", | ||
" llm(\n", | ||
" \"He presented me with plausible evidence for the existence of unicorns: \",\n", | ||
" max_length=256,\n", | ||
" sampling_topk=50,\n", | ||
" sampling_temperature=0.2,\n", | ||
" repetition_penalty=2,\n", | ||
" cache_static_prompt=False,\n", | ||
" )\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Multiple calls:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"generations=[[Generation(text='The list of top romantic songs:\\n1. “I Will Always Love You” by Whitney Houston\\n2. “Can’t Help Falling in Love” by Elvis Presley\\n3. “Unchained Melody” by The Righteous Brothers\\n4. “I Will Always Love You” by Dolly Parton\\n5. “I Will Always Love You” by Whitney Houston\\n6. “I Will Always Love You” by Dolly Parton\\n7. “I Will Always Love You” by The Beatles\\n8. “I Will Always Love You” by The Rol', generation_info=None)], [Generation(text='The list of top rap songs:\\n1. “God’s Plan” by Drake\\n2. “Rockstar” by Post Malone\\n3. “Bad and Boujee” by Migos\\n4. “Humble” by Kendrick Lamar\\n5. “Bodak Yellow” by Cardi B\\n6. “I’m the One” by DJ Khaled\\n7. “Motorsport” by Migos\\n8. “No Limit” by G-Eazy\\n9. “Bounce Back” by Big Sean\\n10. “', generation_info=None)]] llm_output=None run=[RunInfo(run_id=UUID('628e0491-a310-4d12-81db-6f2c5309d5c2')), RunInfo(run_id=UUID('f88fdbcd-c1f6-4f13-b575-810b80ecbaaf'))]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\n", | ||
" llm.generate(\n", | ||
" [\"The list of top romantic songs:\\n1.\", \"The list of top rap songs:\\n1.\"],\n", | ||
" max_length=128\n", | ||
" )\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Integrate the model in an LLMChain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 46, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Who was the US president in the year the first Pokemon game was released?\n", | ||
"\n", | ||
"Let's think step by step. 1996 was the year the first Pokemon game was released.\n", | ||
"\n", | ||
"\\begin{blockquote}\n", | ||
"\n", | ||
"\\begin{itemize}\n", | ||
" \\item 1996 was the year Bill Clinton was president.\n", | ||
" \\item 1996 was the year the first Pokemon game was released.\n", | ||
" \\item 1996 was the year the first Pokemon game was released.\n", | ||
"\n", | ||
"\\end{itemize}\n", | ||
"\\end{blockquote}\n", | ||
"\n", | ||
"I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n", | ||
"Comment: @JoeZ. I'm not sure if this is a valid question, but I'm sure it's a fun one.\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from langchain import PromptTemplate, LLMChain\n", | ||
"\n", | ||
"template = \"\"\"{question}\n", | ||
"\n", | ||
"Let's think step by step. \"\"\"\n", | ||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", | ||
"\n", | ||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n", | ||
"\n", | ||
"question = \"Who was the US president in the year the first Pokemon game was released?\"\n", | ||
"\n", | ||
"print(llm_chain.run(question))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.10.12 ('langchain_venv': venv)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "d1d3a3c58a58885896c5459933a599607cdbb9917d7e1ad7516c8786c51f2dd2" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from langchain.callbacks.manager import CallbackManagerForLLMRun | ||
from langchain.llms.base import BaseLLM | ||
from langchain.pydantic_v1 import Field, root_validator | ||
from langchain.schema.output import Generation, LLMResult | ||
|
||
|
||
class CTranslate2(BaseLLM): | ||
"""CTranslate2 language model.""" | ||
|
||
model_path: str = "" | ||
"""Path to the CTranslate2 model directory.""" | ||
|
||
tokenizer_name: str = "" | ||
"""Name of the original Hugging Face model needed to load the proper tokenizer.""" | ||
|
||
device: str = "cpu" | ||
"""Device to use (possible values are: cpu, cuda, auto).""" | ||
|
||
device_index: Union[int, List[int]] = 0 | ||
"""Device IDs where to place this generator on.""" | ||
|
||
compute_type: Union[str, Dict[str, str]] = "default" | ||
""" | ||
Model computation type or a dictionary mapping a device name to the computation type | ||
(possible values are: default, auto, int8, int8_float32, int8_float16, | ||
int8_bfloat16, int16, float16, bfloat16, float32). | ||
""" | ||
|
||
max_length: int = 512 | ||
"""Maximum generation length.""" | ||
|
||
sampling_topk: int = 1 | ||
"""Randomly sample predictions from the top K candidates.""" | ||
|
||
sampling_topp: float = 1 | ||
"""Keep the most probable tokens whose cumulative probability exceeds this value.""" | ||
|
||
sampling_temperature: float = 1 | ||
"""Sampling temperature to generate more random samples.""" | ||
|
||
client: Any #: :meta private: | ||
|
||
tokenizer: Any #: :meta private: | ||
|
||
ctranslate2_kwargs: Dict[str, Any] = Field(default_factory=dict) | ||
""" | ||
Holds any model parameters valid for `ctranslate2.Generator` call not | ||
explicitly specified. | ||
""" | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that python package exists in environment.""" | ||
|
||
try: | ||
import ctranslate2 | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import ctranslate2 python package. " | ||
"Please install it with `pip install ctranslate2`." | ||
) | ||
|
||
try: | ||
import transformers | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import transformers python package. " | ||
"Please install it with `pip install transformers`." | ||
) | ||
|
||
values["client"] = ctranslate2.Generator( | ||
model_path=values["model_path"], | ||
device=values["device"], | ||
device_index=values["device_index"], | ||
compute_type=values["compute_type"], | ||
**values["ctranslate2_kwargs"], | ||
) | ||
|
||
values["tokenizer"] = transformers.AutoTokenizer.from_pretrained( | ||
values["tokenizer_name"] | ||
) | ||
|
||
return values | ||
|
||
@property | ||
def _default_params(self) -> Dict[str, Any]: | ||
"""Get the default parameters.""" | ||
return { | ||
"max_length": self.max_length, | ||
"sampling_topk": self.sampling_topk, | ||
"sampling_topp": self.sampling_topp, | ||
"sampling_temperature": self.sampling_temperature, | ||
} | ||
|
||
def _generate( | ||
self, | ||
prompts: List[str], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> LLMResult: | ||
# build sampling parameters | ||
params = {**self._default_params, **kwargs} | ||
|
||
# call the model | ||
encoded_prompts = self.tokenizer(prompts)["input_ids"] | ||
tokenized_prompts = [ | ||
self.tokenizer.convert_ids_to_tokens(encoded_prompt) | ||
for encoded_prompt in encoded_prompts | ||
] | ||
|
||
results = self.client.generate_batch(tokenized_prompts, **params) | ||
|
||
sequences = [result.sequences_ids[0] for result in results] | ||
decoded_sequences = [self.tokenizer.decode(seq) for seq in sequences] | ||
|
||
generations = [] | ||
for text in decoded_sequences: | ||
generations.append([Generation(text=text)]) | ||
|
||
return LLMResult(generations=generations) | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of llm.""" | ||
return "ctranslate2" |