From 09fbcbc749c335a5a11cb234b9dc9b89d31d9238 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 16 Oct 2024 10:52:04 -0700 Subject: [PATCH] Separate BasePredictor and BaseInput (#1993) * Currently these two classes live in predictor * Separate them out into their own classes to allow for easier reasoning about the predictor code. --- python/cog/__init__.py | 2 +- python/cog/base_input.py | 35 +++++++++++++++++++++++++ python/cog/base_predictor.py | 26 ++++++++++++++++++ python/cog/predictor.py | 51 ++---------------------------------- 4 files changed, 64 insertions(+), 50 deletions(-) create mode 100644 python/cog/base_input.py create mode 100644 python/cog/base_predictor.py diff --git a/python/cog/__init__.py b/python/cog/__init__.py index b8371e0f0..ed89b17ec 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from .predictor import BasePredictor +from .base_predictor import BasePredictor from .types import ConcatenateIterator, File, Input, Path, Secret try: diff --git a/python/cog/base_input.py b/python/cog/base_input.py new file mode 100644 index 000000000..02072a2e1 --- /dev/null +++ b/python/cog/base_input.py @@ -0,0 +1,35 @@ +from pathlib import Path + +import pydantic +from pydantic import BaseModel + +from .types import PYDANTIC_V2, URLPath + + +# Base class for inputs, constructed dynamically in get_input_type(). +# (This can't be a docstring or it gets passed through to the schema.) +class BaseInput(BaseModel): + if PYDANTIC_V2: + model_config = pydantic.ConfigDict(use_enum_values=True) # type: ignore + else: + + class Config: + # When using `choices`, the type is converted into an enum to validate + # But, after validation, we want to pass the actual value to predict(), not the enum object + use_enum_values = True + + def cleanup(self) -> None: + """ + Cleanup any temporary files created by the input. + """ + for _, value in self: + # Handle URLPath objects specially for cleanup. + # Also handle pathlib.Path objects, which cog.Path is a subclass of. + # A pathlib.Path object shouldn't make its way here, + # but both have an unlink() method, so we may as well be safe. + if isinstance(value, (URLPath, Path)): + # TODO: use unlink(missing_ok=...) when we drop Python 3.7 support. + try: + value.unlink() + except FileNotFoundError: + pass diff --git a/python/cog/base_predictor.py b/python/cog/base_predictor.py new file mode 100644 index 000000000..41393ae37 --- /dev/null +++ b/python/cog/base_predictor.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +from .types import ( + File as CogFile, +) +from .types import ( + Path as CogPath, +) + + +class BasePredictor(ABC): + def setup( + self, + weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument + ) -> None: + """ + An optional method to prepare the model so multiple predictions run efficiently. + """ + return + + @abstractmethod + def predict(self, **kwargs: Any) -> Any: + """ + Run a single prediction on the model + """ diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 72d997cdf..3ac89e1cf 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -7,7 +7,6 @@ import sys import types import uuid -from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from pathlib import Path from typing import ( @@ -38,13 +37,14 @@ # Added in Python 3.9. Can be from typing if we drop support for <3.9 from typing_extensions import Annotated +from .base_input import BaseInput +from .base_predictor import BasePredictor from .code_xforms import load_module_from_string, strip_model_source_code from .errors import ConfigDoesNotExist, PredictorNotSet from .types import ( PYDANTIC_V2, CogConfig, Input, - URLPath, ) from .types import ( File as CogFile, @@ -67,23 +67,6 @@ ] -class BasePredictor(ABC): - def setup( - self, - weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument - ) -> None: - """ - An optional method to prepare the model so multiple predictions run efficiently. - """ - return - - @abstractmethod - def predict(self, **kwargs: Any) -> Any: - """ - Run a single prediction on the model - """ - - def run_setup(predictor: BasePredictor) -> None: weights_type = get_weights_type(predictor.setup) @@ -257,36 +240,6 @@ def load_predictor_from_ref(ref: str) -> BasePredictor: return predictor -# Base class for inputs, constructed dynamically in get_input_type(). -# (This can't be a docstring or it gets passed through to the schema.) -class BaseInput(BaseModel): - if PYDANTIC_V2: - model_config = pydantic.ConfigDict(use_enum_values=True) # type: ignore - else: - - class Config: - # When using `choices`, the type is converted into an enum to validate - # But, after validation, we want to pass the actual value to predict(), not the enum object - use_enum_values = True - - def cleanup(self) -> None: - """ - Cleanup any temporary files created by the input. - """ - - for _, value in dict(self).items(): - # Handle URLPath objects specially for cleanup. - # Also handle pathlib.Path objects, which cog.Path is a subclass of. - # A pathlib.Path object shouldn't make its way here, - # but both have an unlink() method, so we may as well be safe. - if isinstance(value, (URLPath, Path)): - # TODO: use unlink(missing_ok=...) when we drop Python 3.7 support. - try: - value.unlink() - except FileNotFoundError: - pass - - def validate_input_type( type: Type[Any], # pylint: disable=redefined-builtin name: str,