Skip to content

Commit

Permalink
feat: Add kl eval (jbloomAus#124)
Browse files Browse the repository at this point in the history
* add kl divergence to evals.py

* fix linter
  • Loading branch information
tomMcGrath authored May 7, 2024
1 parent fc770b1 commit 2aa2ddd
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def run_evals(
ntp_loss = losses_df["loss"].mean()
recons_loss = losses_df["recons_loss"].mean()
zero_abl_loss = losses_df["zero_abl_loss"].mean()
d_kl = losses_df["d_kl"].mean()

# get cache
_, cache = model.run_with_cache(
Expand Down Expand Up @@ -84,6 +85,8 @@ def run_evals(
f"metrics/ce_loss_without_sae{suffix}": ntp_loss,
f"metrics/ce_loss_with_sae{suffix}": recons_loss,
f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss,
# KL divergence against intact model
f"metrics/kl_div{suffix}": d_kl,
}

if wandb.run is not None:
Expand All @@ -104,7 +107,7 @@ def recons_loss_batched(
losses = []
for _ in range(n_batches):
batch_tokens = activation_store.get_batch_tokens()
score, loss, recons_loss, zero_abl_loss = get_recons_loss(
score, loss, recons_loss, zero_abl_loss, d_kl = get_recons_loss(
sparse_autoencoder, model, batch_tokens
)
losses.append(
Expand All @@ -113,11 +116,13 @@ def recons_loss_batched(
loss.mean().item(),
recons_loss.mean().item(),
zero_abl_loss.mean().item(),
d_kl.mean().item(),
)
)

losses = pd.DataFrame(
losses, columns=cast(Any, ["score", "loss", "recons_loss", "zero_abl_loss"])
losses,
columns=cast(Any, ["score", "loss", "recons_loss", "zero_abl_loss", "d_kl"]),
)

return losses
Expand All @@ -130,10 +135,11 @@ def get_recons_loss(
batch_tokens: torch.Tensor,
):
hook_point = sparse_autoencoder.cfg.hook_point
loss = model(
batch_tokens, return_type="loss", **sparse_autoencoder.cfg.model_kwargs
model_outs = model(
batch_tokens, return_type="both", **sparse_autoencoder.cfg.model_kwargs
)
head_index = sparse_autoencoder.cfg.hook_point_head_index
loss = model_outs.loss

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
activations = sparse_autoencoder.forward(activations).sae_out.to(
Expand Down Expand Up @@ -166,12 +172,13 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
else:
replacement_hook = standard_replacement_hook

recons_loss = model.run_with_hooks(
recons_outs = model.run_with_hooks(
batch_tokens,
return_type="loss",
return_type="both",
fwd_hooks=[(hook_point, partial(replacement_hook))],
**sparse_autoencoder.cfg.model_kwargs,
)
recons_loss = recons_outs.loss

zero_abl_loss = model.run_with_hooks(
batch_tokens,
Expand All @@ -185,7 +192,22 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any):

score = (zero_abl_loss - recons_loss) / div_val

return score, loss, recons_loss, zero_abl_loss
# KL divergence
model_logits = model_outs.logits # [batch, pos, d_vocab]
model_logprobs = torch.nn.functional.log_softmax(model_logits, dim=-1)
recons_logits = recons_outs.logits
recons_logprobs = torch.nn.functional.log_softmax(recons_logits, dim=-1)
# Note: PyTorch KL is backwards compared to the mathematical definition
# target distribution comes second, see
# https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html
d_kl = torch.nn.functional.kl_div(
recons_logprobs,
model_logprobs,
reduction="batchmean",
log_target=True, # for numerics
)

return score, loss, recons_loss, zero_abl_loss, d_kl


def zero_ablate_hook(activations: torch.Tensor, hook: Any):
Expand Down

0 comments on commit 2aa2ddd

Please sign in to comment.