diff --git a/chunked_pooling/chunked_eval_tasks.py b/chunked_pooling/chunked_eval_tasks.py index 23dbcbf..7417fa7 100644 --- a/chunked_pooling/chunked_eval_tasks.py +++ b/chunked_pooling/chunked_eval_tasks.py @@ -454,3 +454,69 @@ def load_data(self, **kwargs): self.relevant_docs[split] = qrels self.data_loaded = True + + +class TriviaQAChunked(AbsTaskChunkedRetrieval): + + _EVAL_SPLIT = ["test"] + + metadata = TaskMetadata( + name='TriviaQAChunked', + description=('Retrieval dataset derived from TriviaQA for chunked evaluation.'), + reference='https://nlp.cs.washington.edu/triviaqa', + dataset={ + 'path': 'mandarjoshi/trivia_qa', + 'revision': '0f7faf33a3908546c6fd5b73a660e0f8ff173c2f', + 'name': 'rc', + }, + type='Retrieval', + category='s2p', + eval_splits=['test'], + eval_langs=['eng-Latn'], + main_score='ndcg_at_10', + date=None, + form=None, + domains=None, + task_subtypes=None, + license=None, + socioeconomic_status=None, + annotations_creators=None, + dialect=None, + text_creation=None, + bibtex_citation=None, + n_samples=None, + avg_character_length=None, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_data(self, **kwargs): + if self.data_loaded: + return + self.corpus = {} + self.queries = {} + self.relevant_docs = {} + + data = datasets.load_dataset(**self.metadata_dict["dataset"]) + for split in self._EVAL_SPLIT: + corpus = {} + queries = {} + relevant_docs = {} + for i, row in enumerate(data[split]): + if len(row['entity_pages']['wiki_context']) < 1: + continue + queries[f'q{i}'] = row['question'] + relevant_docs[f'q{i}'] = dict() + for j, (title, content) in enumerate( + zip( + row['entity_pages']['title'], + row['entity_pages']['wiki_context'], + ) + ): + corpus[f'c{i}-{j}'] = {'title': title, 'text': content} + relevant_docs[f'q{i}'][f'c{i}-{j}'] = 1 + self.corpus[split] = corpus + self.queries[split] = queries + self.relevant_docs[split] = relevant_docs + self.data_loaded = True