forked from Determined22/zh-NER-TF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
executable file
·161 lines (132 loc) · 3.73 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys, pickle, os, random
import numpy as np
## tags, BIO
tag2label = {"O": 0,
"B-PER": 1, "I-PER": 2,
"B-LOC": 3, "I-LOC": 4,
"B-ORG": 5, "I-ORG": 6
}
def read_corpus(corpus_path):
"""
read corpus and return the list of samples
:param corpus_path:
:return: data
"""
data = []
with open(corpus_path, encoding='utf-8') as fr:
lines = fr.readlines()
sent_, tag_ = [], []
for line in lines:
if line != '\n':
[char, label] = line.strip().split()
sent_.append(char)
tag_.append(label)
else:
data.append((sent_, tag_))
sent_, tag_ = [], []
return data
def vocab_build(vocab_path, corpus_path, min_count):
"""
:param vocab_path:
:param corpus_path:
:param min_count:
:return:
"""
data = read_corpus(corpus_path)
word2id = {}
for sent_, tag_ in data:
for word in sent_:
if word.isdigit():
word = '<NUM>'
elif ('\u0041' <= word <='\u005a') or ('\u0061' <= word <='\u007a'):
word = '<ENG>'
if word not in word2id:
word2id[word] = [len(word2id)+1, 1]
else:
word2id[word][1] += 1
low_freq_words = []
for word, [word_id, word_freq] in word2id.items():
if word_freq < min_count and word != '<NUM>' and word != '<ENG>':
low_freq_words.append(word)
for word in low_freq_words:
del word2id[word]
new_id = 1
for word in word2id.keys():
word2id[word] = new_id
new_id += 1
word2id['<UNK>'] = new_id
word2id['<PAD>'] = 0
print(len(word2id))
with open(vocab_path, 'wb') as fw:
pickle.dump(word2id, fw)
def sentence2id(sent, word2id):
"""
:param sent:
:param word2id:
:return:
"""
sentence_id = []
for word in sent:
if word.isdigit():
word = '<NUM>'
elif ('\u0041' <= word <= '\u005a') or ('\u0061' <= word <= '\u007a'):
word = '<ENG>'
if word not in word2id:
word = '<UNK>'
sentence_id.append(word2id[word])
return sentence_id
def read_dictionary(vocab_path):
"""
:param vocab_path:
:return:
"""
vocab_path = os.path.join(vocab_path)
with open(vocab_path, 'rb') as fr:
word2id = pickle.load(fr)
print('vocab_size:', len(word2id))
return word2id
def random_embedding(vocab, embedding_dim):
"""
:param vocab:
:param embedding_dim:
:return:
"""
embedding_mat = np.random.uniform(-0.25, 0.25, (len(vocab), embedding_dim))
embedding_mat = np.float32(embedding_mat)
return embedding_mat
def pad_sequences(sequences, pad_mark=0):
"""
:param sequences:
:param pad_mark:
:return:
"""
max_len = max(map(lambda x : len(x), sequences))
seq_list, seq_len_list = [], []
for seq in sequences:
seq = list(seq)
seq_ = seq[:max_len] + [pad_mark] * max(max_len - len(seq), 0)
seq_list.append(seq_)
seq_len_list.append(min(len(seq), max_len))
return seq_list, seq_len_list
def batch_yield(data, batch_size, vocab, tag2label, shuffle=False):
"""
:param data:
:param batch_size:
:param vocab:
:param tag2label:
:param shuffle:
:return:
"""
if shuffle:
random.shuffle(data)
seqs, labels = [], []
for (sent_, tag_) in data:
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
if len(seqs) != 0:
yield seqs, labels