diff --git a/models/requirements.txt b/models/requirements.txt index 46f76f0e4..d779002c9 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -5,3 +5,7 @@ transformers==4.37.1 accelerate diffusers==0.24.0 brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b +# turbine tank downloading/uploading +azure-storage-blob +# microsoft/phi model +einops diff --git a/models/setup.py b/models/setup.py index 657b8e94a..fae7c4a61 100644 --- a/models/setup.py +++ b/models/setup.py @@ -61,5 +61,7 @@ def load_version_info(): "transformers==4.37.1", "accelerate", "diffusers==0.24.0", + "azure-storage-blob", + "einops", ], ) diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 996d5fb83..4cc5f91dd 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -16,6 +16,7 @@ import torch import torch._dynamo as dynamo from transformers import CLIPTextModel, CLIPTokenizer +from turbine_models.turbine_tank import turbine_tank import argparse @@ -57,6 +58,7 @@ def export_clip_model( device=None, target_triple=None, max_alloc=None, + upload_ir=False, ): # Load the tokenizer and text encoder to tokenize and encode the text. tokenizer = CLIPTokenizer.from_pretrained( @@ -64,6 +66,7 @@ def export_clip_model( subfolder="tokenizer", token=hf_auth_token, ) + text_encoder_model = CLIPTextModel.from_pretrained( hf_model_name, subfolder="text_encoder", @@ -94,6 +97,15 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-clip") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + model_name_upload += "-clip" + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str, tokenizer else: diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 97bd2418f..6dafeb313 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -23,6 +23,8 @@ import safetensors import argparse +from turbine_models.turbine_tank import turbine_tank + parser = argparse.ArgumentParser() parser.add_argument( "--hf_auth_token", type=str, help="The Hugging Face auth token, required" @@ -111,6 +113,7 @@ def export_scheduler( device=None, target_triple=None, max_alloc=None, + upload_ir=False, ): mapper = {} utils.save_external_weights( @@ -145,6 +148,15 @@ def main( module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-scheduler") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "-") + model_name_upload = model_name_upload + "_scheduler" + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 272c7af7f..398ed9bc5 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -18,6 +18,7 @@ import safetensors import argparse +from turbine_models.turbine_tank import turbine_tank parser = argparse.ArgumentParser() parser.add_argument( @@ -90,6 +91,7 @@ def export_unet_model( device=None, target_triple=None, max_alloc=None, + upload_ir=False, ): mapper = {} utils.save_external_weights( @@ -125,6 +127,15 @@ def main( module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-unet") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "-") + model_name_upload += "_unet" + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 03ef85556..fcf9453b4 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -18,6 +18,7 @@ import safetensors import argparse +from turbine_models.turbine_tank import turbine_tank parser = argparse.ArgumentParser() parser.add_argument( @@ -89,6 +90,7 @@ def export_vae_model( target_triple=None, max_alloc=None, variant="decode", + upload_ir=False, ): mapper = {} utils.save_external_weights( @@ -113,6 +115,15 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-vae") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + model_name_upload = model_name_upload + "-vae-" + variant + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index 762690603..6863fd5c2 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -2,6 +2,7 @@ import sys import re import json +from turbine_models.turbine_tank import turbine_tank os.environ["TORCH_LOGS"] = "dynamic" from transformers import AutoTokenizer, AutoModelForCausalLM @@ -107,7 +108,14 @@ def export_transformer_model( vulkan_max_allocation=None, streaming_llm=False, vmfb_path=None, + upload_ir=False, ): + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, + use_fast=False, + token=hf_auth_token, + ) + mod = AutoModelForCausalLM.from_pretrained( hf_model_name, torch_dtype=torch.float, @@ -121,11 +129,7 @@ def export_transformer_model( if precision == "f16": mod = mod.half() dtype = torch.float16 - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, - use_fast=False, - token=hf_auth_token, - ) + # TODO: generate these values instead of magic numbers NUM_LAYERS = mod.config.num_hidden_layers HEADS = getattr(mod.config, "num_key_value_heads", None) @@ -319,6 +323,14 @@ def evict_kvcache_space(self): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str, tokenizer else: diff --git a/models/turbine_models/model_builder.py b/models/turbine_models/model_builder.py index 22139ca64..035244534 100644 --- a/models/turbine_models/model_builder.py +++ b/models/turbine_models/model_builder.py @@ -1,6 +1,9 @@ from transformers import AutoModel, AutoTokenizer, AutoConfig import torch import shark_turbine.aot as aot +from turbine_models.turbine_tank import turbine_tank +import os +import re class HFTransformerBuilder: @@ -18,11 +21,15 @@ class HFTransformerBuilder: def __init__( self, example_input: torch.Tensor, - hf_id: str, + hf_id: str = None, auto_model: AutoModel = AutoModel, auto_tokenizer: AutoTokenizer = None, auto_config: AutoConfig = None, hf_auth_token=None, + upload_ir=False, + model=None, + model_type: str = None, + compile_to_vmfb: bool = None, ) -> None: self.example_input = example_input self.hf_id = hf_id @@ -30,24 +37,29 @@ def __init__( self.auto_tokenizer = auto_tokenizer self.auto_config = auto_config self.hf_auth_token = hf_auth_token - self.model = None + self.model = model self.tokenizer = None - self.build_model() + self.upload_ir = upload_ir + self.model_type = model_type + self.compile_to_vmfb = compile_to_vmfb + if self.model == None: + self.build_model() def build_model(self) -> None: """ Builds a PyTorch model using Hugging Face's transformers library. """ # TODO: check cloud storage for existing ir - self.model = self.auto_model.from_pretrained( - self.hf_id, token=self.hf_auth_token, config=self.auto_config - ) - if self.auto_tokenizer is not None: - self.tokenizer = self.auto_tokenizer.from_pretrained( - self.hf_id, token=self.hf_auth_token + if self.hf_id: + self.model = self.auto_model.from_pretrained( + self.hf_id, token=self.hf_auth_token, config=self.auto_config ) - else: - self.tokenizer = None + if self.auto_tokenizer is not None: + self.tokenizer = self.auto_tokenizer.from_pretrained( + self.hf_id, token=self.hf_auth_token + ) + else: + self.tokenizer = None def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: """ @@ -59,6 +71,24 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: Returns: aot.CompiledModule: The compiled module binary. """ - module = aot.export(self.model, self.example_input) - compiled_binary = module.compile(save_to=save_to) - return compiled_binary + if self.model_type and self.model_type == "hf_seq2seq": + module = aot.export(self.model, *self.example_input) + else: + module = aot.export(self.model, self.example_input) + if self.hf_id: + module_str = str(module.mlir_module) + safe_name = self.hf_id.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if self.upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = self.hf_id.replace("/", "_") + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) + os.remove(f"{safe_name}.mlir") + if self.compile_to_vmfb and not self.compile_to_vmfb: + return + compiled_binary = module.compile(save_to=save_to) + return compiled_binary diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 961b920a0..9d00fb9e5 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -68,6 +68,7 @@ class StableDiffusionTest(unittest.TestCase): def testExportClipModel(self): + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") with self.assertRaises(SystemExit) as cm: clip.export_clip_model( # This is a public model, so no auth required @@ -77,6 +78,7 @@ def testExportClipModel(self): "safetensors", "stable_diffusion_v1_4_clip.safetensors", "cpu", + upload_ir=upload_ir_var == "upload", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" @@ -98,6 +100,7 @@ def testExportClipModel(self): os.remove("stable_diffusion_v1_4_clip.vmfb") def testExportUnetModel(self): + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") with self.assertRaises(SystemExit) as cm: unet.export_unet_model( unet_model, @@ -111,6 +114,7 @@ def testExportUnetModel(self): "safetensors", "stable_diffusion_v1_4_unet.safetensors", "cpu", + upload_ir=upload_ir_var == "upload", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" @@ -148,6 +152,7 @@ def testExportUnetModel(self): os.remove("stable_diffusion_v1_4_unet.vmfb") def testExportVaeModelDecode(self): + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, @@ -162,6 +167,7 @@ def testExportVaeModelDecode(self): "stable_diffusion_v1_4_vae.safetensors", "cpu", variant="decode", + upload_ir=upload_ir_var == "upload", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -193,6 +199,7 @@ def testExportVaeModelDecode(self): os.remove("stable_diffusion_v1_4_vae.vmfb") def testExportVaeModelEncode(self): + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, @@ -207,6 +214,7 @@ def testExportVaeModelEncode(self): "stable_diffusion_v1_4_vae.safetensors", "cpu", variant="encode", + upload_ir=upload_ir_var == "upload", ) self.assertEqual(cm.exception.code, None) arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -239,6 +247,7 @@ def testExportVaeModelEncode(self): @unittest.expectedFailure def testExportPNDMScheduler(self): + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") with self.assertRaises(SystemExit) as cm: schedulers.export_scheduler( scheduler_module, @@ -252,6 +261,7 @@ def testExportPNDMScheduler(self): "safetensors", "stable_diffusion_v1_4_scheduler.safetensors", "cpu", + upload_ir=upload_ir_var == "upload", ) self.assertEqual(cm.exception.code, None) arguments[ diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index 574902101..1e87120fa 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -53,6 +53,8 @@ def test_vmfb_comparison(self): For VMFB, quantization can be int4 or None, but right now only using none for compatibility with torch. """ + upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + llama.export_transformer_model( hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, @@ -63,9 +65,12 @@ def test_vmfb_comparison(self): precision=precision, device="llvm-cpu", target_triple="host", + upload_ir=upload_ir_var == "upload", ) - torch_str_cache_path = f"models/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + torch_str_cache_path = ( + f"vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + ) # if cached, just read if os.path.exists(torch_str_cache_path): with open(torch_str_cache_path, "r") as f: @@ -106,7 +111,9 @@ def test_streaming_vmfb_comparison(self): vmfb_path="streaming_llama.vmfb", ) - torch_str_cache_path = f"models/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + torch_str_cache_path = ( + f"vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + ) # if cached, just read if os.path.exists(torch_str_cache_path): with open(torch_str_cache_path, "r") as f: diff --git a/models/turbine_models/turbine_tank/__init__.py b/models/turbine_models/turbine_tank/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/turbine_models/turbine_tank/turbine_tank.py b/models/turbine_models/turbine_tank/turbine_tank.py new file mode 100644 index 000000000..36dc07a4a --- /dev/null +++ b/models/turbine_models/turbine_tank/turbine_tank.py @@ -0,0 +1,179 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from azure.storage.blob import BlobServiceClient + +import subprocess +import datetime +import os +from pathlib import Path +from functools import cmp_to_key + +custom_path = os.getenv("TURBINE_TANK_CACHE_DIR") +if custom_path is not None: + if not os.path.exists(custom_path): + os.mkdir(custom_path) + + WORKDIR = custom_path + + print(f"Using {WORKDIR} as local turbine_tank cache directory.") +else: + WORKDIR = os.path.join(str(Path.home()), ".local/turbine_tank/") + print( + f"turbine_tank local cache is located at {WORKDIR} . You may change this by assigning the TURBINE_TANK_CACHE_DIR environment variable." + ) +os.makedirs(WORKDIR, exist_ok=True) + +storage_account_key = os.environ.get("AZURE_STORAGE_ACCOUNT_KEY") +storage_account_name = os.environ.get("AZURE_STORAGE_ACCOUNT_NAME") +connection_string = os.environ.get("AZURE_CONNECTION_STRING") +container_name = os.environ.get("AZURE_CONTAINER_NAME") + + +def get_short_git_sha() -> str: + try: + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("utf-8") + .strip() + ) + except FileNotFoundError: + return None + + +def uploadToBlobStorage(file_path, file_name): + # create our prefix (we use this to keep track of when and what version of turbine is being used) + today = str(datetime.date.today()) + commit = get_short_git_sha() + prefix = today + "_" + commit + blob_service_client = BlobServiceClient.from_connection_string(connection_string) + blob_client = blob_service_client.get_blob_client( + container=container_name, blob=prefix + "/" + file_name + ) + blob = blob_client.from_connection_string( + conn_str=connection_string, + container_name=container_name, + blob_name=blob_client.blob_name, + ) + # we check to see if we already uploaded the blob (don't want to duplicate) + if blob.exists(): + print( + f"model artifacts have already been uploaded for {today} on the same github commit ({commit})" + ) + return + # upload to azure storage container tankturbine + with open(file_path, "rb") as data: + blob_client.upload_blob(data) + print(f"Uploaded {file_name}.") + + +def checkAndRemoveIfDownloadedOld(model_name: str, model_dir: str, prefix: str): + if os.path.isdir(model_dir) and len(os.listdir(model_dir)) > 0: + for item in os.listdir(model_dir): + item_path = os.path.join(model_dir, item) + # model artifacts already downloaded and up to date + # we check if model artifacts are behind using the prefix (day + git_sha) + if os.path.isdir(item_path) and item == prefix: + return True + # model artifacts are behind, so remove for new download + if os.path.isdir(item_path) and os.path.isfile( + os.path.join(item_path, model_name + ".mlir") + ): + os.remove(os.path.join(item_path, model_name + ".mlir")) + os.rmdir(item_path) + return False + if os.path.isdir(item_path) and os.path.isfile( + os.path.join(item_path, model_name + "-param.mlir") + ): + os.remove(os.path.join(item_path, model_name + "-param.mlir")) + os.rmdir(item_path) + return False + # did not downloaded this model artifacts yet + return False + + +def download_public_folder(model_name: str, prefix: str, model_dir: str): + """Downloads a folder of blobs in azure container.""" + blob_service_client = BlobServiceClient.from_connection_string(connection_string) + container_client = blob_service_client.get_container_client( + container=container_name + ) + blob_list = container_client.list_blobs(name_starts_with=prefix) + empty = True + + # go through the blobs with our target prefix + # example prefix: "2024-02-13_26d6428/CompVis_stable-diffusion-v1-4-clip" + for blob in blob_list: + empty = False + blob_client = blob_service_client.get_blob_client( + container=container_name, blob=blob.name + ) + # create path if directory doesn't exist locally + dest_path = model_dir + if not os.path.isdir(dest_path): + os.makedirs(dest_path) + # download blob into local turbine tank cache + if "param" in blob.name: + file_path = os.path.join(model_dir, model_name + "-param.mlir") + else: + file_path = os.path.join(model_dir, model_name + ".mlir") + with open(file=file_path, mode="wb") as sample_blob: + download_stream = blob_client.download_blob() + sample_blob.write(download_stream.readall()) + + if empty: + print(f"Model ({model_name}) has not been uploaded yet") + return True + + return False + + +# sort blobs by last modified +def compare(item1, item2): + if item1.last_modified < item2.last_modified: + return -1 + elif item1.last_modified < item2.last_modified: + return 1 + else: + return 0 + + +def downloadModelArtifacts(model_name: str) -> str: + model_name = model_name.replace("/", "_") + container_client = BlobServiceClient.from_connection_string( + connection_string + ).get_container_client(container=container_name) + blob_list = container_client.list_blobs() + # get the latest blob uploaded to turbine tank (can't use [] notation for blob_list) + blob_list = sorted(blob_list, key=cmp_to_key(compare)) + for blob in blob_list: + latest_blob = blob + # get the prefix for the latest blob (2024-02-13_26d6428) + download_latest_prefix = latest_blob.name.split("/")[0] + model_dir = os.path.join(WORKDIR, model_name) + # check if we already downloaded the model artifacts for this day + commit + exists = checkAndRemoveIfDownloadedOld( + model_name=model_name, model_dir=model_dir, prefix=download_latest_prefix + ) + if exists: + print("Already downloaded most recent version") + return "NA" + # download the model artifacts (passing in the model name, path in azure storage to model artifacts, local directory to store) + blobDNE = download_public_folder( + model_name, + download_latest_prefix + "/" + model_name, + os.path.join(model_dir, download_latest_prefix), + ) + if blobDNE: + return + model_dir = os.path.join(WORKDIR, model_name + "/" + download_latest_prefix) + mlir_filename = os.path.join(model_dir, model_name + ".mlir") + print( + f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..." + ) + assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}" + + return mlir_filename