From 82cd43a5704bfbbfc38bfefa8a8ec4c8532b2493 Mon Sep 17 00:00:00 2001 From: Tommi Nieminen Date: Sat, 12 Oct 2024 21:42:29 +0300 Subject: [PATCH] more rat work for nodalida --- Snakefile | 2 +- data.smk | 26 ++- eval.smk | 22 ++- pipeline/data/clean-tcdev.py | 79 +++++++++ pipeline/eval/score-domeval.py | 155 ++++++++++++------ .../{eval-domains.sh => translate-domeval.sh} | 2 +- pipeline/rat/augment.py | 130 +++++++++++++++ pipeline/rat/build_index.sh | 6 +- pipeline/rat/find_matches.sh | 4 +- pipeline/rat/get_matches.py | 43 ++++- pipeline/train/ensemble.py | 44 +++++ profiles/slurm-lumi/config.yaml | 4 +- rat.smk | 96 +++++++---- train.smk | 34 +++- 14 files changed, 522 insertions(+), 125 deletions(-) create mode 100644 pipeline/data/clean-tcdev.py rename pipeline/eval/{eval-domains.sh => translate-domeval.sh} (98%) create mode 100644 pipeline/rat/augment.py create mode 100644 pipeline/train/ensemble.py diff --git a/Snakefile b/Snakefile index 0fbca7bf6..bad0d2ab2 100644 --- a/Snakefile +++ b/Snakefile @@ -47,7 +47,7 @@ use rule * from rat as * vocab_config = { "spm-train": f"{marian_dir}/spm_train", - "user-defined-symbols":"FUZZY_BREAK", + "user-defined-symbols": ",".join(["FUZZY_BREAK","SRC_FUZZY_BREAK"] + [f"FUZZY_BREAK_{bucket}" for bucket in range(0,10)]), "spm-sample-size": 1000000, "spm-character-coverage": 1.0 } diff --git a/data.smk b/data.smk index b30f757b4..e257dc842 100644 --- a/data.smk +++ b/data.smk @@ -65,15 +65,15 @@ rule baseline_preprocessing: input: train_source="{project_name}/{src}-{trg}/{preprocessing}/train.{src}.gz", train_target="{project_name}/{src}-{trg}/{preprocessing}/train.{trg}.gz", - dev_source="{project_name}/{src}-{trg}/{preprocessing}/dev.{src}.gz", - dev_target="{project_name}/{src}-{trg}/{preprocessing}/dev.{trg}.gz", + dev_source="{project_name}/{src}-{trg}/{preprocessing}/cleandev.{src}.gz", + dev_target="{project_name}/{src}-{trg}/{preprocessing}/cleandev.{trg}.gz", eval_source="{project_name}/{src}-{trg}/{preprocessing}/eval.{src}.gz", eval_target="{project_name}/{src}-{trg}/{preprocessing}/eval.{trg}.gz" output: - train_source="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train.{src}.gz", - train_target="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train.{trg}.gz", - dev_source="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/dev.{src}.gz", - dev_target="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/dev.{trg}.gz", + train_source="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train-train.{src}.gz", + train_target="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train-train.{trg}.gz", + dev_source="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train-cleandev.{src}.gz", + dev_target="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/train-cleandev.{trg}.gz", eval_source="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/eval.{src}.gz", eval_target="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/eval.{trg}.gz" params: @@ -81,7 +81,10 @@ rule baseline_preprocessing: output_dir="{project_name}/{src}-{trg}/{preprocessing}/baseline_preprocessing_{max_dev_sents}/" shell: """ - ln {params.input_dir}/{{eval,train}}.*.gz {params.output_dir} >> {log} 2>&1 && \ + ln {input.train_source} {output.train_source} >> {log} 2>&1 && \ + ln {input.train_target} {output.train_target} >> {log} 2>&1 && \ + ln {input.eval_source} {output.eval_source} >> {log} 2>&1 && \ + ln {input.eval_target} {output.eval_target} >> {log} 2>&1 && \ {{ pigz -dc {input.dev_source} | head -n {wildcards.max_dev_sents} | pigz -c > {output.dev_source} ; }} 2>> {log} && \ {{ pigz -dc {input.dev_target} | head -n {wildcards.max_dev_sents} | pigz -c > {output.dev_target} ; }} 2>> {log} """ @@ -97,6 +100,8 @@ rule subset_corpus: input: train_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/train.{src}.gz", train_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/train.{trg}.gz", + dev_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/dev.{src}.gz", + dev_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/dev.{trg}.gz", domeval_src="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/domeval.{src}.gz", domeval_trg="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/domeval.{trg}.gz", domeval_ids="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/domeval.ids.gz" @@ -105,6 +110,8 @@ rule subset_corpus: train_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/train.{trg}.gz", dev_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/dev.{src}.gz", dev_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/dev.{trg}.gz", + cleandev_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/cleandev.{src}.gz", + cleandev_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/cleandev.{trg}.gz", eval_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/eval.{src}.gz", eval_target="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/eval.{trg}.gz", all_filtered_source="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/all_filtered.{src}.gz", @@ -114,7 +121,7 @@ rule subset_corpus: domeval_ids="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/domeval.ids.gz" params: input_dir="{project_name}/{src}-{trg}/{download_tc_dir}/", - output_dir="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/" + output_dir="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/subset_{max_train_sents}/", shell: """ ln {params.input_dir}/{{eval,dev}}.*.gz {params.output_dir} >> {log} 2>&1 && \ @@ -124,7 +131,8 @@ rule subset_corpus: ln {input.domeval_trg} {output.domeval_trg} >> {log} 2>&1 && \ ln {input.domeval_ids} {output.domeval_ids} >> {log} 2>&1 && \ {{ pigz -dc {input.train_source} | head -n {wildcards.max_train_sents}B | pigz -c > {output.train_source} ; }} 2>> {log} && \ - {{ pigz -dc {input.train_target} | head -n {wildcards.max_train_sents}B | pigz -c > {output.train_target} ; }} 2>> {log} + {{ pigz -dc {input.train_target} | head -n {wildcards.max_train_sents}B | pigz -c > {output.train_target} ; }} 2>> {log} && \ + python pipeline/data/clean-tcdev.py --source {input.dev_source} --target {input.dev_target} --prefix {params.output_dir}clean 2>> {log} """ rule use_custom_corpus: diff --git a/eval.smk b/eval.smk index 5d71a36a6..c1d84ff5e 100644 --- a/eval.smk +++ b/eval.smk @@ -52,6 +52,7 @@ checkpoint translate_domeval: log: "{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/{train_vocab}/{train_model}/eval/translate_domeval.log" conda: None container: None + resources: gpu=gpus_num envmodules: "LUMI/22.08", "partition/G", @@ -59,19 +60,18 @@ checkpoint translate_domeval: threads: 1 priority: 50 wildcard_constraints: - min_score="0\.\d+", - model="[\w-]+" + min_score="0\.\d+" input: decoder=ancient(config["marian-decoder"]), - domain_index_src=lambda wildcards: expand("{{project_name}}/{{src}}-{{trg}}/{{download_tc_dir}}/extract_tc_scored_{{min_score}}/{{preprocessing}}/{domain}-domeval.{{src}}.gz", domain=find_domain_sets(wildcards, checkpoints.extract_tc_scored)), - train_index_src="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/train-domeval.{src}.gz", - all_filtered_index_src="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/all_filtered-domeval.{src}.gz" + domain_src=lambda wildcards: expand("{{project_name}}/{{src}}-{{trg}}/{{download_tc_dir}}/extract_tc_scored_{{min_score}}/{{preprocessing}}/{domain}-domeval.{{src}}.gz", domain=find_domain_sets(wildcards, checkpoints.extract_tc_scored)), + train_src="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/train-domeval.{src}.gz", + all_filtered_src="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/all_filtered-domeval.{src}.gz", + decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{download_tc_dir}}/extract_tc_scored_{{min_score}}/{{preprocessing}}/{{train_vocab}}/{{train_model}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml' output: output_dir=directory("{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/{train_vocab}/{train_model}/eval/domeval") params: - domain_index_src_dir="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}", - decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{download_tc_dir}}/extract_tc_scored_{{min_score}}/{{preprocessing}}/{{train_vocab}}/{{train_model}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml' - shell: '''pipeline/eval/eval-domains.sh {params.domain_index_src_dir} {output.output_dir} {src} {trg} {input.decoder} params.decoder_config --mini-batch 128 --workspace 20000 >> {log} 2>&1''' + domain_index_src_dir="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}" + shell: '''pipeline/eval/translate-domeval.sh {params.domain_index_src_dir} {output.output_dir} {wildcards.src} {wildcards.trg} {input.decoder} {input.decoder_config} --mini-batch 128 --workspace 20000 >> {log} 2>&1''' # This evaluates the translations generated with translate_domeval rule eval_domeval: @@ -89,7 +89,11 @@ rule eval_domeval: output: report('{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/{train_vocab}/{train_model}/eval/domeval.done', category='evaluation', subcategory='{model}', caption='reports/evaluation.rst') - shell: '''touch {output} >> {log} 2>&1''' + params: + input_dir="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/{train_vocab}/{train_model}/eval/domeval", + domeval_ids="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/domeval.ids.gz", + system_id="{project_name}/{src}-{trg}/{download_tc_dir}/extract_tc_scored_{min_score}/{preprocessing}/{train_vocab}/{train_model}" + shell: '''python pipeline/eval/score-domeval.py --input_dir {params.input_dir} --report {output} --src_lang {wildcards.src} --trg_lang {wildcards.trg} --system_id {params.system_id} --domeval_ids {params.domeval_ids} >> {log} 2>&1''' rule evaluate: message: "Evaluating a model" diff --git a/pipeline/data/clean-tcdev.py b/pipeline/data/clean-tcdev.py new file mode 100644 index 000000000..75abe2691 --- /dev/null +++ b/pipeline/data/clean-tcdev.py @@ -0,0 +1,79 @@ +import argparse +import gzip +import random +import os +import re + +# Define sentence-ending punctuation +SENTENCE_ENDINGS = re.compile(r'[.!?]') + +def is_valid_line(source_line, target_line, seen_lines): + """Check if the source line is valid based on conditions: + - Source line must be longer than 5 words. + - Source line must not have occurred before. + - Source and target lines must have at most one sentence-ending punctuation. + """ + # Check if the source line is longer than 5 words + if len(source_line.split()) <= 5: + return False + + # Check if the source line has occurred before + if source_line in seen_lines: + return False + + # Check if there is more than one sentence-ending punctuation in both lines + if len(SENTENCE_ENDINGS.findall(source_line)) > 1 and len(SENTENCE_ENDINGS.findall(target_line)) > 1: + return False + + # Add the source line to the set of seen lines + seen_lines.add(source_line) + return True + +def process_files(source_path, target_path, prefix): + """Process the gzipped source and target files, shuffle them, filter, sort by length, and write results to gzipped files.""" + seen_lines = set() + + # Read source and target files into memory, aligned by line + with gzip.open(source_path, 'rt', encoding='utf-8') as src_file, \ + gzip.open(target_path, 'rt', encoding='utf-8') as tgt_file: + lines = [(src_line.strip(), tgt_line.strip()) for src_line, tgt_line in zip(src_file, tgt_file)] + + # Shuffle the lines + random.shuffle(lines) + + # Filter the lines based on conditions and store valid ones + filtered_lines = [ + (src_line, tgt_line) for src_line, tgt_line in lines + if is_valid_line(src_line, tgt_line, seen_lines) + ] + + # Sort the filtered lines by the length of the source line (longest first) + sorted_lines = sorted(filtered_lines, key=lambda x: len(x[0]), reverse=True) + + # Generate output file paths by prefixing the file names + source_output_path = prefix + os.path.basename(source_path) + target_output_path = prefix + os.path.basename(target_path) + + # Write filtered and sorted lines to output files + with gzip.open(source_output_path, 'wt', encoding='utf-8') as src_out_file, \ + gzip.open(target_output_path, 'wt', encoding='utf-8') as tgt_out_file: + + for src_line, tgt_line in sorted_lines: + src_out_file.write(src_line + '\n') + tgt_out_file.write(tgt_line + '\n') + +def main(): + # Argument parsing + parser = argparse.ArgumentParser(description='Process gzipped files, shuffle lines, filter, sort, and output them.') + parser.add_argument('--source', required=True, help='Path to the gzipped source file') + parser.add_argument('--target', required=True, help='Path to the gzipped target file') + parser.add_argument('--prefix', required=True, help='Prefix to be added to output file names') + + args = parser.parse_args() + + # Process the input files and generate output + process_files(args.source, args.target, args.prefix) + +if __name__ == "__main__": + main() + diff --git a/pipeline/eval/score-domeval.py b/pipeline/eval/score-domeval.py index 61bcb2e2e..525a6196d 100644 --- a/pipeline/eval/score-domeval.py +++ b/pipeline/eval/score-domeval.py @@ -2,15 +2,17 @@ import argparse import sacrebleu import csv +import gzip def parse_args(): parser = argparse.ArgumentParser(description="Generate BLEU and chrF scores for domain-specific translations.") - parser.add_argument("input_dir", help="Input directory containing the test files.") - parser.add_argument("report_file", help="Output report file.") - parser.add_argument("src_lang", help="Three-letter source language code.") - parser.add_argument("trg_lang", help="Three-letter target language code.") - parser.add_argument("domeval_ids", help="Path to TSV file containing domain evaluation IDs (with domain names).") + parser.add_argument("--input_dir", help="Input directory containing the test files.") + parser.add_argument("--report_file", help="Output report file.") + parser.add_argument("--src_lang", help="Three-letter source language code.") + parser.add_argument("--trg_lang", help="Three-letter target language code.") + parser.add_argument("--system_id", help="ID of the system used to translate domeval.") + parser.add_argument("--domeval_ids", help="Path to TSV file containing domain evaluation IDs (with domain names).") return parser.parse_args() @@ -18,93 +20,138 @@ def read_file_lines(file_path): with open(file_path, 'r', encoding='utf-8') as f: return [line.strip() for line in f.readlines()] -def write_report(report_file, report_content): - with open(report_file, 'w', encoding='utf-8') as f: - f.write(report_content) - def calculate_sacrebleu(reference, translations): bleu_score = sacrebleu.corpus_bleu(translations, [reference]) chrf_score = sacrebleu.corpus_chrf(translations, [reference]) return bleu_score, chrf_score -def process_domain_files(domain, input_dir, trg_lang, ref_lines, report_content): +def process_domain_files(domain, input_dir, trg_lang, ref_lines, report_file, system_id): # Filenames - fin_file = os.path.join(input_dir, f"{domain}.domeval.fin") - translated_fuzzies_file = os.path.join(input_dir, f"{domain}.domeval.translated_fuzzies") - linenum_file = os.path.join(input_dir, f"{domain}.domeval.linenum") + trg_file = os.path.join(input_dir, f"{domain}-domeval.{trg_lang}") + translated_fuzzies_file = os.path.join(input_dir, f"{domain}-domeval.translated_fuzzies") + linenum_file = os.path.join(input_dir, f"{domain}-domeval.linenum") - # Process fin file (full test set translation) - fin_lines = read_file_lines(fin_file) - bleu_fin, chrf_fin = calculate_sacrebleu(ref_lines, fin_lines) - report_content += f"Domain: {domain} - Full test set (fin)\n" - report_content += f"BLEU: {bleu_fin.score}\nchrF: {chrf_fin.score}\n\n" + # Process trg file (full test set translation) + #trg_lines = read_file_lines(trg_file) + #bleu_trg, chrf_trg = calculate_sacrebleu(ref_lines, trg_lines) + #report_file.write(f"full_domeval\t{domain}\t{bleu_trg.score}\t{chrf_trg.score}\n") # Process translated fuzzies fuzzy_lines = read_file_lines(translated_fuzzies_file) - linenum_lines = [int(line.strip()) for line in read_file_lines(linenum_file)] - fuzzy_ref_lines = [ref_lines[linenum] for linenum in linenum_lines] + linenum_lines = [int(line.split(":")[0].strip()) for line in read_file_lines(linenum_file)] + fuzzy_ref_lines = [ref_lines[linenum-1] for linenum in linenum_lines] bleu_fuzzy, chrf_fuzzy = calculate_sacrebleu(fuzzy_ref_lines, fuzzy_lines) - report_content += f"Domain: {domain} - Fuzzy subset\n" - report_content += f"BLEU: {bleu_fuzzy.score}\nchrF: {chrf_fuzzy.score}\n\n" + report_file.write(f"all_fuzzies\t{domain}\t{bleu_fuzzy.score}\t{chrf_fuzzy.score}\tonly_fuzzies\t{len(fuzzy_lines)}\t{system_id}\n") - return report_content - -def process_domeval_ids(input_dir, trg_lang, tsv_file, ref_file_path, report_content): +def process_domeval_ids(input_dir, trg_lang, tsv_file, ref_file_path, report_file, system_id): # Read the TSV file into a dictionary mapping each line number to its domain - domain_dict = {} - with open(tsv_file, 'r', encoding='utf-8') as f: + id_to_domain_dict = {} + domain_to_id_dict = {} + with gzip.open(tsv_file, 'rt', encoding='utf-8') as f: reader = csv.reader(f, delimiter='\t') for idx, row in enumerate(reader): - domain_dict[idx] = row[0] + domain = row[0] + id_to_domain_dict[idx] = domain + if domain not in domain_to_id_dict: + domain_to_id_dict[domain] = set() + domain_to_id_dict[domain].add(idx) + + # Read the nofuzzies.[trg] file + nofuzzies_trg_path = os.path.join(input_dir, f"nofuzzies.{trg_lang}") + nofuzzies_trg_lines = read_file_lines(nofuzzies_trg_path) # Read the aligned files: reference translations and domain-specific translations - ref_lines = read_file_lines(ref_file_path) - all_fin_lines = {} + all_ref_lines = read_file_lines(ref_file_path) + all_trg_lines = {} - for domain in set(domain_dict.values()): - domain_fin_file = os.path.join(input_dir, f"{domain}.domeval.fin") - if os.path.exists(domain_fin_file): - all_fin_lines[domain] = read_file_lines(domain_fin_file) + #add train and all_filtered here + domains = set(id_to_domain_dict.values()) + + for domain in domains: + domain_trg_file = os.path.join(input_dir, f"{domain}-domeval.{trg_lang}") + if os.path.exists(domain_trg_file): + all_trg_lines[domain] = read_file_lines(domain_trg_file) # Initialize a dictionary to hold domain-specific sentences for evaluation domain_specific_refs = {} domain_specific_trans = {} - - for idx, domain in domain_dict.items(): + domain_specific_nofuzzies = {} + + index_domains = all_trg_lines.keys() + + # open fuzzy linenum file to get domain-specific fuzzy counts + domain_to_fuzzy_id = {} + + for index_domain in index_domains: + domain_to_fuzzy_id[index_domain] = {} + with open(os.path.join(input_dir, f"{index_domain}-domeval.linenum"),'r') as linenum_file: + linenum_lines = {int(line.split(":")[0].strip())-1 for line in linenum_file.readlines()} + for domain in domains: + domain_to_fuzzy_id[index_domain][domain] = linenum_lines.intersection(domain_to_id_dict[domain]) + + for idx, domain in id_to_domain_dict.items(): + if domain not in all_trg_lines.keys(): + continue + #new domain initialization if domain not in domain_specific_refs: domain_specific_refs[domain] = [] - domain_specific_trans[domain] = [] - - domain_specific_refs[domain].append(ref_lines[idx]) - domain_specific_trans[domain].append(all_fin_lines[domain][idx]) + domain_specific_nofuzzies[domain] = [] + for index_domain in index_domains: + if index_domain not in domain_specific_trans: + domain_specific_trans[index_domain] = {} + if domain not in domain_specific_trans[index_domain]: + domain_specific_trans[index_domain][domain] = [] + + domain_specific_refs[domain].append(all_ref_lines[idx]) + domain_specific_nofuzzies[domain].append(nofuzzies_trg_lines[idx]) + + for index_domain in index_domains: + domain_specific_trans[index_domain][domain].append(all_trg_lines[index_domain][idx]) + # Now calculate sacrebleu for domain-specific translations for domain in domain_specific_refs: - ref_lines = domain_specific_refs[domain] - fin_lines = domain_specific_trans[domain] + print(f"processing {domain}") + ref_lines = domain_specific_refs[domain] - bleu_domain, chrf_domain = calculate_sacrebleu(ref_lines, fin_lines) - report_content += f"Domain: {domain} - Domain-specific subset\n" - report_content += f"BLEU: {bleu_domain.score}\nchrF: {chrf_domain.score}\n\n" - - return report_content + for index_domain in index_domains: + domain_fuzzies = domain_to_fuzzy_id[index_domain][domain] + fuzzy_count = len(domain_fuzzies) + if fuzzy_count < 20: + continue + trg_lines = domain_specific_trans[index_domain][domain] + bleu_domain, chrf_domain = calculate_sacrebleu(ref_lines, trg_lines) + report_file.write(f"{domain}\t{index_domain}\t{bleu_domain.score}\t{chrf_domain.score}\tall\t{fuzzy_count}\t{system_id}\n") + + fuzzy_ref_lines = [all_ref_lines[linenum-1] for linenum in domain_fuzzies] + fuzzy_trg_lines = [all_trg_lines[index_domain][linenum-1] for linenum in domain_fuzzies] + + bleu_domain_fuzzy, chrf_domain_fuzzy = calculate_sacrebleu(fuzzy_ref_lines, fuzzy_trg_lines) + report_file.write(f"{domain}\t{index_domain}\t{bleu_domain_fuzzy.score}\t{chrf_domain_fuzzy.score}\tonly_fuzzies\t{fuzzy_count}\t{system_id}\n") + + nofuz_trg_lines = domain_specific_nofuzzies[domain] + no_fuz_bleu_domain, no_fuz_chrf_domain = calculate_sacrebleu(ref_lines, nofuz_trg_lines) + report_file.write(f"{domain}\tnofuzzies\t{no_fuz_bleu_domain.score}\t{no_fuz_chrf_domain.score}\tall\t0\t{system_id}\n") def main(): # Parse the arguments args = parse_args() # Prepare the report content - report_content = "" - # Read the reference file domeval.[trg].ref - ref_file_path = os.path.join(args.input_dir, f"domeval.{args.trg_lang}.ref") + with open(args.report_file,'wt') as report_file: + + # Read the reference file domeval.[trg].ref + ref_file_path = os.path.join(args.input_dir, f"domeval.{args.trg_lang}.ref") - # Process domeval_ids TSV file (and align domeval.[trg].ref with [domain].domeval.fin) - report_content = process_domeval_ids(args.input_dir, args.trg_lang, args.domeval_ids, ref_file_path, report_content) + for file_name in os.listdir(args.input_dir): + if file_name.endswith(f"-domeval.{args.trg_lang}"): + domain = file_name.replace(f"-domeval.{args.trg_lang}","") + process_domain_files(domain, args.input_dir, args.trg_lang, read_file_lines(ref_file_path), report_file, args.system_id) - # Write report to file - write_report(args.report_file, report_content) + # Process domeval_ids TSV file (and align domeval.[trg].ref with [domain].domeval.[trg]) + process_domeval_ids(args.input_dir, args.trg_lang, args.domeval_ids, ref_file_path, report_file, args.system_id) if __name__ == "__main__": main() diff --git a/pipeline/eval/eval-domains.sh b/pipeline/eval/translate-domeval.sh similarity index 98% rename from pipeline/eval/eval-domains.sh rename to pipeline/eval/translate-domeval.sh index 22e533061..afb78f11f 100755 --- a/pipeline/eval/eval-domains.sh +++ b/pipeline/eval/translate-domeval.sh @@ -48,7 +48,7 @@ translate() { } -domeval_dir="$result_directory/domeval" +domeval_dir="$result_directory" # Create the domeval subdirectory in the output directory mkdir -p "$domeval_dir" diff --git a/pipeline/rat/augment.py b/pipeline/rat/augment.py new file mode 100644 index 000000000..395b26729 --- /dev/null +++ b/pipeline/rat/augment.py @@ -0,0 +1,130 @@ +import argparse +import gzip +import os +import re +from random import shuffle + +def get_fuzzy_bucket(score): + return int(score*10) + +def main(args): + print("Augmenting sentences") + + # open all the files for line-by-line processing + with gzip.open(args.src_file_path,'rt') as src_input_file, \ + gzip.open(args.trg_file_path,'rt') as trg_input_file, \ + gzip.open(args.src_output_path,'wt') as src_output_file, \ + gzip.open(args.trg_output_path,'wt') as trg_output_file, \ + gzip.open(args.score_file_path,'rt') as score_file: + + fuzzy_buckets = {} + mix_score_line = None + + if not f"{args.src_lang}-{args.trg_lang}" in os.path.basename(args.score_file_path): + reverse_sents = True + else: + reverse_sents = False + + mix_score_file = None + if args.mix_score_file_path: + mix_score_file = gzip.open(args.mix_score_file_path,'rt') + if not f"{args.src_lang}-{args.trg_lang}" in os.path.basename(args.mix_score_file_path): + reverse_mix_sents = True + else: + reverse_mix_sents = False + + + augmented_count = 0 + for index, src_sentence in enumerate(src_input_file): + if args.lines_to_augment and augmented_count == args.lines_to_augment: + break + + src_sentence = src_sentence.strip() + trg_sentence = trg_input_file.readline().strip() + score_line = score_file.readline().strip() + + if mix_score_file: + mix_score_line = mix_score_file.readline().strip() + + if score_line: + matches = re.findall("(?P\d\.\d+)\t\d+=(?P.+?) \|\|\| (?P[^\t]+)",score_line) + # shuffle to avoid too many high fuzzies + shuffle(matches) + # if index lang pair is different from the args lang pair, switch source and target + + if reverse_sents: + filtered_matches = [(float(score),trg,src) for (score,src,trg) in matches if float(score) >= args.min_score and float(score) <= args.max_score][0:args.max_fuzzies] + else: + filtered_matches = [(float(score),src,trg) for (score,src,trg) in matches if float(score) >= args.min_score and float(score) <= args.max_score][0:args.max_fuzzies] + else: + filtered_matches = [] + + # mix means using one match from an alternative source (used with targetsim to have one match that is guaranteed to be reflected on the target side) + if mix_score_line: + mix_matches = re.findall("(?P\d\.\d+)\t\d+=(?P.+?) \|\|\| (?P[^\t]+)",mix_score_line) + + if reverse_mix_sents: + filtered_mix_matches = [(float(score),trg,src) for (score,src,trg) in mix_matches if float(score) >= args.min_score and float(score) <= args.max_score][0:args.max_fuzzies] + else: + filtered_mix_matches = [(float(score),src,trg) for (score,src,trg) in mix_matches if float(score) >= args.min_score and float(score) <= args.max_score][0:args.max_fuzzies] + + if filtered_mix_matches: + if filtered_matches: + filtered_matches[0] = filtered_mix_matches[0] + else: + filtered_matches = filtered_mix_matches + + #mix up matches to prevent the model from learning an order of sourcesim and targetsim + shuffle(filtered_matches) + + # keep track of fuzzy counts + for fuzzy in filtered_matches: + fuzzy_bucket = get_fuzzy_bucket(fuzzy[0]) + if fuzzy_bucket in fuzzy_buckets: + fuzzy_buckets[fuzzy_bucket] += 1 + else: + fuzzy_buckets[fuzzy_bucket] = 1 + + if len(filtered_matches) >= args.min_fuzzies: + if args.include_source: + fuzzy_string = "".join([f"{match[1]}{args.source_separator}{match[2]}{args.target_separator}_{get_fuzzy_bucket(match[0])}" for match in filtered_matches]) + else: + fuzzy_string = "".join([f"{match[2]}{args.target_separator}_{get_fuzzy_bucket(match[0])}" for match in filtered_matches]) + src_output_file.write(f"{fuzzy_string}{src_sentence}\n") + trg_output_file.write(trg_sentence+"\n") + augmented_count += 1 + else: + if not args.exclude_non_augmented: + src_output_file.write(f"{src_sentence}\n") + trg_output_file.write(trg_sentence+"\n") + augmented_count += 1 + + if args.mix_score_file_path: + mix_score_file.close() + + print(fuzzy_buckets) + +if __name__ == "__main__": + # Set up argument parsing + parser = argparse.ArgumentParser(description="Augment data with fuzzies from index.") + parser.add_argument("--src_file_path", help="Path to the file containing the source sentences that should be augmented with fuzzies.") + parser.add_argument("--trg_file_path", help="Path to the file containing the target sentences that should be augmented with fuzzies.") + parser.add_argument("--src_lang", help="Source lang code.") + parser.add_argument("--trg_lang", help="Target lang code.") + parser.add_argument("--score_file_path", help="Path to the file containing the indices of fuzzies found for each sentence in the sentence file") + parser.add_argument("--mix_score_file_path", help="Path to the file containing the indices of the mix fuzzies found for each sentence in the sentence file. One of these fuzzies is used alongside the normal fuzzies when augmenting a sentence") + parser.add_argument("--src_output_path", help="Path to save the source file augmented with fuzzies.") + parser.add_argument("--trg_output_path", help="Path to save the target file augmented with fuzzies.") + parser.add_argument("--source_separator", default="SRC_FUZZY_BREAK", help="Separator token that separates the source side of fuzzies from other fuzzies and the source sentence") + parser.add_argument("--target_separator", default="FUZZY_BREAK", help="Separator token that separates the target side of fuzzies from other fuzzies and the source sentence") + parser.add_argument("--min_score", type=float, help="Only consider fuzzies that have a score equal or higher than this") + parser.add_argument("--max_score", type=float, help="Only consider fuzzies that have a score equal or lower than this") + parser.add_argument("--min_fuzzies", type=int, help="Augment sentence if it has at least this many fuzzies") + parser.add_argument("--max_fuzzies", type=int, help="Augment the sentence with at most this many fuzzies (use n best matches if more than max fuzzies found)") + parser.add_argument("--lines_to_augment", type=int, default=-1, help="Augment this many lines, default is all lines") + parser.add_argument("--include_source", action="store_true", help="Also include source in the augmented line") + parser.add_argument("--exclude_non_augmented", action="store_true", help="Do not include non-augmented in the output") + # Parse the arguments + args = parser.parse_args() + print(args) + main(args) diff --git a/pipeline/rat/build_index.sh b/pipeline/rat/build_index.sh index 0449cf093..4a7c15705 100755 --- a/pipeline/rat/build_index.sh +++ b/pipeline/rat/build_index.sh @@ -14,10 +14,10 @@ index_file=$4 echo "##### Building a fuzzy match index" # index building runs on single thread, --nthreads is only for matching -${fuzzy_match_cli} --action index --corpus ${src_corpus} +# ${fuzzy_match_cli} --action index --corpus ${src_corpus} -# TODO: test if string can stored efficiently in the index -# ${fuzzy_match_cli} --action index --corpus ${src_corpus},${trg_corpus} --add-target +# Store strings in index. I've also modified fuzzy-match to store source in the index +${fuzzy_match_cli} --action index --corpus ${src_corpus},${trg_corpus} --add-target # index is saved as src_corpus.fmi, move it to the correct place mv "${src_corpus}.fmi" "${index_file}" diff --git a/pipeline/rat/find_matches.sh b/pipeline/rat/find_matches.sh index 426d09f8c..6b7cdd55a 100755 --- a/pipeline/rat/find_matches.sh +++ b/pipeline/rat/find_matches.sh @@ -13,7 +13,7 @@ index_file=$4 output_file=$5 contrastive_factor=$6 -echo "##### Building a fuzzy match index" +echo "##### Finding matches" -zcat ${source_corpus} | ${fuzzy_match_cli} --contrast ${contrastive_factor} --no-perfect --index ${index_file} --fuzzy 0.5 --action match --nthreads ${threads} > $output_file +zcat ${source_corpus} | ${fuzzy_match_cli} --contrast ${contrastive_factor} --no-perfect --index ${index_file} --fuzzy 0.5 --action match --nthreads ${threads} | gzip > $output_file diff --git a/pipeline/rat/get_matches.py b/pipeline/rat/get_matches.py index ff3b1befe..3ada931a6 100644 --- a/pipeline/rat/get_matches.py +++ b/pipeline/rat/get_matches.py @@ -1,6 +1,7 @@ import argparse import gzip import os +from random import shuffle def read_sentence_file(filepath): ext = os.path.splitext(filepath)[-1].lower() @@ -27,6 +28,10 @@ def main(args): src_sentences = read_sentence_file(args.src_sentence_file) trg_sentences = read_sentence_file(args.trg_sentence_file) scores = read_score_file(args.score_file) + mix_scores = None + fuzzy_buckets = {} + if args.mix_score_file: + mix_scores = read_score_file(args.mix_score_file) #if index set is the same as test, don't read it twice if (args.src_sentence_file == args.index_src_sentence_file): @@ -39,16 +44,42 @@ def main(args): with gzip.open(args.src_augmented_file,'wt') as src_output_file, \ gzip.open(args.trg_augmented_file,'wt') as trg_output_file: print("Augmenting with fuzzies") + augmented_count = 0 for index, sentence in enumerate(src_sentences): - if args.lines_to_augment and index == args.lines_to_augment-1: + if args.lines_to_augment and augmented_count == args.lines_to_augment: break if index in scores: score_indices = scores[index] + # shuffle to avoid too many high fuzzies + shuffle(score_indices) corresponding_sentences = [ - (index_src_sentences[i-1],index_trg_sentences[i-1]) for score,i in + (index_src_sentences[i-1],index_trg_sentences[i-1],score) for score,i in score_indices if score >= args.min_score and score <= args.max_score][0:args.max_fuzzies] else: corresponding_sentences = [] + # mix means using one match from an alternative source (used with targetsim to have one match that is guaranteed to be reflected on the target side) + if mix_scores and index in mix_scores: + score_indices = mix_scores[index] + mix_corresponding_sentences = [ + (index_src_sentences[i-1],index_trg_sentences[i-1],score) for score,i in + mix_score_indices if score >= args.min_score and score <= args.max_score][0] + if corresponding_sentences: + corresponding_sentences[0] = mix_corresponding_sentences[0] + else: + corresponding_sentences = mix_corresponding_sentences + + #mix up matches to prevent the model from learning an implicit order + if len(corresponding_sentences) > 1: + shuffle(corresponding_sentences) + + # keep track of fuzzy counts + for fuzzy in corresponding_sentences: + fuzzy_bucket = int(fuzzy[2]*10) + if fuzzy_bucket in fuzzy_buckets: + fuzzy_buckets[fuzzy_bucket] += 1 + else: + fuzzy_buckets[fuzzy_bucket] = 1 + if len(corresponding_sentences) >= args.min_fuzzies: if args.include_source: fuzzies = [f"{x[0]}{args.source_separator}{x[1]}{args.target_separator}" for x in corresponding_sentences] @@ -58,11 +89,14 @@ def main(args): target_fuzzies = [x[1] for x in corresponding_sentences] src_output_file.write(f"{args.target_separator.join(target_fuzzies)}{args.target_separator}{sentence}\n") trg_output_file.write(trg_sentences[index]+"\n") - + augmented_count += 1 else: if not args.exclude_non_augmented: src_output_file.write(f"{sentence}\n") trg_output_file.write(trg_sentences[index]+"\n") + augmented_count += 1 + + print(fuzzy_buckets) if __name__ == "__main__": # Set up argument parsing @@ -70,6 +104,7 @@ def main(args): parser.add_argument("--src_sentence_file", help="Path to the file containing the source sentences that should be augmented with fuzzies.") parser.add_argument("--trg_sentence_file", help="Path to the file containing the target sentences that should be augmented with fuzzies.") parser.add_argument("--score_file", help="Path to the file containing the indices of fuzzies found for each sentence in the sentence file") + parser.add_argument("--mix_score_file", help="Path to the file containing the indices of the mix fuzzies found for each sentence in the sentence file. One of these fuzzies is used alongside the normal fuzzies when augmenting a sentence") parser.add_argument("--src_augmented_file", help="Path to save the source file augmented with fuzzies.") parser.add_argument("--trg_augmented_file", help="Path to save the target file augmented with fuzzies.") parser.add_argument("--index_src_sentence_file", help="Path to the file containing the source sentences corresponding to the fuzzy indices in the score file") @@ -80,7 +115,7 @@ def main(args): parser.add_argument("--max_score", type=float, help="Only consider fuzzies that have a score equal or lower than this") parser.add_argument("--min_fuzzies", type=int, help="Augment sentence if it has at least this many fuzzies") parser.add_argument("--max_fuzzies", type=int, help="Augment the sentence with at most this many fuzzies (use n best matches if more than max fuzzies found)") - parser.add_argument("--lines_to_augment", type=int, help="Augment this many lines, default is all lines") + parser.add_argument("--lines_to_augment", type=int, default=-1, help="Augment this many lines, default is all lines") parser.add_argument("--include_source", action="store_true", help="Also include source in the augmented line") parser.add_argument("--exclude_non_augmented", action="store_true", help="Do not include non-augmented in the output") # Parse the arguments diff --git a/pipeline/train/ensemble.py b/pipeline/train/ensemble.py new file mode 100644 index 000000000..e70bdac30 --- /dev/null +++ b/pipeline/train/ensemble.py @@ -0,0 +1,44 @@ +import yaml +import argparse + +def merge_decoders(decoder_file_1, decoder_file_2, output_decoder_file, vocab_file, decoder_1_weight): + # Load the YAML files + with open(decoder_file_1, 'r') as f1, open(decoder_file_2, 'r') as f2: + decoder_1 = yaml.safe_load(f1) + decoder_2 = yaml.safe_load(f2) + + # Retain all keys from decoder 1 except 'models' + output_data = {key: value for key, value in decoder_1.items() if key != 'models'} + + # Merge the models from both decoders + merged_models = decoder_1['models'] + decoder_2['models'] + output_data['models'] = merged_models + + # Add the weights for decoder 1 and decoder 2 + output_data['weights'] = [decoder_1_weight, 1 - decoder_1_weight] + + # Save the merged data to the output YAML file + with open(output_decoder_file, 'w') as outfile: + yaml.safe_dump(output_data, outfile) + +def main(): + # Argument parser setup + parser = argparse.ArgumentParser(description="Merge two decoder YAML files") + parser.add_argument('--decoder_file_1', type=str, help="Path to the first decoder YAML file") + parser.add_argument('--decoder_file_2', type=str, help="Path to the second decoder YAML file") + parser.add_argument('--output_decoder_file', type=str, help="Path to the output decoder YAML file") + parser.add_argument('--vocab_file', type=str, help="Path to the vocabulary file") + parser.add_argument('--decoder_1_weight', type=float, help="Weight for the first decoder (0 <= weight <= 1)") + + args = parser.parse_args() + + # Ensure the decoder 1 weight is between 0 and 1 + if not (0 <= args.decoder_1_weight <= 1): + raise ValueError("The weight for the first decoder must be between 0 and 1.") + + # Merge the decoders + merge_decoders(args.decoder_file_1, args.decoder_file_2, args.output_decoder_file, args.vocab_file, args.decoder_1_weight) + +if __name__ == '__main__': + main() + diff --git a/profiles/slurm-lumi/config.yaml b/profiles/slurm-lumi/config.yaml index 0a8490845..03177f5e8 100755 --- a/profiles/slurm-lumi/config.yaml +++ b/profiles/slurm-lumi/config.yaml @@ -24,7 +24,7 @@ config: #- root=/pfs/lustrep1/scratch/project_462000088/members/niemine1/data #- rocm=/opt/rocm - workspace=40000 - - numgpus=4 + - numgpus=8 - mariancmake="" - - gpus="0 1 2 3" + - gpus="0 1 2 3 4 5 6 7" - marianversion="lumi-marian" diff --git a/rat.smk b/rat.smk index 48d1a87d7..f9b36e776 100644 --- a/rat.smk +++ b/rat.smk @@ -7,7 +7,7 @@ wildcard_constraints: index_type="[\.\-\w\d_]+" #index_type="(all_filtered|train)" -ruleorder: build_fuzzy_domain_index > build_fuzzy_index +ruleorder: build_reverse_fuzzy_index > build_fuzzy_domain_index > build_fuzzy_index # TODO: all_filtered index should be built earlier, like domain indexes rule build_fuzzy_index: @@ -16,7 +16,7 @@ rule build_fuzzy_index: conda: None priority: 100 container: None - resources: mem_mb=lambda wildcards, input, attempt: (input.size//1000000) * attempt * 10 + resources: mem_mb=lambda wildcards, input, attempt: (input.size//1000000) * attempt * 20 envmodules: "LUMI/22.12", "Boost" @@ -42,7 +42,7 @@ use rule build_fuzzy_index as build_fuzzy_domain_index with: output: index="{project_name}/{src}-{trg}/{preprocessing}/domeval_indexes/index.{index_type}.{src}-{trg}.fmi" ruleorder: find_reverse_fuzzy_matches > find_domain_fuzzy_matches > find_fuzzy_matches - + rule find_fuzzy_matches: message: "Finding fuzzies" log: "{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/find_{index_type}-{set}_matches.log" @@ -60,7 +60,7 @@ rule find_fuzzy_matches: target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{trg}.gz", index="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/index.{index_type}.{src}-{trg}.fmi" output: - matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches" + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches.gz" shell: f'''bash pipeline/rat/find_matches.sh "{fuzzy_match_cli}" "{{input.source}}" {{threads}} "{{input.index}}" "{{output.matches}}" {{wildcards.contrast_factor}} >> {{log}} 2>&1''' use rule find_fuzzy_matches as find_reverse_fuzzy_matches with: @@ -70,7 +70,7 @@ use rule find_fuzzy_matches as find_reverse_fuzzy_matches with: target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{src}.gz", index="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/index.targetsim_{index_type}.{trg}-{src}.fmi" output: - matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches" + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches.gz" use rule find_fuzzy_matches as find_domain_fuzzy_matches with: input: @@ -92,52 +92,78 @@ rule augment_data_with_fuzzies: conda: None container: None priority: 100 - group: "augment" + #group: "augment" threads: 1 resources: mem_mb=60000 input: - index_source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{index_type}.{src}.gz", - index_target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{index_type}.{trg}.gz", augment_source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{src}.gz", augment_target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{trg}.gz", - matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches" + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches.gz" output: source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/{index_type}-{set}.{src}.gz", - target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/{index_type}-{set}.{trg}.gz" + target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/{index_type}-{set}.{trg}.gz", + source_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_{index_type}-{set}.{src}.gz", + target_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_{index_type}-{set}.{trg}.gz" params: - max_sents=lambda wildcards: 2000 if wildcards.set == "dev" else -1, - fuzzy_max_score=lambda wildcards: "1" if wildcards.fuzzy_max_score else wildcards.fuzzy_max_score - shell: f'''python pipeline/rat/get_matches.py \ - --src_sentence_file "{{input.augment_source}}" \ - --trg_sentence_file "{{input.augment_target}}" \ - --score_file "{{input.matches}}" \ - --src_augmented_file "{{output.source}}" \ - --trg_augmented_file "{{output.target}}" \ - --index_src_sentence_file "{{input.index_source}}" \ - --index_trg_sentence_file "{{input.index_target}}" \ - --lines_to_augment {{params.max_sents}} \ - --min_score {{wildcards.fuzzy_min_score}} \ - --max_score {{params.fuzzy_max_score}} \ - --min_fuzzies {{wildcards.min_fuzzies}} \ - --max_fuzzies {{wildcards.max_fuzzies}} >> {{log}} 2>&1''' + fuzzy_max_score=lambda wildcards: "1" if wildcards.fuzzy_max_score == "" else wildcards.fuzzy_max_score.replace("-",""), + mix_matches="", + extra_args=lambda wildcards: "--lines_to_augment 2000 --exclude_non_augmented" if wildcards.set == "cleandev" else "" + shell: '''python pipeline/rat/augment.py \ + --src_file_path "{input.augment_source}" \ + --trg_file_path "{input.augment_target}" \ + --src_lang {wildcards.src} \ + --trg_lang {wildcards.trg} \ + --score_file "{input.matches}" \ + --mix_score_file "{params.mix_matches}" \ + --src_output_path "{output.source}" \ + --trg_output_path "{output.target}" \ + --min_score {wildcards.fuzzy_min_score} \ + --max_score {params.fuzzy_max_score} \ + --min_fuzzies {wildcards.min_fuzzies} \ + --max_fuzzies {wildcards.max_fuzzies} \ + {params.extra_args} >> {log} 2>&1 && \ + {{ zcat {output.source} | sed "s/FUZZY_BREAK_[0-9]/FUZZY_BREAK/g" | gzip > {output.source_nobands} ; }} >> {log} 2>&1 &&\ + ln {output.target} {output.target_nobands} >> {log} 2>&1''' #TODO: reorganize the concept of augmentation. The training data can be augmented with train, all_filtered, train_targetsim, all_filtered_target_sim use rule augment_data_with_fuzzies as augment_data_with_reverse_fuzzies with: log: "{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/augment_targetsim_{index_type}-{set}_matches.log" input: - index_source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{index_type}.{src}.gz", - index_target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{index_type}.{trg}.gz", augment_source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{src}.gz", augment_target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{trg}.gz", - matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches" + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches.gz" output: source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/targetsim_{index_type}-{set}.{src}.gz", - target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/targetsim_{index_type}-{set}.{trg}.gz" - + target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/targetsim_{index_type}-{set}.{trg}.gz", + source_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_targetsim_{index_type}-{set}.{src}.gz", + target_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_targetsim_{index_type}-{set}.{trg}.gz" + +#TODO mixed sourcesim and targetsim augmentation with at least one targetsim match per fuzzy set (random order). The idea being that the model will learn that at least one of the matches is used in the target. This only makes sense if more than one matches are used +use rule augment_data_with_fuzzies as augment_data_with_mixed_fuzzies with: + log: "{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/augment_mixedsim_{index_type}-{set}_matches.log" + input: + augment_source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{src}.gz", + augment_target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/{set}.{trg}.gz", + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches.gz", + mix_matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches.gz" + output: + source="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/mixedsim_{index_type}-{set}.{src}.gz", + target="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/mixedsim_{index_type}-{set}.{trg}.gz", + source_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_mixedsim_{index_type}-{set}.{src}.gz", + target_nobands="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/augment_train_{fuzzy_min_score}{fuzzy_max_score}_{min_fuzzies}_{max_fuzzies}/nobands_mixedsim_{index_type}-{set}.{trg}.gz" + params: + extra_args=lambda wildcards: "--lines_to_augment 2000 --exclude_non_augmented" if wildcards.set == "cleandev" else "", + fuzzy_max_score=lambda wildcards: "1" if wildcards.fuzzy_max_score == "" else wildcards.fuzzy_max_score.replace("-",""), + mix_matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/targetsim_{index_type}-{set}.{trg}-{src}.matches.gz" + use rule augment_data_with_fuzzies as augment_data_with_domain_fuzzies with: + wildcard_constraints: + fuzzy_min_score="0\.\d+", + fuzzy_max_score="(-0\.\d+|)", + min_fuzzies="\d+", + max_fuzzies="\d+", + set="domeval", input: - index_source="{project_name}/{src}-{trg}/{tc_processing}/subcorpora/{index_type}.{src}.gz", - index_target="{project_name}/{src}-{trg}/{tc_processing}/subcorpora/{index_type}.{trg}.gz", - augment_source="{project_name}/{src}-{trg}/{tc_processing}/domeval.{src}.gz", - augment_target="{project_name}/{src}-{trg}/{tc_processing}/domeval.{trg}.gz", - matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches" + augment_source="{project_name}/{src}-{trg}/{tc_processing}/{set}.{src}.gz", + augment_target="{project_name}/{src}-{trg}/{tc_processing}/{set}.{trg}.gz", + matches="{project_name}/{src}-{trg}/{tc_processing}/{preprocessing}/build_index/find_matches_{contrast_factor}/{index_type}-{set}.{src}-{trg}.matches.gz" \ No newline at end of file diff --git a/train.smk b/train.smk index 8a3504907..bdcdbc18d 100644 --- a/train.smk +++ b/train.smk @@ -1,12 +1,36 @@ +localrules: ensemble_models + +ruleorder: ensemble_models > train_model + wildcard_constraints: src="\w{2,3}", trg="\w{2,3}", train_vocab="train_joint_spm_vocab[^/]+", training_type="[^/]+", - model_type="[^/_]+" + model_type="[^/_]+", + index_type="[^-]+" gpus_num=config["gpus-num"] +rule ensemble_models: + wildcard_constraints: + model1="train_model[^\+]+(?=\+)", + model2="(?<=\+)train_model[^\+/]+" + message: "Creating an ensemble model decoder config" + log: "{project_name}/{src}-{trg}/{preprocessing}/{train_vocab}/{model1}+{model2}.log" + conda: "envs/base.yml" + threads: 1 + input: + model1_decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/{{model1}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml', + model2_decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/{{model2}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml', + vocab="{project_name}/{src}-{trg}/{preprocessing}/{train_vocab}/vocab.spm" + output: + decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/{{model1}}+{{model2}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml' + params: + args=config["training-teacher-args"], + decoder_1_weight=0.5 + shell: '''python pipeline/train/ensemble.py --decoder_file_1 "{input.model1_decoder_config}" --decoder_file_2 "{input.model2_decoder_config}" --output_decoder_file {output.decoder_config} --vocab_file {input.vocab} --decoder_1_weight {params.decoder_1_weight} >> {log} 2>&1''' + rule train_model: message: "Training a model" log: "{project_name}/{src}-{trg}/{preprocessing}/{train_vocab}/train_model_{index_type}-{model_type}-{training_type}_train_model.log" @@ -18,20 +42,20 @@ rule train_model: threads: gpus_num*3 resources: gpu=gpus_num,mem_mb=64000 input: - dev_source="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-dev.{src}.gz", - dev_target="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-dev.{trg}.gz", + dev_source="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-cleandev.{src}.gz", + dev_target="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-cleandev.{trg}.gz", train_source="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-train.{src}.gz", train_target="{project_name}/{src}-{trg}/{preprocessing}/{index_type}-train.{trg}.gz", marian=ancient(config["marian"]), vocab="{project_name}/{src}-{trg}/{preprocessing}/{train_vocab}/vocab.spm", output: - model=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/train_model_{{index_type}}-{{model_type}}-{{training_type}}/final.model.npz.best-{config["best-model-metric"]}.npz' + model=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/train_model_{{index_type}}-{{model_type}}-{{training_type}}/final.model.npz.best-{config["best-model-metric"]}.npz', + decoder_config=f'{{project_name}}/{{src}}-{{trg}}/{{preprocessing}}/{{train_vocab}}/train_model_{{index_type}}-{{model_type}}-{{training_type}}/final.model.npz.best-{config["best-model-metric"]}.npz.decoder.yml' params: args=config["training-teacher-args"] shell: f'''bash pipeline/train/train.sh \ {{wildcards.model_type}} {{wildcards.training_type}} {{wildcards.src}} {{wildcards.trg}} "{{input.train_source}}" "{{input.train_target}}" "{{input.dev_source}}" "{{input.dev_target}}" "{{output.model}}" "{{input.vocab}}" "{config["best-model-metric"]}" {{params.args}} >> {{log}} 2>&1''' - use rule train_model as train_student_model with: message: "Training a student model" log: "{project_name}/{src}-{trg}/{preprocessing}/{train_vocab}/{postprocessing}/train_model_{model_type}-{training_type}/train_model.log"