Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gui+dataset #67

Merged
merged 2 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion casalioy/gui.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""LLM through a GUI"""

import streamlit as st
from load_env import get_embedding_model, model_n_ctx, model_path, model_stop, model_temp, n_gpu_layers, persist_directory, use_mlock, print_HTML
from load_env import get_embedding_model, model_n_ctx, model_path, model_stop, model_temp, n_gpu_layers, persist_directory, use_mlock
from streamlit_chat import message
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.colored_header import colored_header

from casalioy import startLLM
from casalioy.startLLM import QASystem
from casalioy.utils import print_HTML

title = "CASALIOY"


Expand Down
4 changes: 2 additions & 2 deletions casalioy/load_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
from langchain.embeddings import HuggingFaceEmbeddings, LlamaCppEmbeddings
from langchain.prompts import PromptTemplate

from casalioy.utils import download_if_repo, print_HTML
from casalioy.utils import download_if_repo

load_dotenv()
text_embeddings_model = os.environ.get("TEXT_EMBEDDINGS_MODEL")
text_embeddings_model_type = os.environ.get("TEXT_EMBEDDINGS_MODEL_TYPE")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
use_mlock = os.environ.get("USE_MLOCK").lower() == "true"

print_HTML
# ingest
persist_directory = os.environ.get("PERSIST_DIRECTORY")
documents_directory = os.environ.get("DOCUMENTS_DIRECTORY")
Expand All @@ -35,6 +34,7 @@
text_embeddings_model = download_if_repo(text_embeddings_model)
model_path = download_if_repo(model_path)


def get_embedding_model() -> tuple[HuggingFaceEmbeddings | LlamaCppEmbeddings, Callable]:
"""get the text embedding model
:returns: tuple[the model, its encoding function]"""
Expand Down
24 changes: 13 additions & 11 deletions casalioy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,29 @@ def download_if_repo(path: str, file: str = None, allow_patterns: str | list[str
"""download model from HF if not local"""
if allow_patterns is None:
allow_patterns = ["*.bin", "*.json"]
p = Path("models/"+path)

# check if dataset
split = path.split("/")
is_dataset = split[0] == "datasets"
if is_dataset:
split = split[1:]
path = "/".join(split)

p = "models/datasets" / Path(path) if is_dataset else "models" / Path(path)
if p.is_file() or p.is_dir():
print(p, "already installed")
print_HTML(f"<r>found local model at {p}</r>")
return str(p)

try:
split = path.split("/")
is_dataset = split[0] == "datasets"
if is_dataset:
split = split[1:]
path = "/".join(split)

try:
if path.endswith(".bin"):
path, file = "/".join(split[: 3 if is_dataset else 2]), split[-1]
validate_repo_id(path)
print_HTML("<r>Downloading {model} from HF</r>", model=path)
print_HTML("<r>Downloading {model_type} {model} from HF</r>", model=path, model_type="dataset" if is_dataset else "model")
new_path = Path(
snapshot_download(
repo_id=path,
allow_patterns=file or allow_patterns,
local_dir=f"models/{path}",
local_dir=str(p),
repo_type="dataset" if is_dataset else None,
local_dir_use_symlinks=False,
)
Expand Down