Skip to content

Commit

Permalink
fixed CUDA/cpu device issues; commented some unnecessary and buggy lines
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpra committed Jul 8, 2020
1 parent f3e8920 commit c795b88
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def work(self, data, beam_size, max_time_step, min_time_step=1):
init_state_dict = {}
init_hyp = Hypothesis(init_state_dict, [DUM], 0.)
bsz = word_repr.size(1)
beams = [ Beam(beam_size, min_time_step, max_time_step, [init_hyp]) for i in range(bsz)]
beams = [ Beam(beam_size, min_time_step, max_time_step, [init_hyp], device=self.device) for i in range(bsz)]
search_by_batch(self, beams, mem_dict)
return beams

Expand Down
9 changes: 6 additions & 3 deletions parser/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def __len__(self):

class Beam(object):
"""each beam for a test instance"""
def __init__(self, beam_size, min_time_step, max_time_step, hypotheses):
def __init__(self, beam_size, min_time_step, max_time_step, hypotheses, device=torch.device('cpu')):
self.beam_size = beam_size
self.min_time_step = min_time_step
self.max_time_step = max_time_step
self.completed_hypotheses = []
self.steps = 0
self.hypotheses = hypotheses # hypotheses are the collection of *alive* hypotheses only
self.device = device

def merge_score(self, prev_hyp, step):
# step has two attributes: token and score
Expand Down Expand Up @@ -81,7 +82,7 @@ def update(self, new_states, next_steps):

# collect new states for selected top candidates
_split_state = dict() # key => list of length live_nyp_num (number of selected top candidates)
_prev_hyp_idx = torch.tensor([ x[0] for x in candidates]).cuda()
_prev_hyp_idx = torch.tensor([ x[0] for x in candidates]).to(self.device) # cuda()
for k, v in new_states.items():
split_dim = 1 if len(v.size()) >= 3 else 0
_split_state[k] = v.index_select(split_dim, _prev_hyp_idx).split(1, dim=split_dim)
Expand Down Expand Up @@ -149,6 +150,8 @@ def ready_to_submit(hypotheses):
concat_hyps[k] = torch.cat(v, 0)
return concat_hyps, inp

device = beams[0].device

while True:
# collect incomplete beams and put all hypotheses together
hypotheses = []
Expand All @@ -167,7 +170,7 @@ def ready_to_submit(hypotheses):

# collect mem_dict
cur_mem_dict = dict()
indices = torch.tensor(indices).cuda()
indices = torch.tensor(indices).to(device) # cuda()
for k, v in mem_dict.items():
if isinstance(v, list):
cur_mem_dict[k] = [v[i] for i in indices]
Expand Down
9 changes: 6 additions & 3 deletions parser/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def get_mask(size):
def forward(self, size):
if self.weights is None or size > self.weights.size(0):
self.weights = SelfAttentionMask.get_mask(size)
res = self.weights[:size,:size].cuda(self.device).detach()
# res = self.weights[:size,:size].cuda(self.device).detach()
res = self.weights[:size,:size].to(self.device).detach()
return res

class LearnedPositionalEmbedding(nn.Module):
Expand All @@ -251,7 +252,8 @@ def reset_parameters(self):
def forward(self, input, offset=0):
"""Input is expected to be of size [seq_len x bsz]."""
seq_len, bsz = input.size()
positions = (offset + torch.arange(seq_len)).cuda(self.device)
# positions = (offset + torch.arange(seq_len)).cuda(self.device)
positions = (offset + torch.arange(seq_len)).to(self.device)
res = self.weights(positions).unsqueeze(1)
return res

Expand Down Expand Up @@ -295,5 +297,6 @@ def forward(self, input, offset=0):
)

positions = offset + torch.arange(seq_len)
res = self.weights.index_select(0, positions).unsqueeze(1).cuda(self.device).detach()
# res = self.weights.index_select(0, positions).unsqueeze(1).cuda(self.device).detach()
res = self.weights.index_select(0, positions).unsqueeze(1).to(self.device).detach()
return res
27 changes: 14 additions & 13 deletions parser/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def parse_data(model, pp, data, input_file, output_file, beam_size=8, alpha=0.6,
match(output_file, input_file)
print ('write down %d amrs'%tot)

def load_ckpt_without_bert(model, test_model):
ckpt = torch.load(test_model)['model']
def load_ckpt_without_bert(model, test_model, device=torch.device('cpu')):
ckpt = torch.load(test_model, map_location=device)['model']
for k, v in model.state_dict().items():
if k.startswith('bert_encoder'):
ckpt[k] = v
Expand All @@ -85,16 +85,21 @@ def load_ckpt_without_bert(model, test_model):

args = parse_config()

if torch.cuda.is_available():
device = torch.device('cuda', args.device)
else:
device = torch.device('cpu')

test_models = []
if os.path.isdir(args.load_path):
for file in os.listdir(args.load_path):
fname = os.path.join(args.load_path, file)
if os.path.isfile(fname):
test_models.append(fname)
model_args = torch.load(fname)['args']
model_args = torch.load(fname, map_location=device)['args']
else:
test_models.append(args.load_path)
model_args = torch.load(args.load_path)['args']
model_args = torch.load(args.load_path, map_location=device)['args']

vocabs = dict()

Expand All @@ -116,11 +121,7 @@ def load_ckpt_without_bert(model, test_model):
bert_encoder = BertEncoder.from_pretrained(model_args.bert_path)
vocabs['bert_tokenizer'] = bert_tokenizer

if args.device < 0:
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)


model = Parser(vocabs,
model_args.word_char_dim, model_args.word_dim, model_args.pos_dim, model_args.ner_dim,
model_args.concept_char_dim, model_args.concept_dim,
Expand All @@ -133,11 +134,11 @@ def load_ckpt_without_bert(model, test_model):
another_test_data = DataLoader(vocabs, lexical_mapping, args.test_data, args.test_batch_size, for_train=False)
for test_model in test_models:
print (test_model)
batch = int(re.search(r'batch([0-9])+', test_model)[0][5:])
epoch = int(re.search(r'epoch([0-9])+', test_model)[0][5:])
# batch = int(re.search(r'batch([0-9])+', test_model)[0][5:])
# epoch = int(re.search(r'epoch([0-9])+', test_model)[0][5:])

load_ckpt_without_bert(model, test_model)
model = model.cuda()
load_ckpt_without_bert(model, test_model, device=device)
model = model.to(device) # cuda()
model.eval()

#loss = show_progress(model, test_data)
Expand Down

0 comments on commit c795b88

Please sign in to comment.