Skip to content

Commit

Permalink
add profiler hints in paralloader (#8244)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Oct 9, 2024
1 parent 07d0823 commit 8ec552d
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch_xla.utils.keyd_queue as kq
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp


class PerDeviceQueue(object):
Expand Down Expand Up @@ -160,7 +161,8 @@ def _loader_worker(self):
try:
while not self._done:
try:
_, data = next(data_iter)
with xp.Trace("cpu_loader.next"):
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
Expand Down Expand Up @@ -227,12 +229,14 @@ def _worker(self, dqueue, host_to_device_transfer_threads):

try:
while True:
batch = self._get_batch(dqueue)
with xp.Trace("get_batch_from_cpu_queue"):
batch = self._get_batch(dqueue)
if not batch:
break
with torch.no_grad():
try:
batch = self.send_cpu_data_to_device(batch, device)
with xp.Trace("cpu_data_to_xla_device"):
batch = self.send_cpu_data_to_device(batch, device)
except Exception as e:
# _worker is being run in a daemon thread, raise the error
# will not work. Put the error in an error queue instead.
Expand Down

0 comments on commit 8ec552d

Please sign in to comment.