Skip to content

Commit

Permalink
Float8 tensor parallel for aqt_dynamic_act_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Oct 15, 2024
1 parent 6314d88 commit 26d84b5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
4 changes: 4 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,10 @@ def _linear_fp8_act_fp8_weight_impl(

# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

print(f"out_shape: {out_shape}")
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")


print(f"out_shape: {out_shape}")
Expand Down
14 changes: 4 additions & 10 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
print('colwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
return m

@staticmethod
Expand All @@ -265,15 +264,11 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
print(f'dtensor shape: {dtensor.shape}')
print(f'Other dtensor values: {local_shard.original_weight_tensor.tensor_impl.float8_data.shape}, {mesh}, {[Shard(1)]}')
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
print('rowwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)

return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
Expand Down Expand Up @@ -306,15 +301,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
print('Run y')
y = proj_dn(proj_up(example_input))
print('Run before y')

# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
print('Run before y_q')
y_q = dn_quant(up_quant(example_input))
print('Executed y_q')


mesh = self.build_device_mesh()
mesh.device_type = "cuda"

Expand Down

0 comments on commit 26d84b5

Please sign in to comment.