From 05a49eaafe7eb7a22701242cca9c6e5ed43f172c Mon Sep 17 00:00:00 2001 From: Felix MIKAELIAN <39884124+fmikaelian@users.noreply.github.com> Date: Fri, 8 Mar 2019 16:00:07 +0100 Subject: [PATCH] NameError: name 'device' is not defined in predict() method #68 (#69) --- cdqa/reader/bertqa_sklearn.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/cdqa/reader/bertqa_sklearn.py b/cdqa/reader/bertqa_sklearn.py index 7f8a12f..3375d53 100644 --- a/cdqa/reader/bertqa_sklearn.py +++ b/cdqa/reader/bertqa_sklearn.py @@ -855,21 +855,21 @@ def __init__(self, self.null_score_diff_threshold = null_score_diff_threshold self.output_dir = output_dir - def fit(self, X, y=None): - - train_examples, train_features = X - if self.local_rank == -1 or self.no_cuda: - device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu") - n_gpu = torch.cuda.device_count() + self.device = torch.device("cuda" if torch.cuda.is_available() and not self.no_cuda else "cpu") + self.n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(self.local_rank) - device = torch.device("cuda", self.local_rank) - n_gpu = 1 + self.device = torch.device("cuda", self.local_rank) + self.n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( - device, n_gpu, bool(self.local_rank != -1), self.fp16)) + self.device, self.n_gpu, bool(self.local_rank != -1), self.fp16)) + + def fit(self, X, y=None): + + train_examples, train_features = X if self.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( @@ -880,7 +880,7 @@ def fit(self, X, y=None): random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) - if n_gpu > 0: + if self.n_gpu > 0: torch.cuda.manual_seed_all(self.seed) if os.path.exists(self.output_dir) and os.listdir(self.output_dir): @@ -899,7 +899,7 @@ def fit(self, X, y=None): if self.fp16: model.half() - model.to(device) + model.to(self.device) if self.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP @@ -907,7 +907,7 @@ def fit(self, X, y=None): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") model = DDP(model) - elif n_gpu > 1: + elif self.n_gpu > 1: model = torch.nn.DataParallel(model) # Prepare optimizer @@ -967,11 +967,11 @@ def fit(self, X, y=None): model.train() for _ in trange(int(self.num_train_epochs), desc="Epoch"): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): - if n_gpu == 1: - batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self + if self.n_gpu == 1: + batch = tuple(t.to(self.device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) - if n_gpu > 1: + if self.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if self.gradient_accumulation_steps > 1: loss = loss / self.gradient_accumulation_steps @@ -1007,7 +1007,7 @@ def fit(self, X, y=None): model = BertForQuestionAnswering(config) model.load_state_dict(torch.load(output_model_file)) - model.to(device) + model.to(self.device) self.model = model return self @@ -1036,9 +1036,9 @@ def predict(self, X): for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) - input_ids = input_ids.to(device) - input_mask = input_mask.to(device) - segment_ids = segment_ids.to(device) + input_ids = input_ids.to(self.device) + input_mask = input_mask.to(self.device) + segment_ids = segment_ids.to(self.device) with torch.no_grad(): batch_start_logits, batch_end_logits = self.model(input_ids, segment_ids, input_mask) for i, example_index in enumerate(example_indices):