Skip to content

Commit

Permalink
fixed a bug for clip+slowfast
Browse files Browse the repository at this point in the history
  • Loading branch information
awkrail committed Sep 18, 2024
1 parent aa87601 commit c6366ac
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
7 changes: 6 additions & 1 deletion lighthouse/feature_extractor/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ def _select_visual_encoders(self) -> List[Any]:
model_path_dict[self._feature_name])]
return visual_encoders

def _trim_shorter_length(self, visual_features):
min_length = min([x.shape[0] for x in visual_features])
trimmed_visual_features = [x[:min_length] for x in visual_features]
return trimmed_visual_features

def encode(
self,
input_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
assert len(self._frame_loaders) == len(self._visual_encoders), 'the number of frame_loaders and visual_encoders is different.'
frame_inputs = [loader(input_path) for loader in self._frame_loaders]
assert not any([item is None for item in frame_inputs]), 'one of the loaders return None object.'
visual_features = [encoder(frames) for encoder, frames in zip(self._visual_encoders, frame_inputs)]
concat_features = torch.concat(visual_features, dim=-1)
concat_features = torch.concat(self._trim_shorter_length(visual_features), dim=-1)
visual_mask = torch.ones(1, len(concat_features)).to(self._device)
return concat_features, visual_mask
4 changes: 2 additions & 2 deletions lighthouse/frame_loaders/slowfast_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def _pad_frames(self, tensor, value=0):
if n == self._target_fps:
return tensor
if self._padding_mode == "constant":
z = torch.ones(n, tensor.shape[1], tensor.shape[2], tensor.shape[3], dtype=torch.uint8)
z = torch.ones(int(n), tensor.shape[1], tensor.shape[2], tensor.shape[3], dtype=torch.uint8)
z *= value
return torch.cat((tensor, z), 0)
elif self._padding_mode == "tile":
z = torch.cat(n * [tensor[-1:, :, :, :]])
z = torch.cat(int(n) * [tensor[-1:, :, :, :]])
return torch.cat((tensor, z), 0)
else:
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest
import subprocess
from lighthouse.models import (MomentDETRPredictor, QDDETRPredictor, EaTRPredictor,
CGDETRPredictor, TRDETRPredictor, UVCOMPredictor)
CGDETRPredictor, UVCOMPredictor)


FEATURES = ['clip', 'clip_slowfast']
MODELS = ['moment_detr', 'qd_detr', 'eatr', 'cg_detr', 'uvcom']
Expand Down Expand Up @@ -81,7 +82,6 @@ def test_model_prediction():
for second in range(MIN_DURATION, MAX_DURATION):
video_path = f'tests/test_videos/video_duration_{second}.mp4'
model.encode_video(video_path)

query = 'A woman wearing a glass is speaking in front of the camera'
prediction = model.predict(query)
assert len(prediction['pred_relevant_windows']) == MOMENT_NUM, \
Expand Down

0 comments on commit c6366ac

Please sign in to comment.