Skip to content

Commit

Permalink
[AFQ] Optimize tensor_flatten for runtime
Browse files Browse the repository at this point in the history
ghstack-source-id: f028ae1d0afb3aaea9a4afebe29b114de80b5d9e
Pull Request resolved: #1114
  • Loading branch information
IvanKobzarev committed Oct 18, 2024
1 parent 3475aed commit 36d7beb
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")

def __tensor_flatten__(self):
return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
# This is used in rumtime to unwrap AffineQuantizedTensor activations.
# AffineQuantizedTensor has __torch_function__ override:
# Each getattr will go through it, which is up to 10x slower than default attribute access.
with torch._C.DisableTorchFunctionSubclass():
return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
Expand Down

0 comments on commit 36d7beb

Please sign in to comment.