From 09d2811656362c32b7e47cd9c40cc14f8fe989e0 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Fri, 25 Aug 2023 16:38:46 +0200 Subject: [PATCH] feat: add pushing the model to hf --- textbook/train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/textbook/train.py b/textbook/train.py index 223b1f9..5f10a1f 100644 --- a/textbook/train.py +++ b/textbook/train.py @@ -49,6 +49,7 @@ def train( wandb_log_model: Optional[ bool ] = None, # will be true by default if use_wandb is true + push_model_to_hf: bool = False, # if set, will push the model to hf local_rank: Annotated[int, typer.Option("--local_rank")] = 0, deepspeed: Optional[str] = None, debug: bool = False, @@ -127,12 +128,13 @@ def train( trainer.train() - model.save_pretrained("jinaai/starcoder-1b-textbook") - tokenizer.save_pretrained("jinaai/starcoder-1b-textbook") + if push_model_to_hf: + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) - # Push to the hub - model.push_to_hub("jinaai/starcoder-1b-textbook") - tokenizer.push_to_hub("jinaai/starcoder-1b-textbook") + # Push to the hub + model.push_to_hub('jinaai/starcoder-1b-textbook') + tokenizer.push_to_hub('jinaai/starcoder-1b-textbook') accuracy_results, sample_results = evaluate( model, tokenizer, eval_size=eval_size, max_new_tokens=eval_max_new_tokens