Skip to content

Commit

Permalink
Fix some bugs (PaddlePaddle#140)
Browse files Browse the repository at this point in the history
* fix some bugs

* update configs
  • Loading branch information
LielinJiang authored Jan 6, 2021
1 parent cd642c0 commit 89dbb63
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 23 deletions.
2 changes: 1 addition & 1 deletion configs/cyclegan_cityscapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dataset:
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
Expand Down
4 changes: 2 additions & 2 deletions configs/cyclegan_horse2zebra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dataset:
batch_size: 1
is_train: True
max_size: inf
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
Expand Down Expand Up @@ -67,7 +67,7 @@ dataset:
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
Expand Down
2 changes: 1 addition & 1 deletion configs/pix2pix_cityscapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
Expand Down
2 changes: 1 addition & 1 deletion configs/pix2pix_cityscapes_2gpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
Expand Down
2 changes: 1 addition & 1 deletion configs/pix2pix_facades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dataset:
dataroot: data/facades/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
Expand Down
3 changes: 2 additions & 1 deletion ppgan/apps/realsr_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def run_image(self, img):

img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
with paddle.no_grad():
out = self.model(x)

pred_img = self.denorm(out.numpy()[0])
pred_img = Image.fromarray(pred_img)
Expand Down
37 changes: 24 additions & 13 deletions ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, cfg):
self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch

self.validate_interval = -1
if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1)
Expand Down Expand Up @@ -177,16 +180,12 @@ def train(self):

self.model.lr_scheduler.step()

if self.by_epoch:
temp = self.current_epoch
else:
temp = self.current_iter
if self.validate_interval > -1 and temp % self.validate_interval == 0:
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test()

if temp % self.weight_interval == 0:
self.save(temp, 'weight', keep=-1)
self.save(temp)
if self.current_iter % self.weight_interval == 0:
self.save(self.current_iter, 'weight', keep=-1)
self.save(self.current_iter)

self.current_iter += 1

Expand Down Expand Up @@ -335,7 +334,12 @@ def save(self, epoch, name='checkpoint', keep=1):
assert name in ['checkpoint', 'weight']

state_dicts = {}
save_filename = 'epoch_%s_%s.pdparams' % (epoch, name)
if self.by_epoch:
save_filename = 'epoch_%s_%s.pdparams' % (
epoch // self.iters_per_epoch, name)
else:
save_filename = 'iter_%s_%s.pdparams' % (epoch, name)

save_path = os.path.join(self.output_dir, save_filename)
for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict()
Expand All @@ -353,9 +357,16 @@ def save(self, epoch, name='checkpoint', keep=1):

if keep > 0:
try:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir,
'epoch_%s_%s.pdparams' % (epoch - keep, name))
if self.by_epoch:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'epoch_%s_%s.pdparams' %
((epoch - keep * self.weight_interval) //
self.iters_per_epoch, name))
else:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'iter_%s_%s.pdparams' %
(epoch - keep * self.weight_interval, name))

if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed)

Expand All @@ -366,7 +377,7 @@ def resume(self, checkpoint_path):
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.steps_per_epoch * state_dicts['epoch']
self.global_steps = self.iters_per_epoch * state_dicts['epoch']

for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
Expand Down
3 changes: 2 additions & 1 deletion ppgan/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

def parse_args():
parser = argparse.ArgumentParser(description='PaddleGAN')
parser.add_argument('--config-file',
parser.add_argument('-c',
'--config-file',
metavar="FILE",
help='config file path')
# cuda setting
Expand Down
6 changes: 4 additions & 2 deletions ppgan/utils/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def setup(args, cfg):
cfg.is_train = True

cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir,
str(cfg.model.name) + cfg.timestamp)
cfg.output_dir = os.path.join(
cfg.output_dir,
os.path.splitext(os.path.basename(str(args.config_file)))[0] +
cfg.timestamp)

logger = setup_logger(cfg.output_dir)

Expand Down

0 comments on commit 89dbb63

Please sign in to comment.