Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add clip_grad #469

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def create_parser():
group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False,
help='training with ema (default=False)')
group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay')
group.add_argument('--use_clip_grad', type=str2bool, nargs='?', const=True, default=False,
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0, help='clip value')

# Optimize parameters
group = parser.add_argument_group('Optimizer parameters')
Expand Down
4 changes: 3 additions & 1 deletion mindcv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""mindcv init"""
from . import data, loss, models, optim, scheduler
from . import data, engine, loss, models, optim, scheduler, utils
from .data import *
from .engine import *
from .loss import *
from .models import *
from .optim import *
Expand All @@ -10,6 +11,7 @@

__all__ = []
__all__.extend(data.__all__)
__all__.extend(engine.__all__)
__all__.extend(loss.__all__)
__all__.extend(models.__all__)
__all__.extend(optim.__all__)
Expand Down
6 changes: 6 additions & 0 deletions mindcv/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Engine Tools"""

from .callbacks import *
from .train_step import *

__all__ = []
77 changes: 1 addition & 76 deletions mindcv/utils/callbacks.py → mindcv/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from mindspore import ParameterTuple, SummaryRecord, Tensor, load_param_into_net
from mindspore import log as logger
from mindspore import ops, save_checkpoint
from mindspore.train._utils import _make_directory
from mindspore.train.callback import Callback

from .checkpoint_manager import CheckpointManager
from .reduce_manager import AllReduceSum
from mindcv.utils import AllReduceSum, CheckpointManager


class StateMonitor(Callback):
Expand Down Expand Up @@ -305,79 +303,6 @@ def remove_oldest_ckpoint_file(self):
self.remove_ckpoint_file(ckpoint_files[0])


class LossAccSummary(Callback):
"""A callback for recording loss and acc during training"""

def __init__(
self,
summary_dir,
model,
dataset_val,
val_interval=1,
val_start_epoch=1,
metric_name="accuracy",
):
super().__init__()
self._summary_dir = _make_directory(summary_dir, "summary_dir")
self.model = model
self.dataset_val = dataset_val
self.val_start_epoch = val_start_epoch
self.metric_name = metric_name
self.val_interval = val_interval

def __enter__(self):
self.summary_record = SummaryRecord(self._summary_dir)
return self

def __exit__(self, *exc_args):
self.summary_record.close()

def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
loss = self._get_loss(cb_params)
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.val_start_epoch and (cur_epoch - self.val_start_epoch) % self.val_interval == 0:
val_acc = self.model.eval(self.dataset_val)[self.metric_name]
if not isinstance(val_acc, Tensor):
val_acc = Tensor(val_acc)
self.summary_record.add_value("scalar", "test_dataset_" + self.metric_name, val_acc)

self.summary_record.add_value("scalar", "loss/auto", loss)
self.summary_record.record(cb_params.cur_step_num)

def _get_loss(self, cb_params):
"""
Get loss from the network output.
Args:
cb_params (_InternalCallbackParam): Callback parameters.
Returns:
Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
"""
output = cb_params.net_outputs
if output is None:
logger.warning("Can not find any output by this network, so SummaryCollector will not collect loss.")
return None

if isinstance(output, (int, float, Tensor)):
loss = output
elif isinstance(output, (list, tuple)) and output:
# If the output is a list, since the default network returns loss first,
# we assume that the first one is loss.
loss = output[0]
else:
logger.warning(
"The output type could not be identified, expect type is one of "
"[int, float, Tensor, list, tuple], so no loss was recorded in SummaryCollector."
)
return None

if not isinstance(loss, Tensor):
loss = Tensor(loss)

loss = Tensor(np.mean(loss.asnumpy()))
return loss


class ValCallback(Callback):
def __init__(self, log_step_interval=100):
super().__init__()
Expand Down
27 changes: 20 additions & 7 deletions mindcv/utils/ema.py → mindcv/engine/train_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""ema define"""

"""Ema define"""
import mindspore as ms
from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
Expand Down Expand Up @@ -32,14 +31,26 @@ def tensor_grad_scale_row_tensor(scale, grad):
)


class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell):
"""TrainOneStepWithEMA"""
class TrainStep(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad."""

def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0):
super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense)
def __init__(
self,
network,
optimizer,
scale_sense=1.0,
use_ema=False,
ema_decay=0.9999,
updates=0,
use_clip_grad=False,
clip_value=15.0,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.use_ema = use_ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.use_clip_grad = use_clip_grad
self.clip_value = clip_value
if self.use_ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")
Expand All @@ -62,6 +73,8 @@ def construct(self, *inputs):
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.use_clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
Expand Down
4 changes: 2 additions & 2 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Utility Tools"""
from .amp import *
from .callbacks import *
from .checkpoint_manager import *
from .download import *
from .ema import *
from .path import *
from .random import *
from .reduce_manager import *
from .utils import *
2 changes: 1 addition & 1 deletion tests/tasks/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@pytest.mark.parametrize("mode", ["GRAPH", "PYNATIVE_FUNC"])
@pytest.mark.parametrize("val_while_train", [True, False])
def test_train(mode, val_while_train, model="resnet18", opt="adamw", scheduler="polynomial"):
def test_train(mode, val_while_train, model="resnet18"):
"""train on a imagenet subset dataset"""
# prepare data
data_dir = "data/Canidae"
Expand Down
17 changes: 12 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from mindspore.communication import get_group_size, get_rank, init

from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.engine import StateMonitor, TrainStep
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindcv.utils import AllReduceSum, StateMonitor, TrainOneStepWithEMA
from mindcv.utils.random import set_seed
from mindcv.utils import AllReduceSum, set_seed

from config import parse_args # isort: skip

Expand Down Expand Up @@ -227,13 +227,20 @@ def train(args):
eval_metrics = {"Top_1_Accuracy": nn.Top1CategoricalAccuracy()}

# init model
if args.use_ema:
if args.use_ema or args.use_clip_grad:
net_with_loss = nn.WithLossCell(network, loss)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
ms.amp.auto_mixed_precision(net_with_loss, amp_level=args.amp_level)
net_with_loss = TrainOneStepWithEMA(
net_with_loss, optimizer, scale_sense=loss_scale_manager, use_ema=args.use_ema, ema_decay=args.ema_decay
net_with_loss = TrainStep(
net_with_loss,
optimizer,
scale_sense=loss_scale_manager,
use_ema=args.use_ema,
ema_decay=args.ema_decay,
use_clip_grad=args.use_clip_grad,
clip_value=args.clip_value,
)

eval_network = nn.WithEvalCell(network, loss, args.amp_level in ["O2", "O3", "auto"])
model = Model(net_with_loss, eval_network=eval_network, metrics=eval_metrics, eval_indexes=[0, 1, 2])
else:
Expand Down
8 changes: 2 additions & 6 deletions train_with_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mindspore as ms
from mindspore import SummaryRecord, Tensor, nn, ops
from mindspore.amp import StaticLossScaler
from mindspore.communication import get_group_size, get_rank, init
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean

Expand Down Expand Up @@ -199,14 +200,9 @@ def train(args):
checkpoint_path=opt_ckpt_path,
)

from mindspore.amp import DynamicLossScaler, StaticLossScaler

# set loss scale for mixed precision training
if args.amp_level != "O0":
if args.dynamic_loss_scale:
loss_scaler = DynamicLossScaler(args.loss_scale, 2, 1000)
else:
loss_scaler = StaticLossScaler(args.loss_scale)
loss_scaler = StaticLossScaler(args.loss_scale)
else:
loss_scaler = NoLossScaler()

Expand Down
4 changes: 2 additions & 2 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from mindspore import Model

from mindcv.data import create_dataset, create_loader, create_transforms
from mindcv.engine import ValCallback
from mindcv.loss import create_loss
from mindcv.models import create_model
from mindcv.utils.callbacks import ValCallback
from mindcv.utils.utils import check_batch_size
from mindcv.utils import check_batch_size

from config import parse_args # isort: skip

Expand Down