You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
chat_1 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "My name is InternLM2! A helpful AI assistant. What can I do for you?"}
]
chat_2 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "I have no idea."}
]
我用脚本评分后,出现的分数为零,请问这是什么情况呢?下面是我的评分代码:
import torch
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained(
"/root/autodl-tmp/xtuner/work_dirs/internlm2_chat_1_8b_reward_qlora_varlenattn_ultrafeedback_copy/iter_15230_hf",
device_map="cuda",
torch_dtype=torch.float16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/xtuner/work_dirs/internlm2_chat_1_8b_reward_qlora_varlenattn_ultrafeedback_copy/iter_15230_hf", trust_remote_code=True)
chat_1 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "My name is InternLM2! A helpful AI assistant. What can I do for you?"}
]
chat_2 = [
{"role": "user", "content": "Hello! What's your name?"},
{"role": "assistant", "content": "I have no idea."}
]
get reward score for a single chat
score1 = model.get_score(tokenizer, chat_1)
score2 = model.get_score(tokenizer, chat_2)
print("score1: ", score1)
print("score2: ", score2)
>>> score1: 0.767578125
>>> score2: -2.22265625
batch inference, get multiple scores at once
scores = model.get_scores(tokenizer, [chat_1, chat_2])
print("scores: ", scores)
>>> scores: [0.767578125, -2.22265625]
compare whether chat_1 is better than chat_2
compare_res = model.compare(tokenizer, chat_1, chat_2)
print("compare_res: ", compare_res)
>>> compare_res: True
rank multiple chats, it will return the ranking index of each chat
the chat with the highest score will have ranking index as 0
rank_res = model.rank(tokenizer, [chat_1, chat_2])
print("rank_res: ", rank_res) # lower index means higher score
>>> rank_res: [0, 1]
The text was updated successfully, but these errors were encountered: