diff --git a/contrib/rag/README.md b/contrib/rag/README.md new file mode 100644 index 000000000..29866cf44 --- /dev/null +++ b/contrib/rag/README.md @@ -0,0 +1,44 @@ +# Retrieval-augmented generation + +LMFlow now supports retrieval-augmented generation. We offer four different retrieval methods which include DPR embeddings and BM25. Also, any model supported by LMFlow can be used for generation. + +* DPR(Dense Passage Retrieval) Embeddings: \ +https://arxiv.org/pdf/2004.04906 +* BM25 retriever: \ +https://python.langchain.com/v0.2/docs/integrations/retrievers/bm25/ + +## Requirements +Faiss library is required for dataset indexing. +``` +pip install faiss-cpu pickle rank_bm25 +``` + +## Build indexing for custom corpus for retrieval +If you want to use your own corpus for retrieval, first use `build_corpus_index.py` to build an index of the corpus embeddings. We offer one type of embedding method `dpr`and one retrieval method, `bm25`, which also requires indexing. + +Below is an example that utilizes OpenAI embedding to index a corpus using '\n\n' as the splitter. + +``` +python ./scripts/build_corpus_index --corpus_path='corpus.txt' --splitter='\n\n' --embedding_type='dpr' --data_index_path='corpus' +``` +Then it would save corpus and corpus index to ```corpus.dpr_index```. + +## Inference and Evaluation + +After building indexing of corpus, you can run the script `run_rag_inference.sh` that user can directly input question, and the script `run_rag_evaluation.sh` that user can input the path of dataset. + +Here are two examples of each script. + +``` +bash ./scripts/run_rag_inference.sh --retriever_type='dpr' --corpus_index_path='corpus.dpr_index' --top_k_retrieve=5 +``` + +``` +bash ./scripts/run_rag_evaluation.sh --retriever_type='dpr' --corpus_index_path='corpus.dpr_index' --top_k_retrieve=5 +``` + +## Known issue + +Current `build_corpus_index.py` has memory issue, since it would load all corpus into memory at once, so if the size of corpus is larger than your memory, the process would be broken. Our next step is to enable our program to load corpus piece by piece, so that memory would not be an issue. Also, + + diff --git a/contrib/rag/build_corpus_index.py b/contrib/rag/build_corpus_index.py new file mode 100644 index 000000000..027a662ce --- /dev/null +++ b/contrib/rag/build_corpus_index.py @@ -0,0 +1,87 @@ +import pickle +import os +from transformers import AutoTokenizer, AutoModel +from transformers import HfArgumentParser +from dataclasses import dataclass, field +from typing import Optional +import torch +import faiss +import numpy as np +@dataclass +class RetrieverArguments: + corpus_path: str = field( + metadata={ + "help": "Please specify the path to the document corpus." + } + ) + + embedding_type: Optional[str] = field( + default="dpr", + metadata={ + "help": "Please specify the type of retriever: bm25, or dpr" + } + ) + splitter: Optional[str] = field( + default="\n\n", + metadata={ + "help": "Please specify the splitter of your document." + } + ) + + data_index_path: Optional[str] = field( + default = './data/corpus', + metadata={ + "help": "Please specify the name of data index name." + } + ) + + device: int = field( + default=0, + metadata={ + "help": "The machine rank of gpu is used." + } + ) + + + +parser = HfArgumentParser((RetrieverArguments)) +retriever_args = parser.parse_args_into_dataclasses()[0] +with open(retriever_args.corpus_path) as f: + text = f.read() +texts = text.split(retriever_args.splitter) + +if retriever_args.embedding_type == 'dpr': + model_name = 'sentence-transformers/facebook-dpr-question_encoder-single-nq-base' + device = torch.device(f'cuda:{retriever_args.device}') + + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/facebook-dpr-question_encoder-single-nq-base') + model = AutoModel.from_pretrained('sentence-transformers/facebook-dpr-question_encoder-single-nq-base').to(device) + encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device) + + with torch.no_grad(): + model_output = model(**encoded_input) + def cls_pooling(model_output): + return model_output[0][:,0] + embeddings = cls_pooling(model_output) + + + dim = 768 + index = faiss.IndexFlatL2(dim) + index.add(embeddings.cpu().numpy()) + chunks = faiss.serialize_index(index) + with open(retriever_args.data_index_path+'.dpr_index', "wb") as fp: + pickle.dump(texts, fp) + pickle.dump(chunks, fp) + +elif retriever_args.embedding_type == 'bm25': + with open(retriever_args.data_index_path+'.bm25_index', "wb") as fp: + pickle.dump(texts, fp) +else: + raise ValueError('The embedded method is not implemented. \ + Please specify the type of document embedding as one of the choices, [dpr, bm25].') + + + + + + diff --git a/contrib/rag/corpus.dpr_index b/contrib/rag/corpus.dpr_index new file mode 100644 index 000000000..8f0757846 Binary files /dev/null and b/contrib/rag/corpus.dpr_index differ diff --git a/contrib/rag/corpus.txt b/contrib/rag/corpus.txt new file mode 100644 index 000000000..a87d14ef5 --- /dev/null +++ b/contrib/rag/corpus.txt @@ -0,0 +1 @@ +In recent years, the rise of Large Language Models (LLMs) has spurred a growing demand for plug-and-play AI systems. Among the various AI techniques, prompt engineering stands out as particularly significant. However, users often face challenges in writing prompts due to the steep learning curve and significant time investment, and existing automatic prompt engineering (APE) models can be difficult to use. To address this issue, we propose PAS, an LLM-based plug-and-play APE system. PAS utilizes LLMs trained on high-quality, automatically generated prompt complementary datasets, resulting in exceptional performance. In comprehensive benchmarks, PAS achieves state-of-the-art (SoTA) results compared to previous APE models, with an average improvement of 6.09 points. Moreover, PAS is highly efficient, achieving SoTA performance with only 9000 data points. Additionally, PAS can autonomously generate prompt augmentation data without requiring additional human labor. Its flexibility also allows it to be compatible with all existing LLMs and applicable to a wide range of tasks. PAS excels in human evaluations, underscoring its suitability as a plug-in for users. This combination of high performance, efficiency, and flexibility makes PAS a valuable system for enhancing the usability and effectiveness of LLMs through improved prompt engineering. \ No newline at end of file diff --git a/contrib/rag/rag_evaluation.py b/contrib/rag/rag_evaluation.py new file mode 100644 index 000000000..2bad8071d --- /dev/null +++ b/contrib/rag/rag_evaluation.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. +"""A one-line summary of the module or program, terminated by a period. + +Leave one blank line. The rest of this docstring should contain an +overall description of the module or program. Optionally, it may also +contain a brief description of exported classes and functions and/or usage +examples. + +Typical usage example: + + foo = ClassFoo() + bar = foo.FunctionBar() +""" +import json +import os +import sys +sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) +from transformers import HfArgumentParser +from transformers import AutoModel as AM +from transformers import AutoTokenizer as AT +from dataclasses import dataclass, field +from typing import Optional +import torch + +import faiss +import pickle +from rank_bm25 import BM25Okapi + + +from lmflow.datasets.dataset import Dataset +from lmflow.pipeline.auto_pipeline import AutoPipeline +from lmflow.models.auto_model import AutoModel +from lmflow.args import ModelArguments, DatasetArguments, AutoArguments + +@dataclass +class RAGEvalArguments: + retriever_type: Optional[str] = field( + default='dpr', + metadata={ + "help": "Please specify the type of document embedding: dpr, and bm25" + } + ) + corpus_index_path: Optional[str] = field( + default="corpus.dpr_index", + metadata={ + "help": "Please specify the path of corpus index. If you select wiki search, you do not need specify it." + } + ) + prompt_structure: Optional[str] = field( + default="Answer the following question based on the background information.\n\nQuestion:{input_text}\n\nBackground:{background}\n\nAnswer:", + metadata={ + "help": "prompt structure given user's input text." + }, + ) + top_k_retrieve: int = field( + default=5, + metadata={ + "help": "Please specify the number of the most relevant documents to be retrieved." + } + ) + +pipeline_name = "evaluator" +PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + +parser = HfArgumentParser((ModelArguments, DatasetArguments, PipelineArguments, RAGEvalArguments)) +model_args, data_args, pipeline_args, rag_args = parser.parse_args_into_dataclasses() + +with open (pipeline_args.deepspeed, "r") as f: + ds_config = json.load(f) + +model = AutoModel.get_model( + model_args, + tune_strategy='none', + ds_config=ds_config, + use_accelerator=pipeline_args.use_accelerator_for_evaluator +) +top_k = rag_args.top_k_retrieve +if rag_args.retriever_type == 'dpr': + with open(rag_args.corpus_index_path, 'rb') as fp: + corpus = pickle.load(fp) + index = faiss.deserialize_index(pickle.load(fp)) + + model_name = 'sentence-transformers/facebook-dpr-question_encoder-single-nq-base' + tokenizer = AT.from_pretrained(model_name) + embed = AM.from_pretrained(model_name) + def cls_pooling(model_output): + return model_output[0][:,0] + +elif rag_args.retriever_type == 'bm25': + with open(rag_args.corpus_index_path, "rb") as fp: + corpus = pickle.load(fp) + tokenized_corpus = [doc.split(" ") for doc in corpus] + bm25 = BM25Okapi(tokenized_corpus) +else: + raise ValueError('The type of retriever you specify is not implemented. Please specify it as one of [openai_embed, dpr_embed, wiki]') + +dataset = Dataset(data_args) + +data_dict = dataset.to_dict() + +for i, instance in enumerate(data_dict["instances"]): + input_text = instance['input'] + + if rag_args.retriever_type == 'dpr': + encoded_input = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = embed(**encoded_input) + embeddings = cls_pooling(model_output).numpy() + + _, ids = index.search(embeddings, k=top_k) + docs = [corpus[int(id)] for id in ids[0]] + elif rag_args.retriever_type == 'bm25': + tokenized_query = input_text.split() + docs = bm25.get_top_n(tokenized_query, corpus, n=top_k) + + background = '\n'.join(docs) + all_input = rag_args.prompt_structure.format(input_text=input_text, background=background) + data_dict["instances"][i] = all_input + +dataset = dataset.from_dict(data_dict) + +evaluator = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, +) +evaluator.evaluate(model=model, dataset=dataset, metric=pipeline_args.metric) diff --git a/contrib/rag/rag_inference.py b/contrib/rag/rag_inference.py new file mode 100644 index 000000000..6e7a31ff0 --- /dev/null +++ b/contrib/rag/rag_inference.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. +"""A simple shell chatbot implemented with lmflow APIs. +""" +import logging +import json +import os +import sys +sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) +import warnings +import pickle +import faiss +import pickle +import torch +from rank_bm25 import BM25Okapi + +from dataclasses import dataclass, field +from transformers import HfArgumentParser +from transformers import AutoModel as AM +from transformers import AutoTokenizer as AT +from typing import Optional + +from lmflow.datasets.dataset import Dataset +from lmflow.pipeline.auto_pipeline import AutoPipeline +from lmflow.models.auto_model import AutoModel +from lmflow.args import ModelArguments, DatasetArguments, AutoArguments + + + + +logging.disable(logging.ERROR) +warnings.filterwarnings("ignore") + + +@dataclass +class RAGInferArguments: + retriever_type: Optional[str] = field( + default='dpr', + metadata={ + "help": "Please specify the type of document embedding: dpr, or bm25" + } + ) + corpus_index_path: Optional[str] = field( + default="corpus.dpr_index", + metadata={ + "help": "Please specify the path of corpus index. If you select wiki search, you do not need specify it." + } + ) + prompt_structure: Optional[str] = field( + default="Answer the following question based on the background information.\n\nQuestion:{input_text}\n\nBackground:{background}\n\nAnswer:", + metadata={ + "help": "prompt structure given user's input text and background information." + }, + ) + top_k_retrieve: int = field( + default=5, + metadata={ + "help": "Please specify the number of the most relevant documents to be retrieved." + } + ) + +logging.disable(logging.ERROR) +warnings.filterwarnings("ignore") + +def main(): + pipeline_name = "inferencer" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser(( + ModelArguments, + PipelineArguments, + RAGInferArguments + )) + model_args, pipeline_args, rag_args = parser.parse_args_into_dataclasses() + inferencer_args = pipeline_args + + with open (pipeline_args.deepspeed, "r") as f: + ds_config = json.load(f) + + model = AutoModel.get_model( + model_args, + tune_strategy='none', + ds_config=ds_config, + device=pipeline_args.device, + use_accelerator=True, + ) + + # We don't need input data, we will read interactively from stdin + data_args = DatasetArguments(dataset_path=None) + dataset = Dataset(data_args) + + inferencer = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, + ) + + # Inferences + model_name = model_args.model_name_or_path + if model_args.lora_model_path is not None: + model_name += f" + {model_args.lora_model_path}" + + top_k = rag_args.top_k_retrieve + if rag_args.retriever_type == 'dpr': + with open(rag_args.corpus_index_path, 'rb') as fp: + corpus = pickle.load(fp) + index = faiss.deserialize_index(pickle.load(fp)) + + model_name = 'sentence-transformers/facebook-dpr-question_encoder-single-nq-base' + tokenizer = AT.from_pretrained(model_name) + embed = AM.from_pretrained(model_name) + def cls_pooling(model_output): + return model_output[0][:,0] + + elif rag_args.retriever_type == 'bm25': + with open(rag_args.corpus_index_path, "rb") as fp: + corpus = pickle.load(fp) + tokenized_corpus = [doc.split(" ") for doc in corpus] + bm25 = BM25Okapi(tokenized_corpus) + else: + raise ValueError('The type of retriever you specify is not implemented. Please specify it as one of [openai_embed, dpr_embed, wiki]') + + while True: + input_text = input("User >>> ") + prompt = rag_args.prompt_structure + + if rag_args.retriever_type == 'dpr': + encoded_input = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = embed(**encoded_input) + embeddings = cls_pooling(model_output).numpy() + + _, ids = index.search(embeddings, k=top_k) + docs = [corpus[int(id)] for id in ids[0]] + elif rag_args.retriever_type == 'bm25': + tokenized_query = input_text.split() + docs = bm25.get_top_n(tokenized_query, corpus, n=top_k) + + background = '\n'.join(docs) + background = background[-(model.get_max_length()-len(input_text)-len(prompt)):] + all_input = prompt.format(input_text=input_text, background=background) + input_dataset = dataset.from_dict({ + "type": "text_only", + "instances": [ { "text": all_input } ] + }) + output_dataset = inferencer.inference( + model=model, + dataset=input_dataset, + max_new_tokens=inferencer_args.max_new_tokens, + temperature=inferencer_args.temperature, + ) + output = output_dataset.to_dict()["instances"][0]["text"] + print('Bot:') + print(output) + + +if __name__ == "__main__": + main() diff --git a/contrib/rag/run_rag_evaluation.sh b/contrib/rag/run_rag_evaluation.sh new file mode 100644 index 000000000..34615eefb --- /dev/null +++ b/contrib/rag/run_rag_evaluation.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +if [ ! -d data/MedQA-USMLE ]; then + cd data && ./download.sh MedQA-USMLE && cd - +fi + +lora_args="" +retriever_type="" +corpus_index_path="" +prompt_structure="" + +top_k="" +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model="$2" + shift + ;; + --lora_model_path) + lora_args="--lora_model_path $2" + shift + ;; + -r|--retriever_type) + retriever_type="--retriever_type $2" + shift + ;; + --corpus_index_path) + corpus_index_path="--corpus_index_path $2" + shift + ;; + --prompt_structure) + prompt_structure="--prompt_structure $2" + shift + ;; + --top_k_retrieve) + top_k="--top_k_retrieve $2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +CUDA_VISIBLE_DEVICES=0 \ + deepspeed examples/rag_evaluation.py \ + --answer_type medmcqa \ + --model_name_or_path gpt2 \ + --dataset_path data/MedQA-USMLE/validation \ + --deepspeed examples/ds_config.json \ + --inference_batch_size_per_device 1 \ + --metric accuracy + ${retriever_type} \ + ${corpus_index_path} \ + ${prompt_structure} \ + ${top_k} diff --git a/contrib/rag/run_rag_inference.sh b/contrib/rag/run_rag_inference.sh new file mode 100644 index 000000000..73373b1e7 --- /dev/null +++ b/contrib/rag/run_rag_inference.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# An interactive inference script without context history, i.e. the chatbot +# won't have conversation memory. + +model=gpt2 +lora_args="" +retriever_type="" +corpus_index_path="" +prompt_structure="" +top_k="" +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model="$2" + shift + ;; + --lora_model_path) + lora_args="--lora_model_path $2" + shift + ;; + -r|--retriever_type) + retriever_type="--retriever_type $2" + shift + ;; + --corpus_index_path) + corpus_index_path="--corpus_index_path $2" + shift + ;; + --prompt_structure) + prompt_structure="--prompt_structure $2" + shift + ;; + --top_k_retrieve) + top_k="--top_k_retrieve $2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +accelerate launch --config_file ../../configs/accelerator_singlegpu_config.yaml \ + rag_inference.py \ + --deepspeed ../../configs/ds_config_chatbot.json \ + --model_name_or_path ${model} \ + --use_accelerator True \ + --max_new_tokens 256 \ + --temperature 1.0 \ + ${lora_args} \ + ${retriever_type} \ + ${corpus_index_path} \ + ${prompt_structure} \ + ${top_k}