forked from open-mmlab/FoleyCrafter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
121 lines (94 loc) · 4.35 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
#https://github.com/replicate/cog/blob/main/docs/getting-started-own-model.md
from typing import List
import subprocess
import os
from cog import BasePredictor, Input, Path
from foleycrafter.models.onset import torch_utils
from foleycrafter.utils.util import build_foleycrafter
from foleycrafter.pipelines.auffusion_pipeline import Generator
from foleycrafter.models.time_detector.model import VideoOnsetNet
from inference import run_inference
import torch
#nprompt optional
class Config:
def __init__(self):
'''
config = {"prompt":prompt,"nprompt":"","seed":42,"semantic_scale":1.0,"temporal_scale":0.2,"input":str(video),"ckpt":"checkpoints/","save_dir":"output/","pretrain":"auffusion/auffusion-full-no-adapter","device":"cuda"}
'''
self.prompt = ""
self.nprompt = ""
self.seed = 42
self.semantic_scale = 1.0
self.temporal_scale = 0.2
self.input = ''
self.ckpt = 'checkpoints/'
self.save_dir = 'output/'
self.pretrain = 'auffusion/auffusion-full-no-adapter'
self.device = 'cuda'
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
print("setup2")
global pipe, vocoder, time_detector
config = Config()
# Load FoleyCrafter pipe
pipe = build_foleycrafter().to(config.device)
# Load temporal adapter
temporal_ckpt_path = os.path.join(config.ckpt, "temporal_adapter.ckpt")
ckpt = torch.load(temporal_ckpt_path)
# Process and load the temporal adapter weights
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]
load_gligen_ckpt = {}
for key, value in ckpt.items():
if key.startswith("module."):
load_gligen_ckpt[key[len("module."):]] = value
else:
load_gligen_ckpt[key] = value
m, u = pipe.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
# Load semantic adapter
pipe.load_ip_adapter(
os.path.join(config.ckpt, "semantic"),
subfolder="",
weight_name="semantic_adapter.bin",
image_encoder_folder=None
)
ip_adapter_weight = config.semantic_scale
pipe.set_ip_adapter_scale(ip_adapter_weight)
# Load vocoder
vocoder_config_path = config.ckpt
vocoder = Generator.from_pretrained(vocoder_config_path, subfolder="vocoder").to(config.device)
# Load time detector
time_detector_ckpt = os.path.join(config.ckpt, "timestamp_detector.pth.tar")
time_detector = VideoOnsetNet(False)
time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, device=config.device, strict=True)
def predict(
self,
video: Path = Input(description="Input video"),
prompt: str = Input(description="Prompt to generate audio"),
#nprompt with default value of ""
nprompt: str = Input(description="Negative prompt for audio generation",default="")
) -> List[Path]:
"""Run a single prediction on the model"""
#get the config from above
#run_inference(config, pipe, vocoder, time_detector):
print("video",video)
print("prompt",prompt)
print("nprompt",nprompt)
#video_ext = video.split('.')[-1]
video_ext = Path(video).suffix[1:]
output_video = str(video).replace('.' + video_ext, '_foleycrafter.' + video_ext)
out_audio = str(video).replace('.' + video_ext, '.wav')
config = Config()
#add the video to the config
config.input = str(video)
config.prompt = prompt
config.nprompt = nprompt
cwd = os.getcwd()
#subprocess.call(["python","inference.py","--input",config.input,"--prompt",config.prompt,"--nprompt",config.nprompt],cwd=cwd)
out_audio,output_video = run_inference(config, pipe, vocoder, time_detector)
#run_inference(config, pipe, vocoder, time_detector):
return [Path(out_audio),Path(output_video)]