Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retriever that can re-phase user inputs #8026

Merged
merged 3 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions docs/extras/integrations/retrievers/re_phrase.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e8624be2",
"metadata": {},
"source": [
"# RePhraseQueryRetriever\n",
"\n",
"Simple retriever that applies an LLM between the user input and the query pass the to retriever.\n",
"\n",
"It can be used to pre-process the user input in any way.\n",
"\n",
"The default prompt used in the `from_llm` classmethod:\n",
"\n",
"```\n",
"DEFAULT_TEMPLATE = \"\"\"You are an assistant tasked with taking a natural language \\\n",
"query from a user and converting it into a query for a vectorstore. \\\n",
"In this process, you strip out information that is not relevant for \\\n",
"the retrieval task. Here is the user query: {question}\"\"\"\n",
"```\n",
"\n",
"Create a vectorstore."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1bfa6834",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import WebBaseLoader\n",
"\n",
"loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
"data = loader.load()\n",
"\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
"all_splits = text_splitter.split_documents(data)\n",
"\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"\n",
"vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d0b51556",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"logging.basicConfig()\n",
"logging.getLogger(\"langchain.retrievers.re_phraser\").setLevel(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "20e1e787",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.retrievers import RePhraseQueryRetriever"
]
},
{
"cell_type": "markdown",
"id": "88c0a972",
"metadata": {},
"source": [
"## Using the default prompt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "503994bd",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0)\n",
"retriever_from_llm = RePhraseQueryRetriever.from_llm(\n",
" retriever=vectorstore.as_retriever(), llm=llm\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8d17ecc9",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:langchain.retrievers.re_phraser:Re-phrased question: The user query can be converted into a query for a vectorstore as follows:\n",
"\n",
"\"approaches to Task Decomposition\"\n"
]
}
],
"source": [
"docs = retriever_from_llm.get_relevant_documents(\n",
" \"Hi I'm Lance. What are the approaches to Task Decomposition?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "76d54f1a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:langchain.retrievers.re_phraser:Re-phrased question: Query for vectorstore: \"Types of Memory\"\n"
]
}
],
"source": [
"docs = retriever_from_llm.get_relevant_documents(\n",
" \"I live in San Francisco. What are the Types of Memory?\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0513a6e2",
"metadata": {},
"source": [
"## Supply a prompt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "410d6a64",
"metadata": {},
"outputs": [],
"source": [
"from langchain import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"QUERY_PROMPT = PromptTemplate(\n",
" input_variables=[\"question\"],\n",
" template=\"\"\"You are an assistant tasked with taking a natural languge query from a user\n",
" and converting it into a query for a vectorstore. In the process, strip out all \n",
" information that is not relevant for the retrieval task and return a new, simplified\n",
" question for vectorstore retrieval. The new user query should be in pirate speech.\n",
" Here is the user query: {question} \"\"\",\n",
")\n",
"llm = ChatOpenAI(temperature=0)\n",
"llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2dbffdd3",
"metadata": {},
"outputs": [],
"source": [
"retriever_from_llm_chain = RePhraseQueryRetriever(\n",
" retriever=vectorstore.as_retriever(), llm_chain=llm_chain\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "103b4be3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:langchain.retrievers.re_phraser:Re-phrased question: Ahoy matey! What be Maximum Inner Product Search, ye scurvy dog?\n"
]
}
],
"source": [
"docs = retriever_from_llm_chain.get_relevant_documents(\n",
" \"Hi I'm Lance. What is Maximum Inner Product Search?\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions libs/langchain/langchain/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pubmed import PubMedRetriever
from langchain.retrievers.re_phraser import RePhraseQueryRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.svm import SVMRetriever
Expand Down Expand Up @@ -86,6 +87,7 @@
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"RePhraseQueryRetriever",
"WebResearchRetriever",
"EnsembleRetriever",
]
87 changes: 87 additions & 0 deletions libs/langchain/langchain/retrievers/re_phraser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
from typing import List

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseRetriever, Document

logger = logging.getLogger(__name__)

# Default template
DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \
query from a user and converting it into a query for a vectorstore. \
In this process, you strip out information that is not relevant for \
the retrieval task. Here is the user query: {question}"""

# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE)


class RePhraseQueryRetriever(BaseRetriever):

"""Given a user query, use an LLM to re-phrase it.
Then, retrieve docs for re-phrased query."""

retriever: BaseRetriever
llm_chain: LLMChain

@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLLM,
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
) -> "RePhraseQueryRetriever":
"""Initialize from llm using default template.

The prompt used here expects a single input: `question`

Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
prompt: prompt template for query generation

Returns:
RePhraseQueryRetriever
"""

llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
retriever=retriever,
llm_chain=llm_chain,
)

def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevated documents given a user question.

Args:
query: user question

Returns:
Relevant documents for re-phrased question
"""
response = self.llm_chain(query, callbacks=run_manager.get_child())
re_phrased_question = response["text"]
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.get_relevant_documents(
re_phrased_question, callbacks=run_manager.get_child()
)
return docs

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError