diff --git a/chunked_pooling/wrappers.py b/chunked_pooling/wrappers.py index 59f5f36..e1ec46f 100644 --- a/chunked_pooling/wrappers.py +++ b/chunked_pooling/wrappers.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Union import torch @@ -138,7 +139,7 @@ def wrapper(self, *args, **kwargs): return wrapper -def load_model(model_name, **model_kwargs): +def load_model(model_name, model_weights=None, **model_kwargs): if model_name in MODEL_WRAPPERS: model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs) if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'): @@ -149,6 +150,9 @@ def load_model(model_name, **model_kwargs): model = AutoModel.from_pretrained(model_name, trust_remote_code=True) has_instructions = False + if model_weights and os.path.exists(model_weights): + model._model.load_state_dict(torch.load(model_weights, device=model.device)) + # encode functions of various models do not support all sentence transformers kwargs parameter if model_name in MODELS_WITHOUT_PROMPT_NAME_ARG: ENCODE_FUNC_NAMES = ['encode', 'encode_queries', 'encode_corpus'] diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 7712bad..95de94a 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -21,6 +21,11 @@ default='jinaai/jina-embeddings-v2-small-en', help='The name of the model to use.', ) +@click.option( + '--model-weights', + default=None, + help='The path to the model weights to use, e.g. in case of finetuning.', +) @click.option( '--strategy', default=DEFAULT_CHUNKING_STRATEGY, @@ -70,6 +75,7 @@ ) def main( model_name, + model_weights, strategy, task_name, eval_split, @@ -91,7 +97,7 @@ def main( f'Truncation is disabled because Long Late Chunking algorithm is enabled.' ) - model, has_instructions = load_model(model_name) + model, has_instructions = load_model(model_name, model_weights) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)