From ca3a179e52c950e7d7ae11e611470300d2c738a9 Mon Sep 17 00:00:00 2001 From: muyo8692 Date: Sun, 17 Nov 2024 13:30:40 +0900 Subject: [PATCH] update test setting --- tests/unit/training/test_coefficient_scheduler.py | 2 +- tests/unit/training/test_gated_sae.py | 6 +++--- tests/unit/training/test_jumprelu_sae.py | 3 +-- tests/unit/training/test_sae_training.py | 10 +++++----- tests/unit/training/test_training_sae.py | 2 +- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/unit/training/test_coefficient_scheduler.py b/tests/unit/training/test_coefficient_scheduler.py index a4be95e0..828c37d3 100644 --- a/tests/unit/training/test_coefficient_scheduler.py +++ b/tests/unit/training/test_coefficient_scheduler.py @@ -30,7 +30,7 @@ def test_coefficient_scheduler_initialization_no_warmup(): cfg = build_sae_cfg( sparsity_coefficient=5, training_tokens=100 * 4, # train batch size (so 100 steps) - coefficient_warm_up_steps=10, + coefficient_warm_up_steps=0, ) coefficient_scheduler = CoefficientScheduler( diff --git a/tests/unit/training/test_gated_sae.py b/tests/unit/training/test_gated_sae.py index 03dd7183..36543ef9 100644 --- a/tests/unit/training/test_gated_sae.py +++ b/tests/unit/training/test_gated_sae.py @@ -68,7 +68,7 @@ def test_gated_sae_loss(): train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -77,7 +77,7 @@ def test_gated_sae_loss(): sae_in_centered = x - sae.b_dec via_gate_feature_magnitudes = torch.relu(sae_in_centered @ sae.W_enc + sae.b_gate) preactivation_l1_loss = ( - sae.cfg.l1_coefficient * torch.sum(via_gate_feature_magnitudes, dim=-1).mean() + sae.cfg.sparsity_coefficient * torch.sum(via_gate_feature_magnitudes, dim=-1).mean() ) via_gate_reconstruction = ( @@ -122,7 +122,7 @@ def test_gated_sae_training_forward_pass(): x = torch.randn(batch_size, d_in) train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) diff --git a/tests/unit/training/test_jumprelu_sae.py b/tests/unit/training/test_jumprelu_sae.py index 6940bf89..d8fc6f59 100644 --- a/tests/unit/training/test_jumprelu_sae.py +++ b/tests/unit/training/test_jumprelu_sae.py @@ -40,8 +40,7 @@ def test_jumprelu_sae_training_forward_pass(): x = torch.randn(batch_size, d_in) train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, - current_l0_lambda=sae.cfg.l0_lambda, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) diff --git a/tests/unit/training/test_sae_training.py b/tests/unit/training/test_sae_training.py index 70bef679..bade2675 100644 --- a/tests/unit/training/test_sae_training.py +++ b/tests/unit/training/test_sae_training.py @@ -149,7 +149,7 @@ def test_sae_forward(training_sae: TrainingSAE): x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -188,7 +188,7 @@ def test_sae_forward(training_sae: TrainingSAE): ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_sae.cfg.sparsity_coefficient * expected_l1_loss.detach().float() ) @@ -206,7 +206,7 @@ def test_sae_forward_with_mse_loss_norm( x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -248,7 +248,7 @@ def test_sae_forward_with_mse_loss_norm( ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_sae.cfg.sparsity_coefficient * expected_l1_loss.detach().float() ) @@ -262,7 +262,7 @@ def test_SparseAutoencoder_forward_ghost_grad_loss_non_zero( x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, dead_neuron_mask=torch.ones_like( training_sae.b_enc ).bool(), # all neurons are dead. diff --git a/tests/unit/training/test_training_sae.py b/tests/unit/training/test_training_sae.py index 16a84594..a61ebd50 100644 --- a/tests/unit/training/test_training_sae.py +++ b/tests/unit/training/test_training_sae.py @@ -22,7 +22,7 @@ def test_TrainingSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder x = torch.randn(32, 3) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=2.0, + current_sparsity_coefficient=2.0, ) feature_acts = train_step_output.feature_acts decoder_norm = training_sae.W_dec.norm(dim=1)