Skip to content

Commit

Permalink
add api example, why deleted?
Browse files Browse the repository at this point in the history
  • Loading branch information
awkrail committed Oct 21, 2024
1 parent d05f29c commit 004ed73
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
47 changes: 47 additions & 0 deletions api_example/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Copyright $today.year LY Corporation
LY Corporation licenses this file to you under the Apache License,
version 2.0 (the "License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at:
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License.
"""
import os
import subprocess
import torch

from lighthouse.models import CGDETRPredictor
from typing import Dict, List, Optional

def load_weights(weight_dir: str) -> None:
if not os.path.exists(os.path.join(weight_dir, 'clip_slowfast_pann_cg_detr_qvhighlight.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13960580/files/clip_slowfast_pann_cg_detr_qvhighlight.ckpt'
subprocess.run(command, shell=True)

if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
subprocess.run('wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl', shell=True)

if not os.path.exists('Cnn14_mAP=0.431.pth'):
subprocess.run('wget https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth', shell=True)

# use GPU if available
device: str = 'cpu'
weight_dir: str = 'gradio_demo/weights'
weight_path: str = os.path.join(weight_dir, 'clip_slowfast_cg_detr_qvhighlight.ckpt')
model: CGDETRPredictor = CGDETRPredictor(weight_path, device=device, feature_name='clip_slowfast',
slowfast_path='SLOWFAST_8x8_R50.pkl', pann_path=None)

# encode video features
model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')

# moment retrieval & highlight detection
query: str = 'A woman wearing a glass is speaking in front of the camera'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
print(prediction)
2 changes: 1 addition & 1 deletion gradio_demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load_pretrained_weights():
for model_name in MODEL_NAMES:
for feature in FEATURES:
file_urls.append(
"https://zenodo.org/records/13639198/files/{}_{}_qvhighlight.ckpt".format(feature, model_name)
"https://zenodo.org/records/13960580/files/{}_{}_qvhighlight.ckpt".format(feature, model_name)
)
for file_url in tqdm(file_urls):
if not os.path.exists('gradio_demo/weights/' + os.path.basename(file_url)):
Expand Down

0 comments on commit 004ed73

Please sign in to comment.