Skip to content

Commit

Permalink
Fix gui+dataset (#67)
Browse files Browse the repository at this point in the history
* fix gui import

* fix dataset import
  • Loading branch information
hippalectryon-0 authored May 16, 2023
1 parent de9209c commit 31570a1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
4 changes: 3 additions & 1 deletion casalioy/gui.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""LLM through a GUI"""

import streamlit as st
from load_env import get_embedding_model, model_n_ctx, model_path, model_stop, model_temp, n_gpu_layers, persist_directory, use_mlock, print_HTML
from load_env import get_embedding_model, model_n_ctx, model_path, model_stop, model_temp, n_gpu_layers, persist_directory, use_mlock
from streamlit_chat import message
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_extras.colored_header import colored_header

from casalioy import startLLM
from casalioy.startLLM import QASystem
from casalioy.utils import print_HTML

title = "CASALIOY"


Expand Down
4 changes: 2 additions & 2 deletions casalioy/load_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
from langchain.embeddings import HuggingFaceEmbeddings, LlamaCppEmbeddings
from langchain.prompts import PromptTemplate

from casalioy.utils import download_if_repo, print_HTML
from casalioy.utils import download_if_repo

load_dotenv()
text_embeddings_model = os.environ.get("TEXT_EMBEDDINGS_MODEL")
text_embeddings_model_type = os.environ.get("TEXT_EMBEDDINGS_MODEL_TYPE")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
use_mlock = os.environ.get("USE_MLOCK").lower() == "true"

print_HTML
# ingest
persist_directory = os.environ.get("PERSIST_DIRECTORY")
documents_directory = os.environ.get("DOCUMENTS_DIRECTORY")
Expand All @@ -35,6 +34,7 @@
text_embeddings_model = download_if_repo(text_embeddings_model)
model_path = download_if_repo(model_path)


def get_embedding_model() -> tuple[HuggingFaceEmbeddings | LlamaCppEmbeddings, Callable]:
"""get the text embedding model
:returns: tuple[the model, its encoding function]"""
Expand Down
24 changes: 13 additions & 11 deletions casalioy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,29 @@ def download_if_repo(path: str, file: str = None, allow_patterns: str | list[str
"""download model from HF if not local"""
if allow_patterns is None:
allow_patterns = ["*.bin", "*.json"]
p = Path("models/"+path)

# check if dataset
split = path.split("/")
is_dataset = split[0] == "datasets"
if is_dataset:
split = split[1:]
path = "/".join(split)

p = "models/datasets" / Path(path) if is_dataset else "models" / Path(path)
if p.is_file() or p.is_dir():
print(p, "already installed")
print_HTML(f"<r>found local model at {p}</r>")
return str(p)

try:
split = path.split("/")
is_dataset = split[0] == "datasets"
if is_dataset:
split = split[1:]
path = "/".join(split)

try:
if path.endswith(".bin"):
path, file = "/".join(split[: 3 if is_dataset else 2]), split[-1]
validate_repo_id(path)
print_HTML("<r>Downloading {model} from HF</r>", model=path)
print_HTML("<r>Downloading {model_type} {model} from HF</r>", model=path, model_type="dataset" if is_dataset else "model")
new_path = Path(
snapshot_download(
repo_id=path,
allow_patterns=file or allow_patterns,
local_dir=f"models/{path}",
local_dir=str(p),
repo_type="dataset" if is_dataset else None,
local_dir_use_symlinks=False,
)
Expand Down

0 comments on commit 31570a1

Please sign in to comment.