diff --git a/latentscope/models/__init__.py b/latentscope/models/__init__.py index 4fdb7ae..8da2eb6 100644 --- a/latentscope/models/__init__.py +++ b/latentscope/models/__init__.py @@ -94,6 +94,20 @@ def get_chat_model(id): "name": model_name, "params": {} } + elif id.startswith("custom-"): + # Get custom model from custom_models.json + import os + from latentscope.util import get_data_dir + DATA_DIR = get_data_dir() + custom_models_path = os.path.join(DATA_DIR, "custom_models.json") + if os.path.exists(custom_models_path): + with open(custom_models_path, "r") as f: + custom_models = json.load(f) + model = next((m for m in custom_models if m["id"] == id), None) + if model is None: + raise ValueError(f"Custom model {id} not found") + else: + raise ValueError("No custom models found") else: model = get_chat_model_dict(id) @@ -101,6 +115,8 @@ def get_chat_model(id): return TransformersChatProvider(model['name'], model['params']) if model['provider'] == "openai": return OpenAIChatProvider(model['name'], model['params']) + if model['provider'] == "custom": + return OpenAIChatProvider(model['name'], model['params'], base_url=model['url']) if model['provider'] == "mistralai": return MistralAIChatProvider(model['name'], model['params']) if model['provider'] == "nltk": diff --git a/latentscope/models/providers/base.py b/latentscope/models/providers/base.py index 7f976f5..b5906ab 100644 --- a/latentscope/models/providers/base.py +++ b/latentscope/models/providers/base.py @@ -10,9 +10,10 @@ def embed(self, text): raise NotImplementedError("This method should be implemented by subclasses.") class ChatModelProvider: - def __init__(self, name, params): + def __init__(self, name, params, base_url=None): self.name = name self.params = params + self.base_url = base_url def load_model(self): raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/latentscope/models/providers/openai.py b/latentscope/models/providers/openai.py index 7cdbc85..8b9dfde 100644 --- a/latentscope/models/providers/openai.py +++ b/latentscope/models/providers/openai.py @@ -43,12 +43,20 @@ def embed(self, inputs, dimensions=None): class OpenAIChatProvider(ChatModelProvider): def load_model(self): - from openai import OpenAI + from openai import OpenAI, AsyncOpenAI import tiktoken import outlines - self.client = OpenAI(api_key=get_key("OPENAI_API_KEY")) - self.encoder = tiktoken.encoding_for_model(self.name) - self.model = outlines.models.openai(self.name, api_key=get_key("OPENAI_API_KEY")) + from outlines.models.openai import OpenAIConfig + if self.base_url is None: + self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY")) + self.encoder = tiktoken.encoding_for_model(self.name) + else: + self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"), base_url=self.base_url) + self.encoder = None + print("BASE URL", self.base_url) + print("MODEL", self.name) + config = OpenAIConfig(self.name) + self.model = outlines.models.openai(self.client, config) self.generator = outlines.generate.text(self.model) diff --git a/latentscope/scripts/label_clusters.py b/latentscope/scripts/label_clusters.py index 4ce1461..23edada 100644 --- a/latentscope/scripts/label_clusters.py +++ b/latentscope/scripts/label_clusters.py @@ -116,7 +116,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" items = items.drop_duplicates() tokens = 0 keep_items = [] - if max_tokens > 0: + if max_tokens > 0 and enc is not None: while tokens < max_tokens: for item in items: if item is None: @@ -162,6 +162,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" # do some cleanup of the labels when the model doesn't follow instructions clean_label = label.replace("\n", " ") + clean_label = clean_label.replace("<|eot_id|>", "") clean_label = clean_label.replace('"', '') clean_label = clean_label.replace("'", '') # clean_label = clean_label.replace("-", '') diff --git a/latentscope/server/app.py b/latentscope/server/app.py index bc71529..8e29a4b 100644 --- a/latentscope/server/app.py +++ b/latentscope/server/app.py @@ -1,7 +1,6 @@ import re import os import sys -import csv import json import math import h5py @@ -71,57 +70,15 @@ def check_read_only(s): if not READ_ONLY: app.register_blueprint(admin_bp, url_prefix='/api/admin') +from .models import models_bp, models_write_bp +app.register_blueprint(models_bp, url_prefix='/api/models') +if not READ_ONLY: + app.register_blueprint(models_write_bp, url_prefix='/api/models') + # =========================================================== # File based routes for reading data and metadata from disk # =========================================================== -@app.route('/api/embedding_models', methods=['GET']) -def get_embedding_models(): - embedding_path = files('latentscope.models').joinpath('embedding_models.json') - with embedding_path.open('r', encoding='utf-8') as file: - models = json.load(file) - return jsonify(models) - -@app.route('/api/chat_models', methods=['GET']) -def get_chat_models(): - chat_path = files('latentscope.models').joinpath('chat_models.json') - with chat_path.open('r', encoding='utf-8') as file: - models = json.load(file) - return jsonify(models) - -@app.route('/api/embedding_models/recent', methods=['GET']) -def get_recent_embedding_models(): - return get_recent_models("embedding") - -@app.route('/api/chat_models/recent', methods=['GET']) -def get_recent_chat_models(): - return get_recent_models("chat") - -def get_recent_models(model_type="embedding"): - recent_models_path = os.path.join(DATA_DIR, f"{model_type}_model_history.csv") - if not os.path.exists(recent_models_path): - return jsonify([]) - - with open(recent_models_path, 'r', encoding='utf-8') as file: - reader = csv.reader(file) - recent_models = [] - for row in reader: - recent_models.append({ - "timestamp": row[0], - "id": row[1], - "group": "recent", - "provider": row[1].split("-")[0], - "name": "-".join(row[1].split("-")[1:]).replace("___", "/") - }) - recent_models.sort(key=lambda x: x["timestamp"], reverse=True) - # Deduplicate models with the same id - seen_ids = set() - unique_recent_models = [] - for model in recent_models: - if model["id"] not in seen_ids: - unique_recent_models.append(model) - seen_ids.add(model["id"]) - recent_models = unique_recent_models[:5] - return jsonify(recent_models) + """ Allow fetching of dataset files directly from disk diff --git a/latentscope/server/jobs.py b/latentscope/server/jobs.py index 367a1f0..fc7ec62 100644 --- a/latentscope/server/jobs.py +++ b/latentscope/server/jobs.py @@ -486,7 +486,7 @@ def download_dataset(): dataset_name = request.args.get('dataset_name') job_id = str(uuid.uuid4()) - command = f'python latentscope/scripts/download_dataset.py "{dataset_repo}" "{dataset_name}" "{DATA_DIR}"' + command = f'ls-download-dataset "{dataset_repo}" "{dataset_name}" "{DATA_DIR}"' threading.Thread(target=run_job, args=(dataset_name, job_id, command)).start() return jsonify({"job_id": job_id}) @@ -499,6 +499,6 @@ def upload_dataset(): job_id = str(uuid.uuid4()) path = os.path.join(DATA_DIR, dataset) - command = f'python latentscope/scripts/upload_dataset.py "{path}" "{hf_dataset}" --main-parquet="{main_parquet}" --private={private}' + command = f'ls-upload-dataset "{path}" "{hf_dataset}" --main-parquet="{main_parquet}" --private={private}' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) diff --git a/latentscope/server/models.py b/latentscope/server/models.py new file mode 100644 index 0000000..47de8bd --- /dev/null +++ b/latentscope/server/models.py @@ -0,0 +1,103 @@ +import os +import re +import csv +import json +import uuid +from importlib.resources import files +from flask import Blueprint, jsonify, request + +# Create a Blueprint +models_bp = Blueprint('models_bp', __name__) +models_write_bp = Blueprint('models_write_bp', __name__) +DATA_DIR = os.getenv('LATENT_SCOPE_DATA') + +@models_bp.route('/embedding_models', methods=['GET']) +def get_embedding_models(): + embedding_path = files('latentscope.models').joinpath('embedding_models.json') + with embedding_path.open('r', encoding='utf-8') as file: + models = json.load(file) + return jsonify(models) + +@models_bp.route('/chat_models', methods=['GET']) +def get_chat_models(): + chat_path = files('latentscope.models').joinpath('chat_models.json') + with chat_path.open('r', encoding='utf-8') as file: + models = json.load(file) + return jsonify(models) + +@models_bp.route('/embedding_models/recent', methods=['GET']) +def get_recent_embedding_models(): + return get_recent_models("embedding") + +@models_bp.route('/chat_models/recent', methods=['GET']) +def get_recent_chat_models(): + return get_recent_models("chat") + +def get_recent_models(model_type="embedding"): + recent_models_path = os.path.join(DATA_DIR, f"{model_type}_model_history.csv") + if not os.path.exists(recent_models_path): + return jsonify([]) + + with open(recent_models_path, 'r', encoding='utf-8') as file: + reader = csv.reader(file) + recent_models = [] + for row in reader: + recent_models.append({ + "timestamp": row[0], + "id": row[1], + "group": "recent", + "provider": row[1].split("-")[0], + "name": "-".join(row[1].split("-")[1:]).replace("___", "/") + }) + recent_models.sort(key=lambda x: x["timestamp"], reverse=True) + # Deduplicate models with the same id + seen_ids = set() + unique_recent_models = [] + for model in recent_models: + if model["id"] not in seen_ids: + unique_recent_models.append(model) + seen_ids.add(model["id"]) + recent_models = unique_recent_models[:5] + return jsonify(recent_models) + +@models_bp.route('/custom-models', methods=['GET']) +def get_custom_models(): + custom_models_path = os.path.join(DATA_DIR, "custom_models.json") + if not os.path.exists(custom_models_path): + return jsonify([]) + with open(custom_models_path, 'r', encoding='utf-8') as file: + custom_models = json.load(file) + return jsonify(custom_models) + +@models_write_bp.route('/custom-models', methods=['POST']) +def add_custom_model(): + data = request.json + custom_models_path = os.path.join(DATA_DIR, "custom_models.json") + + # Read existing models + existing_models = [] + if os.path.exists(custom_models_path): + with open(custom_models_path, 'r', encoding='utf-8') as file: + existing_models = json.load(file) + + # Add new model + data["id"] = "custom-" + data["name"] + existing_models.append(data) + + # Write updated models + with open(custom_models_path, 'w', encoding='utf-8') as file: + json.dump(existing_models, file) + + return jsonify(existing_models) + +@models_write_bp.route('/custom-models/', methods=['DELETE']) +def delete_custom_model(model_id): + custom_models_path = os.path.join(DATA_DIR, "custom_models.json") + if not os.path.exists(custom_models_path): + return jsonify([]) + with open(custom_models_path, 'r', encoding='utf-8') as file: + custom_models = json.load(file) + custom_models = [model for model in custom_models if model["id"] != model_id] + with open(custom_models_path, 'w', encoding='utf-8') as file: + json.dump(custom_models, file) + return jsonify(custom_models) \ No newline at end of file diff --git a/setup.py b/setup.py index 8d696de..8550d05 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,8 @@ def run(self): 'ls-export-plot=latentscope.scripts.export_plot:main', 'ls-update-embedding-stats=latentscope.scripts.embed:update_embedding_stats', 'ls-sae=latentscope.scripts.sae:main', + 'ls-download-dataset=latentscope.scripts.download_dataset:main', + 'ls-upload-dataset=latentscope.scripts.upload_dataset:main', ], }, include_package_data=True, diff --git a/web/src/components/CustomModels.jsx b/web/src/components/CustomModels.jsx new file mode 100644 index 0000000..7aac063 --- /dev/null +++ b/web/src/components/CustomModels.jsx @@ -0,0 +1,107 @@ +import { useState, useEffect, useCallback } from 'react'; +import { Button } from 'react-element-forge'; +import { apiService } from '../lib/apiService'; + +import styles from './CustomModels.module.scss'; + +function CustomModels({ data_dir }) { + const [customModels, setCustomModels] = useState([]); + const [isSubmitting, setIsSubmitting] = useState(false); + + // Fetch existing custom models on component mount + useEffect(() => { + apiService.fetchCustomModels().then((models) => setCustomModels(models)); + }, []); + + const handleAddModel = useCallback(async (e) => { + e.preventDefault(); + setIsSubmitting(true); + + const form = e.target; + const data = new FormData(form); + const newModel = { + name: data.get('name'), + url: data.get('url'), + params: {}, + provider: 'custom', + }; + + try { + const updatedModels = await apiService.addCustomModel(newModel); + console.log('updatedModels', updatedModels); + setCustomModels(updatedModels); + form.reset(); + } catch (error) { + console.error('Failed to add custom model:', error); + } finally { + setIsSubmitting(false); + } + }, []); + + const handleDeleteModel = useCallback( + async (modelId) => { + try { + await apiService.deleteCustomModel(modelId); + setCustomModels(customModels.filter((model) => model.id !== modelId)); + } catch (error) { + console.error('Failed to delete custom model:', error); + } + }, + [customModels] + ); + + return ( +
+
+
+
+ + + + +
+ +
+ {customModels.map((model, index) => ( +
+
+ {model.name} + {model.url} +
+ +
+ ))} +
+
+
+ ); +} + +export default CustomModels; diff --git a/web/src/components/CustomModels.module.scss b/web/src/components/CustomModels.module.scss new file mode 100644 index 0000000..93901d4 --- /dev/null +++ b/web/src/components/CustomModels.module.scss @@ -0,0 +1,74 @@ +.custom-models { + display: flex; + flex-direction: column; + gap: 2rem; + padding: 1rem; + + &-setup { + display: flex; + gap: 2rem; + } + + &-form { + flex: 1; + + form { + display: flex; + flex-direction: column; + gap: 1rem; + margin-top: 1rem; + max-width: 400px; + + label { + display: flex; + flex-direction: column; + gap: 0.5rem; + } + + input { + padding: 0.5rem; + border: 1px solid #ccc; + border-radius: 4px; + } + + .form-label { + font-weight: 500; + } + } + } + + &-list { + flex: 1; + display: flex; + flex-direction: column; + gap: 1rem; + + .item { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1rem; + border: 1px solid #ccc; + border-radius: 4px; + + &-info { + display: flex; + flex-direction: column; + gap: 0.5rem; + } + + &-name { + font-weight: 500; + } + + &-url, &-model { + font-size: 0.9rem; + color: #666; + } + + .delete { + padding: 0.5rem; + } + } + } +} diff --git a/web/src/components/SettingsModal.jsx b/web/src/components/SettingsModal.jsx index d86c613..a40127a 100644 --- a/web/src/components/SettingsModal.jsx +++ b/web/src/components/SettingsModal.jsx @@ -1,11 +1,24 @@ -import { useState } from 'react'; +import { useState, useCallback } from 'react'; import { Button, Modal } from 'react-element-forge'; import Settings from '../pages/Settings'; import { Tooltip } from 'react-tooltip'; import styles from './SettingsModal.module.scss'; -const SettingsModal = ({ tooltip = '', color = 'primary', variant = 'outline' }) => { +const SettingsModal = ({ + tooltip = '', + color = 'primary', + variant = 'outline', + test = () => {}, + onClose = () => {}, // Provide a no-op default function +}) => { const [showSettings, setShowSettings] = useState(false); + const handleClose = useCallback(() => { + setShowSettings(false); + console.log('ON CLOSE', onClose); + if (onClose) onClose(); + console.log('TESTING', test); + test(); + }, [setShowSettings, onClose]); return ( <> @@ -27,19 +40,10 @@ const SettingsModal = ({ tooltip = '', color = 'primary', variant = 'outline' }) )} {showSettings && ( - setShowSettings(false)} - className={styles.modal} - > +
-
diff --git a/web/src/components/Setup/ClusterLabels.jsx b/web/src/components/Setup/ClusterLabels.jsx index fac048a..7d8e0f4 100644 --- a/web/src/components/Setup/ClusterLabels.jsx +++ b/web/src/components/Setup/ClusterLabels.jsx @@ -106,6 +106,14 @@ function ClusterLabels() { .catch(console.error); }, [setPresetModels]); + const [customModels, setCustomModels] = useState([]); + useEffect(() => { + apiService.fetchCustomModels().then((data) => { + console.log('custom models', data); + setCustomModels(data); + }); + }, [setCustomModels]); + const [recentModels, setRecentModels] = useState([]); const fetchRecentModels = useCallback(() => { apiService.getRecentChatModels().then((data) => { @@ -126,6 +134,7 @@ function ClusterLabels() { useEffect(() => { const am = [presetModels[0]] .concat(recentModels) + .concat(customModels) .concat(HFModels) .concat(presetModels.slice(1)) .filter((d) => !!d); @@ -152,7 +161,7 @@ function ClusterLabels() { setDefaultModel(defaultOption); setChatModel(defaultOption.id); } - }, [presetModels, HFModels, recentModels, defaultModel]); + }, [presetModels, HFModels, recentModels, defaultModel, customModels]); const handleModelSelectChange = useCallback( (selectedOption) => { @@ -274,6 +283,14 @@ function ClusterLabels() { goToNextStep(); }, [updateScope, goToNextStep, selected, savedScope, cluster]); + const handleSettingsClose = useCallback(() => { + console.log('CLOSING SETTINGS'); + apiService.fetchCustomModels().then((data) => { + console.log('FETCHED CUSTOM MODELS', data); + setCustomModels(data); + }); + }, [setCustomModels]); + return (
@@ -283,7 +300,7 @@ function ClusterLabels() { {cluster ? ` in ${cluster.id}` : ''} using a chat model. For quickest CPU based results use nltk top-words.

-
+ +
+
+
+

Custom Models

+
+ Add a custom model to use for embeddings. Saved in: + {envSettings.data_dir}/custom_models.json +
+ + +
); };