Skip to content

Commit

Permalink
[Tp Test] Fixe the placment of the device tensor (#1054)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Oct 11, 2024
1 parent ec860a1 commit 5277507
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
y_q = dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
mesh.device_type = "cuda"

# Shard the models
up_dist = self.colwise_shard(up_quant, mesh)
dn_dist = self.rowwise_shard(dn_quant, mesh)
Expand Down

0 comments on commit 5277507

Please sign in to comment.