Skip to content

Commit

Permalink
add clip_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
songyuanwei committed Feb 16, 2023
1 parent 4b578f1 commit 8e7edf0
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 83 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def create_parser():
group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False,
help='training with ema (default=False)')
group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay')
group.add_argument('--use_clip_grad', type=str2bool, nargs='?', const=True, default=False,
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0, help='clip value')


# Optimize parameters
Expand Down
4 changes: 3 additions & 1 deletion mindcv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""mindcv init"""
from .data import *
from .engine import *
from .loss import *
from .models import *
from .optim import *
from .scheduler import *
from .utils import *
from .version import __version__

from . import data, loss, models, optim, scheduler
from . import data, engine, loss, models, optim, scheduler

__all__ = []
__all__.extend(data.__all__)
__all__.extend(engine.__all__)
__all__.extend(loss.__all__)
__all__.extend(models.__all__)
__all__.extend(optim.__all__)
Expand Down
6 changes: 6 additions & 0 deletions mindcv/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Engine Tools"""

from .callbacks import *
from .trainer import *

__all__ = []
70 changes: 1 addition & 69 deletions mindcv/utils/callbacks.py → mindcv/engine/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Callbacks for mindspore.Model"""
import os
from time import time
#import stat
import numpy as np

import mindspore as ms
Expand All @@ -10,8 +9,7 @@
from mindspore.train.callback import Callback
from mindspore.train._utils import _make_directory

from .checkpoint_manager import CheckpointManager
from .reduce_manager import AllReduceSum
from mindcv.utils import CheckpointManager, AllReduceSum

class StateMonitor(Callback):
"""
Expand Down Expand Up @@ -291,72 +289,6 @@ def remove_oldest_ckpoint_file(self):
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
self.remove_ckpoint_file(ckpoint_files[0])

class LossAccSummary(Callback):
''' A callback for recording loss and acc during training '''
def __init__(self,
summary_dir,
model,
dataset_val,
val_interval=1,
val_start_epoch=1,
metric_name="accuracy"):
super().__init__()
self._summary_dir = _make_directory(summary_dir, "summary_dir")
self.model = model
self.dataset_val = dataset_val
self.val_start_epoch = val_start_epoch
self.metric_name = metric_name
self.val_interval = val_interval

def __enter__(self):
self.summary_record = SummaryRecord(self._summary_dir)
return self

def __exit__(self, *exc_args):
self.summary_record.close()

def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
loss = self._get_loss(cb_params)
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.val_start_epoch and (cur_epoch - self.val_start_epoch) % self.val_interval == 0:
val_acc = self.model.eval(self.dataset_val)[self.metric_name]
if not isinstance(val_acc, Tensor):
val_acc = Tensor(val_acc)
self.summary_record.add_value('scalar', 'test_dataset_' + self.metric_name, val_acc)

self.summary_record.add_value('scalar', 'loss/auto', loss)
self.summary_record.record(cb_params.cur_step_num)

def _get_loss(self, cb_params):
"""
Get loss from the network output.
Args:
cb_params (_InternalCallbackParam): Callback parameters.
Returns:
Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
"""
output = cb_params.net_outputs
if output is None:
logger.warning("Can not find any output by this network, so SummaryCollector will not collect loss.")
return None

if isinstance(output, (int, float, Tensor)):
loss = output
elif isinstance(output, (list, tuple)) and output:
# If the output is a list, since the default network returns loss first,
# we assume that the first one is loss.
loss = output[0]
else:
logger.warning("The output type could not be identified, expect type is one of "
"[int, float, Tensor, list, tuple], so no loss was recorded in SummaryCollector.")
return None

if not isinstance(loss, Tensor):
loss = Tensor(loss)

loss = Tensor(np.mean(loss.asnumpy()))
return loss

class ValCallback(Callback):
def __init__(self, log_step_interval=100):
Expand Down
13 changes: 9 additions & 4 deletions mindcv/utils/ema.py → mindcv/engine/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""ema define"""

import mindspore as ms
from mindspore import nn, Tensor, Parameter, ParameterTuple
from mindspore import nn, Tensor, Parameter, ParameterTuple, ops
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
Expand Down Expand Up @@ -29,14 +29,17 @@ def tensor_grad_scale_row_tensor(scale, grad):
grad.dense_shape)


class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell):
class TrainOneStep(nn.TrainOneStepWithLossScaleCell):
"""TrainOneStepWithEMA"""

def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0):
super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense)
def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0,
use_clip_grad=False, clip_value=15):
super(TrainOneStep, self).__init__(network, optimizer, scale_sense)
self.use_ema = use_ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.use_clip_grad = use_clip_grad
self.clip_value = clip_value
if self.use_ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init='same')
Expand All @@ -61,6 +64,8 @@ def construct(self, *inputs):
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.use_clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
Expand Down
4 changes: 2 additions & 2 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utility Tools"""
from .path import *
from .download import *
from .callbacks import *
from .checkpoint_manager import *
from .reduce_manager import *
from .ema import *
from .amp import *
from .utils import *
from .random import *
11 changes: 6 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from mindcv.models import create_model
from mindcv.data import create_dataset, create_transforms, create_loader
from mindcv.engine import TrainOneStep, StateMonitor
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindcv.utils import StateMonitor, AllReduceSum, TrainOneStepWithEMA
from mindcv.utils.random import set_seed
from mindcv.utils import AllReduceSum, set_seed
from config import parse_args


Expand Down Expand Up @@ -216,12 +216,13 @@ def train(args):
eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy()}

# init model
if args.use_ema:
if args.use_ema or args.use_clip_grad:
net_with_loss = nn.WithLossCell(network, loss)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
ms.amp.auto_mixed_precision(net_with_loss, amp_level=args.amp_level)
net_with_loss = TrainOneStepWithEMA(net_with_loss, optimizer, scale_sense=loss_scale_manager,
use_ema=args.use_ema, ema_decay=args.ema_decay)
net_with_loss = TrainOneStep(net_with_loss, optimizer, scale_sense=loss_scale_manager,
use_ema=args.use_ema, ema_decay=args.ema_decay,
use_clip_grad=args.use_clip_grad, clip_value=args.clip_value)
eval_network = nn.WithEvalCell(network, loss, args.amp_level in ["O2", "O3", "auto"])
model = Model(net_with_loss, eval_network=eval_network, metrics=eval_metrics, eval_indexes=[0, 1, 2])
else:
Expand Down
4 changes: 2 additions & 2 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from mindcv.data import create_dataset, create_transforms, create_loader
from mindcv.loss import create_loss
from config import parse_args
from mindcv.utils.utils import check_batch_size
from mindcv.utils.callbacks import ValCallback
from mindcv.utils import check_batch_size
from mindcv.engine import ValCallback


def validate(args):
Expand Down

0 comments on commit 8e7edf0

Please sign in to comment.