From 8391af34005998a4416c9b735ca42c4c295c3ed3 Mon Sep 17 00:00:00 2001 From: awkrail Date: Wed, 2 Oct 2024 20:12:58 +0900 Subject: [PATCH] added finetune mode --- configs/dataset/tvsum.yml | 2 +- training/config.py | 9 +++++++-- training/train.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/configs/dataset/tvsum.yml b/configs/dataset/tvsum.yml index 035d4b5..246dc85 100644 --- a/configs/dataset/tvsum.yml +++ b/configs/dataset/tvsum.yml @@ -7,7 +7,7 @@ max_v_l: 1000 seed: 2018 lr: 0.001 lr_drop: 2000 -n_epoch: 2000 +n_epoch: 1000 bsz: 4 domains: - BK diff --git a/training/config.py b/training/config.py index 2575ec7..02a65c6 100755 --- a/training/config.py +++ b/training/config.py @@ -26,10 +26,11 @@ from easydict import EasyDict class BaseOptions(object): - def __init__(self, model, dataset, feature): + def __init__(self, model, dataset, feature, resume): self.model = model self.dataset = dataset self.feature = feature + self.resume = resume self.opt = {} @property @@ -55,7 +56,11 @@ def parse(self): self.opt = EasyDict(self.opt) # result directory - self.opt.results_dir = os.path.join(self.opt.results_dir, self.model, self.dataset, self.feature) + if self.resume: + self.opt.results_dir = os.path.join(self.opt.results_dir, self.model, f"{self.dataset}_finetune", self.feature) + else: + self.opt.results_dir = os.path.join(self.opt.results_dir, self.model, self.dataset, self.feature) + self.opt.ckpt_filepath = os.path.join(self.opt.results_dir, self.opt.ckpt_filename) self.opt.train_log_filepath = os.path.join(self.opt.results_dir, self.opt.train_log_filename) self.opt.eval_log_filepath = os.path.join(self.opt.results_dir, self.opt.eval_log_filename) diff --git a/training/train.py b/training/train.py index e675d74..0d3e46c 100755 --- a/training/train.py +++ b/training/train.py @@ -289,7 +289,7 @@ def check_valid_combination(dataset, feature): is_valid = check_valid_combination(args.dataset, args.feature) if is_valid: - option_manager = BaseOptions(args.model, args.dataset, args.feature) + option_manager = BaseOptions(args.model, args.dataset, args.feature, args.resume) option_manager.parse() option_manager.clean_and_makedirs()