Skip to content

Commit

Permalink
Merge pull request #354 from MannLabs/refactor_llm_II
Browse files Browse the repository at this point in the history
Refactor llm ii
  • Loading branch information
mschwoer authored Nov 8, 2024
2 parents 9c9e89a + 1c3fc86 commit e8e43ea
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 318 deletions.
75 changes: 28 additions & 47 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
)
from alphastats.gui.utils.gpt_helper import (
display_proteins,
get_assistant_functions,
get_general_assistant_functions,
get_subgroups_for_each_group,
)
from alphastats.gui.utils.ollama_utils import LLMIntegration
from alphastats.gui.utils.ollama_utils import LLMIntegration, Models
from alphastats.gui.utils.openai_utils import set_api_key
from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info

Expand All @@ -36,14 +34,15 @@ def llm_config():
with c1:
st.session_state[StateKeys.API_TYPE] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
[Models.GPT, Models.OLLAMA],
)

if st.session_state[StateKeys.API_TYPE] == "gpt4o":
if st.session_state[StateKeys.API_TYPE] == Models.GPT:
api_key = st.text_input("Enter OpenAI API Key", type="password")
set_api_key(api_key)
else:
st.info("Expecting Ollama API at http://localhost:11434.")
base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
st.info(f"Expecting Ollama API at {base_url}.")


llm_config()
Expand Down Expand Up @@ -72,12 +71,14 @@ def llm_config():
# genes_of_interest_colored_df[gene_names_colname].tolist(),
# )
# ) # TODO unused?
st.session_state[StateKeys.GENE_TO_PROT_ID] = dict(

gene_to_prot_id_map = dict(
zip(
genes_of_interest_colored_df[gene_names_colname].tolist(),
genes_of_interest_colored_df[prot_ids_colname].tolist(),
)
)
st.session_state[StateKeys.GENE_TO_PROT_ID] = gene_to_prot_id_map

with c2:
display_figure(volcano_plot.plot)
Expand Down Expand Up @@ -146,42 +147,24 @@ def llm_config():
st.stop()

try:
if st.session_state[StateKeys.API_TYPE] == "gpt4o":
llm = LLMIntegration(
api_type="gpt",
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
dataset=st.session_state[StateKeys.DATASET],
metadata=st.session_state[StateKeys.DATASET].metadata,
)
else:
llm = LLMIntegration(
api_type="ollama",
base_url=os.getenv("OLLAMA_BASE_URL", None),
dataset=st.session_state[StateKeys.DATASET],
metadata=st.session_state[StateKeys.DATASET].metadata,
)
llm = LLMIntegration(
api_type=st.session_state[StateKeys.API_TYPE],
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=os.getenv("OLLAMA_BASE_URL", None),
dataset=st.session_state[StateKeys.DATASET],
gene_to_prot_id_map=gene_to_prot_id_map,
)

# Set instructions and update tools
llm.tools = [
*get_general_assistant_functions(),
*get_assistant_functions(
gene_to_prot_id_dict=st.session_state[StateKeys.GENE_TO_PROT_ID],
metadata=st.session_state[StateKeys.DATASET].metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
st.session_state[StateKeys.DATASET].metadata
),
),
]

st.session_state[StateKeys.ARTIFACTS] = {}

llm.messages = [{"role": "system", "content": system_message}]

st.session_state[StateKeys.LLM_INTEGRATION] = llm
st.success(
f"{st.session_state[StateKeys.API_TYPE].upper()} integration initialized successfully!"
)

response = llm.chat_completion(user_prompt)
llm.chat_completion(user_prompt)

except AuthenticationError:
st.warning(
Expand All @@ -195,19 +178,17 @@ def llm_chat():
"""The chat interface for the LLM analysis."""
llm = st.session_state[StateKeys.LLM_INTEGRATION]

for num, role_content_dict in enumerate(st.session_state[StateKeys.MESSAGES]):
if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system":
continue
if "tool_calls" in role_content_dict:
continue
with st.chat_message(role_content_dict["role"]):
st.markdown(role_content_dict["content"])
if num in st.session_state[StateKeys.ARTIFACTS]:
for artefact in st.session_state[StateKeys.ARTIFACTS][num]:
if isinstance(artefact, pd.DataFrame):
st.dataframe(artefact)
elif "plotly" in str(type(artefact)):
st.plotly_chart(artefact)
for message in llm.get_print_view(show_all=False):
with st.chat_message(message["role"]):
st.markdown(message["content"])
for artifact in message["artifacts"]:
if isinstance(artifact, pd.DataFrame):
st.dataframe(artifact)
elif "plotly" in str(type(artifact)):
st.plotly_chart(artifact)
elif not isinstance(artifact, str):
st.warning("Don't know how to display artifact:")
st.write(artifact)

if prompt := st.chat_input("Say something"):
llm.chat_completion(prompt)
Expand Down
23 changes: 1 addition & 22 deletions alphastats/gui/utils/gpt_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import json
from typing import Dict, List, Union
from typing import Dict, List

import pandas as pd
import streamlit as st
Expand Down Expand Up @@ -310,26 +309,6 @@ def perform_dimensionality_reduction(group, method, circle, **kwargs):
return dr.plot


def turn_args_to_float(json_string: Union[str, bytes, bytearray]) -> Dict:
"""
Turn all values in a JSON string to floats if possible.
Args:
json_string (Union[str, bytes, bytearray]): The JSON string to convert.
Returns:
dict: The converted JSON string as a dictionary.
"""
data = json.loads(json_string)
for key, value in data.items():
if isinstance(value, str):
try:
data[key] = float(value)
except ValueError:
continue
return data


def get_gene_to_prot_id_mapping(gene_id: str) -> str:
"""Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id.
'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1'
Expand Down
121 changes: 60 additions & 61 deletions alphastats/gui/utils/ollama_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@

import pandas as pd
import plotly.io as pio
import streamlit as st
from IPython.display import HTML, Markdown, display
from openai import OpenAI

from alphastats.gui.utils.enrichment_analysis import get_enrichment_data
from alphastats.gui.utils.gpt_helper import (
get_assistant_functions,
get_general_assistant_functions,
get_subgroups_for_each_group,
perform_dimensionality_reduction,
)
from alphastats.gui.utils.ui_helper import StateKeys

# from alphastats.gui.utils.artefacts import ArtifactManager
from alphastats.gui.utils.uniprot_utils import get_gene_function

logger = logging.getLogger(__name__)


class Models:
GPT = "gpt-4o"
OLLAMA = "llama3.1:70b"


class LLMIntegration:
"""
A class to integrate different Language Model APIs and handle chat interactions.
Expand All @@ -43,8 +48,6 @@ class LLMIntegration:
Attributes
----------
api_type : str
The type of API being used
client : OpenAI
The OpenAI client instance
model : str
Expand All @@ -63,25 +66,27 @@ class LLMIntegration:

def __init__(
self,
api_type: str = "gpt",
api_type: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
dataset=None,
metadata=None,
gene_to_prot_id_map=None,
):
self.api_type = api_type
if api_type == "ollama":
self.model = api_type

if api_type == Models.OLLAMA:
url = f"{base_url or 'http://localhost:11434'}/v1"
self.client = OpenAI(base_url=url, api_key="ollama")
self.model = "llama3.1:70b"
else:
elif api_type == Models.GPT:
self.client = OpenAI(api_key=api_key)
# self.model = "gpt-4-0125-preview"
self.model = "gpt-4o"
else:
raise ValueError(f"Invalid API type: {api_type}")

self.messages = []
self.dataset = dataset
self.metadata = metadata
self.metadata = None if dataset is None else dataset.metadata
self._gene_to_prot_id_map = gene_to_prot_id_map

self.tools = self._get_tools()
self.artifacts = {}
# self.artifact_manager = ArtifactManager()
Expand All @@ -96,8 +101,22 @@ def _get_tools(self) -> List[Dict[str, Any]]:
List[Dict[str, Any]]
A list of dictionaries describing the available tools
"""
general_tools = get_general_assistant_functions()
return general_tools

tools = [
*get_general_assistant_functions(),
]
if self.metadata is not None and self._gene_to_prot_id_map is not None:
tools += (
*get_assistant_functions(
gene_to_prot_id_dict=self._gene_to_prot_id_map,
metadata=self.metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
self.metadata
),
),
)

return tools

def truncate_conversation_history(self, max_tokens: int = 100000):
"""
Expand All @@ -117,17 +136,6 @@ def truncate_conversation_history(self, max_tokens: int = 100000):
removed_message = self.messages.pop(0)
total_tokens -= len(removed_message["content"].split())

def update_session_state(self):
"""
Update the Streamlit session state with current conversation data.
Returns
-------
None
"""
st.session_state[StateKeys.MESSAGES] = self.messages
st.session_state[StateKeys.ARTIFACTS] = self.artifacts

def parse_model_response(self, response: Any) -> Dict[str, Any]:
"""
Parse the response from the language model.
Expand Down Expand Up @@ -243,7 +251,7 @@ def handle_function_calls(
post_artefact_message_idx = len(self.messages)
self.artifacts[post_artefact_message_idx] = new_artifacts.values()
logger.info(
f"Calling 'chat.completions.create' {self.model=} {self.messages=} {self.tools=} .."
f"Calling 'chat.completions.create' {self.messages=} {self.tools=} .."
)
response = self.client.chat.completions.create(
model=self.model,
Expand All @@ -257,6 +265,25 @@ def handle_function_calls(

return parsed_response

def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
"""Get a structured view of the conversation history for display purposes."""

print_view = []
for num, role_content_dict in enumerate(self.messages):
if not show_all and (role_content_dict["role"] in ["tool", "system"]):
continue
if not show_all and "tool_calls" in role_content_dict:
continue

print_view.append(
{
"role": role_content_dict["role"],
"content": role_content_dict["content"],
"artifacts": self.artifacts.get(num, []),
}
)
return print_view

def chat_completion(
self, prompt: str, role: str = "user"
) -> Tuple[str, Dict[str, Any]]:
Expand Down Expand Up @@ -285,7 +312,7 @@ def chat_completion(

try:
logger.info(
f"Calling 'chat.completions.create' {self.model=} {self.messages=} {self.tools=} .."
f"Calling 'chat.completions.create' {self.messages=} {self.tools=} .."
)
response = self.client.chat.completions.create(
model=self.model,
Expand All @@ -306,45 +333,17 @@ def chat_completion(
self.messages.append(
{"role": "assistant", "content": parsed_response["content"]}
)
self.update_session_state()
return parsed_response["content"], new_artifacts
return parsed_response[
"content"
], new_artifacts # TODO response is not used

except ArithmeticError as e:
error_message = f"Error in chat completion: {str(e)}"
self.messages.append({"role": "system", "content": error_message})
self.update_session_state()
return error_message, {}

def switch_backend(
self,
new_api_type: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
"""
Switch between different API backends.
Parameters
----------
new_api_type : str
The new API type to switch to ('gpt' or 'ollama')
base_url : str, optional
The base URL for the new API, by default None
api_key : str, optional
The API key for the new API, by default None
Returns
-------
None
"""
self.__init__(
api_type=new_api_type,
base_url=base_url,
api_key=api_key,
dataset=self.dataset,
metadata=self.metadata,
)

# TODO this seems to be for notebooks?
# we need some "export mode" where everything is shown
def display_chat_history(self):
"""
Display the chat history, including messages, function calls, and associated artifacts.
Expand Down
Loading

0 comments on commit e8e43ea

Please sign in to comment.