From df0d147e10625b00a2d6687c0214fb0fecd111b4 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 14 May 2024 21:28:02 -0700 Subject: [PATCH] reenable disabled pt2e test (#7059) Co-authored-by: Siyuan Liu --- test/stablehlo/test_pt2e_qdq.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index c9e5f04af65..ea3f6cac067 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -116,7 +116,6 @@ def test_resnet18(self): save_torch_module_as_tf_saved_model(m, args, tmp_path) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) - @unittest.skip def test_resnet18_per_channel(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) @@ -127,8 +126,10 @@ def test_resnet18_per_channel(self): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True)) m = prepare_pt2e(m, quantizer) - - # Step 3: Quantize the model + # Step 3: Run through example inputs, otherwise per-channel + # quant may have scalar scale/zero_point + m(*args) + # Step 4: Quantize the model m = convert_pt2e(m, fold_quantize=False) # Trace with torch/xla and export stablehlo