diff --git a/test/test_utils.py b/test/test_utils.py index 5a43691ba..15f6bfe3c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,9 @@ import unittest from unittest.mock import patch + +import torch from torchao.utils import torch_version_at_least +from torchao.utils import TorchAOBaseTensor class TestTorchVersionAtLeast(unittest.TestCase): def test_torch_version_at_least(self): @@ -20,7 +23,24 @@ def test_torch_version_at_least(self): result = torch_version_at_least(compare_version) self.assertEqual(result, expected_result, f"Failed for torch.__version__={torch_version}, comparing with {compare_version}") - print(f"{torch_version}: {result}") + + +class TestTorchAOBaseTensor(unittest.TestCase): + + def test_print_arg_types(self): + class MyTensor(TorchAOBaseTensor): + def __new__(cls, data): + shape = data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape) # type: ignore[attr-defined] + + def __init__(self, data): + self.data = data + + + l = torch.nn.Linear(10, 10) + with self.assertRaisesRegex(NotImplementedError, "arg_types"): + l.weight = torch.nn.Parameter(MyTensor(l.weight)) + if __name__ == '__main__': unittest.main() diff --git a/torchao/utils.py b/torchao/utils.py index c0b79fa71..4b5409e65 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -388,7 +388,9 @@ class MyTensor(torch.Tensor): func in cls._ATEN_OP_OR_TORCH_FN_TABLE: return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}") + arg_types = tuple(type(arg) for arg in args) + kwarg_types = {k: type(arg) for k, arg in kwargs} + raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") def _register_layout_cls(cls: Callable, layout_type_class: Callable): """Helper function for layout registrations, this is used to implement diff --git a/tutorials/developer_api_guide/print_op_and_shapes.py b/tutorials/developer_api_guide/print_op_and_shapes.py index 0be26fd94..9f485a4fc 100644 --- a/tutorials/developer_api_guide/print_op_and_shapes.py +++ b/tutorials/developer_api_guide/print_op_and_shapes.py @@ -1,5 +1,6 @@ import torch +PRINT_ARGS = False linear_shapes = [] from torch.overrides import TorchFunctionMode class TorchFunctionLoggingMode(TorchFunctionMode): @@ -16,11 +17,28 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): M, K = flattened_input_tensor.shape[0], flattened_input_tensor.shape[1] assert K == weight_tensor.shape[1] N = weight_tensor.shape[0] - print(f"TORCH_FUNC={str(func)} (M, K, N):", M, K, N) + print(f"TORCH_FUNC {func=} (M, K, N):", M, K, N) linear_shapes.append((M, K, N)) else: arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None - print(f"TORCH_FUNC={str(func)} args[0] shape:", arg_shape) + if PRINT_ARGS: + print(f"TORCH_FUNC {func=}, {types=}, {args=}, {kwargs=}, args[0] shape: {arg_shape}") + else: + print(f"TORCH_FUNC {func=}, {types=}, args[0] shape: {arg_shape}") + return func(*args, **kwargs) + + +from torch.utils._python_dispatch import TorchDispatchMode +class TorchDispatchLoggingMode(TorchDispatchMode): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None + if PRINT_ARGS: + print(f"ATEN_FUNC {func=}, {types=}, {args=}, {kwargs=}, args[0] shape: {arg_shape}") + else: + print(f"ATEN_FUNC {func=}, {types=}, args[0] shape: {arg_shape}") + return func(*args, **kwargs) # NOTE: Modify this with your own model @@ -33,3 +51,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): print() print("all linear shapes (M, K, N):", linear_shapes) + +# check all aten ops that's called in the model +# with TorchDispatchLoggingMode(): +# m(*example_inputs)