This is a PyTorch-based implementation of GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks, which is a gradient normalization algorithm that automatically balances training in deep multitask models by dynamically tuning gradient magnitudes.
The toy example can be found at here.
- net: a multitask network with task loss
- layer: layers of the network layers where applying GradNorm on the weights
- alpha: hyperparameter of restoring force
- dataloader: training dataloader
- num_epochs: number of epochs
- lr1: learning rate of multitask loss
- lr2: learning rate of weights
- log: flag of result log
from gradnorm import gradNorm
log_weights, log_loss = gradNorm(net=mtlnet, layer=net.fc4, alpha=0.12, dataloader=dataloader,
num_epochs=100, lr1=1e-5, lr2=1e-4, log=False)
Consider
Inputs are dimension 250 and outputs dimension 100, while
from data import toyDataset
dataset = toyDataset(num_data=10000, dim_features=250, dim_labels=100, scalars=[1,100])
A 4-layer fully-connected ReLU-activated network with 100 neurons per layer as a common trunk is used to train our toy example. A final affine transformation layer gives T final predictions.
from model import fcNet, mtlNet
net = fcNet(dim_features=250, dim_labels=100, n_tasks=2) # fc net with multiple heads
mtlnet = mtlNet(net) # multitask net with task loss