Skip to content

Commit

Permalink
support custom models via URLS #2 and #35
Browse files Browse the repository at this point in the history
  • Loading branch information
enjalot committed Nov 20, 2024
1 parent a02e742 commit 5d2a9db
Show file tree
Hide file tree
Showing 14 changed files with 402 additions and 77 deletions.
16 changes: 16 additions & 0 deletions latentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,29 @@ def get_chat_model(id):
"name": model_name,
"params": {}
}
elif id.startswith("custom-"):
# Get custom model from custom_models.json
import os
from latentscope.util import get_data_dir
DATA_DIR = get_data_dir()
custom_models_path = os.path.join(DATA_DIR, "custom_models.json")
if os.path.exists(custom_models_path):
with open(custom_models_path, "r") as f:
custom_models = json.load(f)
model = next((m for m in custom_models if m["id"] == id), None)
if model is None:
raise ValueError(f"Custom model {id} not found")
else:
raise ValueError("No custom models found")
else:
model = get_chat_model_dict(id)

if model['provider'] == "🤗":
return TransformersChatProvider(model['name'], model['params'])
if model['provider'] == "openai":
return OpenAIChatProvider(model['name'], model['params'])
if model['provider'] == "custom":
return OpenAIChatProvider(model['name'], model['params'], base_url=model['url'])
if model['provider'] == "mistralai":
return MistralAIChatProvider(model['name'], model['params'])
if model['provider'] == "nltk":
Expand Down
3 changes: 2 additions & 1 deletion latentscope/models/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ def embed(self, text):
raise NotImplementedError("This method should be implemented by subclasses.")

class ChatModelProvider:
def __init__(self, name, params):
def __init__(self, name, params, base_url=None):
self.name = name
self.params = params
self.base_url = base_url

def load_model(self):
raise NotImplementedError("This method should be implemented by subclasses.")
Expand Down
16 changes: 12 additions & 4 deletions latentscope/models/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,20 @@ def embed(self, inputs, dimensions=None):

class OpenAIChatProvider(ChatModelProvider):
def load_model(self):
from openai import OpenAI
from openai import OpenAI, AsyncOpenAI
import tiktoken
import outlines
self.client = OpenAI(api_key=get_key("OPENAI_API_KEY"))
self.encoder = tiktoken.encoding_for_model(self.name)
self.model = outlines.models.openai(self.name, api_key=get_key("OPENAI_API_KEY"))
from outlines.models.openai import OpenAIConfig
if self.base_url is None:
self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"))
self.encoder = tiktoken.encoding_for_model(self.name)
else:
self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"), base_url=self.base_url)
self.encoder = None
print("BASE URL", self.base_url)
print("MODEL", self.name)
config = OpenAIConfig(self.name)
self.model = outlines.models.openai(self.client, config)
self.generator = outlines.generate.text(self.model)


Expand Down
3 changes: 2 additions & 1 deletion latentscope/scripts/label_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
items = items.drop_duplicates()
tokens = 0
keep_items = []
if max_tokens > 0:
if max_tokens > 0 and enc is not None:
while tokens < max_tokens:
for item in items:
if item is None:
Expand Down Expand Up @@ -162,6 +162,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="

# do some cleanup of the labels when the model doesn't follow instructions
clean_label = label.replace("\n", " ")
clean_label = clean_label.replace("<|eot_id|>", "")
clean_label = clean_label.replace('"', '')
clean_label = clean_label.replace("'", '')
# clean_label = clean_label.replace("-", '')
Expand Down
55 changes: 6 additions & 49 deletions latentscope/server/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
import os
import sys
import csv
import json
import math
import h5py
Expand Down Expand Up @@ -71,57 +70,15 @@ def check_read_only(s):
if not READ_ONLY:
app.register_blueprint(admin_bp, url_prefix='/api/admin')

from .models import models_bp, models_write_bp
app.register_blueprint(models_bp, url_prefix='/api/models')
if not READ_ONLY:
app.register_blueprint(models_write_bp, url_prefix='/api/models')

# ===========================================================
# File based routes for reading data and metadata from disk
# ===========================================================
@app.route('/api/embedding_models', methods=['GET'])
def get_embedding_models():
embedding_path = files('latentscope.models').joinpath('embedding_models.json')
with embedding_path.open('r', encoding='utf-8') as file:
models = json.load(file)
return jsonify(models)

@app.route('/api/chat_models', methods=['GET'])
def get_chat_models():
chat_path = files('latentscope.models').joinpath('chat_models.json')
with chat_path.open('r', encoding='utf-8') as file:
models = json.load(file)
return jsonify(models)

@app.route('/api/embedding_models/recent', methods=['GET'])
def get_recent_embedding_models():
return get_recent_models("embedding")

@app.route('/api/chat_models/recent', methods=['GET'])
def get_recent_chat_models():
return get_recent_models("chat")

def get_recent_models(model_type="embedding"):
recent_models_path = os.path.join(DATA_DIR, f"{model_type}_model_history.csv")
if not os.path.exists(recent_models_path):
return jsonify([])

with open(recent_models_path, 'r', encoding='utf-8') as file:
reader = csv.reader(file)
recent_models = []
for row in reader:
recent_models.append({
"timestamp": row[0],
"id": row[1],
"group": "recent",
"provider": row[1].split("-")[0],
"name": "-".join(row[1].split("-")[1:]).replace("___", "/")
})
recent_models.sort(key=lambda x: x["timestamp"], reverse=True)
# Deduplicate models with the same id
seen_ids = set()
unique_recent_models = []
for model in recent_models:
if model["id"] not in seen_ids:
unique_recent_models.append(model)
seen_ids.add(model["id"])
recent_models = unique_recent_models[:5]
return jsonify(recent_models)


"""
Allow fetching of dataset files directly from disk
Expand Down
4 changes: 2 additions & 2 deletions latentscope/server/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def download_dataset():
dataset_name = request.args.get('dataset_name')

job_id = str(uuid.uuid4())
command = f'python latentscope/scripts/download_dataset.py "{dataset_repo}" "{dataset_name}" "{DATA_DIR}"'
command = f'ls-download-dataset "{dataset_repo}" "{dataset_name}" "{DATA_DIR}"'
threading.Thread(target=run_job, args=(dataset_name, job_id, command)).start()
return jsonify({"job_id": job_id})

Expand All @@ -499,6 +499,6 @@ def upload_dataset():

job_id = str(uuid.uuid4())
path = os.path.join(DATA_DIR, dataset)
command = f'python latentscope/scripts/upload_dataset.py "{path}" "{hf_dataset}" --main-parquet="{main_parquet}" --private={private}'
command = f'ls-upload-dataset "{path}" "{hf_dataset}" --main-parquet="{main_parquet}" --private={private}'
threading.Thread(target=run_job, args=(dataset, job_id, command)).start()
return jsonify({"job_id": job_id})
103 changes: 103 additions & 0 deletions latentscope/server/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import re
import csv
import json
import uuid
from importlib.resources import files
from flask import Blueprint, jsonify, request

# Create a Blueprint
models_bp = Blueprint('models_bp', __name__)
models_write_bp = Blueprint('models_write_bp', __name__)
DATA_DIR = os.getenv('LATENT_SCOPE_DATA')

@models_bp.route('/embedding_models', methods=['GET'])
def get_embedding_models():
embedding_path = files('latentscope.models').joinpath('embedding_models.json')
with embedding_path.open('r', encoding='utf-8') as file:
models = json.load(file)
return jsonify(models)

@models_bp.route('/chat_models', methods=['GET'])
def get_chat_models():
chat_path = files('latentscope.models').joinpath('chat_models.json')
with chat_path.open('r', encoding='utf-8') as file:
models = json.load(file)
return jsonify(models)

@models_bp.route('/embedding_models/recent', methods=['GET'])
def get_recent_embedding_models():
return get_recent_models("embedding")

@models_bp.route('/chat_models/recent', methods=['GET'])
def get_recent_chat_models():
return get_recent_models("chat")

def get_recent_models(model_type="embedding"):
recent_models_path = os.path.join(DATA_DIR, f"{model_type}_model_history.csv")
if not os.path.exists(recent_models_path):
return jsonify([])

with open(recent_models_path, 'r', encoding='utf-8') as file:
reader = csv.reader(file)
recent_models = []
for row in reader:
recent_models.append({
"timestamp": row[0],
"id": row[1],
"group": "recent",
"provider": row[1].split("-")[0],
"name": "-".join(row[1].split("-")[1:]).replace("___", "/")
})
recent_models.sort(key=lambda x: x["timestamp"], reverse=True)
# Deduplicate models with the same id
seen_ids = set()
unique_recent_models = []
for model in recent_models:
if model["id"] not in seen_ids:
unique_recent_models.append(model)
seen_ids.add(model["id"])
recent_models = unique_recent_models[:5]
return jsonify(recent_models)

@models_bp.route('/custom-models', methods=['GET'])
def get_custom_models():
custom_models_path = os.path.join(DATA_DIR, "custom_models.json")
if not os.path.exists(custom_models_path):
return jsonify([])
with open(custom_models_path, 'r', encoding='utf-8') as file:
custom_models = json.load(file)
return jsonify(custom_models)

@models_write_bp.route('/custom-models', methods=['POST'])
def add_custom_model():
data = request.json
custom_models_path = os.path.join(DATA_DIR, "custom_models.json")

# Read existing models
existing_models = []
if os.path.exists(custom_models_path):
with open(custom_models_path, 'r', encoding='utf-8') as file:
existing_models = json.load(file)

# Add new model
data["id"] = "custom-" + data["name"]
existing_models.append(data)

# Write updated models
with open(custom_models_path, 'w', encoding='utf-8') as file:
json.dump(existing_models, file)

return jsonify(existing_models)

@models_write_bp.route('/custom-models/<model_id>', methods=['DELETE'])
def delete_custom_model(model_id):
custom_models_path = os.path.join(DATA_DIR, "custom_models.json")
if not os.path.exists(custom_models_path):
return jsonify([])
with open(custom_models_path, 'r', encoding='utf-8') as file:
custom_models = json.load(file)
custom_models = [model for model in custom_models if model["id"] != model_id]
with open(custom_models_path, 'w', encoding='utf-8') as file:
json.dump(custom_models, file)
return jsonify(custom_models)
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def run(self):
'ls-export-plot=latentscope.scripts.export_plot:main',
'ls-update-embedding-stats=latentscope.scripts.embed:update_embedding_stats',
'ls-sae=latentscope.scripts.sae:main',
'ls-download-dataset=latentscope.scripts.download_dataset:main',
'ls-upload-dataset=latentscope.scripts.upload_dataset:main',
],
},
include_package_data=True,
Expand Down
107 changes: 107 additions & 0 deletions web/src/components/CustomModels.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import { useState, useEffect, useCallback } from 'react';
import { Button } from 'react-element-forge';
import { apiService } from '../lib/apiService';

import styles from './CustomModels.module.scss';

function CustomModels({ data_dir }) {
const [customModels, setCustomModels] = useState([]);
const [isSubmitting, setIsSubmitting] = useState(false);

// Fetch existing custom models on component mount
useEffect(() => {
apiService.fetchCustomModels().then((models) => setCustomModels(models));
}, []);

const handleAddModel = useCallback(async (e) => {
e.preventDefault();
setIsSubmitting(true);

const form = e.target;
const data = new FormData(form);
const newModel = {
name: data.get('name'),
url: data.get('url'),
params: {},
provider: 'custom',
};

try {
const updatedModels = await apiService.addCustomModel(newModel);
console.log('updatedModels', updatedModels);
setCustomModels(updatedModels);
form.reset();
} catch (error) {
console.error('Failed to add custom model:', error);
} finally {
setIsSubmitting(false);
}
}, []);

const handleDeleteModel = useCallback(
async (modelId) => {
try {
await apiService.deleteCustomModel(modelId);
setCustomModels(customModels.filter((model) => model.id !== modelId));
} catch (error) {
console.error('Failed to delete custom model:', error);
}
},
[customModels]
);

return (
<div className={styles['custom-models']}>
<div className={styles['custom-models-setup']}>
<div className={styles['custom-models-form']}>
<form onSubmit={handleAddModel}>
<label>
<span className={styles['form-label']}>Model: </span>
<input
type="text"
name="name"
placeholder="llama3.2"
required
disabled={isSubmitting}
/>
</label>

<label>
<span className={styles['form-label']}>URL: </span>
<input
type="url"
name="url"
placeholder="http://localhost:8080/v1"
required
disabled={isSubmitting}
/>
</label>

<Button type="submit" color="primary" disabled={isSubmitting} text="Add Model" />
</form>
</div>

<div className={styles['custom-models-list']}>
{customModels.map((model, index) => (
<div className={styles['item']} key={index}>
<div className={styles['item-info']}>
<span className={styles['item-name']}>{model.name}</span>
<span className={styles['item-url']}>{model.url}</span>
</div>

<Button
className={styles['delete']}
color="secondary"
onClick={() => handleDeleteModel(model.id)}
disabled={isSubmitting}
text="🗑️"
/>
</div>
))}
</div>
</div>
</div>
);
}

export default CustomModels;
Loading

0 comments on commit 5d2a9db

Please sign in to comment.