Skip to content

Commit

Permalink
Added support for PerRow granularity
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Oct 15, 2024
1 parent 26d84b5 commit 3b5c8f9
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 36 deletions.
138 changes: 132 additions & 6 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import torch
import unittest
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal import common_utils
from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerRow, PerTensor
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
NUM_DEVICES,
)
from torchao.quantization.quant_api import quantize_
from torchao.dtypes import AffineQuantizedTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(int8_weight_only)
Expand All @@ -16,17 +28,131 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8dqTensorAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
class TestFloat8dqAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
TENSOR_SUBCLASS = AffineQuantizedTensor
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {}

@staticmethod
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_rows = orig_weight.size(0) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

@staticmethod
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_cols = orig_weight.size(1) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
"""
Quantize the model
"""
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
return m

def _test_tp(self, dtype):
device = "cuda"
# To make sure different ranks create the same module
torch.manual_seed(5)

class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

# Get rank and device
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")

# Original model
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))
# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
mesh.device_type = "cuda"

# Shard the models
up_dist = self.colwise_shard(up_quant, mesh)
dn_dist = self.rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)

y_d = dn_dist(up_dist(input_dtensor))

if not TORCH_VERSION_AT_LEAST_2_5:
# Need torch 2.5 to support compiled tensor parallelism
return

up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)

class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqTensorAffineQuantizedTensorParallel, "fp8dqt_tp")
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8dqRowAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)

class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqRowAffineQuantizedTensorParallel, "fp8dqr_tp")
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
return self._test_tp(dtype)

common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel)
if __name__ == "__main__":
run_tests()
23 changes: 7 additions & 16 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,12 @@ def __init__(

def _apply_fn_to_data(self, fn):
""" Applys a fn to all tensor components stored on this class"""
fn(self.float8_data)
fn(self.scale)
return self
return self.__class__(
fn(self.float8_data),
fn(self.scale),
self.transposed,
self._layout,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
Expand Down Expand Up @@ -1109,19 +1112,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if dim == 0:
#TODO: scale replecation should be dependent on block size
if self.scale.ndim == 1:
print("slice for dim 0, scale is 1")
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
else:
print("slice for dim 0, scale != 1")
return return_and_correct_aliasing(
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
)
elif dim == 1:
print("slice for dim 1")
return return_and_correct_aliasing(
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout)
)
else:
raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported")
Expand Down Expand Up @@ -1653,10 +1653,6 @@ def _linear_fp8_act_fp8_weight_impl(

# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

print(f"out_shape: {out_shape}")
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")


print(f"out_shape: {out_shape}")
Expand Down Expand Up @@ -1877,17 +1873,12 @@ def _(func, types, args, kwargs):
end = self.shape[dim]
shape = list(self.shape)
shape[dim] = end - start
print(f"Shape: {self.shape} -> {shape}")
print(f"Block size: {self.block_size} -> {self.block_size}")
print(f"end: {end}, start: {start}")
block_size = self.block_size
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
# with slice, some shape dimension might be smaller than block_size dimension, so
# we need to make sure there is no overflow
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
print(f"slice (Outer tensor shape): {self.shape} -> {new.shape}")
print(f"slice (Inner data shape): {self.tensor_impl.float8_data.shape} -> {new.tensor_impl.float8_data.shape}")
return return_and_correct_aliasing(func, args, kwargs, new)

# this is needed for DTensor.from_local() and for flattening tensor
Expand Down
12 changes: 2 additions & 10 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def _(func, types, args, kwargs):
return func(bias, aqt, original_weight_tensor)
else:
# aten.mm.default
print('Args: ', args[0].shape, args[1].shape, type(args[0]), type(args[1]))
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
Expand Down Expand Up @@ -168,24 +167,17 @@ def _(func, types, args, kwargs):

@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
print('Input quant func: ', args[0].input_quant_func)
x = return_and_correct_aliasing(
return return_and_correct_aliasing(
func, args, kwargs, LinearActivationQuantizedTensor(
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
)
print(f'Linear act Post slice: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
return x

# this is needed for DTensor.from_local() and for flattening tensor
@implements(aten.view.default)
def _(func, types, args, kwargs):
print('Linear view args:', args[1:])
print('Device: ', args[0].original_weight_tensor.device)
x= return_and_correct_aliasing(
return return_and_correct_aliasing(
func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
)
print(f'Linear act Post view: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
return x

to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float

Expand Down
4 changes: 0 additions & 4 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ class TorchAOTensorParallelTestCase(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_KWARGS = {}
Expand Down Expand Up @@ -301,14 +300,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
print('Run y')
y = proj_dn(proj_up(example_input))

# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
mesh.device_type = "cuda"

Expand Down

0 comments on commit 3b5c8f9

Please sign in to comment.