Skip to content

Commit

Permalink
Apply lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
George Hong committed Apr 18, 2024
1 parent ea873fd commit 3460375
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 7 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.
import argparse
import itertools

import logging
import os
import sys
import time
Expand All @@ -27,11 +29,11 @@
from cli import add_arguments_for_generate, arg_init, check_args
from quantize import set_precision

import logging
logger = logging.getLogger(__name__)

B_INST, E_INST = "[INST]", "[/INST]"


@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
Expand Down Expand Up @@ -348,14 +350,16 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
logging.warning("""
logging.warning(
"""
*******************************************************
This model is not known to support the chat function.
We will enable chat mode based on your instructions.
If the model is not trained to support chat, it will
produce nonsensical or false output.
*******************************************************
""")
"""
)
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

tokenizer = _initialize_tokenizer(tokenizer_args)
Expand Down
8 changes: 5 additions & 3 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import argparse

import logging

from cli import (
add_arguments_for_eval,
add_arguments_for_export,
Expand All @@ -14,8 +16,6 @@
check_args,
)

import logging

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'


Expand All @@ -37,7 +37,9 @@

args = parser.parse_args()
args = arg_init(args)
logging.basicConfig(format='%(message)s', level=logging.DEBUG if args.verbose else logging.INFO)
logging.basicConfig(
format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO
)

if args.subcommand == "generate":
check_args(args, "generate")
Expand Down

0 comments on commit 3460375

Please sign in to comment.