Skip to content

Commit

Permalink
fix(ollama): resolve model list loading issue and add Pytest for comp…
Browse files Browse the repository at this point in the history
…onent testing (langflow-ai#3575)

* Commit to solve Model not loading issue

The issue was that the url of the models: api/tags was not parsed correctly.
It was having a // hence used urlencode to parse it properly.

Th e correct apporach works only if the base_url is correct,i.e a valid ollama URL:
for DS LF this must be a public ollama Server URL.

* updated the component Ollama Component

changed the get model to take in base url and the function will make the expected url for the model names. This makes the function better, than providing the model url as paramter.

Added Pytest, 7 tests, 1 test excluded for future implememtstion: test_build_model_failure

Make lint and Make format had touched multiple files

* removed unwanted print statements

removed unwanted print statements.

make format, formatted a lot of .tsx files also

* removed skipped tests

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and diogocabral committed Nov 26, 2024
1 parent eba95c2 commit 3a20531
Show file tree
Hide file tree
Showing 13 changed files with 3,129 additions and 1,378 deletions.
7 changes: 4 additions & 3 deletions src/backend/base/langflow/components/models/OllamaModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from urllib.parse import urljoin

import httpx
from langchain_community.chat_models import ChatOllama
Expand Down Expand Up @@ -41,8 +42,7 @@ def update_build_config(self, build_config: dict, field_value: Any, field_name:
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model_name"]["options"] = self.get_model(base_url_value + "/api/tags")

build_config["model_name"]["options"] = self.get_model(base_url_value)
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
Expand All @@ -55,8 +55,9 @@ def update_build_config(self, build_config: dict, field_value: Any, field_name:

return build_config

def get_model(self, url: str) -> list[str]:
def get_model(self, base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()
Expand Down
125 changes: 125 additions & 0 deletions src/backend/tests/unit/components/models/test_ChatOllama_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest
from unittest.mock import patch, MagicMock
from langflow.components.models.OllamaModel import ChatOllamaComponent
from langchain_community.chat_models.ollama import ChatOllama
from urllib.parse import urljoin


@pytest.fixture
def component():
return ChatOllamaComponent()


@patch("httpx.Client.get")
def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response

base_url = "http://localhost:11434"

model_names = component.get_model(base_url)

expected_url = urljoin(base_url, "/api/tags")

mock_get.assert_called_once_with(expected_url)

assert model_names == ["model1", "model2"]


@patch("httpx.Client.get")
def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")

url = "http://localhost:11434/api/tags"

# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
component.get_model(url)


def test_update_build_config_mirostat_disabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None


def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10


@patch("httpx.Client.get")
def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response

build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"

updated_config = component.update_build_config(build_config, field_value, field_name)

assert updated_config["model_name"]["options"] == ["model1", "model2"]


def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"

updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True

field_value = "Immediately"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True


@patch(
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(mock_chat_ollama, component):
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2 # Ensure this is set as a float
component.mirostat_tau = 10.0 # Ensure this is set as a float
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"
Loading

0 comments on commit 3a20531

Please sign in to comment.