From d44ddab2682fb883e46e0781e0c520e31f6a9252 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 18 Oct 2024 05:26:33 -0700 Subject: [PATCH] [AFQ] Optimize tensor_flatten for runtime [ghstack-poisoned] --- torchao/dtypes/affine_quantized_tensor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d14a5dd17..3eb4025d9 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -232,7 +232,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): - return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + # This is used in rumtime to unwrap AffineQuantizedTensor activations. + # AffineQuantizedTensor has __torch_function__ override: + # Each getattr will go through it, which is up to 10x slower than default attribute access. + with torch._C.DisableTorchFunctionSubclass(): + return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__(