Skip to content

Commit

Permalink
Merge pull request #365 from MannLabs/refactor_llm_III
Browse files Browse the repository at this point in the history
Refactor llm iii
  • Loading branch information
mschwoer authored Nov 8, 2024
2 parents e8e43ea + 3bd1e86 commit a27cd95
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 61 deletions.
3 changes: 1 addition & 2 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def llm_config():
# )
# ) # TODO unused?

gene_to_prot_id_map = dict(
gene_to_prot_id_map = dict( # TODO move this logic to dataset
zip(
genes_of_interest_colored_df[gene_names_colname].tolist(),
genes_of_interest_colored_df[prot_ids_colname].tolist(),
)
)
st.session_state[StateKeys.GENE_TO_PROT_ID] = gene_to_prot_id_map

with c2:
display_figure(volcano_plot.plot)
Expand Down
64 changes: 34 additions & 30 deletions alphastats/gui/utils/gpt_helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from typing import Dict, List

import pandas as pd
Expand Down Expand Up @@ -124,26 +123,23 @@ def get_general_assistant_functions() -> List[Dict]:


def get_assistant_functions(
gene_to_prot_id_dict: Dict,
gene_to_prot_id_map: Dict,
metadata: pd.DataFrame,
subgroups_for_each_group: Dict,
) -> List[Dict]:
"""
Get a list of assistant functions for function calling in the ChatGPT model.
You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function.
For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling
Args:
gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values.
metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc).
subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group().
gene_to_prot_id_map (dict): A dictionary with gene names as keys and protein IDs as values.
metadata (pd.DataFrame): The metadata dataframe (which sample has which disease/treatment/condition/etc).
subgroups_for_each_group (dict): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group().
Returns:
list[dict]: A list of assistant functions.
"""
# TODO figure out how this relates to the parameter `subgroups_for_each_group`
subgroups_for_each_group_ = str(
get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata)
)
gene_names = list(gene_to_prot_id_map.keys())
groups = [str(col) for col in metadata.columns.to_list()]
return [
{
"type": "function",
Expand All @@ -153,21 +149,21 @@ def get_assistant_functions(
"parameters": {
"type": "object",
"properties": {
"protein_id": {
"gene_name": { # this will be mapped to "protein_id" when calling the function
"type": "string",
"enum": [i for i in gene_to_prot_id_dict],
"description": "Identifier for the protein of interest",
"enum": gene_names,
"description": "Identifier for the gene of interest",
},
"group": {
"type": "string",
"enum": [str(i) for i in metadata.columns.to_list()],
"enum": groups,
"description": "Column name in the dataset for the group variable",
},
"subgroups": {
"type": "array",
"items": {"type": "string"},
"description": f"Specific subgroups within the group to analyze. For each group you need to look up the subgroups in the dict"
f" {subgroups_for_each_group_} or present user with them first if you are not sure what to choose",
f" {subgroups_for_each_group} or present user with them first if you are not sure what to choose",
},
"method": {
"type": "string",
Expand Down Expand Up @@ -198,7 +194,7 @@ def get_assistant_functions(
"group": {
"type": "string",
"description": "The name of the group column in the dataset",
"enum": [str(i) for i in metadata.columns.to_list()],
"enum": groups,
},
"method": {
"type": "string",
Expand All @@ -225,7 +221,7 @@ def get_assistant_functions(
"color": {
"type": "string",
"description": "The name of the group column in the dataset to color the samples by",
"enum": [str(i) for i in metadata.columns.to_list()],
"enum": groups,
},
"method": {
"type": "string",
Expand Down Expand Up @@ -303,29 +299,37 @@ def get_assistant_functions(


def perform_dimensionality_reduction(group, method, circle, **kwargs):
dataset = st.session_state[StateKeys.DATASET]
dr = DimensionalityReduction(
st.session_state[StateKeys.DATASET], group, method, circle, **kwargs
mat=dataset.mat,
metadate=dataset.metadata,
sample=dataset.sample,
preprocessing_info=dataset.preprocessing_info,
group=group,
circle=circle,
method=method,
**kwargs,
)
return dr.plot


def get_gene_to_prot_id_mapping(gene_id: str) -> str:
def get_protein_id_for_gene_name(
gene_name: str, gene_to_prot_id_map: Dict[str, str]
) -> str:
"""Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id.
'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1'
Args:
gene_id (str): Gene id
gene_name (str): Gene id
Returns:
str: Protein id or gene id if not present in the mapping.
"""
import streamlit as st
if gene_name in gene_to_prot_id_map:
return gene_to_prot_id_map[gene_name]

for gene, protein_id in gene_to_prot_id_map.items():
if gene_name in gene.split(";"):
return protein_id

session_state_copy = dict(copy.deepcopy(st.session_state))
if StateKeys.GENE_TO_PROT_ID not in session_state_copy:
session_state_copy[StateKeys.GENE_TO_PROT_ID] = {}
if gene_id in session_state_copy[StateKeys.GENE_TO_PROT_ID]:
return session_state_copy[StateKeys.GENE_TO_PROT_ID][gene_id]
for gene, prot_id in session_state_copy[StateKeys.GENE_TO_PROT_ID].items():
if gene_id in gene.split(";"):
return prot_id
return gene_id
return gene_name
45 changes: 32 additions & 13 deletions alphastats/gui/utils/ollama_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from alphastats.gui.utils.gpt_helper import (
get_assistant_functions,
get_general_assistant_functions,
get_protein_id_for_gene_name,
get_subgroups_for_each_group,
perform_dimensionality_reduction,
)
Expand Down Expand Up @@ -108,7 +109,7 @@ def _get_tools(self) -> List[Dict[str, Any]]:
if self.metadata is not None and self._gene_to_prot_id_map is not None:
tools += (
*get_assistant_functions(
gene_to_prot_id_dict=self._gene_to_prot_id_map,
gene_to_prot_id_map=self._gene_to_prot_id_map,
metadata=self.metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
self.metadata
Expand All @@ -133,6 +134,7 @@ def truncate_conversation_history(self, max_tokens: int = 100000):
"""
total_tokens = sum(len(m["content"].split()) for m in self.messages)
while total_tokens > max_tokens and len(self.messages) > 1:
# TODO messages should still be displayed!
removed_message = self.messages.pop(0)
total_tokens -= len(removed_message["content"].split())

Expand Down Expand Up @@ -179,24 +181,39 @@ def execute_function(
If the function is not implemented or the dataset is not available
"""
try:
if function_name == "get_gene_function":
# TODO log whats going on
return get_gene_function(**function_args)
elif function_name == "get_enrichment_data":
return get_enrichment_data(**function_args)
elif function_name == "perform_dimensionality_reduction":
return perform_dimensionality_reduction(**function_args)
elif function_name.startswith("plot_") or function_name.startswith(
"perform_"
):
# first try to find the function in the non-Dataset functions
if (
function := {
"get_gene_function": get_gene_function,
"get_enrichment_data": get_enrichment_data,
"perform_dimensionality_reduction": perform_dimensionality_reduction,
}.get(function_name)
) is not None:
return function(**function_args)

# special treatment for this one
elif function_name == "plot_intensity":
gene_name = function_args.pop("gene_name")
protein_id = get_protein_id_for_gene_name(
gene_name, self._gene_to_prot_id_map
)
function_args["protein_id"] = protein_id

return self.dataset.plot_intensity(**function_args)

# fallback: try to find the function in the Dataset functions
else:
plot_function = getattr(
self.dataset, function_name.split(".")[-1], None
self.dataset,
function_name.split(".")[-1],
None, # TODO why split?
)
if plot_function:
return plot_function(**function_args)
raise ValueError(
f"Function {function_name} not implemented or dataset not available"
)

except Exception as e:
return f"Error executing {function_name}: {str(e)}"

Expand All @@ -219,6 +236,7 @@ def handle_function_calls(
"""
new_artifacts = {}

funcs_and_args = "\n".join(
[
f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}"
Expand All @@ -231,7 +249,6 @@ def handle_function_calls(

for tool_call in tool_calls:
function_name = tool_call.function.name
print(f"Calling function: {function_name}")
function_args = json.loads(tool_call.function.arguments)

function_result = self.execute_function(function_name, function_args)
Expand All @@ -248,8 +265,10 @@ def handle_function_calls(
"tool_call_id": tool_call.id,
}
)

post_artefact_message_idx = len(self.messages)
self.artifacts[post_artefact_message_idx] = new_artifacts.values()

logger.info(
f"Calling 'chat.completions.create' {self.messages=} {self.tools=} .."
)
Expand Down
4 changes: 0 additions & 4 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ def init_session_state() -> None:
if StateKeys.USER_SESSION_ID not in st.session_state:
st.session_state[StateKeys.USER_SESSION_ID] = str(uuid.uuid4())

if StateKeys.GENE_TO_PROT_ID not in st.session_state:
st.session_state[StateKeys.GENE_TO_PROT_ID] = {}

if StateKeys.ORGANISM not in st.session_state:
st.session_state[StateKeys.ORGANISM] = 9606 # human

Expand All @@ -97,7 +94,6 @@ class StateKeys:
## 02_Data Import
# on 1st run
ORGANISM = "organism"
GENE_TO_PROT_ID = "gene_to_prot_id"
USER_SESSION_ID = "user_session_id"
LOADER = "loader"
# on sample run (function load_sample_data), removed on new session click
Expand Down
3 changes: 1 addition & 2 deletions alphastats/plots/IntensityPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import plotly.graph_objects as go
import scipy

from alphastats.gui.utils.gpt_helper import get_gene_to_prot_id_mapping
from alphastats.plots.PlotUtils import PlotUtils, plotly_object

plotly.io.templates["alphastats_colors"] = plotly.graph_objects.layout.Template(
Expand Down Expand Up @@ -54,7 +53,7 @@ def __init__(
self.intensity_column = intensity_column
self.preprocessing_info = preprocessing_info

self.protein_id = get_gene_to_prot_id_mapping(protein_id)
self.protein_id = protein_id
self.group = group
self.subgroups = subgroups
self.method = method
Expand Down
2 changes: 0 additions & 2 deletions tests/gui/test_02_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_page_02_loads_without_input():

assert at.session_state[StateKeys.ORGANISM] == 9606
assert at.session_state[StateKeys.USER_SESSION_ID] is not None
assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {}


@patch("streamlit.file_uploader")
Expand All @@ -31,7 +30,6 @@ def test_patched_page_02_loads_without_input(mock_file_uploader: MagicMock):

assert at.session_state[StateKeys.ORGANISM] == 9606
assert at.session_state[StateKeys.USER_SESSION_ID] is not None
assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {}


@patch(
Expand Down
4 changes: 0 additions & 4 deletions tests/test_DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from alphastats.DataSet import DataSet
from alphastats.dataset_factory import DataSetFactory
from alphastats.DataSet_Preprocess import PreprocessingStateKeys
from alphastats.gui.utils.ui_helper import StateKeys
from alphastats.loader.AlphaPeptLoader import AlphaPeptLoader
from alphastats.loader.DIANNLoader import DIANNLoader
from alphastats.loader.FragPipeLoader import FragPipeLoader
Expand Down Expand Up @@ -517,9 +516,6 @@ def test_plot_intenstity_subgroup(self):
self.assertEqual(len(plot_dict.get("data")), 3)

def test_plot_intensity_subgroup_gracefully_handle_one_group(self):
import streamlit as st

st.session_state[StateKeys.GENE_TO_PROT_ID] = {}
plot = self.obj.plot_intensity(
protein_id="K7ERI9;A0A024R0T8;P02654;K7EJI9;K7ELM9;K7EPF9;K7EKP1",
group="disease",
Expand Down
4 changes: 0 additions & 4 deletions tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
from alphastats.gui.utils.uniprot_utils import extract_data, get_uniprot_data
from alphastats.loader.MaxQuantLoader import MaxQuantLoader

if StateKeys.GENE_TO_PROT_ID not in st.session_state:
st.session_state[StateKeys.GENE_TO_PROT_ID] = {}


logger = logging.getLogger(__name__)


Expand Down

0 comments on commit a27cd95

Please sign in to comment.