Skip to content

Commit

Permalink
New indexing (OpenNMT#2496)
Browse files Browse the repository at this point in the history
* New indexing of data
* better code
  • Loading branch information
vince62s authored Oct 26, 2023
1 parent aa06c4c commit be13d12
Show file tree
Hide file tree
Showing 23 changed files with 469 additions and 208 deletions.
88 changes: 88 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,96 @@ jobs:
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_copy \
-dump_preds /tmp/dump_preds \
-position_encoding \
-copy_attn
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_copy -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and maxrelative
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_relative \
-dump_preds /tmp/dump_preds \
-max_relative_positions 8
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_relative -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and rotary
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_rotary \
-dump_preds /tmp/dump_preds \
-max_relative_positions -1
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and alibi
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_alibi \
-dump_preds /tmp/dump_preds \
-max_relative_positions 8
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_alibi -tensorboard_checks valid_metrics
- name: Test LM training
run: |
python train.py \
Expand Down
14 changes: 7 additions & 7 deletions data/data_lm/gen-nucleus-sampling-sol2.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
like your equipment !
in your Presidency or registering erratic full detailed table since Waddington , as well handled separately .
the importance 's future .
there is survivors , under new public beforehand ?
bankers sit on the honour and the old , accusations by Cubase for political leadership in the fifth generation of 0.5 , 1 January in serving the old , 1 .
can do planet what Mr Titley , should like What is really do so done guarantee - ™ your Gore .
we can restore confidence in Switzerland .
well ?
liberty to decide in favor of the Stability tattoo .
a moment of the top of the importance of the importance 's future .
I think that received negative perception that it .
heard from India majority , Romagna .
We move should do any experiences .
our website , you invite shop in the year .
14 changes: 7 additions & 7 deletions data/data_lm/gen-sampling-beams-sol2.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
you ! your staff !
in German Presidency or rather than Andalusia .
a moment , one century .
" s operational from 06: 00 on weekdays , clothes hang thick from the and 00 until about 00 until the sea .
fine words has underlined the fine vegetarian ...
service are buying any of my property .
we are much running behind the facts .
you ! Next to your luck .
inspired by the absolut well-beeing-feeling of the Tauern Spa and ...
the top of the top of the top of the moment , health of the importance of capital , the world .
" Austrian " sold with " Delta , received from twenty years .
administered by the usefulness of the interinstitutional coherence of the recorded by the present in the young majority of renewable energies .
do so would like this subsidy is requested .
800 m2 can 't be seen on payments .
7 changes: 5 additions & 2 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from torch.profiler import profile, record_function, ProfilerActivity
import time


def translate(opt):
Expand Down Expand Up @@ -52,13 +53,15 @@ def main():
parser = _get_parser()
opt = parser.parse_args()
if opt.profile:

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with record_function("Translate"):
translate(opt)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=40))
else:
init_time = time.time()
translate(opt)
print("Time w/o python interpreter load/terminate: ", time.time() - init_time)


if __name__ == "__main__":
Expand Down
14 changes: 6 additions & 8 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,20 +581,18 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs):
{"keys": torch.tensor([]), "values": torch.tensor([])},
)

emb = self.embeddings(tgt, step=step)
dec_out = emb
assert emb.dim() == 3 # len x batch x embedding_dim
dec_out = self.embeddings(tgt, step=step)

pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["src_len"]
src_len = kwargs["src_len"]
src_max_len = self.state["src"].shape[1]
src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen]
src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen]
src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze(
1
) # [B x 1 x slen]
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]

with_align = kwargs.pop("with_align", False)
return_attn = kwargs.pop("return_attn", False)
return_attn = with_align or self._copy or return_attn
return_attn = with_align or self._copy or kwargs.pop("return_attn", False)

attn_aligns = []

Expand Down
2 changes: 1 addition & 1 deletion onmt/encoders/mean_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, src, src_len=None):

if src_len is not None:
# we avoid padding while mean pooling
mask = sequence_mask(src_len).float()
mask = (~sequence_mask(src_len)).float()
mask = mask / src_len.unsqueeze(1).float()
mean = torch.bmm(mask.unsqueeze(1), emb).squeeze(1)
else:
Expand Down
5 changes: 2 additions & 3 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,9 @@ def from_opt(cls, opt, embeddings):
def forward(self, src, src_len=None):
"""See :func:`EncoderBase.forward()`"""
enc_out = self.embeddings(src)
mask = ~sequence_mask(src_len).unsqueeze(1)
mask = mask.unsqueeze(1)
mask = sequence_mask(src_len).unsqueeze(1).unsqueeze(1)
mask = mask.expand(-1, -1, mask.size(3), -1)
# mask is now (batch x 1 x slen x slen)
# Padding mask is now (batch x 1 x slen x slen)
# 1 to be expanded to number of heads in MHA
# Run the forward pass of every layer of the tranformer.

Expand Down
2 changes: 1 addition & 1 deletion onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def translate_batch(self, batch, opt):
def _translate(self, infer_iter):
scores = []
preds = []
for batch in infer_iter:
for batch, bucket_idx in infer_iter:
_scores, _preds = self.translate_batch(batch, self.opt)
scores += _scores
preds += _preds
Expand Down
41 changes: 25 additions & 16 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level
self.random_shuffler = RandomShuffler()
self.bucket_idx = 0

@classmethod
def from_opt(
Expand Down Expand Up @@ -247,6 +248,14 @@ def _tuple_to_json_with_tokIDs(self, tuple_bucket):
bucket.append(numericalize(self.vocabs, example))
return bucket

def _add_indice(self, bucket):
indice = 0
indexed_bucket = []
for ex in bucket:
indexed_bucket.append((ex, indice))
indice += 1
return indexed_bucket

def _bucketing(self):
"""
Add up to bucket_size examples from the mixed corpora according
Expand All @@ -258,17 +267,19 @@ def _bucketing(self):
_bucket_size = self.bucket_size_init
else:
_bucket_size = self.bucket_size

for ex in self.mixer:
bucket.append(ex)
if len(bucket) == _bucket_size:
yield self._tuple_to_json_with_tokIDs(bucket)
yield (self._tuple_to_json_with_tokIDs(bucket), self.bucket_idx)
self.bucket_idx += 1
bucket = []
if _bucket_size < self.bucket_size:
_bucket_size += self.bucket_size_increment
else:
_bucket_size = self.bucket_size
if bucket:
yield self._tuple_to_json_with_tokIDs(bucket)
yield (self._tuple_to_json_with_tokIDs(bucket), self.bucket_idx)

def batch_iter(self, data, batch_size, batch_type="sents", batch_size_multiple=1):
"""Yield elements from data in chunks of batch_size,
Expand All @@ -290,11 +301,11 @@ def max_src_tgt(ex):
return len(ex["src"]["src_ids"])

minibatch, maxlen, size_so_far, seen = [], 0, 0, set()
for ex in data:
for ex, indice in data:
src = ex["src"]["src"]
if src not in seen or (self.task != CorpusTask.TRAIN):
seen.add(src)
minibatch.append(ex)
minibatch.append((ex, indice))
nbsents = len(minibatch)
maxlen = max(max_src_tgt(ex), maxlen)
size_so_far = batch_size_fn(nbsents, maxlen)
Expand All @@ -315,19 +326,17 @@ def max_src_tgt(ex):
else:
yield minibatch[:-overflowed]
minibatch = minibatch[-overflowed:]
maxlen = max([max_src_tgt(ex) for ex in minibatch])
maxlen = max([max_src_tgt(ex) for ex, ind in minibatch])
size_so_far = batch_size_fn(len(minibatch), maxlen)
seen = set()

if minibatch:
yield minibatch

def __iter__(self):
for bucket in self._bucketing():
# For TRAIN we need to group examples by length
# for faster performance, but otherwise, sequential.
if self.task == CorpusTask.TRAIN:
bucket = sorted(bucket, key=self.sort_key)
for bucket, bucket_idx in self._bucketing():
bucket = self._add_indice(bucket)
bucket = sorted(bucket, key=lambda x: self.sort_key(x[0]))
p_batch = list(
self.batch_iter(
bucket,
Expand All @@ -340,13 +349,13 @@ def __iter__(self):
# otherwise sequential
if self.task == CorpusTask.TRAIN:
p_batch = self.random_shuffler(p_batch)
for minibatch in p_batch:
for i, minibatch in enumerate(p_batch):
# for specific case of rnn_packed need to be sorted
# within the batch
if self.task == CorpusTask.TRAIN:
minibatch.sort(key=self.sort_key, reverse=True)
minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True)
tensor_batch = tensorify(self.vocabs, minibatch, self.device)
yield tensor_batch
yield (tensor_batch, bucket_idx)


class OnDeviceDatasetIter:
Expand All @@ -355,11 +364,11 @@ def __init__(self, data_iter, device):
self.device = device

def __iter__(self):
for tensor_batch in self.data_iter:
for (tensor_batch, bucket_idx) in self.data_iter:
for key in tensor_batch.keys():
if key != "src_ex_vocab":
if key not in ["src_ex_vocab", "cid"]:
tensor_batch[key] = tensor_batch[key].to(self.device)
yield tensor_batch
yield (tensor_batch, bucket_idx)


def build_dynamic_dataset_iter(
Expand Down
Loading

0 comments on commit be13d12

Please sign in to comment.