Skip to content

Commit

Permalink
Updated BertQA to enable multiple trainings and handled some errors (#…
Browse files Browse the repository at this point in the history
…130)

* modified BertQA class to enable multiple calls to fit()

* cerrected typo

* Deleted tokenizer saving inside BertQA.fit

* handled problem with self.output_dir
  • Loading branch information
andrelmfarias authored and fmikaelian committed May 14, 2019
1 parent 0dce89f commit 29da9b6
Showing 1 changed file with 16 additions and 31 deletions.
47 changes: 16 additions & 31 deletions cdqa/reader/bertqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,6 @@ class BertQA(BaseEstimator):
Bert pre-trained model selected in the list: bert-base-uncased,
bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,
bert-base-multilingual-cased, bert-base-chinese.
custom_weights : bool, optional
[description] (the default is False)
train_batch_size : int, optional
Total batch size for training. (the default is 32)
predict_batch_size : int, optional
Expand Down Expand Up @@ -947,7 +945,6 @@ class BertQA(BaseEstimator):

def __init__(self,
bert_model='bert-base-uncased',
custom_weights=False,
train_batch_size=32,
predict_batch_size=8,
learning_rate=5e-5,
Expand All @@ -970,7 +967,6 @@ def __init__(self,
server_port=''):

self.bert_model = bert_model
self.custom_weights = custom_weights
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.learning_rate = learning_rate
Expand All @@ -992,6 +988,10 @@ def __init__(self,
self.server_ip = server_ip
self.server_port = server_port

# Prepare model
self.model = BertForQuestionAnswering.from_pretrained(self.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(self.local_rank)))

if self.server_ip and self.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
Expand Down Expand Up @@ -1033,8 +1033,6 @@ def fit(self, X, y=None):
if self.n_gpu > 0:
torch.cuda.manual_seed_all(self.seed)

if os.path.exists(self.output_dir) and os.listdir(self.output_dir):
raise ValueError("Output directory () already exists and is not empty.")
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)

Expand All @@ -1043,26 +1041,22 @@ def fit(self, X, y=None):
if self.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

# Prepare model
model = BertForQuestionAnswering.from_pretrained(self.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(self.local_rank)))

if self.fp16:
model.half()
model.to(self.device)
self.model.half()
self.model.to(self.device)
if self.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

model = DDP(model)
self.model = DDP(self.model)
elif self.n_gpu > 1:
model = torch.nn.DataParallel(model)
self.model = torch.nn.DataParallel(self.model)

# Prepare optimizer
param_optimizer = list(model.named_parameters())
param_optimizer = list(self.model.named_parameters())

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
Expand Down Expand Up @@ -1123,14 +1117,15 @@ def fit(self, X, y=None):
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=self.train_batch_size)

model.train()
self.model.train()
for _ in trange(int(self.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=self.local_rank not in [-1, 0])):
if self.n_gpu == 1:
batch = tuple(t.to(self.device)
for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
loss = self.model(input_ids, segment_ids, input_mask,
start_positions, end_positions)
if self.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if self.gradient_accumulation_steps > 1:
Expand All @@ -1152,28 +1147,18 @@ def fit(self, X, y=None):
optimizer.zero_grad()
global_step += 1

# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(
model, 'module') else model # Only save the model it-self
# Save a trained model and configuration
model_to_save = self.model.module if hasattr(
self.model, 'module') else self.model # Only save the model it-self

# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(self.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(self.output_dir, CONFIG_NAME)

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(self.output_dir)

if self.custom_weights:
model = BertForQuestionAnswering.from_pretrained(self.bert_model)
else:
# Load a trained model and vocabulary that you have fine-tuned
model = BertForQuestionAnswering.from_pretrained(self.output_dir)
tokenizer = BertTokenizer.from_pretrained(
self.output_dir, do_lower_case=self.do_lower_case)

model.to(self.device)
self.model = model
self.model.to(self.device)

return self

Expand Down

0 comments on commit 29da9b6

Please sign in to comment.