-
Notifications
You must be signed in to change notification settings - Fork 0
/
xgboost_cls.py
206 lines (150 loc) · 6.74 KB
/
xgboost_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import time
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics
import xgboost as xgb
from nltk.corpus import stopwords
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import plot_importance
import numpy as np
plt.rc('font',family='Times New Roman')
def processing_sentence(x, stop_words):
cut_word = x.split()
words = [word for word in cut_word if word not in stop_words and word != ' ']
return ' '.join(words)
def data_processing(train_path,test_path):
train_data = pd.read_csv(train_path)
test_data = pd.read_csv(test_path)
x_train = train_data.text
y_train = train_data.labels
x_test = test_data.text
y_test = test_data.labels
stop_words = stopwords.words('english')
x_train = x_train.apply(lambda x: processing_sentence(x, stop_words))
x_test = x_test.apply(lambda x: processing_sentence(x, stop_words))
tf = TfidfVectorizer()
x_train = tf.fit_transform(x_train)
x_test = tf.transform(x_test)
x_train_weight = x_train.toarray()
x_test_weight = x_test.toarray()
wor2idx = tf.vocabulary_
idx2word = {idx:word for word,idx in wor2idx.items()}
feature_names = []
for i in range(len(idx2word)):
feature_names.append(idx2word[i])
return x_train_weight, x_test_weight, y_train, y_test,feature_names
for prompt in [1,2]:
for seed in [1,2]:
# train_path = f'./data/MiMic/prompt{prompt}_seed{seed}_train.csv'
# test_path = f'./data/MiMic/prompt{prompt}_seed{seed}_test.csv'
train_path = f'./data/medical_text/prompt{prompt}_seed{seed}_train.csv'
test_path = f'./data/medical_text/prompt{prompt}_seed{seed}_test.csv'
x_train_weight, x_test_weight, y_train, y_test,feature_names = data_processing(train_path,test_path)
start = time.time()
print("start time is: ", start)
model = xgb.XGBClassifier(max_depth=4, learning_rate=0.1, n_estimators=50, n_jobs=2,
silent=False, objective='binary:logistic')
model.fit(x_train_weight, y_train)
model.get_booster().feature_names = feature_names
end = time.time()
print("end time is: ", end)
print("cost time is: ", (end - start))
y_predict = model.predict(x_test_weight)
confusion_mat = metrics.confusion_matrix(y_test, y_predict)
print('准确率:', metrics.accuracy_score(y_test, y_predict))
print("confusion_matrix is: ", confusion_mat)
print('分类报告:', metrics.classification_report(y_test, y_predict,digits=3))
# fig, ax = plt.subplots(figsize=(15, 15))
plt.rcParams["figure.figsize"] = (5, 5)
plot_importance(model,
max_num_features=15,
height = 0.3,
grid=False,
show_values= False
)
plt.show()
# medical_text
# 准确率: 0.9625
# confusion_matrix is: [[431 9]
# [ 24 416]]
# 分类报告: precision recall f1-score support
# 0 0.947 0.980 0.963 440
# 1 0.979 0.945 0.962 440
# accuracy 0.963 880
# macro avg 0.963 0.962 0.962 880
# weighted avg 0.963 0.963 0.962 880
# 准确率: 0.9613636363636363
# confusion_matrix is: [[433 7]
# [ 27 413]]
# 分类报告: precision recall f1-score support
# 0 0.941 0.984 0.962 440
# 1 0.983 0.939 0.960 440
# accuracy 0.961 880
# macro avg 0.962 0.961 0.961 880
# weighted avg 0.962 0.961 0.961 880
# 准确率: 0.9590909090909091
# confusion_matrix is: [[431 9]
# [ 27 413]]
# 分类报告: precision recall f1-score support
# 0 0.941 0.980 0.960 440
# 1 0.979 0.939 0.958 440
# accuracy 0.959 880
# macro avg 0.960 0.959 0.959 880
# weighted avg 0.960 0.959 0.959 880
# 准确率: 0.9454545454545454
# confusion_matrix is: [[431 9]
# [ 39 401]]
# 分类报告: precision recall f1-score support
# 0 0.917 0.980 0.947 440
# 1 0.978 0.911 0.944 440
# accuracy 0.945 880
# macro avg 0.948 0.945 0.945 880
# weighted avg 0.948 0.945 0.945 880
# acc: [0.963,0.961,0.959,0.945] 0.957
# precision: [0.963,0.962,0.960,0.948] 0.9582499999999999
# recall: [0.962,0.961,0.959,0.945] 0.95675
# f1: [0.962,0.961,0.959, 0.945] 0.95675
# MiMic
# 准确率: 0.928409090909091
# confusion_matrix is: [[417 23]
# [ 40 400]]
# 分类报告: precision recall f1-score support
# 0 0.912 0.948 0.930 440
# 1 0.946 0.909 0.927 440
# accuracy 0.928 880
# macro avg 0.929 0.928 0.928 880
# weighted avg 0.929 0.928 0.928 880
# 准确率: 0.9261363636363636
# confusion_matrix is: [[416 24]
# [ 41 399]]
# 分类报告: precision recall f1-score support
# 0 0.910 0.945 0.928 440
# 1 0.943 0.907 0.925 440
# accuracy 0.926 880
# macro avg 0.927 0.926 0.926 880
# weighted avg 0.927 0.926 0.926 880
# 准确率: 0.928409090909091
# confusion_matrix is: [[419 21]
# [ 42 398]]
# 分类报告: precision recall f1-score support
# 0 0.909 0.952 0.930 440
# 1 0.950 0.905 0.927 440
# accuracy 0.928 880
# macro avg 0.929 0.928 0.928 880
# weighted avg 0.929 0.928 0.928 880
# 准确率: 0.9125
# confusion_matrix is: [[418 22]
# [ 55 385]]
# 分类报告: precision recall f1-score support
# 0 0.884 0.950 0.916 440
# 1 0.946 0.875 0.909 440
# accuracy 0.912 880
# macro avg 0.915 0.912 0.912 880
# weighted avg 0.915 0.912 0.912 880
# acc: [0.928,0.926,0.928,0.912] 0.9235
# precision: [0.929,0.927,0.929,0.915] 0.925
# recall: [0.928,0.926,0.928,0.912] 0.9235
# f1: [0.928,0.926,0.928,0.912] 0.9235