diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 01b3f8316..8e6855a5d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,8 +1,20 @@ import torch +import unittest from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase from torch.testing._internal.common_utils import run_tests +from torch.testing._internal import common_utils from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight from torchao.quantization.observer import PerRow, PerTensor +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + NUM_DEVICES, +) +from torchao.quantization.quant_api import quantize_ +from torchao.dtypes import AffineQuantizedTensor +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): QUANT_METHOD_FN = staticmethod(int8_weight_only) @@ -16,17 +28,131 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): - class TestFloat8dqTensorAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + class TestFloat8dqAffineQuantizedTensorParallel(DTensorTestBase): + """Basic test case for tensor subclasses + """ + COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + TENSOR_SUBCLASS = AffineQuantizedTensor + QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_KWARGS = {} + + @staticmethod + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + @staticmethod + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) + return m + + def _test_tp(self, dtype): + device = "cuda" + # To make sure different ranks create the same module + torch.manual_seed(5) + + class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Get rank and device + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device).to(dtype) + proj_dn = M(2048, 1024).to(device).to(dtype) + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) + y = proj_dn(proj_up(example_input)) + # Quantize the model + up_quant = self.quantize(proj_up) + dn_quant = self.quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + + mesh = self.build_device_mesh() + mesh.device_type = "cuda" + + # Shard the models + up_dist = self.colwise_shard(up_quant, mesh) + dn_dist = self.rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + + if not TORCH_VERSION_AT_LEAST_2_5: + # Need torch 2.5 to support compiled tensor parallelism + return + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + + class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} - copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqTensorAffineQuantizedTensorParallel, "fp8dqt_tp") + COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] -# Run only on H100 -if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): - class TestFloat8dqRowAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerRow()} - copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqRowAffineQuantizedTensorParallel, "fp8dqr_tp") + COMMON_DTYPES = [torch.bfloat16] + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel) + common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel) if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 417412d18..256700ba8 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1062,9 +1062,12 @@ def __init__( def _apply_fn_to_data(self, fn): """ Applys a fn to all tensor components stored on this class""" - fn(self.float8_data) - fn(self.scale) - return self + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -1109,19 +1112,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if dim == 0: #TODO: scale replecation should be dependent on block size if self.scale.ndim == 1: - print("slice for dim 0, scale is 1") return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) else: - print("slice for dim 0, scale != 1") return return_and_correct_aliasing( func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) ) elif dim == 1: - print("slice for dim 1") return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) ) else: raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") @@ -1653,10 +1653,6 @@ def _linear_fp8_act_fp8_weight_impl( # Preprocess data inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - - print(f"out_shape: {out_shape}") - print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}") - print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}") print(f"out_shape: {out_shape}") @@ -1877,17 +1873,12 @@ def _(func, types, args, kwargs): end = self.shape[dim] shape = list(self.shape) shape[dim] = end - start - print(f"Shape: {self.shape} -> {shape}") - print(f"Block size: {self.block_size} -> {self.block_size}") - print(f"end: {end}, start: {start}") block_size = self.block_size assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - print(f"slice (Outer tensor shape): {self.shape} -> {new.shape}") - print(f"slice (Inner data shape): {self.tensor_impl.float8_data.shape} -> {new.tensor_impl.float8_data.shape}") return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 4c2825552..4c0f41b49 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -124,7 +124,6 @@ def _(func, types, args, kwargs): return func(bias, aqt, original_weight_tensor) else: # aten.mm.default - print('Args: ', args[0].shape, args[1].shape, type(args[0]), type(args[1])) assert args[0].shape[-1] == args[1].shape[0], ( f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" @@ -168,24 +167,17 @@ def _(func, types, args, kwargs): @implements(aten.slice.Tensor) def _(func, types, args, kwargs): - print('Input quant func: ', args[0].input_quant_func) - x = return_and_correct_aliasing( + return return_and_correct_aliasing( func, args, kwargs, LinearActivationQuantizedTensor( func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func) ) - print(f'Linear act Post slice: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}') - return x # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) def _(func, types, args, kwargs): - print('Linear view args:', args[1:]) - print('Device: ', args[0].original_weight_tensor.device) - x= return_and_correct_aliasing( + return return_and_correct_aliasing( func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func) ) - print(f'Linear act Post view: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}') - return x to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 1e6e7aeb0..92fd6896f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -228,7 +228,6 @@ class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] - TENSOR_SUBCLASS = AffineQuantizedTensor QUANT_METHOD_FN = staticmethod(int8_weight_only) QUANT_METHOD_KWARGS = {} @@ -301,14 +300,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - print('Run y') y = proj_dn(proj_up(example_input)) - # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) y_q = dn_quant(up_quant(example_input)) - mesh = self.build_device_mesh() mesh.device_type = "cuda"