From e78207d086cb3372dc805cbb4c87b694749cd905 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 16:43:20 -0700 Subject: [PATCH] Add API for getting Neuronpedia feature --- sae_lens/analysis/neuronpedia_integration.py | 22 ++++++++++++++++++- .../analysis/test_neuronpedia_integration.py | 11 ++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/unit/analysis/test_neuronpedia_integration.py diff --git a/sae_lens/analysis/neuronpedia_integration.py b/sae_lens/analysis/neuronpedia_integration.py index 84f628ad..37d80e95 100644 --- a/sae_lens/analysis/neuronpedia_integration.py +++ b/sae_lens/analysis/neuronpedia_integration.py @@ -1,6 +1,7 @@ import json import urllib.parse import webbrowser +import requests def get_neuronpedia_quick_list( @@ -14,10 +15,29 @@ def get_neuronpedia_quick_list( name = urllib.parse.quote(name) url = url + "?name=" + name list_feature = [ - {"modelId": model, "layer": f"{layer}-{dataset}", "index": str(feature)} + { + "modelId": model, + "layer": f"{layer}-{dataset}", + "index": str(feature), + } for feature in features ] url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature)) webbrowser.open(url) return url + + +def get_neuronpedia_feature( + feature: int, + layer: int, + model: str = "gpt2-small", + dataset: str = "res-jb", +): + url = "https://neuronpedia.org/api/feature/" + url = url + f"{model}/{layer}-{dataset}/{feature}" + + result = requests.get(url).json() + result["index"] = int(result["index"]) + + return result diff --git a/tests/unit/analysis/test_neuronpedia_integration.py b/tests/unit/analysis/test_neuronpedia_integration.py new file mode 100644 index 00000000..6b6c9c28 --- /dev/null +++ b/tests/unit/analysis/test_neuronpedia_integration.py @@ -0,0 +1,11 @@ +from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_feature + + +def test_get_neuronpedia_feature(): + result = get_neuronpedia_feature( + feature=0, layer=0, model="gpt2-small", dataset="res-jb" + ) + + assert result["modelId"] == "gpt2-small" + assert result["layer"] == "0-res-jb" + assert result["index"] == 0