From b7905197387b8f6f2b91667a47146068ea21d4e6 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 19 Apr 2024 21:26:39 -0700 Subject: [PATCH] Improve: PAss tests for small models --- .vscode/settings.json | 7 +- python/scripts/export_encoders.ipynb | 130 ++++++++++++++++----------- python/scripts/test_encoders.py | 16 ++-- python/uform/numpy_processors.py | 2 +- python/uform/onnx_encoders.py | 37 +++++++- python/uform/torch_encoders.py | 39 ++++++-- 6 files changed, 156 insertions(+), 75 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3a060e1..3275f93 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,7 +21,9 @@ "ndarray", "numpy", "ONNX", + "onnxconverter", "onnxruntime", + "opset", "packbits", "preprocess", "pretrained", @@ -48,5 +50,8 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, - "python.formatting.provider": "none" + "python.formatting.provider": "none", + "window.autoDetectColorScheme": true, + "workbench.colorTheme": "Default Dark+", + "workbench.preferredDarkColorTheme": "Default Dark+" } \ No newline at end of file diff --git a/python/scripts/export_encoders.ipynb b/python/scripts/export_encoders.ipynb index 029e60a..a8b868d 100644 --- a/python/scripts/export_encoders.ipynb +++ b/python/scripts/export_encoders.ipynb @@ -19,7 +19,6 @@ "metadata": {}, "outputs": [], "source": [ - "!pip uninstall -y uform\n", "!pip install --upgrade \"uform[torch]\" coremltools" ] }, @@ -30,8 +29,13 @@ "outputs": [], "source": [ "import os\n", - "model_name = \"uform-vl-english-small\"\n", - "output_directory = \"../../\"" + "\n", + "working_directory = \"../..\"\n", + "model_name = \"uform3-image-text-english-small\"\n", + "model_directory = os.path.join(working_directory, \"models\", model_name)\n", + "model_weights_path = os.path.join(model_directory, \"torch_weight.pt\")\n", + "config_path = os.path.join(model_directory, \"config.json\")\n", + "tokenizer_path = os.path.join(model_directory, \"tokenizer.json\")" ] }, { @@ -40,20 +44,20 @@ "metadata": {}, "outputs": [], "source": [ - "import uform\n", - "from PIL import Image\n", - "\n", - "model, processor = uform.get_model('unum-cloud/' + model_name)\n", - "text = 'a small red panda in a zoo'\n", - "image = Image.open('../../assets/unum.png')\n", - "\n", - "image_data = processor.preprocess_image(image)\n", - "text_data = processor.preprocess_text(text)\n", - "\n", - "image_features, image_embedding = model.encode_image(image_data, return_features=True)\n", - "text_features, text_embedding = model.encode_text(text_data, return_features=True)\n", + "import torch\n", "\n", - "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" + "state_dict = torch.load(model_weights_path)\n", + "list(state_dict.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from uform.torch_encoders import ImageEncoder, TextEncoder\n", + "from uform.torch_processors import ImageProcessor, TextProcessor" ] }, { @@ -62,7 +66,9 @@ "metadata": {}, "outputs": [], "source": [ - "model.text_encoder" + "image_encoder = ImageEncoder.from_pretrained(config_path, state_dict)\n", + "text_encoder = TextEncoder.from_pretrained(config_path, state_dict)\n", + "image_encoder, text_encoder" ] }, { @@ -71,7 +77,9 @@ "metadata": {}, "outputs": [], "source": [ - "model.image_encoder" + "text_processor = TextProcessor(config_path, tokenizer_path)\n", + "image_processor = ImageProcessor(config_path)\n", + "text_processor, image_processor" ] }, { @@ -80,14 +88,19 @@ "metadata": {}, "outputs": [], "source": [ - "# Assuming `model` is your loaded model with image_encoder and text_encoder attributes\n", - "for name, module in model.image_encoder.named_children():\n", - " print(f\"First layer of image_encoder: {name}\")\n", - " break # We break after the first layer\n", + "import uform\n", + "from PIL import Image\n", "\n", - "for name, module in model.text_encoder.named_children():\n", - " print(f\"First layer of text_encoder: {name}\")\n", - " break # We break after the first layer" + "text = 'a small red panda in a zoo'\n", + "image = Image.open('../../assets/unum.png')\n", + "\n", + "text_data = text_processor(text)\n", + "image_data = image_processor(image)\n", + "\n", + "image_features, image_embedding = image_encoder.forward(image_data, return_features=True)\n", + "text_features, text_embedding = text_encoder.forward(text_data, return_features=True)\n", + "\n", + "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" ] }, { @@ -147,7 +160,7 @@ " input_shape = (ct.RangeDim(lower_bound=1, upper_bound=upper_bound, default=1),) + input_shape[1:]\n", " return input_shape\n", "\n", - "generalize_first_dimensions(image_data.shape), generalize_first_dimensions(text_data[\"input_ids\"].shape), generalize_first_dimensions(text_data[\"attention_mask\"].shape)" + "generalize_first_dimensions(image_data[\"images\"].shape), generalize_first_dimensions(text_data[\"input_ids\"].shape), generalize_first_dimensions(text_data[\"attention_mask\"].shape)" ] }, { @@ -156,7 +169,7 @@ "metadata": {}, "outputs": [], "source": [ - "image_input = ct.TensorType(name=\"images\", shape=generalize_first_dimensions(image_data.shape, 1))\n", + "image_input = ct.TensorType(name=\"images\", shape=generalize_first_dimensions(image_data[\"images\"].shape, 1))\n", "text_input = ct.TensorType(name=\"input_ids\", shape=generalize_first_dimensions(text_data[\"input_ids\"].shape, 1))\n", "text_attention_input = ct.TensorType(name=\"attention_mask\", shape=generalize_first_dimensions(text_data[\"attention_mask\"].shape, 1))\n", "text_features = ct.TensorType(name=\"features\")\n", @@ -171,11 +184,11 @@ "metadata": {}, "outputs": [], "source": [ - "module = model.image_encoder\n", + "module = image_encoder\n", "module.eval()\n", "module.return_features = True\n", "\n", - "traced_script_module = torch.jit.trace(module, example_inputs=image_data)\n", + "traced_script_module = torch.jit.trace(module, example_inputs=image_data[\"images\"])\n", "traced_script_module" ] }, @@ -193,7 +206,7 @@ "coreml_model.author = 'Unum Cloud'\n", "coreml_model.license = 'Apache 2.0'\n", "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", - "coreml_model.save(os.path.join(output_directory, \"image_encoder.mlpackage\"))" + "coreml_model.save(os.path.join(model_directory, \"image_encoder.mlpackage\"))" ] }, { @@ -202,7 +215,7 @@ "metadata": {}, "outputs": [], "source": [ - "module = model.text_encoder\n", + "module = text_encoder\n", "module.eval()\n", "module.return_features = True\n", "\n", @@ -224,7 +237,7 @@ "coreml_model.author = 'Unum Cloud'\n", "coreml_model.license = 'Apache 2.0'\n", "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", - "coreml_model.save(os.path.join(output_directory, \"text_encoder.mlpackage\"))" + "coreml_model.save(os.path.join(model_directory, \"text_encoder.mlpackage\"))" ] }, { @@ -257,8 +270,8 @@ "metadata": {}, "outputs": [], "source": [ - "model.image_encoder.eval()\n", - "model.image_encoder.to(dtype=torch.bfloat16)" + "image_encoder.eval()\n", + "image_encoder.to(dtype=torch.bfloat16)" ] }, { @@ -267,7 +280,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.save(model.image_encoder.state_dict(), os.path.join(output_directory, \"image_encoder.pt\"))" + "torch.save(image_encoder.state_dict(), os.path.join(model_directory, \"image_encoder.pt\"))" ] }, { @@ -276,7 +289,7 @@ "metadata": {}, "outputs": [], "source": [ - "save_file(model.image_encoder.state_dict(), os.path.join(output_directory, \"image_encoder.safetensors\"))" + "save_file(image_encoder.state_dict(), os.path.join(model_directory, \"image_encoder.safetensors\"))" ] }, { @@ -285,8 +298,8 @@ "metadata": {}, "outputs": [], "source": [ - "model.text_encoder.eval()\n", - "model.text_encoder.to(dtype=torch.bfloat16)" + "text_encoder.eval()\n", + "text_encoder.to(dtype=torch.bfloat16)" ] }, { @@ -295,7 +308,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.save(model.text_encoder.state_dict(), os.path.join(output_directory, \"text_encoder.pt\"))" + "torch.save(text_encoder.state_dict(), os.path.join(model_directory, \"text_encoder.pt\"))" ] }, { @@ -304,7 +317,7 @@ "metadata": {}, "outputs": [], "source": [ - "save_file(model.text_encoder.state_dict(), os.path.join(output_directory, \"text_encoder.safetensors\"))" + "save_file(text_encoder.state_dict(), os.path.join(model_directory, \"text_encoder.safetensors\"))" ] }, { @@ -313,8 +326,8 @@ "metadata": {}, "outputs": [], "source": [ - "image_features, image_embedding = model.encode_image(image_data.to(dtype=torch.bfloat16), return_features=True)\n", - "text_features, text_embedding = model.encode_text(text_data, return_features=True)\n", + "image_features, image_embedding = image_encoder.forward(image_data[\"images\"].to(dtype=torch.bfloat16), return_features=True)\n", + "text_features, text_embedding = text_encoder.forward(text_data, return_features=True)\n", "\n", "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" ] @@ -358,7 +371,7 @@ "metadata": {}, "outputs": [], "source": [ - "module = model.text_encoder\n", + "module = text_encoder\n", "module.eval()\n", "module.return_features = True\n", "module.to(dtype=torch.float32)\n", @@ -366,7 +379,7 @@ "onnx_export(\n", " module,\n", " (text_data[\"input_ids\"], text_data[\"attention_mask\"]), \n", - " os.path.join(output_directory, \"text_encoder.onnx\"), \n", + " os.path.join(model_directory, \"text_encoder.onnx\"), \n", " export_params=True,\n", " opset_version=15,\n", " do_constant_folding=True,\n", @@ -392,15 +405,15 @@ "metadata": {}, "outputs": [], "source": [ - "module = model.image_encoder\n", + "module = image_encoder\n", "module.eval()\n", "module.return_features = True\n", "module.to(dtype=torch.float32)\n", "\n", "torch.onnx.export(\n", " module,\n", - " image_data, \n", - " os.path.join(output_directory, \"image_encoder.onnx\"), \n", + " image_data[\"images\"], \n", + " os.path.join(model_directory, \"image_encoder.onnx\"), \n", " export_params=True,\n", " opset_version=15,\n", " do_constant_folding=True,\n", @@ -437,7 +450,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "module_fp16 = float16.convert_float_to_float16(module)\n", "onnx.save(module_fp16, module_path)" @@ -449,7 +462,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "module_fp16 = float16.convert_float_to_float16(module)\n", "onnx.save(module_fp16, module_path)" @@ -480,7 +493,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)" ] }, @@ -490,7 +503,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)" ] }, @@ -512,7 +525,7 @@ "from onnx import helper\n", "\n", "# Load the ONNX model\n", - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "\n", "# Get the module's graph\n", @@ -599,7 +612,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "session = ort.InferenceSession(module_path, sess_options=session_options)" ] }, @@ -609,7 +622,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "session = ort.InferenceSession(module_path, sess_options=session_options)" ] }, @@ -620,6 +633,15 @@ "# Upload to Hugging Face" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../models/uform3-image-text-english-small/ . --exclude=\"torch_weight.pt\"" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/scripts/test_encoders.py b/python/scripts/test_encoders.py index d26e4f2..bd26690 100644 --- a/python/scripts/test_encoders.py +++ b/python/scripts/test_encoders.py @@ -27,16 +27,16 @@ torch_models = [ "unum-cloud/uform3-image-text-english-small", - "unum-cloud/uform3-image-text-english-base", - "unum-cloud/uform3-image-text-english-large", - "unum-cloud/uform3-image-text-multilingual-base", + # "unum-cloud/uform3-image-text-english-base", + # "unum-cloud/uform3-image-text-english-large", + # "unum-cloud/uform3-image-text-multilingual-base", ] onnx_models = [ "unum-cloud/uform3-image-text-english-small", - "unum-cloud/uform3-image-text-english-base", - "unum-cloud/uform3-image-text-english-large", - "unum-cloud/uform3-image-text-multilingual-base", + # "unum-cloud/uform3-image-text-english-base", + # "unum-cloud/uform3-image-text-english-large", + # "unum-cloud/uform3-image-text-multilingual-base", ] # Let's check if the HuggingFace Hub API token is set in the environment variable. @@ -198,8 +198,8 @@ def test_onnx_one_embedding(model_name: str, device: str): # Test if the model outputs actually make sense cross_references_image_and_text_embeddings( - lambda text: model_text(processor_text(text)), - lambda image: model_image(processor_image(image)), + lambda text: model_text(processor_text(text))[1], + lambda image: model_image(processor_image(image))[1], ) except ExecutionProviderError as e: diff --git a/python/uform/numpy_processors.py b/python/uform/numpy_processors.py index a5faca2..027bc0d 100644 --- a/python/uform/numpy_processors.py +++ b/python/uform/numpy_processors.py @@ -34,7 +34,7 @@ def __call__(self, texts: Union[str, List[str]]) -> Dict[str, np.ndarray]: input_ids = np.full( (len(texts), self._max_seq_len), fill_value=self._pad_token_idx, - dtype=np.int64, + dtype=np.int32, ) attention_mask = np.zeros( diff --git a/python/uform/onnx_encoders.py b/python/uform/onnx_encoders.py index 9f63fa4..a6f27d3 100644 --- a/python/uform/onnx_encoders.py +++ b/python/uform/onnx_encoders.py @@ -64,6 +64,7 @@ def __init__( model_path: str, *, device: Literal["cpu", "cuda"] = "cpu", + return_features: bool = True, ): """ :param model_path: Path to onnx model @@ -73,14 +74,21 @@ def __init__( session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + self.return_features = return_features self.session = ort.InferenceSession( model_path, sess_options=session_options, providers=available_providers(device), ) - def __call__(self, images: ndarray) -> Tuple[ndarray, ndarray]: - return self.session.run(None, {"images": images}) + def __call__( + self, images: ndarray, return_features: Optional[bool] = None + ) -> Union[ndarray, Tuple[ndarray, ndarray]]: + features, embeddings = self.session.run(None, {"images": images}) + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings class TextEncoder: @@ -89,6 +97,7 @@ def __init__( model_path: str, *, device: Literal["cpu", "cuda"] = "cpu", + return_features: bool = True, ): """ :param text_encoder_path: Path to onnx of text encoder @@ -98,11 +107,31 @@ def __init__( session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + self.return_features = return_features self.text_encoder_session = ort.InferenceSession( model_path, sess_options=session_options, providers=available_providers(device), ) - def __call__(self, input_ids: ndarray, attention_mask: ndarray) -> Tuple[ndarray, ndarray]: - return self.text_encoder_session.run(None, {"input_ids": input_ids, "attention_mask": attention_mask}) + def __call__( + self, + x: Union[ndarray, dict], + attention_mask: Optional[ndarray] = None, + return_features: Optional[bool] = None, + ) -> Union[ndarray, Tuple[ndarray, ndarray]]: + if isinstance(x, dict): + assert attention_mask is None, "If `x` is a dictionary, then `attention_mask` should be None" + attention_mask = x["attention_mask"] + input_ids = x["input_ids"] + else: + input_ids = x + + features, embeddings = self.text_encoder_session.run( + None, {"input_ids": input_ids, "attention_mask": attention_mask} + ) + + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings diff --git a/python/uform/torch_encoders.py b/python/uform/torch_encoders.py index 8ac7c36..0504a74 100644 --- a/python/uform/torch_encoders.py +++ b/python/uform/torch_encoders.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from os import PathLike -from typing import Dict, Optional, Tuple, Union, Callable +from typing import Dict, Optional, Union, Mapping, Any import json import torch @@ -274,7 +274,12 @@ def forward( return embeddings @staticmethod - def from_pretrained(config: Union[PathLike, str, object], model_path: Union[PathLike, str]) -> TextEncoder: + def from_pretrained(config: Union[PathLike, str, object], model: Union[PathLike, str]) -> TextEncoder: + """Load the image encoder from the given configuration and model path. + + :param config: the configuration dictionary or path to the JSON configuration file + :param model: the model state dictionary or path to the `.pt` model file + """ if isinstance(config, (PathLike, str)): config = json.load(open(config, "r")) if "text_encoder" in config: @@ -283,9 +288,15 @@ def from_pretrained(config: Union[PathLike, str, object], model_path: Union[Path # We must strip all the non-member attributes before initializing the classes. text_fields = TextEncoder.__dataclass_fields__ config = {k: v for k, v in config.items() if k in text_fields} - - state = torch.load(model_path) encoder = TextEncoder(**config) + + # Load from disk + if isinstance(model, (PathLike, str)): + state = torch.load(model) + else: + state = model + if "text_encoder" in state: + state = state["text_encoder"] encoder.load_state_dict(state) return encoder @@ -351,7 +362,15 @@ def forward(self, x: Tensor, return_features: Optional[bool] = None) -> Tensor: return embeddings @staticmethod - def from_pretrained(config: Union[PathLike, str, object], model_path: Union[PathLike, str]) -> ImageEncoder: + def from_pretrained( + config: Union[PathLike, str, object], + model: Union[PathLike, str, Mapping[str, Any]], + ) -> ImageEncoder: + """Load the image encoder from the given configuration and model path. + + :param config: the configuration dictionary or path to the JSON configuration file + :param model: the model state dictionary or path to the `.pt` model file + """ if isinstance(config, (PathLike, str)): config = json.load(open(config, "r")) if "image_encoder" in config: @@ -360,8 +379,14 @@ def from_pretrained(config: Union[PathLike, str, object], model_path: Union[Path # We must strip all the non-member attributes before initializing the classes. image_fields = ImageEncoder.__dataclass_fields__ config = {k: v for k, v in config.items() if k in image_fields} - - state = torch.load(model_path) encoder = ImageEncoder(**config) + + # Load from disk + if isinstance(model, (PathLike, str)): + state = torch.load(model) + else: + state = model + if "image_encoder" in state: + state = state["image_encoder"] encoder.load_state_dict(state) return encoder