From c0ac20ddc775248e4d4b141ab26aca7867727290 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 5 Jul 2024 10:46:22 +0800 Subject: [PATCH] Fix cuda mpi issue --- tetragono/tetragono/utility.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tetragono/tetragono/utility.py b/tetragono/tetragono/utility.py index e85bcb7e..8a848d2e 100644 --- a/tetragono/tetragono/utility.py +++ b/tetragono/tetragono/utility.py @@ -64,7 +64,13 @@ def allgather_array(array): def allreduce_buffer(buffer): - mpi_comm.Allreduce(MPI.IN_PLACE, buffer) + import torch + if isinstance(buffer, torch.Tensor) and buffer.device.type == "cuda": + cbuffer = buffer.cpu() + mpi_comm.Allreduce(MPI.IN_PLACE, cbuffer) + buffer.copy_(cbuffer) + else: + mpi_comm.Allreduce(MPI.IN_PLACE, buffer) def allreduce_iterator_buffer(iterator): @@ -88,7 +94,13 @@ def bcast_number(number, *, root=0, dtype=np.float64): def bcast_buffer(buffer, root=0): - mpi_comm.Bcast(buffer, root=root) + import torch + if isinstance(buffer, torch.Tensor) and buffer.device.type == "cuda": + cbuffer = buffer.cpu() + mpi_comm.Bcast(cbuffer, root=root) + buffer.copy_(cbuffer) + else: + mpi_comm.Bcast(buffer, root=root) def bcast_iterator_buffer(iterator, root=0):