Skip to content

Commit

Permalink
implemented methods to send reader to GPU or CPU inside QAPipeline (#143
Browse files Browse the repository at this point in the history
)
  • Loading branch information
andrelmfarias authored and fmikaelian committed May 21, 2019
1 parent 3b3ed13 commit 923091e
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions cdqa/pipeline/cdqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,30 @@ def predict(self, X=None):
else:
raise TypeError("The input is not a string or a list. \
Please provide a string or a list of strings as input")

def to(self, device):
''' Send reader to CPU if device=='cpu' or to GPU if device=='cuda'
'''
if device not in ('cpu', 'cuda'):
raise ValueError("Attribure device should be 'cpu' or 'cuda'.")

self.reader.model.to(device)
return self

def cpu(self):
''' Send reader to CPU
'''
self.reader.model.cpu()
return self

def cuda(self):
def cpu(self):
''' Send reader to GPU
'''
self.reader.model.cuda()
return self





0 comments on commit 923091e

Please sign in to comment.