Skip to content

Commit

Permalink
Separate BasePredictor and BaseInput (#1993)
Browse files Browse the repository at this point in the history
* Currently these two classes live in predictor
* Separate them out into their own classes to allow
for easier reasoning about the predictor code.
  • Loading branch information
8W9aG authored Oct 16, 2024
1 parent 87787f4 commit 09fbcbc
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 50 deletions.
2 changes: 1 addition & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel

from .predictor import BasePredictor
from .base_predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path, Secret

try:
Expand Down
35 changes: 35 additions & 0 deletions python/cog/base_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pathlib import Path

import pydantic
from pydantic import BaseModel

from .types import PYDANTIC_V2, URLPath


# Base class for inputs, constructed dynamically in get_input_type().
# (This can't be a docstring or it gets passed through to the schema.)
class BaseInput(BaseModel):
if PYDANTIC_V2:
model_config = pydantic.ConfigDict(use_enum_values=True) # type: ignore
else:

class Config:
# When using `choices`, the type is converted into an enum to validate
# But, after validation, we want to pass the actual value to predict(), not the enum object
use_enum_values = True

def cleanup(self) -> None:
"""
Cleanup any temporary files created by the input.
"""
for _, value in self:
# Handle URLPath objects specially for cleanup.
# Also handle pathlib.Path objects, which cog.Path is a subclass of.
# A pathlib.Path object shouldn't make its way here,
# but both have an unlink() method, so we may as well be safe.
if isinstance(value, (URLPath, Path)):
# TODO: use unlink(missing_ok=...) when we drop Python 3.7 support.
try:
value.unlink()
except FileNotFoundError:
pass
26 changes: 26 additions & 0 deletions python/cog/base_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from .types import (
File as CogFile,
)
from .types import (
Path as CogPath,
)


class BasePredictor(ABC):
def setup(
self,
weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument
) -> None:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
return

@abstractmethod
def predict(self, **kwargs: Any) -> Any:
"""
Run a single prediction on the model
"""
51 changes: 2 additions & 49 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import types
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -38,13 +37,14 @@
# Added in Python 3.9. Can be from typing if we drop support for <3.9
from typing_extensions import Annotated

from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .errors import ConfigDoesNotExist, PredictorNotSet
from .types import (
PYDANTIC_V2,
CogConfig,
Input,
URLPath,
)
from .types import (
File as CogFile,
Expand All @@ -67,23 +67,6 @@
]


class BasePredictor(ABC):
def setup(
self,
weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument
) -> None:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
return

@abstractmethod
def predict(self, **kwargs: Any) -> Any:
"""
Run a single prediction on the model
"""


def run_setup(predictor: BasePredictor) -> None:
weights_type = get_weights_type(predictor.setup)

Expand Down Expand Up @@ -257,36 +240,6 @@ def load_predictor_from_ref(ref: str) -> BasePredictor:
return predictor


# Base class for inputs, constructed dynamically in get_input_type().
# (This can't be a docstring or it gets passed through to the schema.)
class BaseInput(BaseModel):
if PYDANTIC_V2:
model_config = pydantic.ConfigDict(use_enum_values=True) # type: ignore
else:

class Config:
# When using `choices`, the type is converted into an enum to validate
# But, after validation, we want to pass the actual value to predict(), not the enum object
use_enum_values = True

def cleanup(self) -> None:
"""
Cleanup any temporary files created by the input.
"""

for _, value in dict(self).items():
# Handle URLPath objects specially for cleanup.
# Also handle pathlib.Path objects, which cog.Path is a subclass of.
# A pathlib.Path object shouldn't make its way here,
# but both have an unlink() method, so we may as well be safe.
if isinstance(value, (URLPath, Path)):
# TODO: use unlink(missing_ok=...) when we drop Python 3.7 support.
try:
value.unlink()
except FileNotFoundError:
pass


def validate_input_type(
type: Type[Any], # pylint: disable=redefined-builtin
name: str,
Expand Down

0 comments on commit 09fbcbc

Please sign in to comment.