Skip to content

Commit

Permalink
Fix cuda mpi issue
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Sep 11, 2024
1 parent 2e6c101 commit 78a5159
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tetragono/tetragono/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def allgather_array(array):


def allreduce_buffer(buffer):
mpi_comm.Allreduce(MPI.IN_PLACE, buffer)
import torch
if isinstance(buffer, torch.Tensor) and buffer.device.type == "cuda":
cbuffer = buffer.cpu()
mpi_comm.Allreduce(MPI.IN_PLACE, cbuffer)
buffer.copy_(cbuffer)
else:
mpi_comm.Allreduce(MPI.IN_PLACE, buffer)


def allreduce_iterator_buffer(iterator):
Expand All @@ -88,7 +94,13 @@ def bcast_number(number, *, root=0, dtype=np.float64):


def bcast_buffer(buffer, root=0):
mpi_comm.Bcast(buffer, root=root)
import torch
if isinstance(buffer, torch.Tensor) and buffer.device.type == "cuda":
cbuffer = buffer.cpu()
mpi_comm.Bcast(cbuffer, root=root)
buffer.copy_(cbuffer)
else:
mpi_comm.Bcast(buffer, root=root)


def bcast_iterator_buffer(iterator, root=0):
Expand Down

0 comments on commit 78a5159

Please sign in to comment.