Skip to content

Commit

Permalink
add n_samples property
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed committed Nov 3, 2024
1 parent 2711bd0 commit ce3edad
Show file tree
Hide file tree
Showing 24 changed files with 1,310 additions and 857 deletions.
7 changes: 4 additions & 3 deletions docs/create_tasks_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@ def task_to_markdown_row(task: mteb.AbsTask) -> str:
domains = (
"[" + ", ".join(task.metadata.domains) + "]" if task.metadata.domains else ""
)
n_samples = task.metadata.n_samples
dataset_statistics = round_floats_in_dict(task.metadata.descriptive_stats)
name_w_reference += author_from_bibtex(task.metadata.bibtex_citation)

return f"| {name_w_reference} | {task.metadata.languages} | {task.metadata.type} | {task.metadata.category} | {domains} | {dataset_statistics} |"
return f"| {name_w_reference} | {task.metadata.languages} | {task.metadata.type} | {task.metadata.category} | {domains} | {n_samples} | {dataset_statistics} |"


def create_tasks_table(tasks: list[mteb.AbsTask]) -> str:
table = """
| Name | Languages | Type | Category | Domains | Dataset statistics |
|------|-----------|------|----------|---------|--------------------|
| Name | Languages | Type | Category | Domains | # Samples | Dataset statistics |
|------|-----------|------|----------|---------|-----------|--------------------|
"""
for task in tasks:
table += task_to_markdown_row(task) + "\n"
Expand Down
3 changes: 3 additions & 0 deletions mteb/abstasks/AbsTaskInstructionRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class InstructionRetrievalDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for Instruction Retrieval tasks
Attributes:
num_samples: Number of samples
num_queries: Number of queries
num_docs: Number of documents
total_symbols: Total number of symbols in the dataset
Expand All @@ -234,6 +235,7 @@ class InstructionRetrievalDescriptiveStatistics(DescriptiveStatistics):
average_top_ranked_per_query: Average number of top ranked docs per query
"""

num_samples: int
num_queries: int
num_docs: int
total_symbols: int
Expand Down Expand Up @@ -684,6 +686,7 @@ def _calculate_metrics_from_split(
else 0
)
return InstructionRetrievalDescriptiveStatistics(
num_samples=len(queries) + len(corpus),
num_docs=len(corpus),
num_queries=len(queries),
total_symbols=total_corpus_len
Expand Down
13 changes: 8 additions & 5 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,21 @@ class RetrievalDescriptiveStatistics(DescriptiveStatistics):
"""Descriptive statistics for Retrieval
Attributes:
num_queries: number of samples in the dataset
num_samples: Number of samples in the dataset
num_queries: number of queries in the dataset
num_documents: Number of documents
total_symbols: Total number of symbols in the dataset
average_document_length: Average length of documents
average_query_length: Average length of queries
num_documents: Number of documents
average_relevant_docs_per_query: Average number of relevant documents per query
"""

num_samples: int
num_queries: int
num_documents: int
total_symbols: int
average_document_length: float
average_query_length: float
num_documents: int
average_relevant_docs_per_query: float


Expand Down Expand Up @@ -434,10 +436,11 @@ def _calculate_metrics_from_split(
qrels_per_doc = num_qrels_non_zero / len(relevant_docs) if num_queries else 0
return RetrievalDescriptiveStatistics(
total_symbols=query_len + doc_len,
num_samples=num_documents + num_queries,
num_queries=num_queries,
num_documents=num_documents,
average_document_length=doc_len / num_documents,
average_query_length=query_len / num_queries,
num_documents=num_documents,
num_queries=num_queries,
average_relevant_docs_per_query=qrels_per_doc,
)

Expand Down
18 changes: 16 additions & 2 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ class TaskMetadata(BaseModel):
"machine-translated and localized".
prompt: The prompt used for the task. Can be a string or a dictionary containing the query and passage prompts.
bibtex_citation: The BibTeX citation for the dataset. Should be an empty string if no citation is available.
n_samples: The number of samples in the dataset. This should only be for the splits evaluated on. For retrieval tasks, this should be the
number of query-document pairs.
"""

dataset: dict
Expand Down Expand Up @@ -394,13 +392,15 @@ def intext_citation(self, include_cite: bool = True) -> str:

@property
def descriptive_stats(self) -> dict[str, DescriptiveStatistics] | None:
"""Return the descriptive statistics for the dataset."""
if self.descriptive_stat_path.exists():
with self.descriptive_stat_path.open("r") as f:
return json.load(f)
return None

@property
def descriptive_stat_path(self) -> Path:
"""Return the path to the descriptive statistics file."""
descriptive_stat_base_dir = Path(__file__).parent.parent / "descriptive_stats"
if not descriptive_stat_base_dir.exists():
descriptive_stat_base_dir.mkdir()
Expand All @@ -409,5 +409,19 @@ def descriptive_stat_path(self) -> Path:
task_type_dir.mkdir()
return task_type_dir / f"{self.name}.json"

@property
def n_samples(self) -> dict[str, int] | None:
"""Returns the number of samples in the dataset"""
stats = self.descriptive_stats
if not stats:
return None

n_samples = {}
for subset, subset_value in stats.items():
if subset == "hf_subset_descriptive_stats":
continue
n_samples[subset] = subset_value["num_samples"]
return n_samples

def __hash__(self) -> int:
return hash(self.model_dump_json())
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"test": {
"num_samples": 19919,
"num_docs": 19899,
"num_queries": 20,
"total_symbols": 44450333,
Expand Down
5 changes: 3 additions & 2 deletions mteb/descriptive_stats/Retrieval/AppsRetrieval.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
{
"test": {
"total_symbols": 2245.837090504686,
"num_samples": 12530,
"num_queries": 3765,
"num_documents": 8765,
"average_document_length": 0.0657169048317138,
"average_query_length": 0.4435135244766838,
"num_documents": 8765,
"num_queries": 3765,
"average_relevant_docs_per_query": 1.0
}
}
Loading

0 comments on commit ce3edad

Please sign in to comment.