-
Notifications
You must be signed in to change notification settings - Fork 42
/
train_flert_model.py
92 lines (73 loc) · 3.18 KB
/
train_flert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from argparse import ArgumentParser
import torch, flair
# dataset, model and embedding imports
from flair.datasets import UniversalDependenciesCorpus, XTREME
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
if __name__ == "__main__":
# All arguments that can be passed
parser = ArgumentParser()
parser.add_argument("-s", "--seeds", nargs='+', type=int, default='42') # pass list of seeds for experiments
parser.add_argument("-c", "--cuda", type=int, default=0, help="CUDA device") # which cuda device to use
parser.add_argument("-m", "--model", type=str, help="Model name (such as Hugging Face model hub name")
parser.add_argument("-d", "--dataset", type=str, help="Defines dataset, choose between imst, boun or xtreme")
# Parse experimental arguments
args = parser.parse_args()
# use cuda device as passed
flair.device = f'cuda:{str(args.cuda)}'
# for each passed seed, do one experimental run
for seed in args.seeds:
flair.set_seed(seed)
# model
hf_model = args.model
# initialize embeddings
embeddings = TransformerWordEmbeddings(
model=hf_model,
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=False,
respect_document_boundaries=False,
)
# select dataset depending on which language variable is passed
tag_type = None
if args.dataset in ["imst", "boun"]:
tag_type = "upos"
corpus = UniversalDependenciesCorpus(data_folder="./data",
train_file=f"tr_{args.dataset}-ud-train.conllu",
dev_file=f"tr_{args.dataset}-ud-dev.conllu",
test_file=f"tr_{args.dataset}-ud-test.conllu")
elif args.dataset == "xtreme":
tag_type = "ner"
corpus = XTREME(languages="tr")
# make the dictionary of tags to predict
tag_dictionary = corpus.make_tag_dictionary(tag_type)
# init bare-bones sequence tagger (no reprojection, LSTM or CRF)
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=False,
use_rnn=False,
reproject_embeddings=False,
)
# init the model trainer
trainer = ModelTrainer(tagger, corpus, optimizer=torch.optim.AdamW)
# make string for output folder
output_folder = f"flert-{args.dataset}-{hf_model}-{seed}"
# train with XLM parameters (AdamW, 20 epochs, small LR)
from torch.optim.lr_scheduler import OneCycleLR
trainer.train(
output_folder,
learning_rate=5.0e-5,
mini_batch_size=16,
mini_batch_chunk_size=1,
max_epochs=10,
scheduler=OneCycleLR,
embeddings_storage_mode='none',
weight_decay=0.,
train_with_dev=False,
use_final_model_for_eval=True
)