From 52775076f833e7301fbac29eb2f4938461f1c94c Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:47:27 -0700 Subject: [PATCH] [Tp Test] Fixe the placment of the device tensor (#1054) --- torchao/testing/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 7fa4ba4a6..5211065e1 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -309,6 +309,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y_q = dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() + mesh.device_type = "cuda" + # Shard the models up_dist = self.colwise_shard(up_quant, mesh) dn_dist = self.rowwise_shard(dn_quant, mesh)