diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index a177c92b59d..b7d2519eccf 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -7,6 +7,7 @@ import torch_xla.utils.keyd_queue as kq import torch_xla.utils.utils as xu import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp class PerDeviceQueue(object): @@ -160,7 +161,8 @@ def _loader_worker(self): try: while not self._done: try: - _, data = next(data_iter) + with xp.Trace("cpu_loader.next"): + _, data = next(data_iter) except StopIteration: break batch.append(data) @@ -227,12 +229,14 @@ def _worker(self, dqueue, host_to_device_transfer_threads): try: while True: - batch = self._get_batch(dqueue) + with xp.Trace("get_batch_from_cpu_queue"): + batch = self._get_batch(dqueue) if not batch: break with torch.no_grad(): try: - batch = self.send_cpu_data_to_device(batch, device) + with xp.Trace("cpu_data_to_xla_device"): + batch = self.send_cpu_data_to_device(batch, device) except Exception as e: # _worker is being run in a daemon thread, raise the error # will not work. Put the error in an error queue instead.