Skip to content

Commit

Permalink
fixed multi-gpu in deep feat reg
Browse files Browse the repository at this point in the history
  • Loading branch information
jesus-villalba committed Jun 24, 2020
1 parent 7b638f1 commit 9225c45
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
2 changes: 2 additions & 0 deletions egs/sre19-cmn2/v2/path.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ export PYTHONPATH=$HYP_ROOT:$KERAS_PATH:$PYTHONPATH
export LD_LIBRARY_PATH
export LC_ALL=C

export HDF5_USE_FILE_LOCKING=FALSE

wait_file() {
local file="$1"; shift
local wait_seconds="${2:-30}"; shift # 10 seconds as default timeout
Expand Down
7 changes: 6 additions & 1 deletion hyperion/bin/torch-finetune-xvec-dfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


def train_xvec(data_rspec, train_list, val_list, exp_path, in_model_path,
prior_model_path,
reg_layers_enc, reg_layers_classif,
reg_weight_enc, reg_weight_classif, reg_loss,
epochs, num_gpus, log_interval, resume, num_workers,
Expand Down Expand Up @@ -67,7 +68,10 @@ def train_xvec(data_rspec, train_list, val_list, exp_path, in_model_path,
xvec_args['num_classes'] = train_data.num_classes
model = TML.load(in_model_path)
model.rebuild_output_layer(**xvec_args)
prior_model = model.copy()
if prior_model_path:
prior_model = TML.load(prior_model_path)
else:
prior_model = model.copy()
prior_model.freeze()
prior_model.eval()
if train_mode == 'ft-embed-affine':
Expand Down Expand Up @@ -136,6 +140,7 @@ def train_xvec(data_rspec, train_list, val_list, exp_path, in_model_path,
help='number of epochs')

parser.add_argument('--in-model-path', required=True)
parser.add_argument('--prior-model-path')
XVec.add_argparse_finetune_args(parser)
OF.add_argparse_args(parser, prefix='opt')
LRSF.add_argparse_args(parser, prefix='lrsch')
Expand Down
28 changes: 24 additions & 4 deletions hyperion/torch/trainers/xvector_trainer_deep_feat_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
from ..utils import MetricAcc
from .torch_trainer import TorchTrainer, TorchDataParallel

class DFRModelWrapper(nn.Module):
"""Wrapper class for the xvector model, which
replace the forward method by the forward_hid_feats method
This is need because nn.DataParallel only support multi-gpu when colling the
forward method, but not the other methods in the nn.Module classes.
"""
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, x, y=None, enc_layers=None, classif_layers=None, return_output=False):
return self.model.forward_hid_feats(x, y, enc_layers, classif_layers, return_output)



class XVectorTrainerDeepFeatReg(TorchTrainer):

Expand All @@ -41,13 +56,18 @@ def __init__(self, model, prior_model, optimizer, epochs, exp_path, cur_epoch=0,
self.reg_layers_classif = reg_layers_classif
self.reg_weight_enc = reg_weight_enc
self.reg_weight_classif = reg_weight_classif

self.model_wrapper = DFRModelWrapper(self.model)
self.prior_model_wrapper = DFRModelWrapper(self.prior_model)

if device is not None:
self.prior_model.to(device)
self.model_wrapper.to(device)
self.prior_model_wrapper.to(device)
self.reg_loss.to(device)

if data_parallel:
self.prior_model = TorchDataParallel(self.prior_model)
self.model_wrapper = TorchDataParallel(self.model_wrapper)
self.prior_model_wrapper = TorchDataParallel(self.prior_model_wrapper)
self.reg_loss = TorchDataParallel(self.reg_loss)


Expand All @@ -74,12 +94,12 @@ def train_epoch(self, data_loader):
data, target = data.to(self.device), target.to(self.device)
batch_size = data.shape[0]

h_enc, h_classif, output = self.model.forward_hid_feats(
h_enc, h_classif, output = self.model_wrapper(
data, target, self.reg_layers_enc, self.reg_layers_classif, return_output=True)
loss = self.loss(output, target).mean() # you need to take the mean here because of the multi-gpu training
batch_metrics['loss-classif'] = loss.item()

prior_h_enc, prior_h_classif = self.prior_model.forward_hid_feats(
prior_h_enc, prior_h_classif = self.prior_model_wrapper(
data, target, self.reg_layers_enc, self.reg_layers_classif, return_output=False)

n_enc = len(h_enc)
Expand Down

0 comments on commit 9225c45

Please sign in to comment.