-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
33 lines (27 loc) · 1.19 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from Dataset import Dataset
import pandas as pd
from transformers import BertTokenizer
from opencc import OpenCC
model_name = input("請輸入希望使用的判斷模型(未輸入則預設使用head_5):")
model_name = "head_5" if(model_name=="") else model_name
model = torch.load(f"./models/{model_name}.pt", map_location=torch.device('cpu'))
cc = OpenCC('t2s')# 繁轉簡
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
types = ['財經(finance)', '國際(global)', '娛樂(star)', '體育(sport)', '社會(society)']
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
model = model.cuda()
while(1):
text = input("請輸入希望判斷的新聞:")
test_input = tokenizer(cc.convert(text),
padding='max_length',
max_length = 512,
truncation=True,
return_tensors="pt")
mask = test_input['attention_mask'].to(device)
input_id = test_input['input_ids'].squeeze(1).to(device)
output = model(input_id, mask)
acc = output.argmax(dim=1).int()
print(output.float(), types[acc])