Skip to content

Commit

Permalink
fix: Merge pull request #3 from AI21Labs/pr_fixes_1
Browse files Browse the repository at this point in the history
fix: Fix LC CR
  • Loading branch information
asafgardin authored Feb 11, 2024
2 parents 7bd791e + 88e79a9 commit f3c5228
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 105 deletions.
1 change: 1 addition & 0 deletions .github/workflows/_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ jobs:
- name: Run integration tests
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
env:
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/ai21/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ all: help

# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

test:
poetry run pytest $(TEST_FILE)

tests:
integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)


Expand Down
74 changes: 74 additions & 0 deletions libs/partners/ai21/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,75 @@
# langchain-ai21

This package contains the LangChain integrations for [AI21](https://docs.ai21.com/) through their [AI21](https://pypi.org/project/ai21/) SDK.

## Installation and Setup

- Install the AI21 partner package
```bash
pip install langchain-ai21
```
- Get an AI21 api key and set it as an environment variable (`AI21_API_KEY`)


## Chat Models

This package contains the `ChatAI21` class, which is the recommended way to interface with AI21 Chat models.

To use, install the requirements, and configure your environment.

```bash
export AI21_API_KEY=your-api-key
```

Then initialize

```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21

chat = ChatAI21(model="j2-ultra")
messages = [HumanMessage(content="Hello from AI21")]
chat.invoke(messages)
```

## LLMs
You can use AI21's generative AI models as Langchain LLMs:

```python
from langchain.prompts import PromptTemplate
from langchain_ai21 import AI21LLM

llm = AI21LLM(model="j2-ultra")

template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)

chain = prompt | llm

question = "Which scientist discovered relativity?"
print(chain.invoke({"question": question}))
```

## Embeddings

You can use AI21's embeddings models as:

### Query

```python
from langchain_ai21 import AI21Embeddings

embeddings = AI21Embeddings()
embeddings.embed_query("Hello! This is some query")
```

### Document

```python
from langchain_ai21 import AI21Embeddings

embeddings = AI21Embeddings()
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
```
4 changes: 2 additions & 2 deletions libs/partners/ai21/langchain_ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from langchain_ai21.chat_models import ChatAI21
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21
from langchain_ai21.llms import AI21LLM

__all__ = [
"AI21",
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
]
4 changes: 3 additions & 1 deletion libs/partners/ai21/langchain_ai21/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class ChatAI21(BaseChatModel, AI21Base):
model = ChatAI21()
"""

model: str = "j2-ultra"
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""

Expand Down
7 changes: 4 additions & 3 deletions libs/partners/ai21/langchain_ai21/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from langchain_ai21.ai21_base import AI21Base


class AI21(BaseLLM, AI21Base):
class AI21LLM(BaseLLM, AI21Base):
"""AI21LLM large language models.
Example:
Expand All @@ -28,8 +28,9 @@ class AI21(BaseLLM, AI21Base):
model = AI21LLM()
"""

model: str = "j2-ultra"
"""Model type you wish to interact with."""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""

num_results: int = 1
"""The number of responses to generate for a given prompt."""
Expand Down
12 changes: 6 additions & 6 deletions libs/partners/ai21/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.0.12"
langchain-core = "^0.1.22"
ai21 = "^2.0.0"

[tool.poetry.group.test]
Expand Down
10 changes: 3 additions & 7 deletions libs/partners/ai21/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration

from langchain_ai21.chat_models import ChatAI21


@pytest.mark.requires("ai21")
def test_invoke() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21()
llm = ChatAI21(model="j2-ultra")

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)


@pytest.mark.requires("ai21")
def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21()
llm = ChatAI21(model="j2-ultra")
message = HumanMessage(content="Hello")

result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
Expand All @@ -31,10 +28,9 @@ def test_generation() -> None:
assert generation.text == generation.message.content


@pytest.mark.requires("ai21")
async def test_ageneration() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21()
llm = ChatAI21(model="j2-ultra")
message = HumanMessage(content="Hello")

result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
Expand Down
79 changes: 29 additions & 50 deletions libs/partners/ai21/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,47 @@
"""Test AI21LLM llm."""

import pytest
from ai21.models import Penalty

from langchain_ai21.llms import AI21
from langchain_ai21.llms import AI21LLM


def _generate_llm_client_parameters() -> AI21:
return AI21(
max_tokens=2,
temperature=0,
top_p=1,
top_k_return=0,
num_results=1,
def _generate_llm() -> AI21LLM:
"""
Testing AI21LLm using non default parameters with the following parameters
"""
return AI21LLM(
model="j2-ultra",
max_tokens=2, # Use less tokens for a faster response
temperature=0, # for a consistent response
epoch=1,
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
frequency_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
presence_penalty=Penalty(
scale=0,
apply_to_emojis=False,
apply_to_numbers=False,
apply_to_stopwords=False,
apply_to_punctuation=False,
apply_to_whitespaces=False,
),
)


@pytest.mark.requires("ai21")
def test_stream() -> None:
"""Test streaming tokens from AI21."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)


@pytest.mark.requires("ai21")
async def test_abatch() -> None:
"""Test streaming tokens from AI21LLM."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


@pytest.mark.requires("ai21")
async def test_abatch_tags() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
Expand All @@ -72,37 +50,39 @@ async def test_abatch_tags() -> None:
assert isinstance(token, str)


@pytest.mark.requires("ai21")
def test_batch() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


@pytest.mark.requires("ai21")
async def test_ainvoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)


@pytest.mark.requires("ai21")
def test_invoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21()
llm = AI21LLM(
model="j2-ultra",
)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)


@pytest.mark.requires("ai21")
def test__generate() -> None:
llm = _generate_llm_client_parameters()
llm = _generate_llm()
llm_result = llm.generate(
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
stop=["##"],
Expand All @@ -112,9 +92,8 @@ def test__generate() -> None:
assert llm_result.llm_output["token_count"] != 0 # type: ignore


@pytest.mark.requires("ai21")
async def test__agenerate() -> None:
llm = _generate_llm_client_parameters()
llm = _generate_llm()
llm_result = await llm.agenerate(
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
stop=["##"],
Expand Down
Loading

0 comments on commit f3c5228

Please sign in to comment.