forked from fastnlp/fastNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
71 lines (65 loc) · 3.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from typing import Union, Dict
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
{
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的,将直接进行raise相应的错误
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
filenames = os.listdir(paths)
files = {}
for filename in filenames:
path_pair = None
if 'train' in filename:
path_pair = ('train', filename)
if 'dev' in filename:
if path_pair:
raise Exception("File:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
raise Exception("File:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
if path_pair[0] in files:
raise RuntimeError(f"Multiple file under {paths} have '{path_pair[0]}' in their filename.")
files[path_pair[0]] = os.path.join(paths, path_pair[1])
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
else:
raise TypeError("All keys and values in paths should be str.")
return paths
else:
raise ValueError("Empty paths is not allowed.")
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
def get_tokenizer():
try:
import spacy
spacy.prefer_gpu()
en = spacy.load('en')
print('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
print('use raw tokenizer')
return lambda x: x.split()