Skip to content

Commit

Permalink
Make build() optional (#5754)
Browse files Browse the repository at this point in the history
Instead of checking if the pipeline is built and raising an error
we just call the build for the user, reducing the steps necessary
to use DALI to instantiating the pipeline and running it.
We already build automatically in DLFW plugins.

Adjust the testing accordingly.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki authored Dec 20, 2024
1 parent 076d1a2 commit 6530ad9
Show file tree
Hide file tree
Showing 157 changed files with 167 additions and 935 deletions.
8 changes: 3 additions & 5 deletions dali/python/nvidia/dali/_debug_mode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -790,8 +790,7 @@ def run(self):
"""Run the pipeline and return the result."""
import numpy as np

if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()

self._debug_on = True
self._cur_operator_id = -1
Expand Down Expand Up @@ -834,8 +833,7 @@ def feed_input(self, data_node, data, **kwargs):
"""Pass data to an ExternalSource operator inside the pipeline.
Refer to :meth:`Pipeline.feed_input() <nvidia.dali.Pipeline.feed_input>` for details."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if isinstance(data_node, str):
name = data_node
else:
Expand Down
74 changes: 42 additions & 32 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ class Pipeline(object):
If ``spawn`` method is used, ExternalSource's callback must be picklable.
In order to use ``fork``, there must be no CUDA contexts acquired at the moment of starting
the workers. For this reason, if you need to build multiple pipelines that use Python
workers, you will need to call :meth:`start_py_workers` before calling :meth:`build` of any
of the pipelines. You can find more details and caveats of both methods in Python's
``multiprocessing`` module documentation.
workers, you will need to call :meth:`start_py_workers` before building or running
any of the pipelines (see :meth:`build` for details). You can find more details and caveats
of both methods in Python's ``multiprocessing`` module documentation.
py_callback_pickler : module or tuple, default = None
If `py_start_method` is set to *spawn*, callback passed to parallel
ExternalSource must be picklable.
Expand Down Expand Up @@ -480,10 +480,12 @@ def is_restored_from_checkpoint(self):

def output_dtype(self) -> list:
"""Data types expected at the outputs."""
self.build()
return [elem if elem != types.NO_TYPE else None for elem in self._pipe.output_dtype()]

def output_ndim(self) -> list:
"""Number of dimensions expected at the outputs."""
self.build()
return [elem if elem != -1 else None for elem in self._pipe.output_ndim()]

def epoch_size(self, name=None):
Expand All @@ -500,8 +502,7 @@ def epoch_size(self, name=None):
The reader which should be used to obtain epoch size.
"""

if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if name is not None:
return self._pipe.reader_meta(name)["epoch_size_padded"]
meta = self._pipe.reader_meta()
Expand Down Expand Up @@ -529,8 +530,7 @@ def executor_statistics(self):
.. note::
Executor statistics are not available when using ``exec_dynamic=True``.
"""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
return self._pipe.executor_statistics()

def external_source_shm_statistics(self):
Expand Down Expand Up @@ -590,8 +590,7 @@ def reader_meta(self, name=None):
name : str, optional, default = None
The reader which should be used to obtain shards_number.
"""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if name is not None:
return self._pipe.reader_meta(name)
return self._pipe.reader_meta()
Expand Down Expand Up @@ -1005,8 +1004,8 @@ def start_py_workers(self):
If you are going to build more than one pipeline that starts Python workers by forking
the process then you need to call :meth:`start_py_workers` method on all those pipelines
before calling :meth:`build` method of any pipeline, as build acquires CUDA context
for current process.
before calling any method that builds or runs the pipeline (see :meth:`build` for details),
as building acquires CUDA context for current process.
The same applies to using any other functionality that would create CUDA context -
for example, initializing a framework that uses CUDA or creating CUDA tensors with it.
Expand Down Expand Up @@ -1042,10 +1041,23 @@ def _next_op_id(self):
return i

def build(self):
"""Build the pipeline.
Pipeline needs to be built in order to run it standalone.
Framework-specific plugins handle this step automatically.
"""Build the pipeline (optional step).
Instantiates the pipeline's backend objects and starts processing threads. If the pipeline
uses multi-processing ``external_source``, the worker processes are also started.
In most cases, there's no need to manually call build. When multi-processing is used,
it may be necessary to call :meth:`build` or :meth:`start_py_workers` before the main
process makes any interaction with the GPU. If needed, the :meth:`build` can used before
running the pipeline to separate the graph building and the processing steps.
Pipeline is automatically built when it is:
* run, either via the run APIs (:meth:`run`, :meth:`schedule_run`),
or the framework-specific plugins,
* the inputs are provided via :meth:`feed_input`
* the pipeline metadata is accessed (:meth:`epoch_size`, :meth:`reader_meta`)
* outputs are accessed - including :meth:`output_stream`
* the graph needs to be otherwise materialized - like :meth:`save_graph_to_dot_file`.
"""
if self._built:
return
Expand All @@ -1065,6 +1077,7 @@ def build(self):
self._built = True

def input_feed_count(self, input_name):
self.build()
return self._pipe.InputFeedCount(input_name)

def _feed_input(self, name, data, layout=None, cuda_stream=None, use_copy_kernel=False):
Expand Down Expand Up @@ -1142,8 +1155,7 @@ def feed_input(self, data_node, data, layout=None, cuda_stream=None, use_copy_ke
If set to True, DALI will use a CUDA kernel to feed the data (only applicable
when copying data to/from GPU memory) instead of ``cudaMemcpyAsync`` (default).
"""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if isinstance(data_node, str):
name = data_node
else:
Expand Down Expand Up @@ -1218,6 +1230,7 @@ def schedule_run(self):
Needs to be used together with :meth:`release_outputs`
and :meth:`share_outputs`.
Should not be mixed with :meth:`run` in the same pipeline"""
self.build()
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
if self._first_iter and self._exec_pipelined:
self._prefetch()
Expand All @@ -1226,8 +1239,7 @@ def schedule_run(self):

def output_stream(self):
"""Returns the internal CUDA stream on which the outputs are produced."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
return self._pipe.GetOutputStream()

# for the backward compatibility
Expand Down Expand Up @@ -1296,8 +1308,7 @@ def release_outputs(self):
When using dynamic executor (``exec_dynamic=True``), the buffers are not invalidated.
"""
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
ret = self._pipe.ReleaseOutputs()
return ret

Expand All @@ -1312,8 +1323,7 @@ def _outputs(self, cuda_stream=None):
Calling this function is equivalent to calling release_outputs
then calling share_outputs"""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
return self._pipe.Outputs(types._raw_cuda_stream_ptr(cuda_stream))

def _are_pipeline_inputs_possible(self):
Expand Down Expand Up @@ -1361,7 +1371,6 @@ def my_pipe():
:meth:`run()` function::
p = my_pipe(prefetch_queue_depth=1, ...)
p.build()
p.run(my_inp=np.random((2,3,2)))
Such keyword argument specified in the :meth:`run()` function has to have a
Expand Down Expand Up @@ -1392,6 +1401,7 @@ def my_pipe():
(e.g. `feed_input` function or `source` argument in the `fn.external_source`
operator.)"""
)
self.build()
for inp_name, inp_value in pipeline_inputs.items():
self.feed_input(inp_name, inp_value)
with self._check_api_type_scope(types.PipelineAPIType.BASIC):
Expand All @@ -1400,8 +1410,7 @@ def my_pipe():

def _prefetch(self):
"""Executes pipeline to fill executor's pipeline."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if not self._pipe:
raise RuntimeError("The pipeline was destroyed.")
self._schedule_py_workers()
Expand Down Expand Up @@ -1466,6 +1475,7 @@ def _run_once(self):
If the pipeline was created with `exec_async` option set to `True`,
this function will return without waiting for the execution to end."""
self.build()
try:
if not self._last_iter:
self._iter_setup()
Expand Down Expand Up @@ -1665,8 +1675,7 @@ def save_graph_to_dot_file(
use_colors : bool
Whether use color to distinguish stages
"""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
self.build()
if show_ids is not None:
with warnings.catch_warnings():
warnings.simplefilter("default")
Expand Down Expand Up @@ -1708,7 +1717,7 @@ def checkpoint(self, filename=None):
filename : str
The file that the serialized pipeline will be written to.
"""

self.build()
cpt = self._get_checkpoint()
if filename is not None:
with open(filename, "wb") as checkpoint_file:
Expand All @@ -1735,6 +1744,7 @@ def define_graph(self):
raise NotImplementedError

def _iter_setup(self):
self.build()
iters, success = self._run_input_callbacks()
if not success:
raise StopIteration
Expand Down Expand Up @@ -2042,7 +2052,7 @@ def my_pipe(flip_vertical, flip_horizontal):
The decorated function returns a DALI Pipeline object::
pipe = my_pipe(True, False)
# pipe.build() # the pipeline is not configured properly yet
# pipe.run() # the pipeline is not configured properly yet
A pipeline requires additional parameters such as batch size, number of worker threads,
GPU device id and so on (see :meth:`nvidia.dali.Pipeline()` for a
Expand All @@ -2051,9 +2061,9 @@ def my_pipe(flip_vertical, flip_horizontal):
passed to the decorated function::
pipe = my_pipe(True, False, batch_size=32, num_threads=1, device_id=0)
pipe.build() # the pipeline is properly configured, we can build it now
The outputs from the original function became the outputs of the Pipeline::
The pipeline is properly configured, we can run it now. The outputs from the original function
became the outputs of the Pipeline::
flipped, img = pipe.run()
Expand Down
7 changes: 1 addition & 6 deletions dali/test/python/auto_aug/test_augmentation_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,7 +77,6 @@ def pipeline():
return const_mag, dyn_mag

p = pipeline()
p.build()
const_mag, dyn_mag = p.run()
const_mag_ref = ref_param(mag_range, 5, [const_bin] * batch_size)
dyn_mag_ref = ref_param(mag_range, 11, list(range(batch_size)))
Expand All @@ -102,7 +101,6 @@ def pipeline():
return const_mag, dyn_mag

p = pipeline()
p.build()
const_mag, dyn_mag = p.run()
const_mag_ref = ref_param(mag_range, None, [const_bin] * batch_size)
dyn_mag_ref = ref_param(mag_range, None, list(range(batch_size)))
Expand Down Expand Up @@ -144,7 +142,6 @@ def pipeline():
warn_glob = "but unsigned `magnitude_bin` was passed to the augmentation call"
with assert_warns(Warning, glob=warn_glob):
p = pipeline()
p.build()
(magnitudes,) = p.run()
magnitudes = [np.array(el) for el in magnitudes]
if use_implicit_sign:
Expand Down Expand Up @@ -192,7 +189,6 @@ def pipeline():
)

p = pipeline()
p.build()
(magnitudes,) = p.run()
magnitude_bin = (
[const_mag] * batch_size
Expand Down Expand Up @@ -235,7 +231,6 @@ def pipeline():
)

p = pipeline()
p.build()
(magnitudes,) = p.run()
if param_device == "cpu":
assert isinstance(magnitudes, _tensors.TensorListCPU)
Expand Down
2 changes: 0 additions & 2 deletions dali/test/python/auto_aug/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def pipeline():
return output, data

p = pipeline()
p.build()
(
output,
data,
Expand Down Expand Up @@ -141,7 +140,6 @@ def pipeline():
return fn.resize(image, size=size)

p = pipeline()
p.build()
(out,) = p.run()

out = [np.array(sample) for sample in out]
Expand Down
7 changes: 0 additions & 7 deletions dali/test/python/auto_aug/test_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def pipeline():

# run the pipeline twice to make sure instantiation preserves determinism
p1 = pipeline()
p1.build()
p2 = pipeline()
p2.build()
for _ in range(3):
(out1,) = p1.run()
(out2,) = p2.run()
Expand Down Expand Up @@ -153,7 +151,6 @@ def pipeline(size):
cls.vid_files = []
for size in (size_1, size_2):
p = pipeline(size=size)
p.build()
(out,) = p.run()
cls.vid_files.extend(np.array(sample) for sample in out.as_cpu())

Expand Down Expand Up @@ -197,9 +194,7 @@ def pipeline():

# run the pipeline twice to make sure instantiation preserves determinism
p1 = pipeline()
p1.build()
p2 = pipeline()
p2.build()

for _ in range(num_iterations):
(out1,) = p1.run()
Expand Down Expand Up @@ -256,7 +251,6 @@ def third(data, op_id_mag_id):

policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy)
p.build()

sub_policy_outputs = collect_sub_policy_outputs(sub_policies, num_magnitude_bins)
# magnitudes are chosen so that the magnitude of the first op in
Expand Down Expand Up @@ -397,7 +391,6 @@ def second_stage_only(data, op_id_mag_id):

policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy, seed=1234)
p.build()

for _ in range(5):
(output,) = p.run()
Expand Down
Loading

0 comments on commit 6530ad9

Please sign in to comment.