From 6bed4fbd9cd2c9d4cfea8b14cea975668ca193cd Mon Sep 17 00:00:00 2001 From: "Rossdan Craig rossdan@lastmileai.dev" <> Date: Thu, 11 Jan 2024 05:03:39 -0500 Subject: [PATCH] [HF][fix] Allow `kwargs` in `ModelParser.run()` See https://github.com/lastmile-ai/aiconfig/pull/877 for context, and https://github.com/lastmile-ai/aiconfig/pull/880 for why that original fix needed to be reverted Just a quick hack to get unblocked. I tried originally to make the ASR and Image2Text ParameterizedModel but that caused other errors. ## Test Plan Before https://github.com/lastmile-ai/aiconfig/assets/151060367/a23a5d5c-d9a2-415b-8a6e-9826da56e985 After https://github.com/lastmile-ai/aiconfig/assets/151060367/f29580e9-5cf6-43c5-b848-bb525eb368f2 --- .../local_inference/automatic_speech_recognition.py | 2 +- .../local_inference/image_2_text.py | 2 +- python/src/aiconfig/Config.py | 2 +- .../src/aiconfig/default_parsers/parameterized_model_parser.py | 2 +- python/src/aiconfig/model_parser.py | 1 + 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 54830533a..18c856c0a 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -85,7 +85,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data - async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index 12455f6d0..f9e2af99e 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -117,7 +117,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) return completion_params - async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/python/src/aiconfig/Config.py b/python/src/aiconfig/Config.py index 27fb71aa3..29ea822f9 100644 --- a/python/src/aiconfig/Config.py +++ b/python/src/aiconfig/Config.py @@ -268,7 +268,7 @@ async def run( options, params, callback_manager=self.callback_manager, - **kwargs, + **kwargs, # TODO: We should remove and make argument explicit ) event = CallbackEvent("on_run_complete", __name__, {"result": response}) diff --git a/python/src/aiconfig/default_parsers/parameterized_model_parser.py b/python/src/aiconfig/default_parsers/parameterized_model_parser.py index d7c414458..374339939 100644 --- a/python/src/aiconfig/default_parsers/parameterized_model_parser.py +++ b/python/src/aiconfig/default_parsers/parameterized_model_parser.py @@ -45,7 +45,7 @@ async def run( aiconfig: AIConfig, options: Optional[InferenceOptions] = None, parameters: Dict = {}, - **kwargs, + **kwargs, #TODO: We should remove and make arguments explicit ) -> List[Output]: # maybe use prompt metadata instead of kwargs? if kwargs.get("run_with_dependencies", False): diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 6ec455a75..006afbeae 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -66,6 +66,7 @@ async def run( aiconfig: AIConfig, options: Optional["InferenceOptions"] = None, parameters: Dict = {}, + **kwargs, # TODO: Remove this, just a hack for now to ensure that it doesn't break ) -> ExecuteResult: """ Execute model inference based on completion data to be constructed in deserialize(), which includes the input prompt and