Skip to content

Commit

Permalink
Use fastchat for loading chat templates instead of HF
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Feb 23, 2024
1 parent 17cdfcb commit ede8a2f
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions analysis/get_per_token_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

# Script to output the per-token reward across a piece of text given a reward model

import os
import argparse
import hashlib
import json
import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand All @@ -28,6 +28,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import Dataset
from fastchat.conversation import get_conv_template
from huggingface_hub import upload_file
from tqdm import tqdm
from transformers import AutoTokenizer, pipeline
Expand Down Expand Up @@ -68,8 +69,8 @@ def get_args():
parser.add_argument(
"--chat_template",
type=str,
default="natolambert/gpt2-dummy-rm",
help="Path to the chat template.",
default="tulu",
help="Path to the chat template. Will be loaded using fastchat",
)
parser.add_argument(
"--output_dir",
Expand Down Expand Up @@ -143,11 +144,11 @@ def _tokenify_string(string):
# If chat_template exists
if args.chat_template:
print(f"Applying chat template: {args.chat_template}")
templater = AutoTokenizer.from_pretrained(args.chat_template)
chat = [{"role": "user", "content": args.text}]
text = templater.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
conv = get_conv_template(args.chat_template)
conv.append_message(role=conv.roles[0], message=args.text)
text = conv.get_prompt()
else:
print("No chat template applied.")
print("No chat template supplied.")
text = args.text

substrings, tokens = _tokenify_string(text)
Expand Down

0 comments on commit ede8a2f

Please sign in to comment.