Skip to content

Commit

Permalink
add clip_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
songyuanwei committed Feb 28, 2023
1 parent d9e6ca0 commit 8aa9500
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 100 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def create_parser():
group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False,
help='training with ema (default=False)')
group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay')
group.add_argument('--use_clip_grad', type=str2bool, nargs='?', const=True, default=False,
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0, help='clip value')

# Optimize parameters
group = parser.add_argument_group('Optimizer parameters')
Expand Down
4 changes: 3 additions & 1 deletion mindcv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""mindcv init"""
from . import data, loss, models, optim, scheduler
from . import data, engine, loss, models, optim, scheduler, utils
from .data import *
from .engine import *
from .loss import *
from .models import *
from .optim import *
Expand All @@ -10,6 +11,7 @@

__all__ = []
__all__.extend(data.__all__)
__all__.extend(engine.__all__)
__all__.extend(loss.__all__)
__all__.extend(models.__all__)
__all__.extend(optim.__all__)
Expand Down
6 changes: 6 additions & 0 deletions mindcv/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Engine Tools"""

from .callbacks import *
from .train_step import *

__all__ = []
77 changes: 1 addition & 76 deletions mindcv/utils/callbacks.py → mindcv/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from mindspore import ParameterTuple, SummaryRecord, Tensor, load_param_into_net
from mindspore import log as logger
from mindspore import ops, save_checkpoint
from mindspore.train._utils import _make_directory
from mindspore.train.callback import Callback

from .checkpoint_manager import CheckpointManager
from .reduce_manager import AllReduceSum
from mindcv.utils import AllReduceSum, CheckpointManager


class StateMonitor(Callback):
Expand Down Expand Up @@ -305,79 +303,6 @@ def remove_oldest_ckpoint_file(self):
self.remove_ckpoint_file(ckpoint_files[0])


class LossAccSummary(Callback):
"""A callback for recording loss and acc during training"""

def __init__(
self,
summary_dir,
model,
dataset_val,
val_interval=1,
val_start_epoch=1,
metric_name="accuracy",
):
super().__init__()
self._summary_dir = _make_directory(summary_dir, "summary_dir")
self.model = model
self.dataset_val = dataset_val
self.val_start_epoch = val_start_epoch
self.metric_name = metric_name
self.val_interval = val_interval

def __enter__(self):
self.summary_record = SummaryRecord(self._summary_dir)
return self

def __exit__(self, *exc_args):
self.summary_record.close()

def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
loss = self._get_loss(cb_params)
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.val_start_epoch and (cur_epoch - self.val_start_epoch) % self.val_interval == 0:
val_acc = self.model.eval(self.dataset_val)[self.metric_name]
if not isinstance(val_acc, Tensor):
val_acc = Tensor(val_acc)
self.summary_record.add_value("scalar", "test_dataset_" + self.metric_name, val_acc)

self.summary_record.add_value("scalar", "loss/auto", loss)
self.summary_record.record(cb_params.cur_step_num)

def _get_loss(self, cb_params):
"""
Get loss from the network output.
Args:
cb_params (_InternalCallbackParam): Callback parameters.
Returns:
Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
"""
output = cb_params.net_outputs
if output is None:
logger.warning("Can not find any output by this network, so SummaryCollector will not collect loss.")
return None

if isinstance(output, (int, float, Tensor)):
loss = output
elif isinstance(output, (list, tuple)) and output:
# If the output is a list, since the default network returns loss first,
# we assume that the first one is loss.
loss = output[0]
else:
logger.warning(
"The output type could not be identified, expect type is one of "
"[int, float, Tensor, list, tuple], so no loss was recorded in SummaryCollector."
)
return None

if not isinstance(loss, Tensor):
loss = Tensor(loss)

loss = Tensor(np.mean(loss.asnumpy()))
return loss


class ValCallback(Callback):
def __init__(self, log_step_interval=100):
super().__init__()
Expand Down
27 changes: 20 additions & 7 deletions mindcv/utils/ema.py → mindcv/engine/train_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""ema define"""

"""Ema define"""
import mindspore as ms
from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
Expand Down Expand Up @@ -32,14 +31,26 @@ def tensor_grad_scale_row_tensor(scale, grad):
)


class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell):
"""TrainOneStepWithEMA"""
class TrainStep(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad."""

def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0):
super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense)
def __init__(
self,
network,
optimizer,
scale_sense=1.0,
use_ema=False,
ema_decay=0.9999,
updates=0,
use_clip_grad=False,
clip_value=15.0,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.use_ema = use_ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.use_clip_grad = use_clip_grad
self.clip_value = clip_value
if self.use_ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")
Expand All @@ -62,6 +73,8 @@ def construct(self, *inputs):
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.use_clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
Expand Down
4 changes: 2 additions & 2 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utility Tools"""
from .amp import *
from .callbacks import *
from .checkpoint_manager import *
from .download import *
from .ema import *
from .path import *
from .random import *
from .reduce_manager import *
from .utils import *
2 changes: 1 addition & 1 deletion tests/tasks/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@pytest.mark.parametrize("mode", ["GRAPH", "PYNATIVE_FUNC"])
@pytest.mark.parametrize("val_while_train", [True, False])
def test_train(mode, val_while_train, model="resnet18", opt="adamw", scheduler="polynomial"):
def test_train(mode, val_while_train, model="resnet18"):
"""train on a imagenet subset dataset"""
# prepare data
data_dir = "data/Canidae"
Expand Down
17 changes: 12 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from mindspore.communication import get_group_size, get_rank, init

from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.engine import StateMonitor, TrainStep
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindcv.utils import AllReduceSum, StateMonitor, TrainOneStepWithEMA
from mindcv.utils.random import set_seed
from mindcv.utils import AllReduceSum, set_seed

from config import parse_args # isort: skip

Expand Down Expand Up @@ -227,13 +227,20 @@ def train(args):
eval_metrics = {"Top_1_Accuracy": nn.Top1CategoricalAccuracy()}

# init model
if args.use_ema:
if args.use_ema or args.use_clip_grad:
net_with_loss = nn.WithLossCell(network, loss)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
ms.amp.auto_mixed_precision(net_with_loss, amp_level=args.amp_level)
net_with_loss = TrainOneStepWithEMA(
net_with_loss, optimizer, scale_sense=loss_scale_manager, use_ema=args.use_ema, ema_decay=args.ema_decay
net_with_loss = TrainStep(
net_with_loss,
optimizer,
scale_sense=loss_scale_manager,
use_ema=args.use_ema,
ema_decay=args.ema_decay,
use_clip_grad=args.use_clip_grad,
clip_value=args.clip_value,
)

eval_network = nn.WithEvalCell(network, loss, args.amp_level in ["O2", "O3", "auto"])
model = Model(net_with_loss, eval_network=eval_network, metrics=eval_metrics, eval_indexes=[0, 1, 2])
else:
Expand Down
8 changes: 2 additions & 6 deletions train_with_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mindspore as ms
from mindspore import SummaryRecord, Tensor, nn, ops
from mindspore.amp import StaticLossScaler
from mindspore.communication import get_group_size, get_rank, init
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean

Expand Down Expand Up @@ -199,14 +200,9 @@ def train(args):
checkpoint_path=opt_ckpt_path,
)

from mindspore.amp import DynamicLossScaler, StaticLossScaler

# set loss scale for mixed precision training
if args.amp_level != "O0":
if args.dynamic_loss_scale:
loss_scaler = DynamicLossScaler(args.loss_scale, 2, 1000)
else:
loss_scaler = StaticLossScaler(args.loss_scale)
loss_scaler = StaticLossScaler(args.loss_scale)
else:
loss_scaler = NoLossScaler()

Expand Down
4 changes: 2 additions & 2 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from mindspore import Model

from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.engine import ValCallback
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.utils.callbacks import ValCallback
from mindcv.utils.utils import check_batch_size
from mindcv.utils import check_batch_size

from config import parse_args # isort: skip

Expand Down

0 comments on commit 8aa9500

Please sign in to comment.