Skip to content

Commit

Permalink
add Cohere support to the chatbot example (elastic#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg authored Mar 8, 2024
1 parent ec14ed9 commit aed8c4e
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 12 deletions.
10 changes: 10 additions & 0 deletions example-apps/chatbot-rag-app/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ export MISTRAL_API_ENDPOINT=... # optional
export MISTRAL_MODEL=... # optional
```

### Cohere

To use Cohere you need to set the following environment variables:

```
export LLM_TYPE=cohere
export COHERE_API_KEY=...
export COHERE_MODEL=... # optional
```

## Running the App

Once you have indexed data into the Elasticsearch index, there are two ways to run the app: via Docker or locally. Docker is advised for testing & production use. Locally is advised for development.
Expand Down
5 changes: 4 additions & 1 deletion example-apps/chatbot-rag-app/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def ask_question(question, session_id):

answer = ""
for chunk in get_llm().stream(qa_prompt):
yield f"data: {chunk.content}\n\n"
content = chunk.content.replace(
"\n", " "
) # the stream can get messed up with newlines
yield f"data: {content}\n\n"
answer += chunk.content

yield f"data: {DONE_TAG}\n\n"
Expand Down
11 changes: 10 additions & 1 deletion example-apps/chatbot-rag-app/api/llm_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
ChatVertexAI,
AzureChatOpenAI,
BedrockChat,
ChatCohere,
)
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
import os
import vertexai
Expand Down Expand Up @@ -76,12 +76,21 @@ def init_mistral_chat(temperature):
return ChatMistralAI(**kwargs)


def init_cohere_chat(temperature):
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
COHERE_MODEL = os.getenv("COHERE_MODEL")
return ChatCohere(
cohere_api_key=COHERE_API_KEY, model=COHERE_MODEL, temperature=temperature
)


MAP_LLM_TYPE_TO_CHAT_MODEL = {
"azure": init_azure_chat,
"bedrock": init_bedrock,
"openai": init_openai_chat,
"vertex": init_vertex_chat,
"mistral": init_mistral_chat,
"cohere": init_cohere_chat,
}


Expand Down
5 changes: 5 additions & 0 deletions example-apps/chatbot-rag-app/env.example
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ ES_INDEX_CHAT_HISTORY=workplace-app-docs-chat-history
# MISTRAL_API_KEY=
# MISTRAL_API_ENDPOINT=
# MISTRAL_MODEL=

# Uncomment and complete if you want to use Cohere
# LLM_TYPE=cohere
# COHERE_API_KEY=
# COHERE_MODEL=
3 changes: 3 additions & 0 deletions example-apps/chatbot-rag-app/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ boto3
# Mistral dependencies
langchain-mistralai

# Cohere dependencies
cohere

# TBD if these are still needed
exceptiongroup
importlib-metadata
Expand Down
90 changes: 80 additions & 10 deletions example-apps/chatbot-rag-app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
#
aiohttp==3.8.5
# via
# cohere
# langchain
# langchain-community
# openai
aiosignal==1.3.1
# via aiohttp
annotated-types==0.5.0
# via pydantic
anyio==3.7.1
# via
# langchain
# httpx
# langchain-core
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via aiohttp
backoff==2.2.1
# via cohere
blinker==1.6.2
# via flask
boto3==1.28.61
Expand All @@ -35,6 +39,8 @@ cachetools==5.3.1
certifi==2023.7.22
# via
# elastic-transport
# httpcore
# httpx
# requests
charset-normalizer==3.2.0
# via
Expand All @@ -44,8 +50,12 @@ click==8.1.7
# via
# flask
# pip-tools
cohere==4.52
# via -r requirements.in
dataclasses-json==0.5.14
# via langchain
# via
# langchain
# langchain-community
elastic-transport==8.4.0
# via elasticsearch
elasticsearch==8.12.1
Expand All @@ -54,6 +64,10 @@ elasticsearch==8.12.1
# langchain-elasticsearch
exceptiongroup==1.2.0
# via -r requirements.in
fastavro==1.9.4
# via cohere
filelock==3.13.1
# via huggingface-hub
flask==2.3.3
# via
# -r requirements.in
Expand All @@ -64,6 +78,8 @@ frozenlist==1.4.0
# via
# aiohttp
# aiosignal
fsspec==2024.2.0
# via huggingface-hub
google-api-core[grpc]==2.14.0
# via
# google-cloud-aiplatform
Expand Down Expand Up @@ -112,13 +128,24 @@ grpcio-status==1.59.3
# via
# -r requirements.in
# google-api-core
h11==0.14.0
# via httpcore
httpcore==1.0.4
# via httpx
httpx==0.25.2
# via mistralai
huggingface-hub==0.21.4
# via tokenizers
idna==3.4
# via
# anyio
# httpx
# requests
# yarl
importlib-metadata==6.8.0
# via -r requirements.in
# via
# -r requirements.in
# cohere
itsdangerous==2.1.2
# via flask
jinja2==3.1.2
Expand All @@ -135,20 +162,31 @@ jsonpointer==2.4
# via jsonpatch
langchain==0.1.9
# via -r requirements.in
langchain-core==0.1.23
# via langchain-elasticsearch
langchain-community==0.0.27
# via langchain
langchain-core==0.1.30
# via
# langchain
# langchain-community
# langchain-elasticsearch
# langchain-mistralai
langchain-elasticsearch==0.1.0
# via -r requirements.in
langchain-mistralai==0.0.5
# via -r requirements.in
langsmith==0.1.10
# via
# langchain
# langchain-community
# langchain-core
markupsafe==2.1.3
# via
# jinja2
# werkzeug
marshmallow==3.20.1
# via dataclasses-json
mistralai==0.1.3
# via langchain-mistralai
multidict==6.0.4
# via
# aiohttp
Expand All @@ -160,18 +198,28 @@ numexpr==2.8.5
numpy==1.25.2
# via
# langchain
# langchain-community
# langchain-elasticsearch
# numexpr
# pandas
# pyarrow
# shapely
openai==0.27.9
# via -r requirements.in
orjson==3.9.15
# via
# langsmith
# mistralai
packaging==23.2
# via
# build
# google-cloud-aiplatform
# google-cloud-bigquery
# huggingface-hub
# langchain-core
# marshmallow
pandas==2.2.1
# via mistralai
pip-tools==7.3.0
# via -r requirements.in
proto-plus==1.22.3
Expand All @@ -189,6 +237,8 @@ protobuf==4.25.1
# grpc-google-iam-v1
# grpcio-status
# proto-plus
pyarrow==15.0.1
# via mistralai
pyasn1==0.5.0
# via
# pyasn1-modules
Expand All @@ -200,6 +250,7 @@ pydantic==2.5.2
# langchain
# langchain-core
# langsmith
# mistralai
pydantic-core==2.14.5
# via pydantic
pyproject-hooks==1.0.0
Expand All @@ -208,20 +259,28 @@ python-dateutil==2.8.2
# via
# botocore
# google-cloud-bigquery
# pandas
python-dotenv==1.0.0
# via -r requirements.in
pytz==2024.1
# via pandas
pyyaml==6.0.1
# via
# huggingface-hub
# langchain
# langchain-community
# langchain-core
regex==2023.10.3
# via tiktoken
requests==2.31.0
# via
# cohere
# google-api-core
# google-cloud-bigquery
# google-cloud-storage
# huggingface-hub
# langchain
# langchain-community
# langchain-core
# langsmith
# openai
Expand All @@ -235,28 +294,41 @@ shapely==2.0.2
six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via anyio
# via
# anyio
# httpx
sqlalchemy==2.0.20
# via langchain
# via
# langchain
# langchain-community
tenacity==8.2.3
# via
# langchain
# langchain-community
# langchain-core
tiktoken==0.5.1
# via -r requirements.in
tokenizers==0.15.2
# via langchain-mistralai
tqdm==4.66.1
# via openai
# via
# huggingface-hub
# openai
typing-extensions==4.7.1
# via
# huggingface-hub
# pydantic
# pydantic-core
# sqlalchemy
# typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1
# via pandas
urllib3==1.26.16
# via
# botocore
# cohere
# elastic-transport
# requests
werkzeug==2.3.7
Expand All @@ -268,8 +340,6 @@ yarl==1.9.2
zipp==3.17.0
# via importlib-metadata

langchain-mistralai==0.0.5
# via -r requirements.in
# The following packages are considered to be unsafe in a requirements file:
# pip
# setuptools

0 comments on commit aed8c4e

Please sign in to comment.