Skip to content

Commit

Permalink
Apply code review suggestions
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 24, 2024
1 parent 926b3bc commit 826f463
Show file tree
Hide file tree
Showing 8 changed files with 627 additions and 246 deletions.
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SourceKind as _SourceKind,
)


def _get_shape(data):
if isinstance(data, (_tensors.TensorCPU, _tensors.TensorGPU)):
if callable(data.shape):
Expand Down
6 changes: 6 additions & 0 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ def is_restored_from_checkpoint(self):
"""If True, this pipeline was restored from checkpoint."""
return self._is_restored_from_checkpoint

@property
def num_outputs(self):
"""Number of pipeline outputs."""
return self._num_outputs

def output_dtype(self) -> list:
"""Data types expected at the outputs."""
self.build()
Expand Down Expand Up @@ -854,6 +859,7 @@ def contains_nested_datanode(nested):
self._require_no_foreign_ops("The pipeline does not support checkpointing")

self._graph_outputs = outputs
self._num_outputs = len(self._graph_outputs)
self._setup_input_callbacks()
self._disable_pruned_external_source_instances()
self._py_graph_built = True
Expand Down
438 changes: 280 additions & 158 deletions dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py

Large diffs are not rendered by default.

180 changes: 167 additions & 13 deletions dali/test/python/test_dali_proxy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali import pipeline_def, fn, types
import numpy as np
import os
from nose2.tools import params
from nose_utils import attr, assert_raises
import PIL.Image


def read_file(path):
return np.fromfile(path, dtype=np.uint8)

Expand Down Expand Up @@ -49,9 +64,7 @@ def image_pipe(dali_device="gpu", include_decoder=True, random_pipe=True):
if random_pipe:
shapes = images.shape()
crop_anchor, crop_shape = fn.random_crop_generator(
shapes,
random_aspect_ratio=[0.75, 4.0 / 3.0],
random_area=[0.08, 1.0]
shapes, random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0]
)
images = fn.slice(images, start=crop_anchor, shape=crop_shape, axes=[0, 1])

Expand All @@ -75,7 +88,7 @@ def image_pipe(dali_device="gpu", include_decoder=True, random_pipe=True):


@attr("pytorch")
@params(("cpu", False), ("cpu", True), ("cpu", False), ("gpu", True))
@params(("cpu", False), ("cpu", True), ("gpu", False), ("gpu", True))
def test_dali_proxy_torch_data_loader(device, include_decoder, debug=False):
# Shows how DALI proxy is used in practice with a PyTorch data loader

Expand Down Expand Up @@ -114,7 +127,9 @@ def test_dali_proxy_torch_data_loader(device, include_decoder, debug=False):

if include_decoder:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy, loader=read_filepath)
dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy(), loader=read_filepath)
dataset_ref = datasets.ImageFolder(
jpeg, transform=lambda x: x.copy(), loader=read_filepath
)
else:
dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy)
dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy())
Expand Down Expand Up @@ -149,7 +164,9 @@ def ref_collate_fn(batch):
target.shape,
)
np.testing.assert_array_equal(target, ref_target)
ref_data_nparrays = [np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data]
ref_data_nparrays = [
np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data
]
ref_data_tensors = [TensorCPU(arr) for arr in ref_data_nparrays]
pipe_ref.feed_input("images", ref_data_tensors)
(ref_data,) = pipe_ref.run()
Expand All @@ -161,7 +178,8 @@ def ref_collate_fn(batch):


@attr("pytorch")
def test_dali_proxy_torch_data_loader_manual_integration(device="gpu", debug=False):
@params(("gpu",))
def test_dali_proxy_manual_integration(device, debug=False):
# Shows how to integrate with DALI proxy manually with an existing data loader

from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
Expand Down Expand Up @@ -251,18 +269,16 @@ def __getitem__(self, idx):
return img2, other

# This is just for educational purposes. It is recommended to rely
# default_collate_fn_map, which is updated to handle DALIProcessedSampleRef
# default_collate_fn_map, which is updated to handle DALIOuputSampleRef
def custom_collate_fn(batch):
images, labels = zip(*batch)
return dali_proxy._collate_dali_processed_sample_ref_fn(images), torch.tensor(
return dali_proxy._collate_dali_output_sample_ref_fn(images), torch.tensor(
labels, dtype=torch.long
)

# Run the server (it also cleans up on scope exit)
with dali_proxy.DALIServer(pipe) as dali_server:

dataset = CustomDatasetDALI(plain_dataset, dali_server.proxy)

loader = torchdata.dataloader.DataLoader(
dataset,
batch_size=batch_size,
Expand All @@ -272,9 +288,8 @@ def custom_collate_fn(batch):
)

assert len(loader) > 0

for next_input, next_target in loader:
assert isinstance(next_input, dali_proxy.DALIPipelineRunRef)
assert isinstance(next_input, dali_proxy.DALIOutputBatchRef)
next_input = dali_server.produce_data(next_input)
assert isinstance(next_input, torch.Tensor)
np.testing.assert_equal([batch_size, 3, 224, 224], next_input.shape)
Expand Down Expand Up @@ -392,3 +407,142 @@ def pipe_with_error():
# messages in the next test
pipe._shutdown()
del pipe


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_dali_proxy_duplicated_outputs(device, debug=False):
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from torch.utils import data as torchdata
from PIL import Image

batch_size = 4
num_threads = 3
device_id = 0
nworkers = 4
pipe = image_pipe(
dali_device=device,
include_decoder=False,
random_pipe=False,
batch_size=batch_size,
num_threads=num_threads,
device_id=device_id,
prefetch_queue_depth=2 + nworkers,
)

class MyDataset(torchdata.Dataset):
def __init__(self, folder_path, transform):
self.folder_path = folder_path
self.image_files = self._find_images_in_folder(folder_path)
self.transform = transform

def _find_images_in_folder(self, folder_path):
"""
Recursively find all image files in the folder and its subdirectories.
"""
image_files = []

# Walk through all directories and subdirectories
for root, _, files in os.walk(folder_path):
for file in files:
if file.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
image_files.append(os.path.join(root, file))

return image_files

def __len__(self):
"""Returns the number of images in the folder."""
return len(self.image_files)

def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.folder_path, img_name)
img = Image.open(img_path).convert("RGB") # Convert image to RGB (3 channels)
img = self.transform(img)
return img, 1, img

with dali_proxy.DALIServer(pipe) as dali_server:
dataset = MyDataset(jpeg, transform=dali_server.proxy)
loader = dali_proxy.DataLoader(
dali_server,
dataset,
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)

for data1, _, data2 in loader:
np.testing.assert_array_equal(data1, data2)


@attr("pytorch")
@params(("cpu",), ("gpu",))
def test_dali_proxy_rearrange_output_order_and_positional_args(device, debug=False):
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from torch.utils import data as torchdata

batch_size = 4
num_threads = 3
device_id = 0
nworkers = 4
arrs = np.random.rand(20, 3)

@pipeline_def
def pipe_2_outputs():
a = fn.external_source(name="a", no_copy=True)
b = fn.external_source(name="b", no_copy=True)
if device == "gpu":
a = a.gpu()
b = b.gpu()
return a + b, b - a

pipe1 = pipe_2_outputs(
batch_size=batch_size,
num_threads=num_threads,
device_id=device_id,
prefetch_queue_depth=2 + nworkers,
)
pipe2 = pipe_2_outputs(
batch_size=batch_size,
num_threads=num_threads,
device_id=device_id,
prefetch_queue_depth=2 + nworkers,
)

class MyDataset(torchdata.Dataset):
def __init__(self, arrs, transform, reverse_order):
self.arrs = arrs
self.n = len(arrs)
self.transform = transform
self.reverse_order = reverse_order

def __len__(self):
"""Returns the number of images in the folder."""
return self.n

def __getitem__(self, idx):
a = self.arrs[idx]
b = self.arrs[idx + 1 if idx < self.n - 1 else 0]
a_plus_b, b_minus_a = self.transform(b=b, a=a) # reverse order in purpose
return (b_minus_a, 1, a_plus_b) if self.reverse_order else (a_plus_b, 1, b_minus_a)

with dali_proxy.DALIServer(pipe1) as dali_server1, dali_proxy.DALIServer(pipe2) as dali_server2:
loader1 = dali_proxy.DataLoader(
dali_server1,
MyDataset(arrs, dali_server1.proxy, reverse_order=False),
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)
loader2 = dali_proxy.DataLoader(
dali_server2,
MyDataset(arrs, dali_server2.proxy, reverse_order=True),
batch_size=batch_size,
num_workers=nworkers,
drop_last=True,
)

for data1, data2 in zip(loader1, loader2):
np.testing.assert_array_equal(data1[0].cpu(), data2[2].cpu())
np.testing.assert_array_equal(data1[1].cpu(), data2[1].cpu())
np.testing.assert_array_equal(data1[2].cpu(), data2[0].cpu())
Loading

0 comments on commit 826f463

Please sign in to comment.