Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#52 from chengduoZH/use_multi_processe…
Browse files Browse the repository at this point in the history
…s_run_rcnn

[multi process]Use multi process run mask rcnn
  • Loading branch information
chengduo authored May 21, 2019
2 parents cd098d8 + f14fe81 commit 0c12ff5
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 9 deletions.
45 changes: 45 additions & 0 deletions Mask-RCNN/paddle/rcnn/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid

def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)

def prepare_for_multi_process(exe, build_strategy, train_prog, startup_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
if num_trainers > 1:
nccl2_prepare(trainer_id,
startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
24 changes: 24 additions & 0 deletions Mask-RCNN/paddle/rcnn/run_multi_process.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!bin/bash
set -xe

#export FLAGS_cudnn_deterministic=true
#export FLAGS_enable_parallel_graph=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_memory_fraction_of_eager_deletion=1.0
export FLAGS_conv_workspace_size_limit=1500

base_batch_size=1

export CUDA_VISIBLE_DEVICES=0,1

device=${CUDA_VISIBLE_DEVICES//,/ }
arr=($device)
num_gpu_devices=${#arr[*]}

python -m paddle.distributed.launch --gpus ${num_gpu_devices} train.py \
--model_save_dir=output/ \
--pretrained_model=../imagenet_resnet50_fusebn/ \
--data_dir=./dataset/coco \
--im_per_batch=${base_batch_size} \
--MASK_ON=True
51 changes: 42 additions & 9 deletions Mask-RCNN/paddle/rcnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,30 @@
import models.resnet as resnet
from learning_rate import exponential_with_warmup_decay
from config import cfg
import dist_utils

def get_device_num():
visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers > 1 : return 1
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(['nvidia-smi','-L']).decode().count('\n')
return device_num

def update_lr(args):
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
args.learning_rate = args.learning_rate / num_trainers
# TODO(zcd): The loss_cls or loss maybe NAN, so we decreate the learning rate here.
# The reasons for this should be analyzed in depth.
if num_trainers > 1:
args.learning_rate = args.learning_rate / 10

def train():
update_lr(cfg)
learning_rate = cfg.learning_rate
image_shape = [3, cfg.TRAIN.max_size, cfg.TRAIN.max_size]

Expand All @@ -43,8 +64,7 @@ def train():
random.seed(0)
np.random.seed(0)

devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
devices_num = get_device_num()
total_batch_size = devices_num * cfg.TRAIN.im_per_batch

use_random = True
Expand Down Expand Up @@ -82,28 +102,40 @@ def train():
var.persistable = True

#fluid.memory_optimize(fluid.default_main_program(), skip_opt_set=set(fetch_list))

place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if not cfg.parallel: exe.run(fluid.default_startup_program())

if cfg.pretrained_model:

def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))

fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)

if cfg.parallel:
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False

if cfg.use_gpu:
dist_utils.prepare_for_multi_process(
exe,
build_strategy,
fluid.default_main_program(),
fluid.default_startup_program())

trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))

exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_iteration_per_drop_scope = 100
train_exe = fluid.ParallelExecutor(
use_cuda=bool(cfg.use_gpu), loss_name=loss.name, build_strategy=build_strategy, exec_strategy=exec_strategy)
train_exe = fluid.ParallelExecutor(use_cuda=bool(cfg.use_gpu),
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
num_trainers=num_trainers,
trainer_id=trainer_id)
else:
train_exe = exe

Expand Down Expand Up @@ -209,3 +241,4 @@ def train_loop():
args = parse_args()
print_arguments(args)
train()

0 comments on commit 0c12ff5

Please sign in to comment.