Skip to content

Commit

Permalink
add epochs_no_optarch to replace method in darts/train_search (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
baiyfbupt authored Apr 8, 2020
1 parent 7d1ec56 commit 84d7653
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
4 changes: 2 additions & 2 deletions demo/darts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
``` bash
python search.py # DARTS一阶近似搜索方法
python search.py --unrolled=True # DARTS的二阶近似搜索方法
python search.py --method='PC-DARTS' # PC-DARTS搜索方法
python search.py --method='PC-DARTS' --batch_size=256 --learning_rate=0.1 --arch_learning_rate=6e-4 --epochs_no_archopt=15 # PC-DARTS搜索方法
```

模型结构随搜索轮数的变化如图1所示。需要注意的是,图中准确率Acc并不代表该结构最终准确率,为了获得当前结构的最佳准确率,请对得到的genotype做网络结构评估训练。
Expand Down Expand Up @@ -86,4 +86,4 @@ def train_search(batch_size, train_portion, is_shuffle, args):
python visualize.py PC-DARTS
```

`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
`PC-DARTS`代表某个Genotype结构,需要预先添加到genotype.py中
6 changes: 4 additions & 2 deletions demo/darts/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
add_arg('model_save_dir', str, 'search_cifar', "The path to save model.")
add_arg('grad_clip', float, 5, "Gradient clipping.")
add_arg('arch_learning_rate',float, 3e-4, "Learning rate for arch encoding.")
add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('method', str, 'DARTS', "The search method you would like to use")
add_arg('epochs_no_archopt', int, 0, "Epochs not optimize the arch params")
add_arg('cutout_length', int, 16, "Cutout length.")
add_arg('cutout', ast.literal_eval, False, "Whether use cutout.")
add_arg('unrolled', ast.literal_eval, False, "Use one-step unrolled validation loss")
Expand Down Expand Up @@ -84,8 +85,9 @@ def main(args):
num_imgs=args.trainset_num,
arch_learning_rate=args.arch_learning_rate,
unrolled=args.unrolled,
method=args.method,
num_epochs=args.epochs,
epochs_no_archopt=args.epochs_no_archopt,
use_gpu=args.use_gpu,
use_data_parallel=args.use_data_parallel,
log_freq=args.log_freq)
searcher.train()
Expand Down
10 changes: 3 additions & 7 deletions paddleslim/nas/darts/train_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from .architect import Architect
logger = get_logger(__name__, level=logging.INFO)

SUPPORTED_METHODS = ["PC-DARTS", "DARTS"]


def count_parameters_in_MB(all_params):
parameters_number = 0
Expand All @@ -47,8 +45,8 @@ def __init__(self,
num_imgs=50000,
arch_learning_rate=3e-4,
unrolled='False',
method='DARTS',
num_epochs=50,
epochs_no_archopt=0,
use_gpu=True,
use_data_parallel=False,
log_freq=50):
Expand All @@ -60,9 +58,7 @@ def __init__(self,
self.num_imgs = num_imgs
self.arch_learning_rate = arch_learning_rate
self.unrolled = unrolled
self.method = method
assert (self.method in SUPPORTED_METHODS
), "Currently only support PC-DARTS, DARTS two methods"
self.epochs_no_archopt = epochs_no_archopt
self.num_epochs = num_epochs
self.use_gpu = use_gpu
self.use_data_parallel = use_data_parallel
Expand Down Expand Up @@ -94,7 +90,7 @@ def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
valid_label.stop_gradient = True
n = train_image.shape[0]

if not (self.method == "PC-DARTS" and epoch < 15):
if epoch >= self.epochs_no_archopt:
architect.step(train_image, train_label, valid_image,
valid_label)

Expand Down

0 comments on commit 84d7653

Please sign in to comment.