From 7fc6a12ad0d5b8858466b5adf34821d41d0801eb Mon Sep 17 00:00:00 2001 From: Isabelle Mohr Date: Thu, 26 Sep 2024 11:12:17 +0200 Subject: [PATCH] feat: allow loading weights from local --- chunked_pooling/wrappers.py | 6 +++++- run_chunked_eval.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chunked_pooling/wrappers.py b/chunked_pooling/wrappers.py index 59f5f36..e1ec46f 100644 --- a/chunked_pooling/wrappers.py +++ b/chunked_pooling/wrappers.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Union import torch @@ -138,7 +139,7 @@ def wrapper(self, *args, **kwargs): return wrapper -def load_model(model_name, **model_kwargs): +def load_model(model_name, model_weights=None, **model_kwargs): if model_name in MODEL_WRAPPERS: model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs) if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'): @@ -149,6 +150,9 @@ def load_model(model_name, **model_kwargs): model = AutoModel.from_pretrained(model_name, trust_remote_code=True) has_instructions = False + if model_weights and os.path.exists(model_weights): + model._model.load_state_dict(torch.load(model_weights, device=model.device)) + # encode functions of various models do not support all sentence transformers kwargs parameter if model_name in MODELS_WITHOUT_PROMPT_NAME_ARG: ENCODE_FUNC_NAMES = ['encode', 'encode_queries', 'encode_corpus'] diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 88494bd..857462e 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -18,6 +18,11 @@ default='jinaai/jina-embeddings-v2-small-en', help='The name of the model to use.', ) +@click.option( + '--model-weights', + default=None, + help='The path to the model weights to use, e.g. in case of finetuning.', +) @click.option( '--strategy', default=DEFAULT_CHUNKING_STRATEGY, @@ -55,6 +60,7 @@ ) def main( model_name, + model_weights, strategy, task_name, eval_split, @@ -68,7 +74,7 @@ def main( except: raise ValueError(f'Unknown task name: {task_name}') - model, has_instructions = load_model(model_name) + model, has_instructions = load_model(model_name, model_weights) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)