Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] #337

Open
Yoon-Jeong-ho opened this issue Oct 18, 2024 · 6 comments
Open

[Bug Report] #337

Yoon-Jeong-ho opened this issue Oct 18, 2024 · 6 comments

Comments

@Yoon-Jeong-ho
Copy link

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
I encountered a RuntimeError during training while using sae_lens. The error appears to be related to a mismatch between the device used for tensor operations and the indices (CPU vs CUDA).

error message

 3 Training SAE:   0%|                                                                                                                           | 0/2048000000 [00:00<?, ?it/s]/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/activations_store.py:283: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
 4   yield torch.tensor(ng factor:   0%|                                                                                                               | 0/1000 [00:00<?, ?it/s]
 5 Estimating norm scaling factor: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:32<00:00,  1.95it/s]
 6 89900| MSE Loss 15.375 | L1 51.068:  18%|█████████████▊                                                               | 368230400/2048000000 [13:27:20<66:45:33, 6989.32it/s]Traceback (most recent call last):
 7   File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module>
 8     sparse_autoencoder = SAETrainingRunner(cfg).run()
 9                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run
11     sae = self.run_trainer_with_interruption_handling(trainer)
12           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling
14     sae = trainer.fit()
15           ^^^^^^^^^^^^^
16   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit
17     self._run_and_log_evals()
18   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
19     return func(*args, **kwargs)
20            ^^^^^^^^^^^^^^^^^^^^^
21   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals
22     eval_metrics = run_evals(
23                    ^^^^^^^^^^
24   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
25     return func(*args, **kwargs)
26            ^^^^^^^^^^^^^^^^^^^^^
27   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals
28     metrics |= get_sparsity_and_variance_metrics(
29                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
30   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics
31     flattened_sae_input = flattened_sae_input[flattened_mask]
32                           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
33 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)

Code example

import os
from setproctitle import setproctitle

setproctitle("aa007878")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 4"
gpu_num = 4
import torch

from huggingface_hub import login


# HuggingFace API 토큰으로 로그인



if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
    
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
scale_factor = 1
total_training_steps = int(500000 * scale_factor)  # probably we should do more
batch_size = int(4096 / scale_factor)
total_training_tokens = total_training_steps * batch_size

context_size = 1024
latent_size = 8
layer = 10
l1_coefficient = 0.05

lr_warm_up_steps = 200
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="meta-llama/Llama-3.2-1B",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name=f"blocks.{layer}.hook_resid_pre",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=layer,  # Only one layer in the model.
    d_in=2048,  # the width of the mlp output.
    dataset_path=f"yoonLM/llama3.2_org_1b_tokenizingdata_{context_size}", 
    is_dataset_tokenized=True,
    streaming=False,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=latent_size,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-6,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    #lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=l1_coefficient,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=context_size,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-8,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project=f"sae_LLaMa3.2_1B_{context_size}_{latent_size}_{l1_coefficient}",
    wandb_log_frequency=300,
    eval_every_n_wandb_logs=300,
    model_from_pretrained_kwargs={"n_devices": gpu_num},
    # Misc
    device= device,
    seed=42,
    n_checkpoints=5,
    checkpoint_path=f"checkpoints_LLama3.2_1B_{context_size}_{latent_size}_{l1_coefficient}",
    dtype="float32"
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()   


from sae_lens import upload_saes_to_huggingface

layer_sae_path = f"layer_{layer}_sae"
sparse_autoencoder.save_model(layer_sae_path)

saes_dict = {
    f"blocks.{layer}.hook_resid_pre": layer_sae_path, # values can be an SAE object
}

upload_saes_to_huggingface(
    saes_dict,
    # change this to your own huggingface username and repo
    hf_repo_id=f"yoonLM/sae_llama3.2org_1B_{context_size}_{latent_size}_l1_{l1_coefficient}",
)

System Info
Python : 3.11.9
CUDA : 12.4
GPU : NVIDIA RTX A6000
PyTorch : 2.0.1
ununtu : 20.04.1 LTS
sae-lens : 3.22.2
torch : 2.4.1
transformer-lens : 2.7.0

Checklist

  • [o ] I have checked that there is no similar issue in the repo (required)
@chanind
Copy link
Collaborator

chanind commented Oct 18, 2024

I can't reproduce this on my local machine, but I also don't have multiple GPUs. Does this only happen when using multiple GPUs?

@Yoon-Jeong-ho
Copy link
Author

Yes, I didn't encounter this error when using just one GPU, but when using multiple GPUs with a larger context size and latent size, causing higher GPU memory usage, this error occurs.

@Yoon-Jeong-ho
Copy link
Author

Previously, the same error occurred when learning the same-size sparse autoencoder.

self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Training SAE: 0%| | 0/2048000000 [00:00<?, ?it/s]/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/activations_store.py:283: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 4 yield torch.tensor(ng factor: 0%| | 0/1000 [00:00<?, ?it/s] 5 Estimating norm scaling factor: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:33<00:00, 1.95it/s] 6 249900| MSE Loss 7.922 | L1 72.222: 50%|█████████████████████████████████████▉ | 1023590400/2048000000 [37:11:49<40:19:52, 7055.51it/s]Traceback (most recent call last): 7 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 8 sparse_autoencoder = SAETrainingRunner(cfg).run() 9 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 10 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 11 sae = self.run_trainer_with_interruption_handling(trainer) 12 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 14 sae = trainer.fit() 15 ^^^^^^^^^^^^^ 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 17 self._run_and_log_evals() 18 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 19 return func(*args, **kwargs) 20 ^^^^^^^^^^^^^^^^^^^^^ 21 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 22 eval_metrics = run_evals( 23 ^^^^^^^^^^ 24 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 25 return func(*args, **kwargs) 26 ^^^^^^^^^^^^^^^^^^^^^ 27 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 28 metrics |= get_sparsity_and_variance_metrics( 29 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 31 flattened_sae_input = flattened_sae_input[flattened_mask] 32 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 33 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

@Yoon-Jeong-ho
Copy link
Author

I tried to lower the learning rate under the same conditions, but the same error occurred in the same place.

self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Estimating norm scaling factor: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:41<00:00, 1.92it/s] 4 89900| MSE Loss 141.638 | L1 62.357: 18%|████████████████████████▎ | 368230400/2048000000 [13:34:58<67:16:47, 6935.25it/s]Traceback (most recent call last): 5 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 6 sparse_autoencoder = SAETrainingRunner(cfg).run() 7 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 9 sae = self.run_trainer_with_interruption_handling(trainer) 10 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 11 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 12 sae = trainer.fit() 13 ^^^^^^^^^^^^^ 14 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 15 self._run_and_log_evals() 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 17 return func(*args, **kwargs) 18 ^^^^^^^^^^^^^^^^^^^^^ 19 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 20 eval_metrics = run_evals( 21 ^^^^^^^^^^ 22 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 23 return func(*args, **kwargs) 24 ^^^^^^^^^^^^^^^^^^^^^ 25 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 26 metrics |= get_sparsity_and_variance_metrics( 27 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 28 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 29 flattened_sae_input = flattened_sae_input[flattened_mask] 30 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 31 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)

@callummcdougall
Copy link
Contributor

Update here - I believe the root cause of this issue might be the PR last week which added get_sae_config, this introduced a bug where SAEs wouldn't be loaded onto the specified device. I also had a device-related bug, and making this change has now fixed things. See link to my PR.

@chanind
Copy link
Collaborator

chanind commented Oct 24, 2024

@Yoon-Jeong-ho Is this fixed in the most recent version of SAELens (4.0.9)? Thanks for the fix @callummcdougall!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants