From 5dd01329e8706e67acff88129a7f7a4a9caac71c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 3 Oct 2024 22:50:43 -0700 Subject: [PATCH] Unskip `test_choose_qparams_token_asym` in 2.6 (#1004) * Unskip `test_choose_qparams_token_asym` in 2.6 Summary: Fixes: https://github.com/pytorch/ao/issues/970 The test was broken by a recent refactor in pytorch: https://github.com/pytorch/pytorch/pull/136807 Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * fix --- test/quantization/test_quant_primitives.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index fa66867f3..4e0663eb8 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -202,14 +202,16 @@ def test_choose_qparams_group_sym_no_clipping_err(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or higher") @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + if TORCH_VERSION_AT_LEAST_2_6: + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + else: + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) scale_ref = scale_ref.squeeze()