Skip to content

Commit

Permalink
various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Feb 19, 2024
1 parent a02330a commit 6820e2e
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 26 deletions.
67 changes: 56 additions & 11 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,70 @@
import copy
import numpy as np
import time
import re
from onmt.inference_engine import InferenceEnginePY
import onmt.opts as opts
from onmt.utils.logging import init_logger
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed


def wikitext_detokenizer(line):
string = line
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string


def tokenize_dataset(opt, context_length):
print("Tokenization...")
# Clean and Concat the dataset
x = open(opt.src, "r").readlines()
xx = [_x for _x in x if _x != " \n"]
from onmt.transforms.tokenize import SentencePieceTransform
xx = open(opt.src, "r").readlines()
if "sentencepiece" in opt.transforms:
from onmt.transforms.tokenize import SentencePieceTransform

tokenizer = SentencePieceTransform(opt)
elif "onmt_tokenize" in opt.transforms:
from onmt.transforms.tokenize import ONMTTokenizerTransform

tokenizer = SentencePieceTransform(opt)
tokenizer = ONMTTokenizerTransform(opt)
else:
raise ValueError("No valid tokenizer found")
tokenizer.warm_up()
tokens = tokenizer._tokenize(xx)
print("Done !")
print("warmup done")
# joiner = tokenizer._tokenize("\n")
tokens = []
for x in xx:
tokens += tokenizer._tokenize([wikitext_detokenizer(x)])
# tokens += tokenizer._tokenize([x])
print("Tokenization Done !")
return tokens


Expand All @@ -38,7 +84,7 @@ def evaluate(opt):
set_random_seed(opt.seed, use_gpu(opt))

# Tokenize the dataset.
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokens = tokenize_dataset(opt, context_length=512)

# Build the translator (along with the model.
Expand All @@ -47,8 +93,8 @@ def evaluate(opt):
engine = InferenceEnginePY(engine_opt)

# Score the dataset.
stride = 512
max_seq_length = 4096
stride = 256
max_seq_length = 512

seq_len = len(tokens)
src = []
Expand All @@ -75,8 +121,7 @@ def evaluate(opt):
end_time = time.time()
logger.info("total run time %.2f" % (end_time - start_time))
logger.info(
"wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501
% (ppl)
"wikitext-2 perplexity with rolling likelihood: %.2f" % (ppl) # noqa: E501
)


Expand Down
80 changes: 72 additions & 8 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,63 @@ def exfile_open(filename, *args, **kwargs):
_file.close()


class BlockwiseCorpus(object):
"""A corpus class for reading a single file block by block."""

def __init__(self, name, file_path, block_size=4096):
"""Initialize file path and block size."""
self.id = name
self.file_path = file_path
self.block_size = block_size

def load(self, offset=0, stride=1):
"""
Load file and iterate by blocks.
`offset` and `stride` allow iterating only on every
`stride` block, starting from `offset`.
"""

def make_ex(block_content):
example = {
"src": block_content,
"tgt": block_content,
"src_original": block_content,
"tgt_original": block_content,
}
return example

with open(self.file_path, mode="r", encoding="utf-8") as file:
block_content = ""
block_index = 0

while True:
chunk = file.read(self.block_size)
if not chunk:
break

if (block_index // stride) % stride == offset:
block_content += chunk

if len(chunk) < self.block_size:
# Reached end of file
yield make_ex(block_content)
break

if len(block_content) >= self.block_size:
yield make_ex(block_content)
block_content = ""
block_index += 1

def __str__(self):
cls_name = type(self).__name__
return (
f"{cls_name}({self.id}, {self.file_path}, {self.file_path}"
f"align={None}, "
f"n_src_feats={0}, "
f'src_feats_defaults="{None}")'
)


class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

Expand Down Expand Up @@ -117,14 +174,21 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
if corpus_id != CorpusName.VALID:
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
)
if corpus_dict.get("path_txt", None) is None:
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
corpus_id,
corpus_dict["path_txt"],
block_size=8192, # number of characters
)
elif task == CorpusTask.VALID:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down
1 change: 0 additions & 1 deletion onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def numericalize(vocabs, example):
for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]):
numeric_feats.append(fv(feat.split(" ")))
numeric["src"]["feats"] = numeric_feats

return numeric


Expand Down
3 changes: 1 addition & 2 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,7 @@ def model_opts(parser):
group.add(
"--rotary_interleave",
"-rotary_interleave",
type=bool,
default=True,
action="store_true",
help="Interleave the head dimensions when rotary"
" embeddings are applied."
" Otherwise the head dimensions are sliced in half."
Expand Down
13 changes: 9 additions & 4 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def _validate_data(cls, opt):
# Check path
path_src = corpus.get("path_src", None)
path_tgt = corpus.get("path_tgt", None)
if path_src is None:
path_txt = corpus.get("path_txt", None)
if path_src is None and path_txt is None:
raise ValueError(
f"Corpus {cname} src path is required."
f"Corpus {cname} src/txt path is required."
"tgt path is also required for non language"
" modeling tasks."
)
Expand All @@ -57,8 +58,12 @@ def _validate_data(cls, opt):
corpus["path_tgt"] = path_src
corpora[cname] = corpus
path_tgt = path_src
cls._validate_file(path_src, info=f"{cname}/path_src")
cls._validate_file(path_tgt, info=f"{cname}/path_tgt")
if path_src is not None:
cls._validate_file(path_src, info=f"{cname}/path_src")
if path_txt is not None:
cls._validate_file(path_txt, info=f"{cname}/path_txt")
if path_tgt is not None:
cls._validate_file(path_tgt, info=f"{cname}/path_tgt")
path_align = corpus.get("path_align", None)
if path_align is None:
if hasattr(opt, "lambda_align") and opt.lambda_align > 0.0:
Expand Down

0 comments on commit 6820e2e

Please sign in to comment.