forked from langflow-ai/langflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(ollama): resolve model list loading issue and add Pytest for comp…
…onent testing (langflow-ai#3575) * Commit to solve Model not loading issue The issue was that the url of the models: api/tags was not parsed correctly. It was having a // hence used urlencode to parse it properly. Th e correct apporach works only if the base_url is correct,i.e a valid ollama URL: for DS LF this must be a public ollama Server URL. * updated the component Ollama Component changed the get model to take in base url and the function will make the expected url for the model names. This makes the function better, than providing the model url as paramter. Added Pytest, 7 tests, 1 test excluded for future implememtstion: test_build_model_failure Make lint and Make format had touched multiple files * removed unwanted print statements removed unwanted print statements. make format, formatted a lot of .tsx files also * removed skipped tests * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
eba95c2
commit 3a20531
Showing
13 changed files
with
3,129 additions
and
1,378 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
src/backend/tests/unit/components/models/test_ChatOllama_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
from langflow.components.models.OllamaModel import ChatOllamaComponent | ||
from langchain_community.chat_models.ollama import ChatOllama | ||
from urllib.parse import urljoin | ||
|
||
|
||
@pytest.fixture | ||
def component(): | ||
return ChatOllamaComponent() | ||
|
||
|
||
@patch("httpx.Client.get") | ||
def test_get_model_success(mock_get, component): | ||
mock_response = MagicMock() | ||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | ||
mock_response.raise_for_status.return_value = None | ||
mock_get.return_value = mock_response | ||
|
||
base_url = "http://localhost:11434" | ||
|
||
model_names = component.get_model(base_url) | ||
|
||
expected_url = urljoin(base_url, "/api/tags") | ||
|
||
mock_get.assert_called_once_with(expected_url) | ||
|
||
assert model_names == ["model1", "model2"] | ||
|
||
|
||
@patch("httpx.Client.get") | ||
def test_get_model_failure(mock_get, component): | ||
# Mock the response for the HTTP GET request to raise an exception | ||
mock_get.side_effect = Exception("HTTP request failed") | ||
|
||
url = "http://localhost:11434/api/tags" | ||
|
||
# Assert that the ValueError is raised when an exception occurs | ||
with pytest.raises(ValueError, match="Could not retrieve models"): | ||
component.get_model(url) | ||
|
||
|
||
def test_update_build_config_mirostat_disabled(component): | ||
build_config = { | ||
"mirostat_eta": {"advanced": False, "value": 0.1}, | ||
"mirostat_tau": {"advanced": False, "value": 5}, | ||
} | ||
field_value = "Disabled" | ||
field_name = "mirostat" | ||
|
||
updated_config = component.update_build_config(build_config, field_value, field_name) | ||
|
||
assert updated_config["mirostat_eta"]["advanced"] is True | ||
assert updated_config["mirostat_tau"]["advanced"] is True | ||
assert updated_config["mirostat_eta"]["value"] is None | ||
assert updated_config["mirostat_tau"]["value"] is None | ||
|
||
|
||
def test_update_build_config_mirostat_enabled(component): | ||
build_config = { | ||
"mirostat_eta": {"advanced": False, "value": None}, | ||
"mirostat_tau": {"advanced": False, "value": None}, | ||
} | ||
field_value = "Mirostat 2.0" | ||
field_name = "mirostat" | ||
|
||
updated_config = component.update_build_config(build_config, field_value, field_name) | ||
|
||
assert updated_config["mirostat_eta"]["advanced"] is False | ||
assert updated_config["mirostat_tau"]["advanced"] is False | ||
assert updated_config["mirostat_eta"]["value"] == 0.2 | ||
assert updated_config["mirostat_tau"]["value"] == 10 | ||
|
||
|
||
@patch("httpx.Client.get") | ||
def test_update_build_config_model_name(mock_get, component): | ||
# Mock the response for the HTTP GET request | ||
mock_response = MagicMock() | ||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | ||
mock_response.raise_for_status.return_value = None | ||
mock_get.return_value = mock_response | ||
|
||
build_config = { | ||
"base_url": {"load_from_db": False, "value": None}, | ||
"model_name": {"options": []}, | ||
} | ||
field_value = None | ||
field_name = "model_name" | ||
|
||
updated_config = component.update_build_config(build_config, field_value, field_name) | ||
|
||
assert updated_config["model_name"]["options"] == ["model1", "model2"] | ||
|
||
|
||
def test_update_build_config_keep_alive(component): | ||
build_config = {"keep_alive": {"value": None, "advanced": False}} | ||
field_value = "Keep" | ||
field_name = "keep_alive_flag" | ||
|
||
updated_config = component.update_build_config(build_config, field_value, field_name) | ||
assert updated_config["keep_alive"]["value"] == "-1" | ||
assert updated_config["keep_alive"]["advanced"] is True | ||
|
||
field_value = "Immediately" | ||
updated_config = component.update_build_config(build_config, field_value, field_name) | ||
assert updated_config["keep_alive"]["value"] == "0" | ||
assert updated_config["keep_alive"]["advanced"] is True | ||
|
||
|
||
@patch( | ||
"langchain_community.chat_models.ChatOllama", | ||
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"), | ||
) | ||
def test_build_model(mock_chat_ollama, component): | ||
component.base_url = "http://localhost:11434" | ||
component.model_name = "llama3.1" | ||
component.mirostat = "Mirostat 2.0" | ||
component.mirostat_eta = 0.2 # Ensure this is set as a float | ||
component.mirostat_tau = 10.0 # Ensure this is set as a float | ||
component.temperature = 0.2 | ||
component.verbose = True | ||
model = component.build_model() | ||
assert isinstance(model, ChatOllama) | ||
assert model.base_url == "http://localhost:11434" | ||
assert model.model == "llama3.1" |
Oops, something went wrong.