Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 24, 2024
1 parent 659d0f1 commit 18e0781
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 92 deletions.
9 changes: 4 additions & 5 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@


def _get_shape(data):
if isinstance(data, (_tensors.TensorCPU, _tensors.TensorGPU)):
if callable(data.shape):
return data.shape()
else:
return data.shape
if hasattr(data, "shape"):
return data.shape() if callable(data.shape) else data.shape
elif hasattr(data, "__array_interface__"):
return data.__array_interface__["shape"]
elif hasattr(data, "__cuda_array_interface__"):
return data.__cuda_array_interface__["shape"]
elif hasattr(data, "__array__"):
return data.__array__().shape
else:
Expand Down
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def is_restored_from_checkpoint(self):
@property
def num_outputs(self):
"""Number of pipeline outputs."""
self.build()
return self._num_outputs

def output_dtype(self) -> list:
Expand Down
165 changes: 117 additions & 48 deletions dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
#
# The diagram below shows how the different processes and thread interact with each other
# via shared queues. req_n_k represents the k-th processing request from data worker n,
# consisting of a batch identifier (n, k) and a set of inputs. data_n_k represents the
# outputs of a DALI pipeline corresponding to the same batch identifier, consisting of the
# batch identifier and a set of outputs.
# an instance of `DALIPipelineRunRef`, consisting of a batch identifier (n, k) and a set of inputs.
# `data_n_k` represents the outputs of a DALI pipeline corresponding to the same batch identifier,
# consisting of the batch identifier and a set of outputs.
#
# +-------+ +---------------+ +-------------+ +---------------+ +-----------+ +-----------+
# | main | | dali_output_q | | data_thread | | dali_input_q | | worker_0 | | worker_1 |
Expand Down Expand Up @@ -128,6 +128,13 @@ class DALIOutputSampleRef:
"""

def __init__(self, proxy, pipe_run_ref, output_idx, sample_idx):
"""
Args:
proxy (_DALIProxy): The proxy object used for communication or data handling.
pipe_run_ref (DALIPipelineRunRef): A reference to the pipeline run.
output_idx (int): The index of the output in the pipeline.
sample_idx (int): The index of the sample within the batch.
"""
self.proxy = proxy
self.pipe_run_ref = pipe_run_ref
self.output_idx = output_idx
Expand All @@ -146,6 +153,12 @@ class DALIPipelineRunRef:
"""

def __init__(self, proxy, batch_id):
"""
Args:
proxy (_DALIProxy): The proxy object used for communication or data handling.
batch_id (tuple(int, int)): A tuple that uniquely identifies the batch. The first
element represent the worker, and the second the batch index for that worker
"""
self.batch_id = batch_id
self.inputs = {name: [] for name in proxy._dali_input_names}
self.is_scheduled = False
Expand All @@ -164,6 +177,11 @@ class DALIOutputBatchRef:
"""

def __init__(self, pipe_run_ref, output_idx):
"""
Args:
pipe_run_ref (DALIPipelineRunRef): A reference to the pipeline run.
output_idx (int): The index of the output in the pipeline.
"""
self.pipe_run_ref = pipe_run_ref
self.output_idx = output_idx

Expand Down Expand Up @@ -213,7 +231,7 @@ def __init__(self, dali_input_names, dali_input_q, dali_num_outputs, determinist
def _init_worker_data(self):
self._worker_data = {
"worker_id": self._get_worker_id(),
"data_idx": 0,
"worker_batch_idx": 0,
"pipe_run_ref": None,
"batch_sample_idx": 0,
}
Expand All @@ -232,12 +250,16 @@ def _get_worker_id(self):
return worker_info.id if worker_info else threading.get_ident()

def _add_sample(self, bound_args):
"""
Adds a sample to the current batch. In the collate function, we mark the batch as
complete. When a completed batch is encountered, a new batch should be started.
"""
state = self._get_worker_data()
if state["pipe_run_ref"] is None or state["pipe_run_ref"].is_complete:
state["pipe_run_ref"] = DALIPipelineRunRef(
self, (state["worker_id"], state["data_idx"])
self, (state["worker_id"], state["worker_batch_idx"])
)
state["data_idx"] += 1
state["worker_batch_idx"] += 1
state["batch_sample_idx"] = 0

for name, value in bound_args.arguments.items():
Expand All @@ -257,6 +279,9 @@ def _add_sample(self, bound_args):
return ret[0] if len(ret) == 1 else ret

def _schedule_batch(self, pipe_run_ref):
"""
Schedules a batch for execution by appending it to the DALI input queue.
"""
if pipe_run_ref.inputs is None:
raise RuntimeError("No inputs for the pipeline to run (was it already scheduled?)")
_torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}")
Expand Down Expand Up @@ -285,7 +310,7 @@ def __init__(self, pipeline, input_names=None, deterministic=False):
loader has returned the batch information, and not as soon
as data worker collates the batch.
Example:
Example 1 - Full integration with PyTorch via DALI proxy DataLoader:
.. code-block:: python
Expand Down Expand Up @@ -343,17 +368,70 @@ def read_filepath(path):
for data, target in loader:
# consume it
pass
Example 2 - Manual execution using DALI proxy / DALI server and PyTorch's default_collate:
.. code-block:: python
@pipeline_def
def my_pipe():
a = fn.external_source(name="a", no_copy=True)
b = fn.external_source(name="b", no_copy=True)
return a + b, a - b
with dali_proxy.DALIServer(
my_pipe(device='cpu', batch_size=batch_size,
num_threads=3, device_id=None)) as dali_server:
outs = []
for _ in range(batch_size):
a = np.array(np.random.rand(3, 3), dtype=np.float32)
b = np.array(np.random.rand(3, 3), dtype=np.float32)
if named_arguments:
out0, out1 = dali_server.proxy(a=a, b=b)
else:
out0, out1 = dali_server.proxy(a, b)
outs.append((a, b, out0, out1))
outs = torch.utils.data.dataloader.default_collate(outs)
a, b, a_plus_b, a_minus_b = dali_server.produce_data(outs)
Example 3 - Full integration with PyTorch but using the original PyTorch DataLoader
.. code-block:: python
pipe = rn50_train_pipe(...)
with dali_proxy.DALIServer(pipe) as dali_server:
dataset = torchvision.datasets.ImageFolder(
jpeg, transform=dali_server.proxy, loader=read_filepath)
# Using PyTorch DataLoader directly
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)
for data, target in loader:
# replaces the output reference with actual data
data = dali_server.produce_data(data)
...
"""
if not isinstance(pipeline, _Pipeline):
raise RuntimeError(f"Expected an NVIDIA DALI pipeline, got: {pipeline}")
else:
self._pipe = pipeline
# get and validate dali pipeline input names
self._dali_input_names, self._allow_positional_args = self._check_dali_input_names(
input_names
)

# get the dali pipeline input names
self._dali_input_names = _external_source_node_names(self._pipe)
if len(self._dali_input_names) == 0:
raise RuntimeError("The provided pipeline doesn't have any inputs")

# Multi-process queue used to transfer data from the pytorch workers to the main process
self._dali_input_q = _mp.Queue()
# Multi-process queue used by the main process to consume outputs from the DALI pipeline
Expand All @@ -367,47 +445,28 @@ def read_filepath(path):
# Proxy
self._proxy = None

def _check_dali_input_names(self, input_names):
pipe_input_names = _external_source_node_names(self._pipe)
if len(pipe_input_names) == 0:
raise RuntimeError("The provided pipeline doesn't have any inputs")
pipe_input_names_set = set(pipe_input_names)
input_names_set = set(input_names or [])
if len(input_names_set) != len(input_names_set):
raise RuntimeError("``input_names`` argument should not contain any duplicated values")

if len(input_names_set) == 0:
allow_positional_args = True if len(pipe_input_names) == 1 else False
return pipe_input_names, allow_positional_args

if input_names_set != pipe_input_names_set:
raise RuntimeError(
"The set of DALI input names provided should match exactly the "
"ones provided by the pipeline. "
f"\nProvided input names are: {input_names}"
f"\nPipeline input names are: {pipe_input_names}"
)
return input_names, True

@property
def proxy(self):
"""
DALI proxy callable instance, which has the signature defined by the input/output
configuration of the DALI pipeline bound to this DALI server instance.
See :class:`nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer` for a full example.
"""
if self._proxy is None:
parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
for input_name in self._dali_input_names:
if self._allow_positional_args:
parameters.append(
inspect.Parameter(input_name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
)
else:
parameters.append(inspect.Parameter(input_name, inspect.Parameter.KEYWORD_ONLY))

parameters.append(
inspect.Parameter(input_name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
)
signature = inspect.Signature(parameters)

def call_impl(self, *args, **kwargs):
try:
bound_args = signature.bind(self, *args, **kwargs)
except Exception as exc:
raise RuntimeError(f"{exc}. Signature is {signature}")
raise RuntimeError(
f"{exc}. Signature is {signature}. Got args={args} kwargs={kwargs}"
)
return self._add_sample(bound_args)

call_impl.__signature__ = inspect.Signature(parameters)
Expand Down Expand Up @@ -497,13 +556,15 @@ def produce_data(self, obj):
"""
A generic function to recursively visits all elements in a nested structure and replace
instances of DALIOutputBatchRef with the actual data provided by the DALI server
See :class:`nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer` for a full example.
Args:
obj: The object to map (can be an instance of any class).
Returns:
A new object where any instance of DALIOutputBatchRef has been replaced with actual
data.
"""
cache = dict()
ret = self._produce_data_impl(obj, cache)
Expand Down Expand Up @@ -537,8 +598,6 @@ def _thread_fn(self):
Asynchronous DALI thread that gets iteration data from the queue and schedules it
for execution
"""
self._pipe.build() # just in case

fed_batches = queue.Queue()
while not self._thread_stop_event.is_set():
_torch.cuda.nvtx.range_push("get_input_batches")
Expand Down Expand Up @@ -581,7 +640,7 @@ def _thread_fn(self):

def start_thread(self):
"""
Starts the DALI pipeline thread
Starts the DALI pipeline thread. Note: Using scope's __enter__/__exit__ is preferred
"""
if self._thread is not None:
return
Expand All @@ -591,7 +650,7 @@ def start_thread(self):

def stop_thread(self):
"""
Stops the DALI pipeline thread
Stops the DALI pipeline thread. Note: Using scope's __enter__/__exit__ is preferred
"""
if self._thread_stop_event is None:
return
Expand All @@ -601,10 +660,16 @@ def stop_thread(self):
self._thread_stop_event = None

def __enter__(self):
"""
Starts the DALI pipeline thread
"""
self.start_thread()
return self

def __exit__(self, exc_type, exc_value, tb):
"""
Stops the DALI pipeline thread
"""
self.stop_thread()
if exc_type is not None:
warnings.warn(f"An exception occurred: {exc_value}", category=UserWarning)
Expand All @@ -613,8 +678,9 @@ def __exit__(self, exc_type, exc_value, tb):

class DataLoader(_torchdata.dataloader.DataLoader):
"""
DALI data loader to be used in the main loop, which runs the DALI pipeline doing the
processing asynchronously with regards to the training.
DALI data loader to be used in the main loop, which replaces the pipeline run references
with actual data produced by the DALI server.
See :class:`nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer` for a full example.
"""

class _Iter(_torchdata.dataloader._MultiProcessingDataLoaderIter):
Expand All @@ -633,6 +699,9 @@ def _next_data(self):
return self.loader.dali_server.produce_data(data)

def __init__(self, dali_server, *args, **kwargs):
"""
Same interface as PyTorch's DataLoader except for the extra DALIServer argument
"""
super().__init__(*args, **kwargs)
self.dali_server = dali_server

Expand Down
Loading

0 comments on commit 18e0781

Please sign in to comment.