Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Transcoders #7

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 103 additions & 27 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm
from transformer_lens import HookedTransformer

#import gc

class ActivationsStore:
"""
Expand Down Expand Up @@ -56,7 +57,12 @@ def __init__(

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.storage_buffer_out = None
if self.cfg.is_transcoder:
# if we're a transcoder, then we want to keep a buffer for our input activations and our output activations
self.storage_buffer, self.storage_buffer_out = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
else:
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()

def get_batch_tokens(self):
Expand Down Expand Up @@ -145,25 +151,42 @@ def get_batch_tokens(self):
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens, get_loss=False):
# TODO: get transcoders working with head indices
assert(not (self.cfg.is_transcoder and (self.cfg.hook_point_head_index is not None)))
act_name = self.cfg.hook_point
hook_point_layer = self.cfg.hook_point_layer
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name][:, :, self.cfg.hook_point_head_index]
else:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name]
if not self.cfg.is_transcoder:
activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_name,
stop_at_layer=hook_point_layer+1
)[1][act_name]
else:
cache = self.model.run_with_cache(
batch_tokens,
names_filter=[act_name, self.cfg.out_hook_point],
stop_at_layer=self.cfg.out_hook_point_layer+1
)[1]
activations = (cache[act_name], cache[self.cfg.out_hook_point])

return activations

def get_buffer(self, n_batches_in_buffer):
#gc.collect()
#torch.cuda.empty_cache()

context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer

# TODO: get transcoders working with cached activations
assert(not (self.cfg.is_transcoder and self.cfg.use_cached_activations))
if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
Expand Down Expand Up @@ -230,21 +253,46 @@ def get_buffer(self, n_batches_in_buffer):
device=self.cfg.device,
)

new_buffer_out = None
if self.cfg.is_transcoder:
new_buffer_out = torch.zeros(
(total_size, context_size, self.cfg.d_out),
dtype=self.cfg.dtype,
device=self.cfg.device,
)

# Insert activations directly into pre-allocated buffer
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations
if not self.cfg.is_transcoder:
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations
else:
refill_activations_in, refill_activations_out = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations_in

new_buffer_out[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations_out
# pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]
randperm = torch.randperm(new_buffer.shape[0])
new_buffer = new_buffer[randperm]

if self.cfg.is_transcoder:
new_buffer_out = new_buffer_out.reshape(-1, self.cfg.d_out)
new_buffer_out = new_buffer_out[randperm]

return new_buffer
if self.cfg.is_transcoder:
return new_buffer, new_buffer_out
else:
return new_buffer

def get_data_loader(
self,
Expand All @@ -258,25 +306,53 @@ def get_data_loader(
"""

batch_size = self.cfg.train_batch_size

if self.cfg.is_transcoder:
# ugly code duplication if we're a transcoder
new_buffer, new_buffer_out = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
mixing_buffer = torch.cat(
[new_buffer,
self.storage_buffer]
)
mixing_buffer_out = torch.cat(
[new_buffer_out,
self.storage_buffer_out]
)

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer]
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

# 3. put other 50 % in a dataloader
dataloader = iter(
DataLoader(
mixing_buffer[mixing_buffer.shape[0] // 2 :],
batch_size=batch_size,
shuffle=True,
assert(mixing_buffer.shape[0] == mixing_buffer_out.shape[0])
randperm = torch.randperm(mixing_buffer.shape[0])
mixing_buffer = mixing_buffer[randperm]
mixing_buffer_out = mixing_buffer_out[randperm]

self.storage_buffer = mixing_buffer[:mixing_buffer.shape[0]//2]
self.storage_buffer_out = mixing_buffer_out[:mixing_buffer_out.shape[0]//2]

# have to properly stack both of our new buffers into the dataloader
"""stacked_buffers = torch.stack([
mixing_buffer[mixing_buffer.shape[0]//2:],
mixing_buffer_out[mixing_buffer.shape[0]//2:]
], dim=1)"""
catted_buffers = torch.cat([
mixing_buffer[mixing_buffer.shape[0]//2:],
mixing_buffer_out[mixing_buffer.shape[0]//2:]
], dim=1)

#dataloader = iter(DataLoader(stacked_buffers, batch_size=batch_size, shuffle=True))
dataloader = iter(DataLoader(catted_buffers, batch_size=batch_size, shuffle=True))
else:
# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2),
self.storage_buffer]
)
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[:mixing_buffer.shape[0]//2]

# 3. put other 50 % in a dataloader
dataloader = iter(DataLoader(mixing_buffer[mixing_buffer.shape[0]//2:], batch_size=batch_size, shuffle=True))

return dataloader

Expand Down
14 changes: 13 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class RunnerConfig(ABC):
seed: int = 42
dtype: torch.dtype = torch.float32

# transcoder stuff
is_transcoder: bool = False
out_hook_point: Optional[str] = None
out_hook_point_layer: Optional[int] = None
d_out: Optional[int] = None

def __post_init__(self):
# Autofill cached_activations_path unless the user overrode it
if self.cached_activations_path is None:
Expand All @@ -65,6 +71,12 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
lr_warm_up_steps: int = 500
train_batch_size: int = 4096

# transcoder stuff
is_transcoder: bool = False
out_hook_point: Optional[str] = None
out_hook_point_layer: Optional[int] = None
d_out: Optional[int] = None

# Resampling protocol args
use_ghost_grads: bool = False # want to change this to true on some timeline.
feature_sampling_window: int = 2000
Expand Down Expand Up @@ -164,4 +176,4 @@ def __post_init__(self):
# this is a dummy property in this context; only here to avoid class compatibility headaches
raise ValueError(
"use_cached_activations should be False when running cache_activations_runner"
)
)
Loading
Loading