Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: implement video and image classification tool #57

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hugohonda
Copy link
Contributor

@hugohonda hugohonda commented Sep 9, 2024

Do it similar to this tutorial in huggingface: https://huggingface.co/docs/transformers/en/model_doc/clip

Convert it into a tool called TextToClassification with clip model:

  • https://github.com/landing-ai/vision-agent-tools/blob/main/vision_agent_tools/models/clip_media_sim.py
  • class TextVideoClassifier(BaseTool):
    """
    Takes in a list of texts and a video and classifies each frame in a video according
    to the given texts.
    """
    def __init__(self, device: str = "cuda", model: str = "siglip"):
    """
    Initializes the TextVideoClassifier object with a pre-trained SigLip model.
    """
    if model == "siglip":
    model_key = _HF_SIGLIP_MODEL
    elif model == "clip":
    model_key = _HF_CLIP_MODEL
    else:
    raise ValueError(f"Unknown model type, only accepts ['siglip', 'clip']")
    self.model_key = model_key
    self.model = AutoModel.from_pretrained(model_key).eval().to(device)
    self.processor = AutoProcessor.from_pretrained(model_key)
    self.device = device
    @validate_call(config={"arbitrary_types_allowed": True})
    @torch.inference_mode()
    def __call__(
    self,
    video: VideoNumpy[np.uint8],
    target_text: list[str],
    ) -> list[tuple[float, float]]:
    """
    Receives a video and target text and returns a probability score for each frame
    over each target_text element.
    Args:
    video (VideoNumpy: The input video to be processed.
    target_text (list[str]): The target text used to classify.
    """
    if len(target_text) < 2:
    raise ValueError(f"Must contain at least 2 classes")
    inputs = self.processor(text=target_text, return_tensors="pt", padding=True)
    with torch.autocast(self.device):
    inputs.to(self.device)
    outputs = self.model.get_text_features(**inputs)
    target = outputs.detach()
    frame_embs = []
    for frame in video:
    inputs = self.processor(images=frame, return_tensors="pt")
    with torch.autocast(self.device):
    inputs.to(self.device)
    outputs = self.model.get_image_features(**inputs)
    frame_embs.append(outputs.squeeze().detach())
    frame_embs = torch.stack(frame_embs)
    # first dim is frame count, second dim is taret classes
    probs = (
    (
    F.cosine_similarity(target, frame_embs.unsqueeze(1), dim=-1)
    * self.model.logit_scale.exp()
    + (self.model.logit_bias if self.model_key == "siglip" else 0)
    )
    .softmax(dim=1)
    .detach()
    .cpu()
    .numpy()
    )
    return probs.tolist()

@hugohonda hugohonda self-assigned this Sep 9, 2024
@hugohonda hugohonda force-pushed the feat/video-image-classification-tool branch from f0c4a78 to 50f801a Compare September 9, 2024 20:27
@hugohonda hugohonda force-pushed the feat/video-image-classification-tool branch from 50f801a to d304881 Compare September 12, 2024 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant