diff --git a/prover/proof_search.py b/prover/proof_search.py index 347a978..e86aa6f 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -336,8 +336,8 @@ def initialize(self) -> None: engine_args = AsyncEngineArgs( model=self.model_path, tensor_parallel_size=self.num_gpus, - max_num_batched_tokens=8192, - enable_chunked_prefill=False, + max_num_batched_tokens=2048, + enable_chunked_prefill=True, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) @@ -394,17 +394,11 @@ def __init__( tac_gen = VllmGenerator(vllm_actor) elif indexed_corpus_path is not None: tac_gen = RetrievalAugmentedGenerator( - gen_ckpt_path, - ret_ckpt_path, - indexed_corpus_path, - device, - max_num_retrieved=100, + gen_ckpt_path, ret_ckpt_path, indexed_corpus_path, device, max_num_retrieved=100 ) else: device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu") - tac_gen = HuggingFaceGenerator( - gen_ckpt_path, device, max_oup_seq_len, length_penalty - ) + tac_gen = HuggingFaceGenerator(gen_ckpt_path, device, max_oup_seq_len, length_penalty) self.distributed = num_workers > 1 if not self.distributed: diff --git a/retrieval/index.py b/retrieval/index.py index 783eb63..c7b51b7 100644 --- a/retrieval/index.py +++ b/retrieval/index.py @@ -30,7 +30,7 @@ def main() -> None: device = torch.device("cpu") else: device = torch.device("cuda") - model = PremiseRetriever.load_hf(args.ckpt_path, device, num_retrieved=100) + model = PremiseRetriever.load_hf(args.ckpt_path, device, max_seq_len=2048) model.load_corpus(args.corpus_path) model.reindex_corpus(batch_size=args.batch_size) diff --git a/retrieval/model.py b/retrieval/model.py index 972dc84..900427e 100644 --- a/retrieval/model.py +++ b/retrieval/model.py @@ -9,7 +9,7 @@ from loguru import logger import pytorch_lightning as pl import torch.nn.functional as F -from typing import List, Dict, Any, Tuple, Union +from typing import List, Dict, Any, Tuple, Union, Optional from transformers import AutoModelForTextEncoding, AutoTokenizer from common import ( @@ -51,10 +51,12 @@ def load(cls, ckpt_path: str, device, freeze: bool) -> "PremiseRetriever": @classmethod def load_hf( - cls, ckpt_path: str, device, dtype, num_retrieved: int + cls, ckpt_path: str, device: int, dtype = None, max_seq_len: Optional[int] = None ) -> "PremiseRetriever": + if max_seq_len is None: + max_seq_len = 999999999999 model = ( - PremiseRetriever(ckpt_path, 0.0, 0, 999999999999, num_retrieved) + PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100) .to(device) .eval() )