Skip to content

Commit

Permalink
fix the bug of sparse momentum for partial fc and update readme (Padd…
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Nov 22, 2022
1 parent e35e218 commit 9c0cb49
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 78 deletions.
4 changes: 4 additions & 0 deletions plsc/engine/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

def defualt_train_one_epoch(engine, epoch_id):
tic = time.time()

if hasattr(engine.train_dataloader.batch_sampler, "set_epoch"):
engine.train_dataloader.batch_sampler.set_epoch(epoch_id)

for iter_id, batch in enumerate(engine.train_dataloader):

if iter_id >= engine.max_iter:
Expand Down
7 changes: 5 additions & 2 deletions plsc/engine/recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@

def defualt_train_one_epoch(engine, epoch_id):
tic = time.time()
dev_id = paddle.distributed.ParallelEnv().dev_id

if hasattr(engine.train_dataloader.batch_sampler, "set_epoch"):
engine.train_dataloader.batch_sampler.set_epoch(epoch_id)

for iter_id, batch in enumerate(engine.train_dataloader):
for i in range(len(batch)):
batch[i] = batch[i].cuda(dev_id)
batch[i] = batch[i].cuda()

if iter_id >= engine.max_iter:
break
Expand Down
55 changes: 41 additions & 14 deletions plsc/models/layers/partialfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,15 @@ def __init__(self,
if not self.model_parallel:
self.group = False

self.num_local: int = (num_classes + world_size - 1) // world_size
if num_classes % world_size != 0 and rank == world_size - 1:
self.num_local = num_classes % self.num_local
self.num_sample: int = int(self.sample_ratio * self.num_local)

self.rank = rank
self.world_size = world_size

self.num_local: int = num_classes // self.world_size + int(
self.rank < num_classes % self.world_size)
self.class_start: int = num_classes // self.world_size * self.rank + min(
self.rank, num_classes % self.world_size)
self.num_sample: int = int(self.sample_ratio * self.num_local)

if model_parallel and world_size > 0:
if name is None:
name = 'dist@partialfc@rank@%05d.w' % rank
Expand All @@ -140,13 +141,14 @@ def __init__(self,
if name is None:
name = 'partialfc.w'

stddev = math.sqrt(2.0 / (self.embedding_size + self.num_local))
param_attr = paddle.ParamAttr(
name=name, initializer=paddle.nn.initializer.Normal(std=stddev))
name=name,
initializer=paddle.nn.initializer.Normal(
mean=0, std=0.01))

self.index = None
self.weight = self.create_parameter(
shape=[self.embedding_size, self.num_local],
shape=[self.num_local, self.embedding_size],
attr=param_attr,
is_bias=False)
self.weight.is_distributed = self.model_parallel
Expand All @@ -158,6 +160,31 @@ def __init__(self,
self.weight.stop_gradient = True
self.sub_weight = None

@paddle.no_grad()
def class_center_sample(self, labels):

labels = labels.reshape((-1, 1))
index_positive = (self.class_start <= labels) & (
labels < self.class_start + self.num_local)

local_label = labels[index_positive] - self.class_start

positive = paddle.unique(local_label)
if self.num_sample - positive.shape[0] >= 0:
perm = paddle.rand([self.num_local])
perm[positive] = 2.0
index = paddle.topk(perm, k=self.num_sample)[1]
index = paddle.sort(index)
else:
index = positive

local_sampled_ids = index + self.class_start
sampled_ids = all_gather(local_sampled_ids, axis=0)

labels = paddle.searchsorted(sampled_ids, labels)

return labels, index

def forward(self, feature, label):
if self.model_parallel:
total_feature = all_gather(feature, axis=0)
Expand All @@ -170,11 +197,10 @@ def forward(self, feature, label):

if self.sample_ratio < 1.0:
# partial fc sample process
total_label, self.index = paddle.nn.functional.class_center_sample(
total_label, self.num_local, self.num_sample, group=self.group)
total_label, self.index = self.class_center_sample(total_label)
total_label.stop_gradient = True
self.index.stop_gradient = True
self.sub_weight = paddle.gather(self.weight, self.index, axis=1)
self.sub_weight = paddle.gather(self.weight, self.index, axis=0)

# NOTE(GuoxiaWang): stop generate the full gradient
# when use partial fc in model parallel,
Expand All @@ -184,7 +210,7 @@ def forward(self, feature, label):

def sparse_grad_hook_fn():
setattr(self.weight, 'index', self.index)
setattr(self.weight, 'axis', 1)
setattr(self.weight, 'axis', 0)
self.weight._set_grad_ivar(self.sub_weight.grad)

self.sub_weight._register_backward_hook(sparse_grad_hook_fn)
Expand All @@ -193,7 +219,8 @@ def sparse_grad_hook_fn():
self.sub_weight = self.weight

norm_feature = paddle.fluid.layers.l2_normalize(total_feature, axis=1)
norm_weight = paddle.fluid.layers.l2_normalize(self.sub_weight, axis=0)
norm_weight = paddle.fluid.layers.l2_normalize(self.sub_weight, axis=1)

local_logit = paddle.matmul(norm_feature, norm_weight)
local_logit = paddle.matmul(
norm_feature, norm_weight, transpose_y=True)
return local_logit, total_label
49 changes: 39 additions & 10 deletions plsc/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,42 @@ def step(self):
paddle.float16, paddle.bfloat16
}:
master_param = state['master_param']
_, _, _, _, _, _ = _C_ops.adamw(
p, grad,
paddle.to_tensor(lr), exp_avg, exp_avg_sq, beta1_pow,
beta2_pow, master_param, p, exp_avg, exp_avg_sq, beta1_pow,
beta2_pow, master_param, 'epsilon', group['eps'],
'lazy_mode', False, 'min_row_size_to_use_multithread',
1000, 'beta1', beta1, 'beta2', beta2, "with_decay",
with_decay, 'coeff', group['weight_decay'],
'multi_precision', master_param is not None, 'lr_ratio',
1.0)

if getattr(p, 'has_sparse_grad', None):
index = getattr(p, 'index', None)
axis = getattr(p, 'axis', None)
assert axis == 0, 'Only support axis=0 now!'
assert index is not None
assert axis is not None

sub_p = paddle.gather(p, index, axis=axis)
sub_exp_avg = paddle.gather(exp_avg, index, axis=axis)
sub_exp_avg_sq = paddle.gather(
exp_avg_sq, index, axis=axis)

_, _, _, _, _, _ = _C_ops.adamw(
sub_p, grad,
paddle.to_tensor(lr), sub_exp_avg, sub_exp_avg_sq,
beta1_pow, beta2_pow, master_param, sub_p, sub_exp_avg,
sub_exp_avg_sq, beta1_pow, beta2_pow, master_param,
'epsilon', group['eps'], 'lazy_mode', False,
'min_row_size_to_use_multithread', 1000, 'beta1',
beta1, 'beta2', beta2, "with_decay", with_decay,
'coeff', group['weight_decay'], 'multi_precision',
master_param is not None, 'lr_ratio', 1.0)

p.scatter_(index, sub_p)
exp_avg.scatter_(index, sub_exp_avg)
exp_avg_sq.scatter_(index, sub_exp_avg_sq)

else:
_, _, _, _, _, _ = _C_ops.adamw(
p, grad,
paddle.to_tensor(lr), exp_avg, exp_avg_sq, beta1_pow,
beta2_pow, master_param, p, exp_avg, exp_avg_sq,
beta1_pow, beta2_pow, master_param, 'epsilon',
group['eps'], 'lazy_mode', False,
'min_row_size_to_use_multithread', 1000, 'beta1',
beta1, 'beta2', beta2, "with_decay", with_decay,
'coeff', group['weight_decay'], 'multi_precision',
master_param is not None, 'lr_ratio', 1.0)
63 changes: 40 additions & 23 deletions plsc/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,48 @@ def step(self):
if getattr(p, 'has_sparse_grad', None):
index = getattr(p, 'index', None)
axis = getattr(p, 'axis', None)
assert axis == 0, 'Only support axis=0 now!'
assert index is not None
assert axis is not None
_, _, _ = _C_ops.sparse_momentum(
p,
grad,
exp_avg,
index,
paddle.to_tensor(
lr, dtype='float32'),
master_param,
p,
exp_avg,
master_param,
'mu',
momentum,
'use_nesterov',
False,
'regularization_method',
'l2_decay',
'regularization_coeff',
group['weight_decay'],
'axis',
axis,
'multi_precision',
master_param is not None)
sub_p = paddle.gather(p, index, axis=axis)
sub_exp_avg = paddle.gather(exp_avg, index, axis=axis)

if group['weight_decay'] != 0.0:
grad = (grad + group['weight_decay'] * sub_p
).astype(grad.dtype)

if initialized is False:
sub_exp_avg.copy_(grad, False)
else:
sub_exp_avg.copy_(sub_exp_avg * momentum + grad, False)
sub_p.copy_(sub_p - lr * sub_exp_avg, False)

p.scatter_(index, sub_p)
exp_avg.scatter_(index, sub_exp_avg)

# _, _, _ = _C_ops.sparse_momentum(
# p,
# grad,
# exp_avg,
# index,
# paddle.to_tensor(
# lr, dtype='float32'),
# master_param,
# p,
# exp_avg,
# master_param,
# 'mu',
# momentum,
# 'use_nesterov',
# False,
# 'regularization_method',
# 'l2_decay',
# 'regularization_coeff',
# group['weight_decay'],
# 'axis',
# axis,
# 'multi_precision',
# master_param is not None)
else:
p_fp32 = p
if group['use_master_param'] and p.dtype in {
Expand Down
59 changes: 31 additions & 28 deletions plsc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,34 +159,37 @@ def save_checkpoint(net,

keep_prefixs = ['best', 'latest']

if all(p not in prefix for p in keep_prefixs) and max_num_checkpoint >= 0:
pdstates_list = glob.glob(os.path.join(model_dir, '*.pdstates'))

timestamp_to_path = {}
for path in pdstates_list:
if any(p in path for p in keep_prefixs):
continue
metric_dict = paddle.load(path)
timestamp_to_path[metric_dict['timestamp']] = path[:-9]

# sort by ascend
timestamps = list(timestamp_to_path.keys())
timestamps.sort()

if max_num_checkpoint > 0:
to_remove = timestamps[:-max_num_checkpoint]
else:
to_remove = timestamps
for timestamp in to_remove:
model_prefix = timestamp_to_path[timestamp]
for ext in ['.pdparams', '.pdopt', '.pdlr', '.pdstates']:
path = model_prefix + ext
_remove_if_exist(path)

if ext in ['.pdparams', '.pdopt']:
for rank_id in range(world_size):
path = model_prefix + "_rank{}".format(rank_id) + ext
_remove_if_exist(path)
if local_rank == 0:
if all(p not in prefix
for p in keep_prefixs) and max_num_checkpoint >= 0:
pdstates_list = glob.glob(os.path.join(model_dir, '*.pdstates'))

timestamp_to_path = {}
for path in pdstates_list:
if any(p in path for p in keep_prefixs):
continue
metric_dict = paddle.load(path)
timestamp_to_path[metric_dict['timestamp']] = path[:-9]

# sort by ascend
timestamps = list(timestamp_to_path.keys())
timestamps.sort()

if max_num_checkpoint > 0:
to_remove = timestamps[:-max_num_checkpoint]
else:
to_remove = timestamps
for timestamp in to_remove:
model_prefix = timestamp_to_path[timestamp]
for ext in ['.pdparams', '.pdopt', '.pdlr', '.pdstates']:
path = model_prefix + ext
_remove_if_exist(path)

if ext in ['.pdparams', '.pdopt']:
for rank_id in range(world_size):
path = model_prefix + "_rank{}".format(
rank_id) + ext
_remove_if_exist(path)


def export(config, net, path):
Expand Down
3 changes: 2 additions & 1 deletion task/recognition/face/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ python onnx_ijbc.py \
| Datasets | Backbone | Config | Devices | PFC | IJB-C(1E-4) | IJB-C(1E-5) | checkpoint | log |
| :--------: | :---------------------- | ------------------------------------------------------------ | --------- | ---- | ----------- | :---------- | :----------------------------------------------------------- | ------------------------------------------------------------ |
| MS1MV3 | Res50 | [config](./configs/IResNet50_MS1MV3_ArcFace_pfc10_1n8c_dp_mp_fp16o1.yaml) | N1C8*A100 | 1.0 | 96.52 | 94.60 | [download](https://plsc.bj.bcebos.com/models/face/v2.4/IResNet50_MS1MV3_ArcFace_pfc10_1n8c_dp_mp_fp16o1.pdparams) | [download](https://plsc.bj.bcebos.com/models/face/v2.4/IResNet50_MS1MV3_ArcFace_pfc10_1n8c_dp_mp_fp16o1.log) |
| WebFace42M | FaceViT_tiny_patch9_112 | [config](./configs/FaceViT_tiny_patch9_112_WebFace42M_ArcFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.yaml) | N1C8*A100 | 1.0 | 97.24 | 95.79 | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_ArcFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.pdparams) | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_ArcFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.log) |
| WebFace42M | FaceViT_tiny_patch9_112 | [config](./configs/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.yaml) | N1C8*A100 | 1.0 | 97.24 | 95.79 | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.pdparams) | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc10_droppath005_mask0_1n8c_dp_mp_fp16o1.log) |
| WebFace42M | FaceViT_tiny_patch9_112 | [config](./configs/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc02_droppath005_mask0_1n8c_dp_mp_fp16o1.yaml) | N1C8*A100 | 0.2 | 97.28 | 95.79 | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc02_droppath005_mask0_1n8c_dp_mp_fp16o1.pdparams) | [download](https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_WebFace42M_CosFace_pfc02_droppath005_mask0_1n8c_dp_mp_fp16o1.log) |

## Citations

Expand Down
Loading

0 comments on commit 9c0cb49

Please sign in to comment.