diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 9df590194..18c856c0a 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -3,17 +3,15 @@ import torch from transformers import pipeline, Pipeline from aiconfig_extension_hugging_face.local_inference.util import get_hf_model - +from aiconfig import ModelParser, InferenceOptions from aiconfig.callback import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment if TYPE_CHECKING: from aiconfig import AIConfigRuntime -class HuggingFaceAutomaticSpeechRecognitionTransformer(ParameterizedModelParser): +class HuggingFaceAutomaticSpeechRecognitionTransformer(ModelParser): """ Model Parser for HuggingFace ASR (Automatic Speech Recognition) models. """ @@ -87,7 +85,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index 787e86212..f9e2af99e 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -10,9 +10,8 @@ from aiconfig_extension_hugging_face.local_inference.util import get_hf_model +from aiconfig import ModelParser, InferenceOptions from aiconfig.callback import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( Attachment, ExecuteResult, @@ -25,7 +24,7 @@ from aiconfig import AIConfigRuntime -class HuggingFaceImage2TextTransformer(ParameterizedModelParser): +class HuggingFaceImage2TextTransformer(ModelParser): def __init__(self): """ Returns: @@ -118,7 +117,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) return completion_params - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py index 14db41a3a..a7ff6d307 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py @@ -14,7 +14,7 @@ def get_hf_model(aiconfig: "AIConfigRuntime", prompt: Prompt, model_parser: Para """ model_name: str | None = aiconfig.get_model_name(prompt) model_settings = model_parser.get_model_settings(prompt, aiconfig) - hf_model = model_settings.get("model", None) + hf_model = model_settings.get("model") or None # Replace "" with None value if hf_model is not None and isinstance(hf_model, str): # If the model property is set in the model settings, use that. diff --git a/python/src/aiconfig/Config.py b/python/src/aiconfig/Config.py index 27fb71aa3..29ea822f9 100644 --- a/python/src/aiconfig/Config.py +++ b/python/src/aiconfig/Config.py @@ -268,7 +268,7 @@ async def run( options, params, callback_manager=self.callback_manager, - **kwargs, + **kwargs, # TODO: We should remove and make argument explicit ) event = CallbackEvent("on_run_complete", __name__, {"result": response}) diff --git a/python/src/aiconfig/default_parsers/parameterized_model_parser.py b/python/src/aiconfig/default_parsers/parameterized_model_parser.py index d7c414458..374339939 100644 --- a/python/src/aiconfig/default_parsers/parameterized_model_parser.py +++ b/python/src/aiconfig/default_parsers/parameterized_model_parser.py @@ -45,7 +45,7 @@ async def run( aiconfig: AIConfig, options: Optional[InferenceOptions] = None, parameters: Dict = {}, - **kwargs, + **kwargs, #TODO: We should remove and make arguments explicit ) -> List[Output]: # maybe use prompt metadata instead of kwargs? if kwargs.get("run_with_dependencies", False): diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 6ec455a75..006afbeae 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -66,6 +66,7 @@ async def run( aiconfig: AIConfig, options: Optional["InferenceOptions"] = None, parameters: Dict = {}, + **kwargs, # TODO: Remove this, just a hack for now to ensure that it doesn't break ) -> ExecuteResult: """ Execute model inference based on completion data to be constructed in deserialize(), which includes the input prompt and