Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ollama): resolve model list loading issue and add Pytest for component testing #3575

Merged
merged 8 commits into from
Aug 27, 2024
8 changes: 5 additions & 3 deletions src/backend/base/langflow/components/models/OllamaModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from urllib.parse import urljoin

import httpx
from langchain_community.chat_models import ChatOllama
Expand Down Expand Up @@ -41,8 +42,7 @@ def update_build_config(self, build_config: dict, field_value: Any, field_name:
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model_name"]["options"] = self.get_model(base_url_value + "/api/tags")

build_config["model_name"]["options"] = self.get_model(base_url_value)
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
Expand All @@ -55,8 +55,9 @@ def update_build_config(self, build_config: dict, field_value: Any, field_name:

return build_config

def get_model(self, url: str) -> list[str]:
def get_model(self, base_url_value: str) -> list[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just call this base_url

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it such that the endpoint url of the models are formed inside the function.
Earlier the input to the function was the url to the tags/models, which is ok, but in this way it would be more clear that the model is loaded from the base url and the endpoint url is framed without error inside the get_model function. Should I revert back to the previous method and have the model url endpoint before calling the function ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, that's fine

try:
url = urljoin(base_url_value, "/api/tags")
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()
Expand Down Expand Up @@ -252,6 +253,7 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
output = ChatOllama(**llm_params) # type: ignore
except Exception as e:
print(f"Exception caught: {e}")
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Could not initialize Ollama LLM.") from e

return output # type: ignore
131 changes: 131 additions & 0 deletions src/backend/tests/unit/components/models/test_ChatOllama_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest
from unittest.mock import patch, MagicMock
from langflow.components.models.OllamaModel import ChatOllamaComponent
from langchain_community.chat_models.ollama import ChatOllama
from urllib.parse import urljoin


@pytest.fixture
def component():
return ChatOllamaComponent()


@patch("httpx.Client.get")
def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response

base_url = "http://localhost:11434"

model_names = component.get_model(base_url)

expected_url = urljoin(base_url, "/api/tags")

mock_get.assert_called_once_with(expected_url)

assert model_names == ["model1", "model2"]


@patch("httpx.Client.get")
def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")

url = "http://localhost:11434/api/tags"

# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
component.get_model(url)


def test_update_build_config_mirostat_disabled(component):
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None


def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10


@patch("httpx.Client.get")
def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response

build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["model_name"]["options"] == ["model1", "model2"]


def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"

updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True

field_value = "Immediately"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True


@patch(
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(mock_chat_ollama, component):
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2 # Ensure this is set as a float
component.mirostat_tau = 10.0 # Ensure this is set as a float
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"


@patch("langchain_community.chat_models.ChatOllama")
def test_build_model_failure(mock_chat_ollama, component):
# Mock the ChatOllama to raise an exception when initialized
pass
Loading
Loading