Skip to content

Commit

Permalink
remove UserWarning: masked_fill_
Browse files Browse the repository at this point in the history
  • Loading branch information
jiesutd committed Nov 10, 2020
1 parent ab82b68 commit 105a53a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def batchify_sequence_labeling_with_label(input_batch_list, gpu, if_train=True):
feature_seq_tensors = []
for idx in range(feature_num):
feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long())
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte()
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool()
for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)):
seqlen = seqlen.item()
word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
Expand Down Expand Up @@ -304,7 +304,7 @@ def batchify_sentence_classification_with_label(input_batch_list, gpu, if_train=
feature_seq_tensors = []
for idx in range(feature_num):
feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long())
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte()
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool()
label_seq_tensor = torch.LongTensor(labels)
# exit(0)
for idx, (seq, seqlen) in enumerate(zip(words, word_seq_lengths)):
Expand Down
2 changes: 1 addition & 1 deletion main_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
feature_seq_tensors = []
for idx in range(feature_num):
feature_seq_tensors.append(autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).long())
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).byte()
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).bool()
for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)):
word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
label_seq_tensor[idx, :seqlen] = torch.LongTensor(label)
Expand Down
4 changes: 2 additions & 2 deletions model/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _viterbi_decode(self, feats, mask):
partition_history = list()
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
# mask = 1 + (-1)*mask
mask = (1 - mask.long()).byte()
mask = (1 - mask.long()).bool()
_, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size
Expand Down Expand Up @@ -297,7 +297,7 @@ def _viterbi_decode_nbest(self, feats, mask, nbest):
partition_history = list()
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
# mask = 1 + (-1)*mask
mask = (1 - mask.long()).byte()
mask = (1 - mask.long()).bool()
_, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone() # bat_size * to_target_size
Expand Down

0 comments on commit 105a53a

Please sign in to comment.