Skip to content

Commit

Permalink
better code, more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Oct 26, 2023
1 parent 5dc76e5 commit ef7a6fd
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 70 deletions.
88 changes: 88 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,96 @@ jobs:
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_copy \
-dump_preds /tmp/dump_preds \
-position_encoding \
-copy_attn
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_copy -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and maxrelative
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_relative \
-dump_preds /tmp/dump_preds \
-max_relative_positions 8
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_relative -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and rotary
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_rotary \
-dump_preds /tmp/dump_preds \
-max_relative_positions -1
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics
- name : Test Transformer training and validation with dynamic scoring and alibi
run: |
python3 train.py \
-config data/data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer \
-decoder_type transformer \
-layers 4 \
-word_vec_size 16 \
-hidden_size 16 \
-num_workers 0 -bucket_size 1024 \
-heads 2 \
-transformer_ff 64 \
-num_workers 0 -bucket_size 1024 \
-accum_count 2 4 8 \
-accum_steps 0 15000 30000 \
-save_model /tmp/onmt.model \
-train_steps 10 -valid_steps 5 \
-report_every 2 \
-valid_metrics "BLEU" "TER" \
-tensorboard "true" \
-scoring_debug "true" \
-tensorboard_log_dir /tmp/logs_dynamic-scoring_and_alibi \
-dump_preds /tmp/dump_preds \
-max_relative_positions 8
python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_alibi -tensorboard_checks valid_metrics
- name: Test LM training
run: |
python train.py \
Expand Down
7 changes: 5 additions & 2 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from torch.profiler import profile, record_function, ProfilerActivity
import time


def translate(opt):
Expand Down Expand Up @@ -52,13 +53,15 @@ def main():
parser = _get_parser()
opt = parser.parse_args()
if opt.profile:

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with record_function("Translate"):
translate(opt)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=40))
else:
init_time = time.time()
translate(opt)
print("Time w/o python interpreter load/terminate: ", time.time() - init_time)


if __name__ == "__main__":
Expand Down
14 changes: 6 additions & 8 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,20 +581,18 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs):
{"keys": torch.tensor([]), "values": torch.tensor([])},
)

emb = self.embeddings(tgt, step=step)
dec_out = emb
assert emb.dim() == 3 # len x batch x embedding_dim
dec_out = self.embeddings(tgt, step=step)

pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["src_len"]
src_len = kwargs["src_len"]
src_max_len = self.state["src"].shape[1]
src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen]
src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen]
src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze(
1
) # [B x 1 x slen]
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]

with_align = kwargs.pop("with_align", False)
return_attn = kwargs.pop("return_attn", False)
return_attn = with_align or self._copy or return_attn
return_attn = with_align or self._copy or kwargs.pop("return_attn", False)

attn_aligns = []

Expand Down
2 changes: 1 addition & 1 deletion onmt/encoders/mean_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, src, src_len=None):

if src_len is not None:
# we avoid padding while mean pooling
mask = sequence_mask(src_len).float()
mask = (~sequence_mask(src_len)).float()
mask = mask / src_len.unsqueeze(1).float()
mean = torch.bmm(mask.unsqueeze(1), emb).squeeze(1)
else:
Expand Down
5 changes: 2 additions & 3 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,9 @@ def from_opt(cls, opt, embeddings):
def forward(self, src, src_len=None):
"""See :func:`EncoderBase.forward()`"""
enc_out = self.embeddings(src)
mask = ~sequence_mask(src_len).unsqueeze(1)
mask = mask.unsqueeze(1)
mask = sequence_mask(src_len).unsqueeze(1).unsqueeze(1)
mask = mask.expand(-1, -1, mask.size(3), -1)
# mask is now (batch x 1 x slen x slen)
# Padding mask is now (batch x 1 x slen x slen)
# 1 to be expanded to number of heads in MHA
# Run the forward pass of every layer of the tranformer.

Expand Down
51 changes: 17 additions & 34 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,44 +172,22 @@ def __init__(
self.stride = stride
self.offset = offset

def _tokenize(self, stream):
for example in stream:
def _process(self, stream):
for i, example in enumerate(stream):
example["src"] = example["src"].strip("\n").split()
example["src_original"] = example["src_original"].strip("\n").split()
if "src_feats" in example:
example["src_feats"] = [
feat.strip("\n").split() for feat in example["src_feats"]
]
if example["tgt"] is not None:
example["tgt"] = example["tgt"].strip("\n").split()
example["tgt_original"] = example["tgt_original"].strip("\n").split()
if "align" in example:
example["align"] = example["align"].strip("\n").split()
yield example

def _transform(self, stream):
for example in stream:
# NOTE: moved to dynamic_iterator.py cf process()
# item = self.transform.apply(
# example, is_train=self.infinitely, corpus_name=self.cid)
item = (example, self.transform, self.cid)
if item is not None:
yield item
report_msg = self.transform.stats()
if report_msg != "":
logger.info(
"* Transform statistics for {}({:.2f}%):\n{}\n".format(
self.cid, 100 / self.stride, report_msg
)
)

def _add_index(self, stream):
for i, item in enumerate(stream):
example = item[0]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip("\n").split()
if example["tgt"] is not None:
example["tgt"] = example["tgt"].strip("\n").split()
example["tgt_original"] = example["tgt_original"].strip("\n").split()
if (
len(example["src"]) == 0
or len(example["tgt"]) == 0
Expand All @@ -222,16 +200,21 @@ def _add_index(self, stream):
elif self.skip_empty_level == "warning":
logger.warning(empty_msg)
if len(example["src"]) == 0 and len(example["tgt"]) == 0:
yield item
yield (example, self.transform, self.cid)
continue
yield item
yield (example, self.transform, self.cid)
report_msg = self.transform.stats()
if report_msg != "":
logger.info(
"* Transform statistics for {}({:.2f}%):\n{}\n".format(
self.cid, 100 / self.stride, report_msg
)
)

def __iter__(self):
corpus_stream = self.corpus.load(stride=self.stride, offset=self.offset)
tokenized_corpus = self._tokenize(corpus_stream)
transformed_corpus = self._transform(tokenized_corpus)
indexed_corpus = self._add_index(transformed_corpus)
yield from indexed_corpus
corpus = self._process(corpus_stream)
yield from corpus


def build_corpora_iters(
Expand Down
2 changes: 2 additions & 0 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def load_test_model(opt, device_id=0, model_path=None):
model_opt.attention_dropout = (
0.0 # required to force no dropout at inference with flash
)

model = build_base_model(model_opt, vocabs)

precision = torch.float32
Expand Down Expand Up @@ -162,6 +163,7 @@ def load_test_model(opt, device_id=0, model_path=None):
for name, module in model.named_modules():
if hasattr(module, "dropout_p"):
module.dropout_p = 0.0

return vocabs, model, model_opt


Expand Down
2 changes: 1 addition & 1 deletion onmt/modules/global_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def forward(self, src, enc_out, src_len=None, coverage=None):
align = self.score(src, enc_out)

if src_len is not None:
mask = sequence_mask(src_len, max_len=align.size(-1))
mask = ~sequence_mask(src_len, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(~mask, -float("inf"))

Expand Down
2 changes: 1 addition & 1 deletion onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor:
https://arxiv.org/pdf/1803.02155.pdf
x shape [batch_size x heads x q_len x k_len]
"""
batch_size, heads, length = x.size()
batch_size, heads, length, _ = x.size()
x_t = x.permute(2, 0, 1, 3)
x_t_r = x_t.contiguous().view(length, heads * batch_size, -1)
if transpose:
Expand Down
Loading

0 comments on commit ef7a6fd

Please sign in to comment.