Skip to content

Commit

Permalink
Added caching to leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Oct 30, 2024
1 parent 9a80941 commit cf32f23
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mteb/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Sequence
from dataclasses import dataclass
from functools import lru_cache
from typing import Annotated

from pydantic import AnyUrl, BeforeValidator, TypeAdapter
Expand Down Expand Up @@ -57,9 +58,15 @@ def __getitem__(self, index):
def load_results(
self, base_results: None | BenchmarkResults = None
) -> BenchmarkResults:
if not hasattr(self, "results_cache"):
self.results_cache = {}
if base_results in self.results_cache:
return self.results_cache[base_results]
if base_results is None:
base_results = load_results()
return base_results.select_tasks(self.tasks)
results = base_results.select_tasks(self.tasks)
self.results_cache[base_results] = results
return results


MTEB_EN = Benchmark(
Expand Down
18 changes: 18 additions & 0 deletions mteb/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
from functools import lru_cache
from typing import Callable


def json_cache(function: Callable):
"""Caching decorator that can deal with anything json serializable"""
cached_results = {}

def wrapper(*args, **kwargs):
key = json.dumps({"__args": args, **kwargs})
if key in cached_results:
return cached_results[key]
result = function(*args, **kwargs)
cached_results[key] = result
return result

return wrapper
3 changes: 3 additions & 0 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gradio_rangeslider import RangeSlider

import mteb
from mteb.caching import json_cache
from mteb.leaderboard.table import scores_to_tables


Expand Down Expand Up @@ -209,6 +210,7 @@ def update_tables(scores):
domain_select,
],
)
@json_cache
def on_select_benchmark(benchmark_name):
benchmark = mteb.get_benchmark(benchmark_name)
benchmark_results = benchmark.load_results(base_results=all_results)
Expand All @@ -222,6 +224,7 @@ def on_select_benchmark(benchmark_name):
inputs=[benchmark_select, lang_select, type_select, domain_select],
outputs=[task_select],
)
@json_cache
def update_task_list(benchmark_name, languages, task_types, domains):
benchmark = mteb.get_benchmark(benchmark_name)
benchmark_results = benchmark.load_results(base_results=all_results)
Expand Down
3 changes: 3 additions & 0 deletions mteb/load_results/benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def __repr__(self) -> str:
n_models = len(self.model_results)
return f"BenchmarkResults(model_results=[...](#{n_models}))"

def __hash__(self) -> int:
return id(self)

def filter_tasks(
self,
task_names: list[str] | None = None,
Expand Down

0 comments on commit cf32f23

Please sign in to comment.