From dc77c8ebce8ec4135f4e0c03a9d336b3f0957358 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 13 Jun 2023 12:01:46 +0900 Subject: [PATCH] chore: Refactor inf_kwargs out --- scripts/finetune.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 283100c8a..785f3cf23 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]: return instruction -def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): +def do_inference(cfg, model, tokenizer, prompter: Optional[str]): default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} for token, symbol in default_tokens.items(): @@ -257,13 +257,13 @@ def train( if cfg.inference: logging.info("calling do_inference function") - inf_kwargs: Dict[str, Any] = {} + prompter: Optional[str] = "AlpacaPrompter" if "prompter" in kwargs: if kwargs["prompter"] == "None": - inf_kwargs["prompter"] = None + prompter = None else: - inf_kwargs["prompter"] = kwargs["prompter"] - do_inference(cfg, model, tokenizer, **inf_kwargs) + prompter = kwargs["prompter"] + do_inference(cfg, model, tokenizer, prompter=prompter) return if "shard" in kwargs: