From 149b9505bc86ea6ab0105b017aff9c4dc8412d80 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 5 Mar 2024 14:48:51 +0800 Subject: [PATCH] Minor refine --- tests/test_optim/test_optimizer/test_optimizer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index d388c95f12..113aacd6c8 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -592,12 +592,16 @@ def test_default_optimizer_constructor_no_grad(self): dwconv_decay_mult=0.1, dcn_offset_lr_mult=0.1) - for param in self.model.parameters(): - param.requires_grad = False + self.model.conv1.requires_grad_(False) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) - with self.assertRaises(ValueError): - optim_constructor(self.model) + optim_wrapper = optim_constructor(self.model) + + all_params = [] + for pg in optim_wrapper.param_groups: + all_params.extend(map(id, pg['params'])) + self.assertNotIn(id(self.model.conv1.weight), all_params) + self.assertIn(id(self.model.conv2.weight), all_params) def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option