Skip to content

Commit

Permalink
Merge pull request #20 from homanp/finetune-replicate
Browse files Browse the repository at this point in the history
Add possibility to fine-tune `Replicate` models.
  • Loading branch information
homanp authored Oct 11, 2023
2 parents 6292536 + 949a524 commit 9d3ccb3
Show file tree
Hide file tree
Showing 14 changed files with 249 additions and 27 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ DATABASE_URL=
DATABASE_MIGRATION_URL=
OPENAI_API_KEY=
HF_API_KEY=
PINECONE_API_KEY=
PINECONE_API_KEY=
REPLICATE_API_TOKEN=
19 changes: 15 additions & 4 deletions lib/api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio

from fastapi import APIRouter
from service.flows import create_embeddings, create_finetune

from lib.models.ingest import IngestRequest
from lib.service.flows import create_finetune
from lib.utils.prisma import prisma
from prisma.models import Datasource

router = APIRouter()

Expand All @@ -14,7 +17,15 @@
)
async def ingest(body: IngestRequest):
"""Endpoint for ingesting data"""
datasource = await prisma.datasource.create(data={**body})
await create_embeddings(datasource=datasource)
await create_finetune(datasource=datasource)
datasource = await prisma.datasource.create(data=body.dict())

async def run_training_flow(datasource: Datasource):
try:
await create_finetune(
datasource=datasource,
)
except Exception as flow_exception:
raise flow_exception

asyncio.create_task(run_training_flow(datasource=datasource))
return {"success": True, "data": datasource}
9 changes: 8 additions & 1 deletion lib/models/ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import Optional

from pydantic import BaseModel


class IngestRequest(BaseModel):
webhook_url: str
type: str
base_model: str
provider: str
url: Optional[str]
content: Optional[str]
webhook_url: Optional[str]
2 changes: 1 addition & 1 deletion lib/service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def generate_chunks(
async def generate_embeddings(
self, nodes: List[Union[Document, None]]
) -> List[ndarray]:
vectordb = get_vector_service(
vectordb = await get_vector_service(
provider="pinecone",
index_name="all-minilm-l6-v2",
namespace=self.datasource.id,
Expand Down
134 changes: 127 additions & 7 deletions lib/service/finetune.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,56 @@
# flake8: noqa

import asyncio
import json
import os
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Union

import httpx
import openai
import replicate
from decouple import config
from llama_index import Document
from numpy import ndarray

from lib.service.prompts import GPT_DATA_FORMAT, generate_qa_pair_prompt
from lib.service.prompts import (
GPT_DATA_FORMAT,
REPLICATE_FORMAT,
generate_qa_pair_prompt,
)

openai.api_key = config("OPENAI_API_KEY")

REPLICATE_MODELS = {
"LLAMA2_7B_CHAT": "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e",
"LLAMA2_7B": "meta/llama-2-7b:527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef",
"LLAMA2_13B_CHAT": "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
"LLAMA2_13B": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38",
"LLAMA2_70B_CHAT": "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
"LLAMA2_70B": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00",
"GPT_J_6B": "replicate/gpt-j-6b:b3546aeec6c9891f0dd9929c2d3bedbf013c12e02e7dd0346af09c37e008c827",
"DOLLY_V2_12B": "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
}

OPENAI_MODELS = {"GPT_35_TURBO": "gpt-3.5-turbo"}


class FinetuningService(ABC):
def __init__(self, nodes: List[Union[Document, None]]):
self.nodes = nodes

@abstractmethod
async def generate_dataset(self) -> List[Tuple[str, ndarray]]:
async def generate_dataset(self) -> str:
pass

@abstractmethod
async def finetune(self, training_file: str) -> Dict:
async def validate_dataset(self, training_file: str) -> str:
pass

@abstractmethod
async def finetune(self, training_file: str, base_model: str) -> Dict:
pass

async def cleanup(self, training_file: str) -> None:
os.remove(training_file)

Expand All @@ -37,10 +61,12 @@ def __init__(
nodes: List[Union[Document, None]],
num_questions_per_chunk: int = 10,
batch_size: int = 10,
base_model: str = "GPT_35_TURBO",
):
super().__init__(nodes=nodes)
self.num_questions_per_chunk = num_questions_per_chunk
self.batch_size = batch_size
self.base_model = base_model

async def generate_prompt_and_completion(self, node):
prompt = generate_qa_pair_prompt(
Expand Down Expand Up @@ -68,23 +94,95 @@ async def generate_dataset(self) -> str:
json_objects = qa_pair.split("\n\n")
for json_obj in json_objects:
f.write(json_obj + "\n")
return training_file

async def validate_dataset(self, training_file: str) -> str:
pass

async def finetune(self, training_file: str) -> Dict:
file = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune")
finetune = await openai.FineTuningJob.acreate(
training_file=file.get("id"), model="gpt-3.5-turbo"
training_file=file.get("id"), model=OPENAI_MODELS[self.base_model]
)
return {**finetune, "training_file": training_file}


class ReplicateFinetuningService(FinetuningService):
def __init__(
self,
nodes: List[Union[Document, None]],
num_questions_per_chunk: int = 10,
batch_size: int = 10,
base_model: str = "LLAMA2_7B_CHAT",
):
super().__init__(nodes=nodes)
self.num_questions_per_chunk = num_questions_per_chunk
self.batch_size = batch_size
self.base_model = base_model

async def generate_prompt_and_completion(self, node):
prompt = generate_qa_pair_prompt(
context=node.text, num_of_qa_paris=10, format=REPLICATE_FORMAT
)
completion = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
return completion.choices[0].message.content

async def generate_dataset(self) -> str:
training_file = f"{uuid.uuid4()}.jsonl"
with open(training_file, "w") as f:
for i in range(
0, len(self.nodes), self.batch_size
): # Process nodes in chunks of batch_size
tasks = [
self.generate_prompt_and_completion(node)
for node in self.nodes[i : i + self.batch_size]
]
qa_pairs = await asyncio.gather(*tasks)
for qa_pair in qa_pairs:
json_objects = qa_pair.split("\n\n")
for json_obj in json_objects:
f.write(json_obj + "\n")
return training_file

async def validate_dataset(self, training_file: str) -> str:
valid_data = []
with open(training_file, "r") as f:
for line in f:
data = json.loads(line)
if "prompt" in data and "completion" in data:
valid_data.append(data)
with open(training_file, "w") as f:
for data in valid_data:
f.write(json.dumps(data) + "\n")
return training_file

async def finetune(self, training_file: str) -> Dict:
training_file_url = await upload_replicate_dataset(training_file=training_file)
training = replicate.trainings.create(
version=REPLICATE_MODELS[self.base_model],
input={
"train_data": training_file_url,
"num_train_epochs": 6,
},
destination="homanp/test",
)
return {"id": training.id, "training_file": training_file}


async def get_finetuning_service(
nodes: List[Union[Document, None]],
provider: str = "openai",
base_model: str = "GPT_35_TURBO",
num_questions_per_chunk: int = 10,
batch_size: int = 10,
):
services = {
"openai": OpenAIFinetuningService,
"OPENAI": OpenAIFinetuningService,
"REPLICATE": ReplicateFinetuningService,
# Add other providers here
}
service = services.get(provider)
Expand All @@ -94,4 +192,26 @@ async def get_finetuning_service(
nodes=nodes,
num_questions_per_chunk=num_questions_per_chunk,
batch_size=batch_size,
base_model=base_model,
)


async def upload_replicate_dataset(training_file: str) -> str:
headers = {"Authorization": f"Token {config('REPLICATE_API_TOKEN')}"}
async with httpx.AsyncClient() as client:
response = await client.post(
"https://dreambooth-api-experimental.replicate.com/v1/upload/data.jsonl",
headers=headers,
)
response_data = response.json()
upload_url = response_data["upload_url"]

with open(training_file, "rb") as f:
await client.put(
upload_url,
headers={"Content-Type": "application/jsonl"},
content=f.read(),
)

serving_url = response_data["serving_url"]
return serving_url
26 changes: 15 additions & 11 deletions lib/service/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import List, Union

import openai
Expand Down Expand Up @@ -27,24 +28,27 @@ async def create_finetuned_model(datasource: Datasource):
documents = await embedding_service.generate_documents()
nodes = await embedding_service.generate_chunks(documents=documents)
finetunning_service = await get_finetuning_service(
nodes=nodes, provider="openai", batch_size=5
nodes=nodes,
provider=datasource.provider,
batch_size=5,
base_model=datasource.base_model,
)
await finetunning_service.generate_dataset()
finetune_job = await finetunning_service.finetune()
finetune = await openai.FineTune.retrieve(id=finetune_job.id)
await finetunning_service.cleanup(training_file=finetune_job.get("training_file"))
training_file = await finetunning_service.generate_dataset()
formatted_training_file = await finetunning_service.validate_dataset(
training_file=training_file
)
finetune = await finetunning_service.finetune(training_file=formatted_training_file)
if datasource.provider == "OPENAI":
finetune = await openai.FineTune.retrieve(id=finetune.get("id"))
await finetunning_service.cleanup(training_file=finetune.get("training_file"))
return finetune


@flow(name="create_embeddings", description="Create embeddings", retries=0)
async def create_embeddings(datasource: Datasource):
await create_vector_embeddings(datasource=datasource)


@flow(name="create_finetune", description="Create a finetune", retries=0)
async def create_finetune(datasource: Datasource):
await create_vector_embeddings(datasource=datasource)
finetune = await create_finetuned_model(datasource=datasource)
await prisma.datasource.update(
where={"id": datasource.id},
data={"finetune": finetune},
data={"finetune": json.dumps(finetune)},
)
9 changes: 8 additions & 1 deletion lib/service/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
'"messages": ['
'{"role": "system", "content": "You are an AI agent that\'s an expert at answering questions."}, '
'{"role": "user", "content": "What\'s the capital of France?"}, '
'{"role": "assistant", "content": "Paris, as if everyone doesn\'t know that already."}'
'{"role": "assistant", "content": "Paris, is the capital of France."}'
"]"
"}"
)

REPLICATE_FORMAT = (
"{"
'"prompt": "What\'s the capital of France?",'
'"completion": "Paris, is the capital of France"'
"}"
)


def generate_qa_pair_prompt(format: str, context: str, num_of_qa_paris: int = 10):
prompt = (
Expand Down
21 changes: 20 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- CreateEnum
CREATE TYPE "BaseModelType" AS ENUM ('GPT_35_TURBO', 'LLAMA2');

-- AlterTable
ALTER TABLE "Datasource" ADD COLUMN "base_model" TEXT NOT NULL DEFAULT 'GPT_35_TURBO';
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
Warnings:
- The values [LLAMA2] on the enum `BaseModelType` will be removed. If these variants are still used in the database, this will fail.
*/
-- AlterEnum
BEGIN;
CREATE TYPE "BaseModelType_new" AS ENUM ('GPT_35_TURBO', 'LLAMA2_7B', 'LLAMA2_7B_CHAT', 'LLAMA2_13B', 'LLAMA2_13B_CHAT', 'LLAMA2_70B', 'LLAMA2_70B_CHAT');
ALTER TYPE "BaseModelType" RENAME TO "BaseModelType_old";
ALTER TYPE "BaseModelType_new" RENAME TO "BaseModelType";
DROP TYPE "BaseModelType_old";
COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- CreateEnum
CREATE TYPE "ModelProviders" AS ENUM ('OPENAI', 'REPLICATE');

-- AlterTable
ALTER TABLE "Datasource" ADD COLUMN "provider" TEXT NOT NULL DEFAULT 'OPENAI';
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- AlterEnum
-- This migration adds more than one value to an enum.
-- With PostgreSQL versions 11 and earlier, this is not possible
-- in a single migration. This can be worked around by creating
-- multiple migrations, each migration adding only one value to
-- the enum.


ALTER TYPE "BaseModelType" ADD VALUE 'GPT_J_6B';
ALTER TYPE "BaseModelType" ADD VALUE 'DOLLY_V2_12B';
Loading

0 comments on commit 9d3ccb3

Please sign in to comment.