Skip to content

Commit

Permalink
add EDVR predictor dynamic (PaddlePaddle#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangna11BD authored May 18, 2021
1 parent d81d9cc commit 61d4939
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
2 changes: 0 additions & 2 deletions applications/tools/video-enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,8 @@
weight_path=args.RealSR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'EDVR':
paddle.enable_static()
predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path)
paddle.disable_static()

print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
Expand Down
44 changes: 23 additions & 21 deletions ppgan/apps/edvr_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import numpy as np
from tqdm import tqdm

import paddle
from paddle.io import Dataset, DataLoader

from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames

from ppgan.models.generators import EDVRNet
from .base_predictor import BasePredictor

EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams'


def get_img(pred):
Expand Down Expand Up @@ -110,7 +113,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
return return_l


class EDVRDataset:
class EDVRDataset(Dataset):
def __init__(self, frame_paths):
self.frames = frame_paths

Expand All @@ -133,16 +136,15 @@ def __len__(self):


class EDVRPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None):
def __init__(self, output='output', weight_path=None, bs=1):
self.input = input
self.output = os.path.join(output, 'EDVR')

self.bs = bs
self.model = EDVRNet(nf=128, back_RBs=40)
if weight_path is None:
weight_path = get_path_from_url(EDVR_WEIGHT_URL)

self.weight_path = weight_path

self.build_inference_model()
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()

def run(self, video_path):
vid = video_path
Expand All @@ -163,23 +165,23 @@ def run(self, video_path):

frames = sorted(glob.glob(os.path.join(out_path, '*.png')))

dataset = EDVRDataset(frames)
test_dataset = EDVRDataset(frames)
dataset = DataLoader(test_dataset, batch_size=self.bs, num_workers=2)

periods = []
cur_time = time.time()
for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = [data[0]]

outs = self.base_forward(np.array(data_feed_in))

infer_result_list = [item for item in outs]

data_feed_in = paddle.to_tensor(data[0])
with paddle.no_grad():
outs = self.model(data_feed_in).numpy()
infer_result_list = [outs[i, :, :, :] for i in range(self.bs)]
frame_path = data[1]

img_i = get_img(infer_result_list[0])
save_img(
img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path)))
for i in range(self.bs):
img_i = get_img(infer_result_list[i])
save_img(
img_i,
os.path.join(pred_frame_path,
os.path.basename(frame_path[i])))

prev_time = cur_time
cur_time = time.time()
Expand Down

0 comments on commit 61d4939

Please sign in to comment.