From 890df9d7f44367b329cb6a56873840bb91f79214 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Mon, 9 Oct 2023 15:52:45 +0200 Subject: [PATCH 1/4] WIP --- .../20231007205348_datasource_base_model/migration.sql | 5 +++++ prisma/schema.prisma | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 prisma/migrations/20231007205348_datasource_base_model/migration.sql diff --git a/prisma/migrations/20231007205348_datasource_base_model/migration.sql b/prisma/migrations/20231007205348_datasource_base_model/migration.sql new file mode 100644 index 0000000..bb2cb6c --- /dev/null +++ b/prisma/migrations/20231007205348_datasource_base_model/migration.sql @@ -0,0 +1,5 @@ +-- CreateEnum +CREATE TYPE "BaseModelType" AS ENUM ('GPT_35_TURBO', 'LLAMA2'); + +-- AlterTable +ALTER TABLE "Datasource" ADD COLUMN "base_model" TEXT NOT NULL DEFAULT 'GPT_35_TURBO'; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index efa399c..af61954 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -5,7 +5,7 @@ generator client { datasource db { provider = "postgresql" - url = env("DATABASE_URL") + url = env("DATABASE_MIGRATION_URL") shadowDatabaseUrl = env("DATABASE_SHADOW_URL") } @@ -16,6 +16,11 @@ enum DatasourceType { MARKDOWN } +enum BaseModelType { + GPT_35_TURBO + LLAMA2 +} + enum DatasourceStatus { IN_PROGRESS DONE @@ -24,6 +29,7 @@ enum DatasourceStatus { model Datasource { id String @id @default(uuid()) + base_model String @default("GPT_35_TURBO") content String? @db.Text() status DatasourceStatus @default(IN_PROGRESS) type DatasourceType From 69e54480a68b09f4b70bdd4ddd5d2ddf864615ee Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Tue, 10 Oct 2023 22:41:27 +0200 Subject: [PATCH 2/4] WIP --- lib/api/ingest.py | 28 ++++++++++++++++++++++------ lib/models/ingest.py | 6 +++++- lib/service/embedding.py | 2 +- lib/service/finetune.py | 2 +- lib/service/flows.py | 6 +++--- prisma/schema.prisma | 2 +- 6 files changed, 33 insertions(+), 13 deletions(-) diff --git a/lib/api/ingest.py b/lib/api/ingest.py index ece3f61..6fa3aa0 100644 --- a/lib/api/ingest.py +++ b/lib/api/ingest.py @@ -1,5 +1,8 @@ -from fastapi import APIRouter -from service.flows import create_embeddings, create_finetune +import asyncio + +from prisma.models import Datasource +from fastapi import APIRouter, BackgroundTasks +from lib.service.flows import create_embeddings, create_finetune from lib.models.ingest import IngestRequest from lib.utils.prisma import prisma @@ -7,14 +10,27 @@ router = APIRouter() +async def run_embedding_flow(datasource: Datasource): + await create_embeddings( + datasource=datasource, + ) + + +async def run_finetune_flow(datasource: Datasource): + await create_finetune( + datasource=datasource, + ) + + @router.post( "/ingest", name="ingest", description="Ingest data", ) -async def ingest(body: IngestRequest): +async def ingest(body: IngestRequest, background_tasks: BackgroundTasks): """Endpoint for ingesting data""" - datasource = await prisma.datasource.create(data={**body}) - await create_embeddings(datasource=datasource) - await create_finetune(datasource=datasource) + datasource = await prisma.datasource.create(data=body.dict()) + + background_tasks.add_task(run_embedding_flow, datasource=datasource) + background_tasks.add_task(run_finetune_flow, datasource=datasource) return {"success": True, "data": datasource} diff --git a/lib/models/ingest.py b/lib/models/ingest.py index d1fc22f..8ba02ae 100644 --- a/lib/models/ingest.py +++ b/lib/models/ingest.py @@ -1,5 +1,9 @@ from pydantic import BaseModel +from typing import Optional class IngestRequest(BaseModel): - webhook_url: str + type: str + url: Optional[str] + content: Optional[str] + webhook_url: Optional[str] diff --git a/lib/service/embedding.py b/lib/service/embedding.py index 3f36259..a0c0160 100644 --- a/lib/service/embedding.py +++ b/lib/service/embedding.py @@ -47,7 +47,7 @@ async def generate_chunks( async def generate_embeddings( self, nodes: List[Union[Document, None]] ) -> List[ndarray]: - vectordb = get_vector_service( + vectordb = await get_vector_service( provider="pinecone", index_name="all-minilm-l6-v2", namespace=self.datasource.id, diff --git a/lib/service/finetune.py b/lib/service/finetune.py index eb67f66..8d3d882 100644 --- a/lib/service/finetune.py +++ b/lib/service/finetune.py @@ -26,7 +26,6 @@ async def generate_dataset(self) -> List[Tuple[str, ndarray]]: async def finetune(self, training_file: str) -> Dict: pass - @abstractmethod async def cleanup(self, training_file: str) -> None: os.remove(training_file) @@ -68,6 +67,7 @@ async def generate_dataset(self) -> str: json_objects = qa_pair.split("\n\n") for json_obj in json_objects: f.write(json_obj + "\n") + return training_file async def finetune(self, training_file: str) -> Dict: file = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune") diff --git a/lib/service/flows.py b/lib/service/flows.py index 4c1c4f4..f625c9e 100644 --- a/lib/service/flows.py +++ b/lib/service/flows.py @@ -29,9 +29,9 @@ async def create_finetuned_model(datasource: Datasource): finetunning_service = await get_finetuning_service( nodes=nodes, provider="openai", batch_size=5 ) - await finetunning_service.generate_dataset() - finetune_job = await finetunning_service.finetune() - finetune = await openai.FineTune.retrieve(id=finetune_job.id) + training_file = await finetunning_service.generate_dataset() + finetune_job = await finetunning_service.finetune(training_file=training_file) + finetune = await openai.FineTune.retrieve(id=finetune_job.get("id")) await finetunning_service.cleanup(training_file=finetune_job.get("training_file")) return finetune diff --git a/prisma/schema.prisma b/prisma/schema.prisma index af61954..a344050 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -5,7 +5,7 @@ generator client { datasource db { provider = "postgresql" - url = env("DATABASE_MIGRATION_URL") + url = env("DATABASE_URL") shadowDatabaseUrl = env("DATABASE_SHADOW_URL") } From 3545fd55c9c971950441a666b211b1a6a8410d35 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Wed, 11 Oct 2023 10:03:44 +0200 Subject: [PATCH 3/4] Add possiblity to create fine-tunes on Replicate --- .env.example | 3 +- lib/api/ingest.py | 25 ++-- lib/models/ingest.py | 2 + lib/service/finetune.py | 137 ++++++++++++++++-- lib/service/flows.py | 28 ++-- lib/service/prompts.py | 9 +- poetry.lock | 21 ++- .../migration.sql | 13 ++ .../migration.sql | 5 + .../migration.sql | 10 ++ prisma/schema.prisma | 15 +- pyproject.toml | 1 + 12 files changed, 229 insertions(+), 40 deletions(-) create mode 100644 prisma/migrations/20231010204838_add_base_model_types/migration.sql create mode 100644 prisma/migrations/20231011062332_add_datasource_provider/migration.sql create mode 100644 prisma/migrations/20231011075903_add_dolly_gptj_base_models/migration.sql diff --git a/.env.example b/.env.example index 940959b..9a27584 100644 --- a/.env.example +++ b/.env.example @@ -2,4 +2,5 @@ DATABASE_URL= DATABASE_MIGRATION_URL= OPENAI_API_KEY= HF_API_KEY= -PINECONE_API_KEY= \ No newline at end of file +PINECONE_API_KEY= +REPLICATE_API_TOKEN= \ No newline at end of file diff --git a/lib/api/ingest.py b/lib/api/ingest.py index 6fa3aa0..b1ba744 100644 --- a/lib/api/ingest.py +++ b/lib/api/ingest.py @@ -2,7 +2,7 @@ from prisma.models import Datasource from fastapi import APIRouter, BackgroundTasks -from lib.service.flows import create_embeddings, create_finetune +from lib.service.flows import create_finetune from lib.models.ingest import IngestRequest from lib.utils.prisma import prisma @@ -10,18 +10,6 @@ router = APIRouter() -async def run_embedding_flow(datasource: Datasource): - await create_embeddings( - datasource=datasource, - ) - - -async def run_finetune_flow(datasource: Datasource): - await create_finetune( - datasource=datasource, - ) - - @router.post( "/ingest", name="ingest", @@ -31,6 +19,13 @@ async def ingest(body: IngestRequest, background_tasks: BackgroundTasks): """Endpoint for ingesting data""" datasource = await prisma.datasource.create(data=body.dict()) - background_tasks.add_task(run_embedding_flow, datasource=datasource) - background_tasks.add_task(run_finetune_flow, datasource=datasource) + async def run_training_flow(datasource: Datasource): + try: + await create_finetune( + datasource=datasource, + ) + except Exception as flow_exception: + raise flow_exception + + asyncio.create_task(run_training_flow(datasource=datasource)) return {"success": True, "data": datasource} diff --git a/lib/models/ingest.py b/lib/models/ingest.py index 8ba02ae..59d68c3 100644 --- a/lib/models/ingest.py +++ b/lib/models/ingest.py @@ -4,6 +4,8 @@ class IngestRequest(BaseModel): type: str + base_model: str + provider: str url: Optional[str] content: Optional[str] webhook_url: Optional[str] diff --git a/lib/service/finetune.py b/lib/service/finetune.py index 8d3d882..6b9a1c4 100644 --- a/lib/service/finetune.py +++ b/lib/service/finetune.py @@ -1,29 +1,52 @@ import asyncio +import httpx import os import uuid -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union - import openai +import replicate +import json + +from abc import ABC, abstractmethod +from typing import Dict, List, Union from decouple import config from llama_index import Document -from numpy import ndarray -from lib.service.prompts import GPT_DATA_FORMAT, generate_qa_pair_prompt +from lib.service.prompts import ( + GPT_DATA_FORMAT, + REPLICATE_FORMAT, + generate_qa_pair_prompt, +) openai.api_key = config("OPENAI_API_KEY") +REPLICATE_MODELS = { + "LLAMA2_7B_CHAT": "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e", + "LLAMA2_7B": "meta/llama-2-7b:527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef", + "LLAMA2_13B_CHAT": "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + "LLAMA2_13B": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38", + "LLAMA2_70B_CHAT": "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", + "LLAMA2_70B": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00", + "GPT_J_6B": "replicate/gpt-j-6b:b3546aeec6c9891f0dd9929c2d3bedbf013c12e02e7dd0346af09c37e008c827", + "DOLLY_V2_12B": "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5", +} + +OPENAI_MODELS = {"GPT_35_TURBO": "gpt-3.5-turbo"} + class FinetuningService(ABC): def __init__(self, nodes: List[Union[Document, None]]): self.nodes = nodes @abstractmethod - async def generate_dataset(self) -> List[Tuple[str, ndarray]]: + async def generate_dataset(self) -> str: pass @abstractmethod - async def finetune(self, training_file: str) -> Dict: + async def validate_dataset(self, training_file: str) -> str: + pass + + @abstractmethod + async def finetune(self, training_file: str, base_model: str) -> Dict: pass async def cleanup(self, training_file: str) -> None: @@ -36,10 +59,12 @@ def __init__( nodes: List[Union[Document, None]], num_questions_per_chunk: int = 10, batch_size: int = 10, + base_model: str = "GPT_35_TURBO", ): super().__init__(nodes=nodes) self.num_questions_per_chunk = num_questions_per_chunk self.batch_size = batch_size + self.base_model = base_model async def generate_prompt_and_completion(self, node): prompt = generate_qa_pair_prompt( @@ -69,22 +94,93 @@ async def generate_dataset(self) -> str: f.write(json_obj + "\n") return training_file + async def validate_dataset(self, training_file: str) -> str: + pass + async def finetune(self, training_file: str) -> Dict: file = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune") finetune = await openai.FineTuningJob.acreate( - training_file=file.get("id"), model="gpt-3.5-turbo" + training_file=file.get("id"), model=OPENAI_MODELS[self.base_model] ) return {**finetune, "training_file": training_file} +class ReplicateFinetuningService(FinetuningService): + def __init__( + self, + nodes: List[Union[Document, None]], + num_questions_per_chunk: int = 1, + batch_size: int = 10, + base_model: str = "LLAMA2_7B_CHAT", + ): + super().__init__(nodes=nodes) + self.num_questions_per_chunk = num_questions_per_chunk + self.batch_size = batch_size + self.base_model = base_model + + async def generate_prompt_and_completion(self, node): + prompt = generate_qa_pair_prompt( + context=node.text, num_of_qa_paris=10, format=REPLICATE_FORMAT + ) + completion = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + return completion.choices[0].message.content + + async def generate_dataset(self) -> str: + training_file = f"{uuid.uuid4()}.jsonl" + with open(training_file, "w") as f: + for i in range( + 0, len(self.nodes), self.batch_size + ): # Process nodes in chunks of batch_size + tasks = [ + self.generate_prompt_and_completion(node) + for node in self.nodes[i : i + self.batch_size] + ] + qa_pairs = await asyncio.gather(*tasks) + for qa_pair in qa_pairs: + json_objects = qa_pair.split("\n\n") + for json_obj in json_objects: + f.write(json_obj + "\n") + return training_file + + async def validate_dataset(self, training_file: str) -> str: + valid_data = [] + with open(training_file, "r") as f: + for line in f: + data = json.loads(line) + if "prompt" in data and "completion" in data: + valid_data.append(data) + with open(training_file, "w") as f: + for data in valid_data: + f.write(json.dumps(data) + "\n") + return training_file + + async def finetune(self, training_file: str) -> Dict: + training_file_url = await upload_replicate_dataset(training_file=training_file) + training = replicate.trainings.create( + version=REPLICATE_MODELS[self.base_model], + input={ + "train_data": training_file_url, + "num_train_epochs": 3, + }, + destination="homanp/test", + ) + return {"id": training.id, "training_file": training_file} + + async def get_finetuning_service( nodes: List[Union[Document, None]], provider: str = "openai", + base_model: str = "GPT_35_TURBO", num_questions_per_chunk: int = 10, batch_size: int = 10, ): services = { - "openai": OpenAIFinetuningService, + "OPENAI": OpenAIFinetuningService, + "REPLICATE": ReplicateFinetuningService, # Add other providers here } service = services.get(provider) @@ -94,4 +190,27 @@ async def get_finetuning_service( nodes=nodes, num_questions_per_chunk=num_questions_per_chunk, batch_size=batch_size, + base_model=base_model, ) + + +async def upload_replicate_dataset(training_file: str) -> str: + headers = {"Authorization": f"Token {config('REPLICATE_API_TOKEN')}"} + async with httpx.AsyncClient() as client: + response = await client.post( + "https://dreambooth-api-experimental.replicate.com/v1/upload/data.jsonl", + headers=headers, + ) + response_data = response.json() + upload_url = response_data["upload_url"] + + with open(training_file, "rb") as f: + await client.put( + upload_url, + headers={"Content-Type": "application/jsonl"}, + content=f.read(), + ) + + serving_url = response_data["serving_url"] + print(serving_url) + return serving_url diff --git a/lib/service/flows.py b/lib/service/flows.py index f625c9e..c53d5cf 100644 --- a/lib/service/flows.py +++ b/lib/service/flows.py @@ -1,6 +1,7 @@ -from typing import List, Union - import openai +import json + +from typing import List, Union from llama_index import Document from prefect import flow, task @@ -27,24 +28,27 @@ async def create_finetuned_model(datasource: Datasource): documents = await embedding_service.generate_documents() nodes = await embedding_service.generate_chunks(documents=documents) finetunning_service = await get_finetuning_service( - nodes=nodes, provider="openai", batch_size=5 + nodes=nodes, + provider=datasource.provider, + batch_size=5, + base_model=datasource.base_model, ) training_file = await finetunning_service.generate_dataset() - finetune_job = await finetunning_service.finetune(training_file=training_file) - finetune = await openai.FineTune.retrieve(id=finetune_job.get("id")) - await finetunning_service.cleanup(training_file=finetune_job.get("training_file")) + formatted_training_file = await finetunning_service.validate_dataset( + training_file=training_file + ) + finetune = await finetunning_service.finetune(training_file=formatted_training_file) + if datasource.provider == "OPENAI": + finetune = await openai.FineTune.retrieve(id=finetune.get("id")) + await finetunning_service.cleanup(training_file=finetune.get("training_file")) return finetune -@flow(name="create_embeddings", description="Create embeddings", retries=0) -async def create_embeddings(datasource: Datasource): - await create_vector_embeddings(datasource=datasource) - - @flow(name="create_finetune", description="Create a finetune", retries=0) async def create_finetune(datasource: Datasource): + await create_vector_embeddings(datasource=datasource) finetune = await create_finetuned_model(datasource=datasource) await prisma.datasource.update( where={"id": datasource.id}, - data={"finetune": finetune}, + data={"finetune": json.dumps(finetune)}, ) diff --git a/lib/service/prompts.py b/lib/service/prompts.py index 223735c..f87c87d 100644 --- a/lib/service/prompts.py +++ b/lib/service/prompts.py @@ -5,11 +5,18 @@ '"messages": [' '{"role": "system", "content": "You are an AI agent that\'s an expert at answering questions."}, ' '{"role": "user", "content": "What\'s the capital of France?"}, ' - '{"role": "assistant", "content": "Paris, as if everyone doesn\'t know that already."}' + '{"role": "assistant", "content": "Paris, is the capital of France."}' "]" "}" ) +REPLICATE_FORMAT = ( + "{" + '"prompt": "What\'s the capital of France?",' + '"completion": "Paris, is the capital of France"' + "}" +) + def generate_qa_pair_prompt(format: str, context: str, num_of_qa_paris: int = 10): prompt = ( diff --git a/poetry.lock b/poetry.lock index 4ee1fd9..bcc50f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3019,6 +3019,25 @@ files = [ {file = "regex-2023.8.8.tar.gz", hash = "sha256:fcbdc5f2b0f1cd0f6a56cdb46fe41d2cce1e644e3b68832f3eeebc5fb0f7712e"}, ] +[[package]] +name = "replicate" +version = "0.15.4" +description = "Python client for Replicate" +optional = false +python-versions = ">=3.8" +files = [ + {file = "replicate-0.15.4-py3-none-any.whl", hash = "sha256:082cc363357ba02da820ede147cc35677499d4ae67e554e48d4e8212122b3f14"}, + {file = "replicate-0.15.4.tar.gz", hash = "sha256:8f3fd07685da42aa6de31f895129762322ef07352294c13af8080bf53f126c83"}, +] + +[package.dependencies] +httpx = ">=0.21.0,<1" +packaging = "*" +pydantic = ">1" + +[package.extras] +dev = ["black", "mypy", "pytest", "pytest-asyncio", "pytest-recording", "respx", "ruff"] + [[package]] name = "requests" version = "2.31.0" @@ -4729,4 +4748,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0ad1dc0620c169bc927aa392d4cccc6f2e87be994f03b41a01dad5647c96e495" +content-hash = "9201706ae52ca01c96b099e04d699ffd06e43de885386c6632ab24dcc8525e46" diff --git a/prisma/migrations/20231010204838_add_base_model_types/migration.sql b/prisma/migrations/20231010204838_add_base_model_types/migration.sql new file mode 100644 index 0000000..6effade --- /dev/null +++ b/prisma/migrations/20231010204838_add_base_model_types/migration.sql @@ -0,0 +1,13 @@ +/* + Warnings: + + - The values [LLAMA2] on the enum `BaseModelType` will be removed. If these variants are still used in the database, this will fail. + +*/ +-- AlterEnum +BEGIN; +CREATE TYPE "BaseModelType_new" AS ENUM ('GPT_35_TURBO', 'LLAMA2_7B', 'LLAMA2_7B_CHAT', 'LLAMA2_13B', 'LLAMA2_13B_CHAT', 'LLAMA2_70B', 'LLAMA2_70B_CHAT'); +ALTER TYPE "BaseModelType" RENAME TO "BaseModelType_old"; +ALTER TYPE "BaseModelType_new" RENAME TO "BaseModelType"; +DROP TYPE "BaseModelType_old"; +COMMIT; diff --git a/prisma/migrations/20231011062332_add_datasource_provider/migration.sql b/prisma/migrations/20231011062332_add_datasource_provider/migration.sql new file mode 100644 index 0000000..4a5137a --- /dev/null +++ b/prisma/migrations/20231011062332_add_datasource_provider/migration.sql @@ -0,0 +1,5 @@ +-- CreateEnum +CREATE TYPE "ModelProviders" AS ENUM ('OPENAI', 'REPLICATE'); + +-- AlterTable +ALTER TABLE "Datasource" ADD COLUMN "provider" TEXT NOT NULL DEFAULT 'OPENAI'; diff --git a/prisma/migrations/20231011075903_add_dolly_gptj_base_models/migration.sql b/prisma/migrations/20231011075903_add_dolly_gptj_base_models/migration.sql new file mode 100644 index 0000000..613ee1c --- /dev/null +++ b/prisma/migrations/20231011075903_add_dolly_gptj_base_models/migration.sql @@ -0,0 +1,10 @@ +-- AlterEnum +-- This migration adds more than one value to an enum. +-- With PostgreSQL versions 11 and earlier, this is not possible +-- in a single migration. This can be worked around by creating +-- multiple migrations, each migration adding only one value to +-- the enum. + + +ALTER TYPE "BaseModelType" ADD VALUE 'GPT_J_6B'; +ALTER TYPE "BaseModelType" ADD VALUE 'DOLLY_V2_12B'; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index a344050..fa8947f 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -18,7 +18,19 @@ enum DatasourceType { enum BaseModelType { GPT_35_TURBO - LLAMA2 + LLAMA2_7B + LLAMA2_7B_CHAT + LLAMA2_13B + LLAMA2_13B_CHAT + LLAMA2_70B + LLAMA2_70B_CHAT + GPT_J_6B + DOLLY_V2_12B +} + +enum ModelProviders { + OPENAI + REPLICATE } enum DatasourceStatus { @@ -30,6 +42,7 @@ enum DatasourceStatus { model Datasource { id String @id @default(uuid()) base_model String @default("GPT_35_TURBO") + provider String @default("OPENAI") content String? @db.Text() status DatasourceStatus @default(IN_PROGRESS) type DatasourceType diff --git a/pyproject.toml b/pyproject.toml index f69bb7e..49bc10c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ llama-index = "^0.8.37" pypdf = "^3.16.2" tiktoken = "^0.5.1" sentence-transformers = "^2.2.2" +replicate = "^0.15.4" [build-system] requires = ["poetry-core"] From 949a5242e7f82d1873193fd7e9294b3e0b4a47d6 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Wed, 11 Oct 2023 10:21:47 +0200 Subject: [PATCH 4/4] Minor tweaks --- lib/api/ingest.py | 8 ++++---- lib/models/ingest.py | 3 ++- lib/service/finetune.py | 17 +++++++++-------- lib/service/flows.py | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/lib/api/ingest.py b/lib/api/ingest.py index b1ba744..c938e00 100644 --- a/lib/api/ingest.py +++ b/lib/api/ingest.py @@ -1,11 +1,11 @@ import asyncio -from prisma.models import Datasource -from fastapi import APIRouter, BackgroundTasks -from lib.service.flows import create_finetune +from fastapi import APIRouter from lib.models.ingest import IngestRequest +from lib.service.flows import create_finetune from lib.utils.prisma import prisma +from prisma.models import Datasource router = APIRouter() @@ -15,7 +15,7 @@ name="ingest", description="Ingest data", ) -async def ingest(body: IngestRequest, background_tasks: BackgroundTasks): +async def ingest(body: IngestRequest): """Endpoint for ingesting data""" datasource = await prisma.datasource.create(data=body.dict()) diff --git a/lib/models/ingest.py b/lib/models/ingest.py index 59d68c3..55f90f7 100644 --- a/lib/models/ingest.py +++ b/lib/models/ingest.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import Optional +from pydantic import BaseModel + class IngestRequest(BaseModel): type: str diff --git a/lib/service/finetune.py b/lib/service/finetune.py index 6b9a1c4..7ed10c8 100644 --- a/lib/service/finetune.py +++ b/lib/service/finetune.py @@ -1,13 +1,15 @@ +# flake8: noqa + import asyncio -import httpx +import json import os import uuid -import openai -import replicate -import json - from abc import ABC, abstractmethod from typing import Dict, List, Union + +import httpx +import openai +import replicate from decouple import config from llama_index import Document @@ -109,7 +111,7 @@ class ReplicateFinetuningService(FinetuningService): def __init__( self, nodes: List[Union[Document, None]], - num_questions_per_chunk: int = 1, + num_questions_per_chunk: int = 10, batch_size: int = 10, base_model: str = "LLAMA2_7B_CHAT", ): @@ -164,7 +166,7 @@ async def finetune(self, training_file: str) -> Dict: version=REPLICATE_MODELS[self.base_model], input={ "train_data": training_file_url, - "num_train_epochs": 3, + "num_train_epochs": 6, }, destination="homanp/test", ) @@ -212,5 +214,4 @@ async def upload_replicate_dataset(training_file: str) -> str: ) serving_url = response_data["serving_url"] - print(serving_url) return serving_url diff --git a/lib/service/flows.py b/lib/service/flows.py index c53d5cf..7d138de 100644 --- a/lib/service/flows.py +++ b/lib/service/flows.py @@ -1,7 +1,7 @@ -import openai import json - from typing import List, Union + +import openai from llama_index import Document from prefect import flow, task