Skip to content

Commit

Permalink
Unskip test_choose_qparams_token_asym in 2.6 (#1004)
Browse files Browse the repository at this point in the history
* Unskip `test_choose_qparams_token_asym` in 2.6

Summary:
Fixes: #970

The test was broken by a recent refactor in pytorch: pytorch/pytorch#136807

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* fix
  • Loading branch information
jerryzh168 authored Oct 4, 2024
1 parent 0cb91ea commit 5dd0132
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,16 @@ def test_choose_qparams_group_sym_no_clipping_err(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower")
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or higher")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
if TORCH_VERSION_AT_LEAST_2_6:
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64)
else:
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
scale_ref = scale_ref.squeeze()
Expand Down

0 comments on commit 5dd0132

Please sign in to comment.