Skip to content

Commit

Permalink
reenable disabled pt2e test (#7059)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored May 15, 2024
1 parent cbb9e21 commit df0d147
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def test_resnet18(self):
save_torch_module_as_tf_saved_model(m, args, tmp_path)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

@unittest.skip
def test_resnet18_per_channel(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand All @@ -127,8 +126,10 @@ def test_resnet18_per_channel(self):
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True))
m = prepare_pt2e(m, quantizer)

# Step 3: Quantize the model
# Step 3: Run through example inputs, otherwise per-channel
# quant may have scalar scale/zero_point
m(*args)
# Step 4: Quantize the model
m = convert_pt2e(m, fold_quantize=False)

# Trace with torch/xla and export stablehlo
Expand Down

0 comments on commit df0d147

Please sign in to comment.