Skip to content

Commit

Permalink
Simplify example code
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 24, 2024
1 parent 18e0781 commit fb43cf3
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ def _schedule_batch(self, pipe_run_ref):
pipe_run_ref.is_scheduled = True
_torch.cuda.nvtx.range_pop()

def __call__(self, *args, **kwargs):
"""
To be implemented by the child class
"""
raise RuntimeError("Not implemented")


class DALIServer:
def __init__(self, pipeline, input_names=None, deterministic=False):
Expand Down Expand Up @@ -438,13 +444,17 @@ def my_pipe():
self._dali_output_q = queue.Queue()
# Thread
self._thread = None
self._thread_stop_event = None
# Cache
self._cache_outputs = dict()
# Whether we want the order of DALI execution to be reproducible
self._deterministic = deterministic
# Proxy
self._proxy = None

def __del__(self):
self.stop_thread()

@property
def proxy(self):
"""
Expand All @@ -458,6 +468,7 @@ def proxy(self):
parameters.append(
inspect.Parameter(input_name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
)

signature = inspect.Signature(parameters)

def call_impl(self, *args, **kwargs):
Expand All @@ -469,9 +480,15 @@ def call_impl(self, *args, **kwargs):
)
return self._add_sample(bound_args)

call_impl.__signature__ = inspect.Signature(parameters)
_DALIProxy.__call__ = call_impl
self._proxy = _DALIProxy(
call_impl.__signature__ = signature

class _DALIProxyImpl(_DALIProxy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

_DALIProxyImpl.__call__ = call_impl

self._proxy = _DALIProxyImpl(
self._dali_input_names,
self._dali_input_q,
self._pipe.num_outputs,
Expand Down Expand Up @@ -598,13 +615,13 @@ def _thread_fn(self):
Asynchronous DALI thread that gets iteration data from the queue and schedules it
for execution
"""
fed_batches = queue.Queue()
fed_batches = []
while not self._thread_stop_event.is_set():
_torch.cuda.nvtx.range_push("get_input_batches")
timeout = 5 if fed_batches.empty() else None
timeout = 5 if len(fed_batches) == 0 else None
# We try to feed as many batches as the prefetch queue (if available)
batches = self._get_input_batches(
self._pipe.prefetch_queue_depth - fed_batches.qsize(), timeout=timeout
self._pipe.prefetch_queue_depth - len(fed_batches), timeout=timeout
)
_torch.cuda.nvtx.range_pop()
if batches is not None and len(batches) > 0:
Expand All @@ -613,15 +630,15 @@ def _thread_fn(self):
for input_name, input_data in inputs.items():
self._pipe.feed_input(input_name, input_data)
self._pipe._run_once()
fed_batches.put(batch_id)
fed_batches.append(batch_id)
_torch.cuda.nvtx.range_pop()

# If no batches to consume, continue
if fed_batches.qsize() == 0:
if len(fed_batches) == 0:
continue

_torch.cuda.nvtx.range_push("outputs")
batch_id = fed_batches.get_nowait() # we are sure there's at least one
batch_id = fed_batches.pop(0) # we are sure there's at least one
err = None
torch_outputs = None
try:
Expand Down Expand Up @@ -691,12 +708,15 @@ class _Iter(_torchdata.dataloader._MultiProcessingDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.loader = loader
if self.loader.dali_server._thread is None:
raise RuntimeError("DALI server is not running")

def _next_data(self):
data = super()._next_data()
return self.loader.dali_server.produce_data(data)
self.loader.dali_server.start_thread()
try:
data = super()._next_data()
return self.loader.dali_server.produce_data(data)
except StopIteration:
self.loader.dali_server.stop_thread()
raise

def __init__(self, dali_server, *args, **kwargs):
"""
Expand Down
134 changes: 86 additions & 48 deletions dali/test/python/test_dali_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,59 +121,57 @@ def test_dali_proxy_torch_data_loader(device, include_decoder, debug=False):
prefetch_queue_depth=1,
)

# Run the server (it also cleans up on scope exit)
with dali_proxy.DALIServer(pipe) as dali_server:
dali_server = dali_proxy.DALIServer(pipe)
if include_decoder:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy, loader=read_filepath)
dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy(), loader=read_filepath)
else:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy)
dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy())

if include_decoder:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy, loader=read_filepath)
dataset_ref = datasets.ImageFolder(
jpeg, transform=lambda x: x.copy(), loader=read_filepath
)
else:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy)
dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy())
loader = dali_proxy.DataLoader(
dali_server,
dataset,
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)

loader = dali_proxy.DataLoader(
dali_server,
dataset,
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)
def ref_collate_fn(batch):
filepaths, labels = zip(*batch) # Separate the inputs and labels
# Just return the batch as they are, a list of individual tensors
return filepaths, labels

def ref_collate_fn(batch):
filepaths, labels = zip(*batch) # Separate the inputs and labels
# Just return the batch as they are, a list of individual tensors
return filepaths, labels
loader_ref = torchdata.dataloader.DataLoader(
dataset_ref,
batch_size=batch_size,
num_workers=1,
collate_fn=ref_collate_fn,
shuffle=False,
)

loader_ref = torchdata.dataloader.DataLoader(
dataset_ref,
batch_size=batch_size,
num_workers=1,
collate_fn=ref_collate_fn,
shuffle=False,
for _, ((data, target), (ref_data, ref_target)) in enumerate(zip(loader, loader_ref)):
np.testing.assert_equal([batch_size, 3, 224, 224], data.shape)
np.testing.assert_equal(
[
batch_size,
],
target.shape,
)

for _, ((data, target), (ref_data, ref_target)) in enumerate(zip(loader, loader_ref)):
np.testing.assert_equal([batch_size, 3, 224, 224], data.shape)
np.testing.assert_equal(
[
batch_size,
],
target.shape,
)
np.testing.assert_array_equal(target, ref_target)
ref_data_nparrays = [
np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data
]
ref_data_tensors = [TensorCPU(arr) for arr in ref_data_nparrays]
pipe_ref.feed_input("images", ref_data_tensors)
(ref_data,) = pipe_ref.run()
for sample_idx in range(batch_size):
ref_tensor = ref_data[sample_idx]
if isinstance(ref_tensor, TensorGPU):
ref_tensor = ref_tensor.as_cpu()
np.testing.assert_array_equal(ref_tensor, data[sample_idx].cpu())
np.testing.assert_array_equal(target, ref_target)
ref_data_nparrays = [
np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data
]
ref_data_tensors = [TensorCPU(arr) for arr in ref_data_nparrays]
pipe_ref.feed_input("images", ref_data_tensors)
(ref_data,) = pipe_ref.run()
for sample_idx in range(batch_size):
ref_tensor = ref_data[sample_idx]
if isinstance(ref_tensor, TensorGPU):
ref_tensor = ref_tensor.as_cpu()
np.testing.assert_array_equal(ref_tensor, data[sample_idx].cpu())

dali_server.stop_thread() # make sure we stop the thread before leaving the test


@attr("pytorch")
Expand Down Expand Up @@ -579,3 +577,43 @@ def test_dali_proxy_proxy_callable(named_arguments, debug=False):

np.testing.assert_array_almost_equal(a_plus_b, a + b)
np.testing.assert_array_almost_equal(a_minus_b, b - a)


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_dali_proxy_restart_server(device, debug=False):
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from torch.utils import data as torchdata

@pipeline_def
def square(device):
a = fn.external_source(name="a", no_copy=True)
if device == "gpu":
a = a.gpu()
return a**2

class MyDataset(torchdata.Dataset):
def __init__(self, transform_fn):
self.transform_fn = transform_fn

def __len__(self):
return 10

def __getitem__(self, idx):
return np.array(idx), self.transform_fn(np.array(idx))

batch_size = 4
dali_server = dali_proxy.DALIServer(
square(device="cpu", batch_size=batch_size, num_threads=3, device_id=None)
)

dataset = MyDataset(dali_server.proxy)
loader = dali_proxy.DataLoader(
dali_server, dataset, batch_size=batch_size, num_workers=2, drop_last=True
)
for _ in range(3): # 3 epochs
assert dali_server._thread is None
for data0, data1 in iter(loader):
np.testing.assert_array_almost_equal(data0**2, data1.cpu())
assert dali_server._thread is not None
assert dali_server._thread is None
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ def gdtl(
pipe, reader_name="Reader", fill_last_batch=False
)

dali_server = None
return (
DALIWrapper(train_loader, num_classes, one_hot, memory_format),
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
dali_server,
)

return gdtl
Expand Down Expand Up @@ -222,11 +220,9 @@ def gdvl(
pipe, reader_name="Reader", fill_last_batch=False
)

dali_server = None
return (
DALIWrapper(val_loader, num_classes, one_hot, memory_format),
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
dali_server,
)

return gdvl
Expand Down Expand Up @@ -378,11 +374,9 @@ def get_pytorch_train_loader(
persistent_workers=True,
prefetch_factor=prefetch_factor,
)
dali_server = None
return (
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, True, memory_format, "CHW"),
len(train_loader),
dali_server,
)


Expand Down Expand Up @@ -435,11 +429,9 @@ def get_pytorch_val_loader(
persistent_workers=True,
prefetch_factor=prefetch_factor,
)
dali_server = None
return (
PrefetchedWrapper(val_loader, 0, num_classes, one_hot, True, memory_format, "CHW"),
len(val_loader),
dali_server
)

def read_file(path):
Expand Down Expand Up @@ -519,7 +511,6 @@ def get_impl(data_path,
return (
PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot, False, memory_format, output_layout),
len(train_loader),
dali_server,
)
return get_impl

Expand Down Expand Up @@ -593,7 +584,6 @@ def get_impl(data_path,
return (
PrefetchedWrapper(val_loader, 0, num_classes, one_hot, False, memory_format, output_layout),
len(val_loader),
dali_server,
)
return get_impl

Expand Down Expand Up @@ -644,7 +634,6 @@ def get_synthetic_loader(
memory_format=torch.contiguous_format,
**kwargs,
):
dali_server = None
return (
SynteticDataLoader(
batch_size,
Expand All @@ -656,5 +645,4 @@ def get_synthetic_loader(
memory_format=memory_format,
),
-1,
dali_server,
)
Loading

0 comments on commit fb43cf3

Please sign in to comment.