From 35fdbd920f380c46a5c7dea3c43092d743cead6c Mon Sep 17 00:00:00 2001 From: Joaquin Anton Date: Fri, 19 Apr 2024 07:12:23 +0100 Subject: [PATCH] Refactoring of File Reader classes to accommodate for AWS SDK S3 integration (#5434) Signed-off-by: Joaquin Anton --- .../slice/slice_flip_normalize_gpu_test.cu | 4 +- dali/operators/imgcodec/decoder_test_helper.h | 2 +- dali/operators/reader/file_reader_op.cc | 6 +- dali/operators/reader/loader/CMakeLists.txt | 6 +- dali/operators/reader/loader/coco_loader.cc | 28 +-- dali/operators/reader/loader/coco_loader.h | 11 +- dali/operators/reader/loader/cufile_loader.h | 5 +- .../operators/reader/loader/discover_files.cc | 142 +++++++++++++++ dali/operators/reader/loader/discover_files.h | 57 ++++++ .../reader/loader/discover_files_test.cc | 154 ++++++++++++++++ .../reader/loader/file_label_loader.cc | 26 +-- .../reader/loader/file_label_loader.h | 48 ++--- dali/operators/reader/loader/file_loader.h | 31 ++-- dali/operators/reader/loader/filesystem.cc | 163 ++--------------- dali/operators/reader/loader/filesystem.h | 16 +- .../reader/loader/filesystem_test.cc | 146 ++-------------- dali/operators/reader/loader/fits_loader.h | 15 +- .../operators/reader/loader/fits_loader_gpu.h | 4 +- .../reader/loader/indexed_file_loader.h | 17 +- dali/operators/reader/loader/loader_test.cc | 11 +- dali/operators/reader/loader/numpy_loader.cc | 26 ++- dali/operators/reader/loader/numpy_loader.h | 4 +- .../reader/loader/numpy_loader_gpu.cc | 8 +- .../reader/loader/numpy_loader_gpu.h | 3 +- .../operators/reader/loader/recordio_loader.h | 10 +- .../reader/loader/sequence_loader.cc | 4 +- .../reader/loader/webdataset/tar_utils.cc | 7 +- .../loader/webdataset/tar_utils_test.cc | 23 +-- .../reader/loader/webdataset_loader.cc | 7 +- dali/operators/reader/numpy_reader_gpu_op.cc | 4 +- dali/util/CMakeLists.txt | 11 +- dali/util/cufile.cc | 7 +- dali/util/cufile.h | 4 +- dali/util/file.cc | 21 ++- dali/util/file.h | 32 +++- dali/util/fits_test.cc | 8 +- dali/util/mmaped_file.h | 5 +- dali/util/odirect_file.cc | 8 +- dali/util/odirect_file.h | 3 +- dali/util/std_cufile.cc | 7 - dali/util/std_cufile.h | 3 +- dali/util/std_file.cc | 8 +- dali/util/std_file.h | 3 +- dali/util/uri.cc | 165 ++++++++++++++++++ dali/util/uri.h | 91 ++++++++++ dali/util/uri_test.cc | 114 ++++++++++++ 46 files changed, 997 insertions(+), 481 deletions(-) create mode 100644 dali/operators/reader/loader/discover_files.cc create mode 100644 dali/operators/reader/loader/discover_files.h create mode 100644 dali/operators/reader/loader/discover_files_test.cc create mode 100644 dali/util/uri.cc create mode 100644 dali/util/uri.h create mode 100644 dali/util/uri_test.cc diff --git a/dali/kernels/slice/slice_flip_normalize_gpu_test.cu b/dali/kernels/slice/slice_flip_normalize_gpu_test.cu index 4d4b7aac59..01fb3a1fd6 100644 --- a/dali/kernels/slice/slice_flip_normalize_gpu_test.cu +++ b/dali/kernels/slice/slice_flip_normalize_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -79,7 +79,7 @@ class SliceFlipNormalizeGPUTest : public ::testing::Test { } void LoadTensor(Tensor &tensor, const std::string& path_npy) { - auto stream = FileStream::Open(path_npy, false, false); + auto stream = FileStream::Open(path_npy); tensor = ::dali::numpy::ReadTensor(stream.get(), true); } diff --git a/dali/operators/imgcodec/decoder_test_helper.h b/dali/operators/imgcodec/decoder_test_helper.h index f5353f6bec..d9775fd6a1 100644 --- a/dali/operators/imgcodec/decoder_test_helper.h +++ b/dali/operators/imgcodec/decoder_test_helper.h @@ -248,7 +248,7 @@ inline Tensor ReadReference(InputStream *src, TensorLayout layout = */ inline Tensor ReadReferenceFrom(const std::string &reference_path, TensorLayout layout = "HWC") { - auto src = FileStream::Open(reference_path, false, false); + auto src = FileStream::Open(reference_path); return ReadReference(src.get(), layout); } diff --git a/dali/operators/reader/file_reader_op.cc b/dali/operators/reader/file_reader_op.cc index 13bbc88f6c..3d52e1ee07 100644 --- a/dali/operators/reader/file_reader_op.cc +++ b/dali/operators/reader/file_reader_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -114,6 +114,10 @@ list of files in the sub-directories of the ``file_root``. This argument is ignored when file paths are taken from ``file_list`` or ``files``.)", kKnownExtensionsGlob) + .AddOptionalArg>("dir_filters", R"(A list of glob strings to filter the +list of sub-directories under ``file_root``. + +This argument is ignored when file paths are taken from ``file_list`` or ``files``.)", nullptr) .AddOptionalArg("case_sensitive_filter", R"(If set to True, the filter will be matched case-sensitively, otherwise case-insensitively.)", false) .AddParent("LoaderBase"); diff --git a/dali/operators/reader/loader/CMakeLists.txt b/dali/operators/reader/loader/CMakeLists.txt index 6b8b992690..69d9a47888 100644 --- a/dali/operators/reader/loader/CMakeLists.txt +++ b/dali/operators/reader/loader/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ collect_headers(DALI_INST_HDRS PARENT_SCOPE) set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/filesystem.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/discover_files.cc" "${CMAKE_CURRENT_SOURCE_DIR}/file_label_loader.cc" "${CMAKE_CURRENT_SOURCE_DIR}/coco_loader.cc" "${CMAKE_CURRENT_SOURCE_DIR}/loader.cc" @@ -57,7 +58,8 @@ endif() set(DALI_OPERATOR_TEST_SRCS ${DALI_OPERATOR_TEST_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/loader_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/sequence_loader_test.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/filesystem_test.cc") + "${CMAKE_CURRENT_SOURCE_DIR}/filesystem_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/discover_files_test.cc") if (BUILD_LIBSND) set(DALI_OPERATOR_TEST_SRCS ${DALI_OPERATOR_TEST_SRCS} diff --git a/dali/operators/reader/loader/coco_loader.cc b/dali/operators/reader/loader/coco_loader.cc index b13f9baaa1..d284bb1eff 100644 --- a/dali/operators/reader/loader/coco_loader.cc +++ b/dali/operators/reader/loader/coco_loader.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -141,13 +141,13 @@ void SaveToFile(const std::vector > &input, const std::string pat } template <> -void SaveToFile(const ImageIdPairs &image_id_pairs, const std::string path) { - if (image_id_pairs.empty()) +void SaveToFile(const std::vector &entries, const std::string path) { + if (entries.empty()) return; std::ofstream file(path); DALI_ENFORCE(file, "CocoReader meta file error while saving: " + path); - for (const auto &p : image_id_pairs) { - file << p.first << std::endl; + for (const auto &p : entries) { + file << p.filename << std::endl; } DALI_ENFORCE(file.good(), make_string("Error writing to path: ", path)); } @@ -203,16 +203,16 @@ void LoadFromFile(std::vector > &output, const std::string path) } template <> -void LoadFromFile(ImageIdPairs &image_id_pairs, const std::string path) { +void LoadFromFile(std::vector &entries, const std::string path) { std::ifstream file(path); - image_id_pairs.clear(); + entries.clear(); if (!file.good()) return; int id = 0; std::string filename; while (file >> filename) { - image_id_pairs.emplace_back(std::move(filename), int{id}); + entries.push_back({std::move(filename), id}); ++id; } } @@ -417,14 +417,14 @@ void ParseJsonFile(const OpSpec &spec, std::vector &image_inf } // namespace detail -void CocoLoader::SavePreprocessedAnnotations(const std::string &path, - const ImageIdPairs &image_id_pairs) { +void CocoLoader::SavePreprocessedAnnotations( + const std::string &path, const std::vector &entries) { using detail::SaveToFile; SaveToFile(offsets_, path + "/offsets.dat"); SaveToFile(boxes_, path + "/boxes.dat"); SaveToFile(labels_, path + "/labels.dat"); SaveToFile(counts_, path + "/counts.dat"); - SaveToFile(image_id_pairs, path + "/filenames.dat"); + SaveToFile(entries, path + "/filenames.dat"); if (output_polygon_masks_ || output_pixelwise_masks_) { SaveToFile(polygon_data_, path + "/polygon_data.dat"); @@ -459,7 +459,7 @@ void CocoLoader::ParsePreprocessedAnnotations() { LoadFromFile(boxes_, path + "/boxes.dat"); LoadFromFile(labels_, path + "/labels.dat"); LoadFromFile(counts_, path + "/counts.dat"); - LoadFromFile(image_label_pairs_, path + "/filenames.dat"); + LoadFromFile(file_label_entries_, path + "/filenames.dat"); if (output_polygon_masks_ || output_pixelwise_masks_) { LoadFromFile(polygon_data_, path + "/polygon_data.dat"); @@ -628,7 +628,7 @@ void CocoLoader::ParseJsonAnnotations() { } } - image_label_pairs_.emplace_back(std::move(image_info.filename_), new_image_id); + file_label_entries_.push_back({std::move(image_info.filename_), new_image_id}); new_image_id++; } } @@ -639,7 +639,7 @@ void CocoLoader::ParseJsonAnnotations() { if (spec_.GetArgument("save_preprocessed_annotations")) { SavePreprocessedAnnotations( spec_.GetArgument("save_preprocessed_annotations_dir"), - image_label_pairs_); + file_label_entries_); } } diff --git a/dali/operators/reader/loader/coco_loader.h b/dali/operators/reader/loader/coco_loader.h index e3d0b70b17..af312eb872 100644 --- a/dali/operators/reader/loader/coco_loader.h +++ b/dali/operators/reader/loader/coco_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,8 +34,6 @@ extern "C" { namespace dali { -using ImageIdPairs = std::vector>; - inline bool OutPolygonMasksEnabled(const OpSpec &spec) { return spec.GetArgument("polygon_masks") || (spec.HasArgument("masks") && spec.GetArgument("masks")); @@ -189,12 +187,12 @@ class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase { // seeded with hardcoded value to get // the same sequence on every shard std::mt19937 g(kDaliDataloaderSeed); - std::shuffle(image_label_pairs_.begin(), image_label_pairs_.end(), g); + std::shuffle(file_label_entries_.begin(), file_label_entries_.end(), g); } if (IsCheckpointingEnabled() && shuffle_after_epoch_) { // save initial order - backup_image_label_pairs_ = image_label_pairs_; + backup_file_label_entries_ = file_label_entries_; } Reset(true); } @@ -203,7 +201,8 @@ class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase { void ParseJsonAnnotations(); - void SavePreprocessedAnnotations(const std::string &path, const ImageIdPairs &image_id_pairs); + void SavePreprocessedAnnotations( + const std::string &path, const std::vector &image_id_pairs); private: const OpSpec spec_; diff --git a/dali/operators/reader/loader/cufile_loader.h b/dali/operators/reader/loader/cufile_loader.h index 848e9c10f9..e489841f8f 100755 --- a/dali/operators/reader/loader/cufile_loader.h +++ b/dali/operators/reader/loader/cufile_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,8 +38,7 @@ namespace dali { template class CUFileLoader : public FileLoader { public: - explicit CUFileLoader(const OpSpec& spec, vector images = {}, - bool shuffle_after_epoch = false) + CUFileLoader(const OpSpec& spec, bool shuffle_after_epoch) : FileLoader(spec, shuffle_after_epoch) { } diff --git a/dali/operators/reader/loader/discover_files.cc b/dali/operators/reader/loader/discover_files.cc new file mode 100644 index 0000000000..260daba0f5 --- /dev/null +++ b/dali/operators/reader/loader/discover_files.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/operators/reader/loader/discover_files.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "dali/core/call_at_exit.h" +#include "dali/core/error_handling.h" +#include "dali/operators/reader/loader/filesystem.h" +#include "dali/operators/reader/loader/utils.h" + +namespace dali { + +std::vector list_subdirectories(const std::string &parent_dir, + const std::vector dir_filters = {}, + bool case_sensitive_filter = true) { + // open the root + DIR *dir = opendir(parent_dir.c_str()); + DALI_ENFORCE(dir != nullptr, make_string("Failed to open ", parent_dir)); + auto cleanup = AtScopeExit([&dir] { + closedir(dir); + }); + + struct dirent *entry; + std::vector subdirs; + + while ((entry = readdir(dir))) { + struct stat s; + std::string entry_name(entry->d_name); + std::string full_path = filesystem::join_path(parent_dir, entry_name); + int ret = stat(full_path.c_str(), &s); + DALI_ENFORCE(ret == 0, "Could not access " + full_path + " during directory traversal."); + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) + continue; + if (S_ISDIR(s.st_mode)) { + if (dir_filters.empty()) { + subdirs.push_back(entry_name); + } else { + for (auto &filter : dir_filters) { + if (fnmatch(filter.c_str(), entry_name.c_str(), + case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) { + subdirs.push_back(entry_name); + } + } + } + } + } + // sort directories to preserve class alphabetic order, as readdir could + // return unordered dir list. Otherwise file reader for training and validation + // could return directories with the same names in completely different order + std::sort(subdirs.begin(), subdirs.end()); + return subdirs; +} + +std::vector list_files(const std::string &parent_dir, + const std::vector filters = {}, + bool case_sensitive_filter = true) { + DIR *dir = opendir(parent_dir.c_str()); + DALI_ENFORCE(dir != nullptr, make_string("Failed to open ", parent_dir)); + auto cleanup = AtScopeExit([&dir] { + closedir(dir); + }); + + dirent *entry; + std::vector files; + while ((entry = readdir(dir))) { +#ifdef _DIRENT_HAVE_D_TYPE + /* + * we support only regular files and symlinks, if FS returns DT_UNKNOWN + * it doesn't mean anything and let us validate filename itself + */ + if (entry->d_type != DT_REG && entry->d_type != DT_LNK && entry->d_type != DT_UNKNOWN) { + continue; + } +#endif + std::string fname(entry->d_name); + for (auto &filter : filters) { + if (fnmatch(filter.c_str(), fname.c_str(), case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) { + files.push_back(fname); + break; + } + } + } + std::sort(files.begin(), files.end()); + return files; +} + +std::vector discover_files(const std::string &file_root, + const FileDiscoveryOptions &opts) { + bool is_s3 = starts_with(file_root, "s3://"); + if (is_s3) { + DALI_FAIL("This version of DALI was not built with AWS S3 storage support."); + } + + std::vector subdirs; + subdirs = list_subdirectories(file_root, opts.dir_filters, opts.case_sensitive_filter); + std::vector entries; + auto process_dir = [&](const std::string &rel_dirpath, std::optional label = {}) { + auto full_dirpath = filesystem::join_path(file_root, rel_dirpath); + auto tmp_files = list_files(full_dirpath, opts.file_filters, opts.case_sensitive_filter); + for (const auto &f : tmp_files) { + entries.push_back({filesystem::join_path(rel_dirpath, f), label}); + } + }; + + // if we are in "label_from_subdir" mode, we need a subdir to infer the label, therefore we don't + // visit the current directory + if (!opts.label_from_subdir) { + process_dir("."); + } + for (unsigned dir_idx = 0; dir_idx < subdirs.size(); ++dir_idx) { + process_dir(subdirs[dir_idx], + opts.label_from_subdir ? std::optional{dir_idx} : std::nullopt); + } + size_t total_dir_count = opts.label_from_subdir ? subdirs.size() : subdirs.size() + 1; + LOG_LINE << "read " << entries.size() << " files from " << total_dir_count << "directories\n"; + return entries; +} + +} // namespace dali diff --git a/dali/operators/reader/loader/discover_files.h b/dali/operators/reader/loader/discover_files.h new file mode 100644 index 0000000000..a66371171b --- /dev/null +++ b/dali/operators/reader/loader/discover_files.h @@ -0,0 +1,57 @@ +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_H_ +#define DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_H_ + +#include +#include +#include +#include +#include "dali/core/common.h" + +namespace dali { + +inline bool starts_with(const std::string &str, const char *prefix) { + // TODO(janton): this is a substitute for C++20's string::starts_with + // trick: this only matches if the prefix is found at the beginning of the string + return str.rfind(prefix, 0) == 0; +} + +struct FileLabelEntry { + std::string filename; + // only if label_from_subdir==true + std::optional label = {}; + // only populated when size is known without opening (e.g. s3) + std::optional size = {}; +}; + +struct FileDiscoveryOptions { + bool label_from_subdir = true; // if true, the directory is expected to contain a subdirectory + // for each category. The traversal will assign ascending integers + // as labels for each of those + bool case_sensitive_filter = false; // whether the filter patterns are case-sensitive + std::vector file_filters; // pattern to apply to filenames + std::vector dir_filters; // pattern to apply to subdirectories +}; + +/** + * @brief Finds all (file, label, size) information, following the criteria given by opts. + */ +DLL_PUBLIC vector discover_files(const std::string &file_root, + const FileDiscoveryOptions &opts); + +} // namespace dali + +#endif // DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_H_ diff --git a/dali/operators/reader/loader/discover_files_test.cc b/dali/operators/reader/loader/discover_files_test.cc new file mode 100644 index 0000000000..3327920aa5 --- /dev/null +++ b/dali/operators/reader/loader/discover_files_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "dali/core/error_handling.h" +#include "dali/operators/reader/loader/filesystem.h" +#include "dali/operators/reader/loader/discover_files.h" +#include "dali/operators/reader/loader/utils.h" +#include "dali/test/dali_test_config.h" + +namespace dali { + +class DiscoverFilesTest : public ::testing::Test { + std::vector> readFileLabelFile() { + std::vector> image_label_pairs; + std::string file_list = file_root + "/image_list.txt"; + std::ifstream s(file_list); + DALI_ENFORCE(s.is_open(), "Cannot open: " + file_list); + + std::vector line_buf(16 << 10); + char *line = line_buf.data(); + for (int n = 1; s.getline(line, line_buf.size()); n++) { + int i = strlen(line) - 1; + + for (; i >= 0 && isspace(line[i]); i--) {} + + int label_end = i + 1; + + if (i < 0) + continue; + + for (; i >= 0 && isdigit(line[i]); i--) {} + + int label_start = i + 1; + + for (; i >= 0 && isspace(line[i]); i--) {} + + int name_end = i + 1; + DALI_ENFORCE( + name_end > 0 && name_end < label_start && label_start >= 2 && label_end > label_start, + make_string("Incorrect format of the list file \"", file_list, "\":", n, + " expected file name followed by a label; got: ", line)); + + line[label_end] = 0; + line[name_end] = 0; + + image_label_pairs.emplace_back(line, std::atoi(line + label_start)); + } + std::sort(image_label_pairs.begin(), image_label_pairs.end()); + DALI_ENFORCE(s.eof(), "Wrong format of file_list: " + file_list); + + return image_label_pairs; + } + + protected: + DiscoverFilesTest() + : file_root(testing::dali_extra_path() + "/db/single/jpeg"), + file_label_pairs(readFileLabelFile()) {} + + std::vector globMatch(std::vector &filters, std::string path) { + std::vector correct_match; + glob_t pglob; + for (auto &filter : filters) { + std::string pattern = path + filesystem::dir_sep + '*' + filesystem::dir_sep + filter; + if (glob(pattern.c_str(), GLOB_TILDE, NULL, &pglob) == 0) { + for (unsigned int count = 0; count < pglob.gl_pathc; ++count) { + std::string match(pglob.gl_pathv[count]); + correct_match.push_back(match.substr(path.length() + 1, std::string::npos)); + } + globfree(&pglob); + } + } + std::sort(correct_match.begin(), correct_match.end()); + std::unique(correct_match.begin(), correct_match.end()); + return correct_match; + } + + std::string file_root; + std::vector> file_label_pairs; +}; + +TEST_F(DiscoverFilesTest, MatchAllFilter) { + auto file_label_pairs_filtered = + discover_files(file_root, {true, false, kKnownExtensionsGlob, {}}); + ASSERT_EQ(this->file_label_pairs.size(), file_label_pairs_filtered.size()); + for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { + ASSERT_EQ(this->file_label_pairs[i].first, file_label_pairs_filtered[i].filename); + } +} + +TEST_F(DiscoverFilesTest, SingleFilter) { + std::vector filters{"dog*.jpg"}; + auto file_label_pairs_filtered = + discover_files(file_root, {true, false, filters}); + std::vector correct_match = globMatch(filters, file_root); + + + for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { + ASSERT_EQ(correct_match[i], file_label_pairs_filtered[i].filename); + } +} + +TEST_F(DiscoverFilesTest, MultipleOverlappingFilters) { + std::vector filters{"dog*.jpg", "snail*.jpg", "*_1280.jpg"}; + auto file_label_pairs_filtered = + discover_files(file_root, {true, false, filters}); + std::vector correct_match = globMatch(filters, file_root); + + for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { + EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].filename); + } +} + +TEST_F(DiscoverFilesTest, CaseSensitiveFilters) { + std::vector filters{"*.jPg"}; + std::string root = (testing::dali_extra_path() + "/db/single/case_sensitive"); + auto file_label_pairs_filtered = discover_files(root, {true, true, filters}); + std::vector correct_match = globMatch(filters, root); + + for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { + EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].filename); + } +} + +TEST_F(DiscoverFilesTest, CaseInsensitiveFilters) { + std::vector filters{"*.jPg"}; + std::vector glob_filters{"*.jpg", "*.jpG", "*.jPg", "*.jPG", + "*.Jpg", "*.JpG", "*.JPg", "*.JPG"}; + std::string root = (testing::dali_extra_path() + "/db/single/case_sensitive"); + auto file_label_pairs_filtered = discover_files(root, {true, false, filters}); + std::vector correct_match = globMatch(glob_filters, root); + + for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { + EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].filename); + } +} + +} // namespace dali diff --git a/dali/operators/reader/loader/file_label_loader.cc b/dali/operators/reader/loader/file_label_loader.cc index 46c67c4ada..08dcc6036c 100644 --- a/dali/operators/reader/loader/file_label_loader.cc +++ b/dali/operators/reader/loader/file_label_loader.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "dali/operators/reader/loader/file_label_loader.h" #include - #include "dali/core/common.h" -#include "dali/operators/reader/loader/file_label_loader.h" -#include "dali/util/file.h" +#include "dali/operators/reader/loader/filesystem.h" #include "dali/operators/reader/loader/utils.h" +#include "dali/util/file.h" namespace dali { @@ -30,19 +30,19 @@ void FileLabelLoaderBase::PrepareEmpty(ImageLabelWrappe template void FileLabelLoaderBase::ReadSample(ImageLabelWrapper &image_label) { - auto image_pair = image_label_pairs_[current_index_++]; + auto entry = file_label_entries_[current_index_++]; // handle wrap-around MoveToNextShard(current_index_); // copy the label - image_label.label = image_pair.second; + image_label.label = entry.label.value(); DALIMeta meta; - meta.SetSourceInfo(image_pair.first); + meta.SetSourceInfo(entry.filename); meta.SetSkipSample(false); // if image is cached, skip loading - if (ShouldSkipImage(image_pair.first)) { + if (ShouldSkipImage(entry.filename)) { meta.SetSkipSample(true); image_label.image.Reset(); image_label.image.SetMeta(meta); @@ -50,8 +50,8 @@ void FileLabelLoaderBase::ReadSample(ImageLabelWrapper return; } - auto current_image = FileStream::Open(filesystem::join_path(file_root_, image_pair.first), - read_ahead_, !copy_read_data_); + auto uri = filesystem::join_path(file_root_, entry.filename); + auto current_image = FileStream::Open(uri, {read_ahead_, !copy_read_data_, false}, entry.size); Index image_size = current_image->Size(); if (copy_read_data_) { @@ -61,10 +61,10 @@ void FileLabelLoaderBase::ReadSample(ImageLabelWrapper image_label.image.Resize({image_size}, DALI_UINT8); // copy the image Index ret = current_image->Read(image_label.image.mutable_data(), image_size); - DALI_ENFORCE(ret == image_size, make_string("Failed to read file: ", image_pair.first)); + DALI_ENFORCE(ret == image_size, make_string("Failed to read file: ", entry.filename)); } else { auto p = current_image->Get(image_size); - DALI_ENFORCE(p != nullptr, make_string("Failed to read file: ", image_pair.first)); + DALI_ENFORCE(p != nullptr, make_string("Failed to read file: ", entry.filename)); // Wrap the raw data in the Tensor object. image_label.image.ShareData(p, image_size, false, {image_size}, DALI_UINT8, CPU_ONLY_DEVICE_ID); } @@ -77,7 +77,7 @@ void FileLabelLoaderBase::ReadSample(ImageLabelWrapper template Index FileLabelLoaderBase::SizeImpl() { - return static_cast(image_label_pairs_.size()); + return static_cast(file_label_entries_.size()); } template class FileLabelLoaderBase; diff --git a/dali/operators/reader/loader/file_label_loader.h b/dali/operators/reader/loader/file_label_loader.h index b6b051b87d..52232b3ad5 100755 --- a/dali/operators/reader/loader/file_label_loader.h +++ b/dali/operators/reader/loader/file_label_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,8 +27,9 @@ #include #include "dali/core/common.h" -#include "dali/operators/reader/loader/loader.h" +#include "dali/operators/reader/loader/discover_files.h" #include "dali/operators/reader/loader/filesystem.h" +#include "dali/operators/reader/loader/loader.h" #include "dali/util/file.h" namespace dali { @@ -45,7 +46,7 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader; explicit inline FileLabelLoaderBase( const OpSpec& spec, - bool shuffle_after_epoch = false) + bool shuffle_after_epoch) : Base(spec), shuffle_after_epoch_(shuffle_after_epoch), current_index_(0), @@ -58,11 +59,14 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader 0, - "``file_filters`` list cannot be empty."); + DALI_ENFORCE(!has_file_filters_arg || traverse_opts_.file_filters.size() > 0, + "``file_filters`` list cannot be empty."); + DALI_ENFORCE(!has_dir_filters_arg || traverse_opts_.dir_filters.size() > 0, + "``dir_filters`` list cannot be empty."); if (has_file_list_arg_) { DALI_ENFORCE(!file_list_.empty(), "``file_list`` argument cannot be empty"); @@ -94,10 +100,10 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader> image_label_pairs_; - vector> backup_image_label_pairs_; - vector filters_; + vector file_label_entries_; + vector backup_file_label_entries_; + FileDiscoveryOptions traverse_opts_; bool has_files_arg_ = false; bool has_labels_arg_ = false; bool has_file_list_arg_ = false; bool has_file_root_arg_ = false; - bool case_sensitive_filter_ = false; bool shuffle_after_epoch_; Index current_index_; diff --git a/dali/operators/reader/loader/file_loader.h b/dali/operators/reader/loader/file_loader.h index b7da55ea2a..bd0db9f8ab 100755 --- a/dali/operators/reader/loader/file_loader.h +++ b/dali/operators/reader/loader/file_loader.h @@ -28,6 +28,7 @@ #include "dali/core/common.h" #include "dali/operators/reader/loader/filesystem.h" +#include "dali/operators/reader/loader/discover_files.h" #include "dali/operators/reader/loader/loader.h" #include "dali/operators/reader/loader/utils.h" #include "dali/util/file.h" @@ -51,6 +52,10 @@ class FileLoader : public Loader { current_epoch_(0) { vector files; + traverse_opts_.label_from_subdir = false; + traverse_opts_.case_sensitive_filter = true; + traverse_opts_.file_filters.push_back(file_filter_); + has_files_arg_ = spec.TryGetRepeatedArgument(files, "files"); has_file_list_arg_ = spec.TryGetArgument(file_list_, "file_list"); has_file_root_arg_ = spec.TryGetArgument(file_root_, "file_root"); @@ -73,8 +78,11 @@ class FileLoader : public Loader { if (has_files_arg_) { DALI_ENFORCE(files.size() > 0, "``files`` specified an empty list."); - files_ = std::move(files); + for (auto& f : files) { + file_entries_.push_back({std::move(f)}); + } } + files.clear(); // we moved the elements /* * Those options are mutually exclusive as `shuffle_after_epoch` will make every shard looks @@ -102,13 +110,13 @@ class FileLoader : public Loader { protected: Index SizeImpl() override { - return static_cast(files_.size()); + return static_cast(file_entries_.size()); } void PrepareMetadataImpl() override { - if (files_.empty()) { + if (file_entries_.empty()) { if (!has_files_arg_ && !has_file_list_arg_) { - files_ = filesystem::traverse_directories(file_root_, file_filter_); + file_entries_ = discover_files(file_root_, traverse_opts_); } else if (has_file_list_arg_) { // load paths from list std::ifstream s(file_list_); @@ -118,7 +126,7 @@ class FileLoader : public Loader { char *line = line_buf.data(); while (s.getline(line, line_buf.size())) { if (line[0]) // skip empty lines - files_.emplace_back(line); + file_entries_.push_back({std::string(line)}); } DALI_ENFORCE(s.eof(), "Wrong format of file_list: " + file_list_); } @@ -126,13 +134,13 @@ class FileLoader : public Loader { DALI_ENFORCE(SizeImpl() > 0, "No files found."); if (IsCheckpointingEnabled()) { - backup_files_ = files_; + backup_file_entries_ = file_entries_; } if (shuffle_) { // seeded with hardcoded value to get // the same sequence on every shard std::mt19937 g(kDaliDataloaderSeed); - std::shuffle(files_.begin(), files_.end(), g); + std::shuffle(file_entries_.begin(), file_entries_.end(), g); } Reset(true); } @@ -151,10 +159,10 @@ class FileLoader : public Loader { // With checkpointing enabled dataset order must be easy to restore. // Shuffling is run with different seed every epoch, so this doesn't // reduce the randomness. - files_ = backup_files_; + file_entries_ = backup_file_entries_; } std::mt19937 g(kDaliDataloaderSeed + current_epoch_); - std::shuffle(files_.begin(), files_.end(), g); + std::shuffle(file_entries_.begin(), file_entries_.end(), g); } } @@ -178,8 +186,9 @@ class FileLoader : public Loader { using Loader::IsCheckpointingEnabled; string file_list_, file_root_, file_filter_; - vector files_; - vector backup_files_; + FileDiscoveryOptions traverse_opts_; + vector file_entries_; + vector backup_file_entries_; bool has_files_arg_ = false; bool has_file_list_arg_ = false; diff --git a/dali/operators/reader/loader/filesystem.cc b/dali/operators/reader/loader/filesystem.cc index 3b6e6f93d4..d4460a472e 100644 --- a/dali/operators/reader/loader/filesystem.cc +++ b/dali/operators/reader/loader/filesystem.cc @@ -13,168 +13,37 @@ // limitations under the License. #include "dali/operators/reader/loader/filesystem.h" -#include -#include -#include -#include -#include -#include #include #include -#include -#include -#include "dali/core/call_at_exit.h" -#include "dali/core/error_handling.h" -#include "dali/operators/reader/loader/utils.h" +#include +#include "dali/util/uri.h" namespace dali { namespace filesystem { -inline bool starts_with(const std::string &str, const char *prefix) { - // TODO(janton): this is a substitute for C++20's string::starts_with - // trick: this only matches if the prefix is found at the beginning of the string - return str.rfind(prefix, 0) == 0; -} - std::string join_path(const std::string &dir, const std::string &path) { if (dir.empty()) return path; if (path.empty()) return dir; - if (path[0] == dir_sep) // absolute path - return path; -#ifdef WINVER - if (path[1] == ':') - return path; -#endif - if (dir[dir.length() - 1] == dir_sep) - return dir + path; - else - return dir + dir_sep + path; -} - -std::vector list_subdirectories(const std::string &parent_dir, - const std::vector dir_filters = {}, - bool case_sensitive_filter = true) { - // open the root - DIR *dir = opendir(parent_dir.c_str()); - DALI_ENFORCE(dir != nullptr, make_string("Failed to open ", parent_dir)); - auto cleanup = AtScopeExit([&dir] { - closedir(dir); - }); - - struct dirent *entry; - std::vector subdirs; - - while ((entry = readdir(dir))) { - struct stat s; - std::string entry_name(entry->d_name); - std::string full_path = join_path(parent_dir, entry_name); - int ret = stat(full_path.c_str(), &s); - DALI_ENFORCE(ret == 0, "Could not access " + full_path + " during directory traversal."); - if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) - continue; - if (S_ISDIR(s.st_mode)) { - if (dir_filters.empty()) { - subdirs.push_back(entry_name); - } else { - for (auto &filter : dir_filters) { - if (fnmatch(filter.c_str(), entry_name.c_str(), - case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) { - subdirs.push_back(entry_name); - } - } - } - } - } - // sort directories to preserve class alphabetic order, as readdir could - // return unordered dir list. Otherwise file reader for training and validation - // could return directories with the same names in completely different order - std::sort(subdirs.begin(), subdirs.end()); - return subdirs; -} -std::vector list_files(const std::string &parent_dir, - const std::vector filters = {}, - bool case_sensitive_filter = true) { - DIR *dir = opendir(parent_dir.c_str()); - DALI_ENFORCE(dir != nullptr, make_string("Failed to open ", parent_dir)); - auto cleanup = AtScopeExit([&dir] { - closedir(dir); - }); - - dirent *entry; - std::vector files; - while ((entry = readdir(dir))) { -#ifdef _DIRENT_HAVE_D_TYPE - /* - * we support only regular files and symlinks, if FS returns DT_UNKNOWN - * it doesn't mean anything and let us validate filename itself - */ - if (entry->d_type != DT_REG && entry->d_type != DT_LNK && entry->d_type != DT_UNKNOWN) { - continue; - } + auto uri = URI::Parse(dir); + if (uri.valid()) { + const char *separators = "/"; + // TODO(janton): In case we ever support Windows +#ifdef _WINVER + if (uri.scheme() == "file:") + separators = "/\\"; #endif - std::string fname(entry->d_name); - for (auto &filter : filters) { - if (fnmatch(filter.c_str(), fname.c_str(), case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) { - files.push_back(fname); - break; - } - } - } - std::sort(files.begin(), files.end()); - return files; -} -vector> traverse_directories(const std::string &file_root, - const std::vector &filters, - bool case_sensitive_filter, - const std::vector &dir_filters) { - std::vector subdirs; - bool is_s3 = starts_with(file_root, "s3://"); - if (is_s3) - DALI_FAIL("This version of DALI was not built with AWS S3 storage support.") - subdirs = list_subdirectories(file_root, dir_filters, case_sensitive_filter); - - std::vector> file_label_pairs; - for (unsigned dir_count = 0; dir_count < subdirs.size(); ++dir_count) { - std::vector tmp_files; - const auto &rel_dirpath = subdirs[dir_count]; - auto full_dirpath = join_path(file_root, rel_dirpath); - tmp_files = list_files(full_dirpath, filters, case_sensitive_filter); - for (const auto &f : tmp_files) { - file_label_pairs.push_back({join_path(rel_dirpath, f), dir_count}); - } + if (strchr(separators, path[0])) // absolute path + return std::string(uri.scheme_authority()) + path; + else if (strchr(separators, dir[dir.length() - 1])) // dir ends with a separator + return dir + path; + else // basic case + return dir + '/' + path; } - LOG_LINE << "read " << file_label_pairs.size() << " files from " << subdirs.size() - << "directories\n"; - return file_label_pairs; -} - - -vector traverse_directories(const std::string &file_root, const std::string &filter) { - std::vector subdirs; - bool is_s3 = starts_with(file_root, "s3://"); - if (is_s3) - DALI_FAIL("This version of DALI was not built with AWS S3 storage support."); - subdirs = list_subdirectories(file_root); - - std::vector files; - auto process_dir = [&](const std::string &rel_dirpath) { - auto full_dirpath = join_path(file_root, rel_dirpath); - auto tmp_files = list_files(full_dirpath, {filter}, true); - for (const auto &f : tmp_files) { - files.push_back(join_path(rel_dirpath, f)); - } - }; - - process_dir("."); // process current dir as well - for (const auto &subdir : subdirs) - process_dir(subdir); - - LOG_LINE << "read " << files.size() << " files from " << subdirs.size() << "directories\n"; - return files; + return std::filesystem::path(dir) / std::filesystem::path(path); } } // namespace filesystem diff --git a/dali/operators/reader/loader/filesystem.h b/dali/operators/reader/loader/filesystem.h index f1644f4ef2..12871d71b2 100644 --- a/dali/operators/reader/loader/filesystem.h +++ b/dali/operators/reader/loader/filesystem.h @@ -17,29 +17,15 @@ #include #include -#include #include "dali/core/common.h" namespace dali { namespace filesystem { -DLL_PUBLIC vector traverse_directories(const string &path, const string &filter); - -/** - * @brief Finds all (file, label) pairs matching any filter from the list. - */ -DLL_PUBLIC vector> traverse_directories( - const string &file_root, - const vector &filters, - bool case_sensitive_filter = false, - const vector& dir_filters = {}); - /** * @brief Prepends dir to a relative path and keeps absolute path unchanged. */ -DLL_PUBLIC string join_path(const string &dir, const string &path); - -DLL_PUBLIC string dir_path(const string &path); +DLL_PUBLIC string join_path(const std::string &dir, const std::string &path); #ifdef WINVER constexpr char dir_sep = '\\'; diff --git a/dali/operators/reader/loader/filesystem_test.cc b/dali/operators/reader/loader/filesystem_test.cc index ce20782acf..3d8fbe7242 100644 --- a/dali/operators/reader/loader/filesystem_test.cc +++ b/dali/operators/reader/loader/filesystem_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,139 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include - -#include "dali/core/error_handling.h" #include "dali/operators/reader/loader/filesystem.h" -#include "dali/operators/reader/loader/utils.h" -#include "dali/test/dali_test_config.h" +#include namespace dali { -class FilesystemTest : public ::testing::Test { - std::vector> readFileLabelFile() { - std::vector> image_label_pairs; - std::string file_list = file_root + "/image_list.txt"; - std::ifstream s(file_list); - DALI_ENFORCE(s.is_open(), "Cannot open: " + file_list); - - std::vector line_buf(16 << 10); - char *line = line_buf.data(); - for (int n = 1; s.getline(line, line_buf.size()); n++) { - int i = strlen(line) - 1; - - for (; i >= 0 && isspace(line[i]); i--) {} - - int label_end = i + 1; - - if (i < 0) - continue; - - for (; i >= 0 && isdigit(line[i]); i--) {} - - int label_start = i + 1; - - for (; i >= 0 && isspace(line[i]); i--) {} - - int name_end = i + 1; - DALI_ENFORCE( - name_end > 0 && name_end < label_start && label_start >= 2 && label_end > label_start, - make_string("Incorrect format of the list file \"", file_list, "\":", n, - " expected file name followed by a label; got: ", line)); - - line[label_end] = 0; - line[name_end] = 0; - - image_label_pairs.emplace_back(line, std::atoi(line + label_start)); - } - std::sort(image_label_pairs.begin(), image_label_pairs.end()); - DALI_ENFORCE(s.eof(), "Wrong format of file_list: " + file_list); - - return image_label_pairs; - } - - protected: - FilesystemTest() - : file_root(testing::dali_extra_path() + "/db/single/jpeg"), - file_label_pairs(readFileLabelFile()) {} - - std::vector globMatch(std::vector &filters, std::string path) { - std::vector correct_match; - glob_t pglob; - for (auto &filter : filters) { - std::string pattern = path + filesystem::dir_sep + '*' + filesystem::dir_sep + filter; - if (glob(pattern.c_str(), GLOB_TILDE, NULL, &pglob) == 0) { - for (unsigned int count = 0; count < pglob.gl_pathc; ++count) { - std::string match(pglob.gl_pathv[count]); - correct_match.push_back(match.substr(path.length() + 1, std::string::npos)); - } - globfree(&pglob); - } - } - std::sort(correct_match.begin(), correct_match.end()); - std::unique(correct_match.begin(), correct_match.end()); - return correct_match; - } - - std::string file_root; - std::vector> file_label_pairs; -}; - -TEST_F(FilesystemTest, MatchAllFilter) { - auto file_label_pairs_filtered = - filesystem::traverse_directories(file_root, kKnownExtensionsGlob); - ASSERT_EQ(this->file_label_pairs.size(), file_label_pairs_filtered.size()); - for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { - ASSERT_EQ(this->file_label_pairs[i].first, file_label_pairs_filtered[i].first); - } -} - -TEST_F(FilesystemTest, SingleFilter) { - std::vector filters{"dog*.jpg"}; - auto file_label_pairs_filtered = filesystem::traverse_directories(file_root, filters); - std::vector correct_match = globMatch(filters, file_root); - - - for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { - ASSERT_EQ(correct_match[i], file_label_pairs_filtered[i].first); - } +TEST(JoinPath, File) { + EXPECT_EQ("/path/dir/path2", + filesystem::join_path("/path/dir", "path2")); + EXPECT_EQ("/path/dir/path2", + filesystem::join_path("/path/dir/", "path2")); + EXPECT_EQ("/path2", + filesystem::join_path("/path/dir", "/path2")); } -TEST_F(FilesystemTest, MultipleOverlappingFilters) { - std::vector filters{"dog*.jpg", "snail*.jpg", "*_1280.jpg"}; - auto file_label_pairs_filtered = filesystem::traverse_directories(file_root, filters); - std::vector correct_match = globMatch(filters, file_root); - - for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { - EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].first); - } +TEST(JoinPath, URI) { + EXPECT_EQ("s3://my_bucket/mypath/path2", + filesystem::join_path("s3://my_bucket/mypath", "path2")); + EXPECT_EQ("s3://my_bucket/mypath/path2", + filesystem::join_path("s3://my_bucket/mypath/", "path2")); + EXPECT_EQ("s3://my_bucket/path2", + filesystem::join_path("s3://my_bucket/mypath", "/path2")); + EXPECT_EQ("s3://my_bucket/path2", + filesystem::join_path("s3://my_bucket/mypath/", "/path2")); } -TEST_F(FilesystemTest, CaseSensitiveFilters) { - std::vector filters{"*.jPg"}; - std::string root = (testing::dali_extra_path() + "/db/single/case_sensitive"); - auto file_label_pairs_filtered = filesystem::traverse_directories(root, filters, true); - std::vector correct_match = globMatch(filters, root); - - for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { - EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].first); - } -} - -TEST_F(FilesystemTest, CaseInsensitiveFilters) { - std::vector filters{"*.jPg"}; - std::vector glob_filters{"*.jpg", "*.jpG", "*.jPg", "*.jPG", - "*.Jpg", "*.JpG", "*.JPg", "*.JPG"}; - std::string root = (testing::dali_extra_path() + "/db/single/case_sensitive"); - auto file_label_pairs_filtered = filesystem::traverse_directories(root, filters); - std::vector correct_match = globMatch(glob_filters, root); - - for (size_t i = 0; i < file_label_pairs_filtered.size(); ++i) { - EXPECT_EQ(correct_match[i], file_label_pairs_filtered[i].first); - } -} } // namespace dali diff --git a/dali/operators/reader/loader/fits_loader.h b/dali/operators/reader/loader/fits_loader.h index f0ed30dccf..a2bd05910c 100644 --- a/dali/operators/reader/loader/fits_loader.h +++ b/dali/operators/reader/loader/fits_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,13 +19,12 @@ #include #include #include - #include #include #include - #include "dali/core/common.h" #include "dali/operators/reader/loader/file_loader.h" +#include "dali/operators/reader/loader/filesystem.h" #include "dali/pipeline/data/types.h" #include "dali/util/file.h" #include "dali/util/fits.h" @@ -41,7 +40,7 @@ struct FitsFileWrapper { template class FitsLoader : public FileLoader { public: - explicit FitsLoader(const OpSpec& spec, bool shuffle_after_epoch = false) + explicit FitsLoader(const OpSpec& spec, bool shuffle_after_epoch) : FileLoader(spec, shuffle_after_epoch), hdu_indices_(spec.GetRepeatedArgument("hdu_indices")) { // default to DALI_UINT8, if argument dtypes not provided @@ -65,7 +64,7 @@ class FitsLoader : public FileLoader { } void ReadSample(Target& target) override { - auto filename = files_[current_index_++]; + auto filename = file_entries_[current_index_++].filename; int status = 0, num_hdus = 0; // handle wrap-around @@ -77,6 +76,8 @@ class FitsLoader : public FileLoader { meta.SetSkipSample(false); auto path = filesystem::join_path(file_root_, filename); + bool is_s3 = path.rfind("s3://", 0) == 0; + DALI_ENFORCE(!is_s3, "S3 storage not supported for FITS reader"); auto current_file = fits::FitsHandle::OpenFile(path.c_str(), READONLY); FITS_CALL(fits_get_num_hdus(current_file, &num_hdus, &status)); @@ -120,7 +121,7 @@ class FitsLoader : public FileLoader { private: using FileLoader::MoveToNextShard; - using FileLoader::files_; + using FileLoader::file_entries_; using FileLoader::current_index_; using FileLoader::file_root_; std::vector hdu_indices_; @@ -129,7 +130,7 @@ class FitsLoader : public FileLoader { class FitsLoaderCPU : public FitsLoader { public: - explicit FitsLoaderCPU(const OpSpec& spec, bool shuffle_after_epoch = false) + explicit FitsLoaderCPU(const OpSpec& spec, bool shuffle_after_epoch) : FitsLoader(spec, shuffle_after_epoch) {} protected: diff --git a/dali/operators/reader/loader/fits_loader_gpu.h b/dali/operators/reader/loader/fits_loader_gpu.h index 0869b950ee..e04415e101 100644 --- a/dali/operators/reader/loader/fits_loader_gpu.h +++ b/dali/operators/reader/loader/fits_loader_gpu.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ struct FitsFileWrapperGPU { class FitsLoaderGPU : public FitsLoader { public: - explicit FitsLoaderGPU(const OpSpec& spec, bool shuffle_after_epoch = false) + explicit FitsLoaderGPU(const OpSpec& spec, bool shuffle_after_epoch) : FitsLoader(spec, shuffle_after_epoch) {} protected: diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index dbaf25bd95..20197a769a 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,10 +62,13 @@ class IndexedFileLoader : public Loader, true> { meta.SetSourceInfo(image_key); meta.SetSkipSample(false); + bool is_s3 = uris_[file_index].rfind("s3://", 0) == 0; + bool use_o_direct = !is_s3 && use_o_direct_; + bool use_mmap = !is_s3 && !copy_read_data_; + if (file_index != current_file_index_) { current_file_.reset(); - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, - use_o_direct_); + current_file_ = FileStream::Open(uris_[file_index], {read_ahead_, use_mmap, use_o_direct}); current_file_index_ = file_index; // invalidate the buffer if (use_o_direct_) read_buffer_.reset(); @@ -87,7 +90,7 @@ class IndexedFileLoader : public Loader, true> { } next_seek_pos_ = seek_pos + size; - if (!copy_read_data_) { + if (use_mmap && current_file_->CanMemoryMap()) { auto p = current_file_->Get(size); DALI_ENFORCE(p != nullptr, "Error reading from a file " + uris_[current_file_index_]); // Wrap the raw data in the Tensor object. @@ -96,7 +99,7 @@ class IndexedFileLoader : public Loader, true> { if (tensor.shares_data()) { tensor.Reset(); } - if (use_o_direct_) { + if (use_o_direct) { /* * ** - sample data * XX - buffer padding, data of other samples @@ -224,8 +227,8 @@ class IndexedFileLoader : public Loader, true> { std::tie(seek_pos, size, file_index) = indices_[current_index_]; if (file_index != current_file_index_) { current_file_.reset(); - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, - use_o_direct_); + current_file_ = + FileStream::Open(uris_[file_index], {read_ahead_, !copy_read_data_, use_o_direct_}); current_file_index_ = file_index; // invalidate the buffer if (use_o_direct_) read_buffer_.reset(); diff --git a/dali/operators/reader/loader/loader_test.cc b/dali/operators/reader/loader/loader_test.cc index ff2f787ab9..a4af828dab 100644 --- a/dali/operators/reader/loader/loader_test.cc +++ b/dali/operators/reader/loader/loader_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -65,6 +65,7 @@ TYPED_TEST(DataLoadStoreTest, LMDBTest) { #endif TYPED_TEST(DataLoadStoreTest, FileLabelLoaderMmmap) { + bool shuffle_after_epoch = false; for (bool dont_use_mmap : {true, false}) { shared_ptr reader( new FileLabelLoader( @@ -72,7 +73,7 @@ TYPED_TEST(DataLoadStoreTest, FileLabelLoaderMmmap) { .AddArg("file_root", loader_test_image_folder) .AddArg("max_batch_size", 32) .AddArg("device_id", 0) - .AddArg("dont_use_mmap", dont_use_mmap))); + .AddArg("dont_use_mmap", dont_use_mmap), shuffle_after_epoch)); reader->PrepareMetadata(); auto sample = reader->ReadOne(false, false); @@ -136,12 +137,13 @@ TYPED_TEST(DataLoadStoreTest, CocoLoaderMmmap) { } TYPED_TEST(DataLoadStoreTest, LoaderTest) { + bool shuffle_after_epoch = false; shared_ptr reader( new FileLabelLoader( OpSpec("FileReader") .AddArg("file_root", loader_test_image_folder) .AddArg("max_batch_size", 32) - .AddArg("device_id", 0))); + .AddArg("device_id", 0), shuffle_after_epoch)); reader->PrepareMetadata(); @@ -154,11 +156,12 @@ TYPED_TEST(DataLoadStoreTest, LoaderTest) { } TYPED_TEST(DataLoadStoreTest, LoaderTestFail) { + bool shuffle_after_epoch = false; shared_ptr reader( new FileLabelLoader(OpSpec("FileReader") .AddArg("file_root", loader_test_image_folder + "/does_not_exist") .AddArg("max_batch_size", 32) - .AddArg("device_id", 0))); + .AddArg("device_id", 0), shuffle_after_epoch)); ASSERT_THROW(reader->PrepareMetadata(), std::runtime_error); } diff --git a/dali/operators/reader/loader/numpy_loader.cc b/dali/operators/reader/loader/numpy_loader.cc index c49ceab5de..9b6f37f978 100644 --- a/dali/operators/reader/loader/numpy_loader.cc +++ b/dali/operators/reader/loader/numpy_loader.cc @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "dali/operators/reader/loader/numpy_loader.h" #include #include #include #include - #include "dali/core/common.h" -#include "dali/operators/reader/loader/numpy_loader.h" -#include "dali/util/file.h" +#include "dali/operators/reader/loader/filesystem.h" #include "dali/operators/reader/loader/utils.h" +#include "dali/util/file.h" +#include "dali/util/uri.h" namespace dali { namespace detail { @@ -49,7 +50,9 @@ void NumpyHeaderCache::UpdateCache(const string &file_name, const numpy::HeaderD } // namespace detail void NumpyLoader::ReadSample(NumpyFileWrapper& target) { - auto filename = files_[current_index_++]; + const auto& entry = file_entries_[current_index_++]; + auto filename = entry.filename; + auto size = entry.size; // handle wrap-around MoveToNextShard(current_index_); @@ -69,15 +72,24 @@ void NumpyLoader::ReadSample(NumpyFileWrapper& target) { return; } + FileStream::Options opts; + opts.read_ahead = read_ahead_; + opts.use_mmap = !copy_read_data_; + opts.use_odirect = use_o_direct_; auto path = filesystem::join_path(file_root_, filename); - auto current_file = FileStream::Open(path, read_ahead_, !copy_read_data_, use_o_direct_); + auto uri = URI::Parse(path); + if (uri.valid()) { + opts.use_mmap = false; + opts.use_odirect = false; + } + auto current_file = FileStream::Open(path, opts, size); // read the header numpy::HeaderData header; auto ret = header_cache_.GetFromCache(filename, header); try { if (!ret) { - if (use_o_direct_) { + if (opts.use_odirect) { numpy::ParseODirectHeader(header, current_file.get(), o_direct_alignm_, o_direct_read_len_alignm_); } else { @@ -98,7 +110,7 @@ void NumpyLoader::ReadSample(NumpyFileWrapper& target) { target.nbytes = nbytes; target.filename = std::move(path); - if (copy_read_data_) { + if (!opts.use_mmap || !current_file->CanMemoryMap()) { target.current_file = std::move(current_file); } else { auto p = current_file->Get(nbytes); diff --git a/dali/operators/reader/loader/numpy_loader.h b/dali/operators/reader/loader/numpy_loader.h index b627d0bcf7..b1518b27a0 100755 --- a/dali/operators/reader/loader/numpy_loader.h +++ b/dali/operators/reader/loader/numpy_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ class NumpyLoader : public FileLoader { public: explicit inline NumpyLoader( const OpSpec& spec, - bool shuffle_after_epoch = false, + bool shuffle_after_epoch, bool use_o_direct = false, size_t o_direct_alignm = 512, size_t o_direct_read_len_alignm = 512) diff --git a/dali/operators/reader/loader/numpy_loader_gpu.cc b/dali/operators/reader/loader/numpy_loader_gpu.cc index 1df0f68be6..a18aedaf4d 100644 --- a/dali/operators/reader/loader/numpy_loader_gpu.cc +++ b/dali/operators/reader/loader/numpy_loader_gpu.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "dali/operators/reader/loader/numpy_loader_gpu.h" #include #include #include #include - #include "dali/core/common.h" -#include "dali/operators/reader/loader/numpy_loader_gpu.h" +#include "dali/operators/reader/loader/filesystem.h" namespace dali { @@ -27,7 +27,7 @@ void NumpyLoaderGPU::PrepareEmpty(NumpyFileWrapperGPU& target) { } void NumpyFileWrapperGPU::Reopen() { - file_stream_ = CUFileStream::Open(filename, read_ahead, false); + file_stream_ = CUFileStream::Open(filename, {read_ahead, false, false}); } void NumpyFileWrapperGPU::ReadHeader(detail::NumpyHeaderCache &cache) { @@ -61,7 +61,7 @@ void NumpyLoaderGPU::ReadSample(NumpyFileWrapperGPU& target) { DeviceGuard g(device_id_); // extract image file - auto filename = files_[current_index_++]; + auto filename = file_entries_[current_index_++].filename; // handle wrap-around MoveToNextShard(current_index_); diff --git a/dali/operators/reader/loader/numpy_loader_gpu.h b/dali/operators/reader/loader/numpy_loader_gpu.h index 951316b38b..266ab17c07 100755 --- a/dali/operators/reader/loader/numpy_loader_gpu.h +++ b/dali/operators/reader/loader/numpy_loader_gpu.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -77,7 +77,6 @@ struct NumpyFileWrapperGPU { class NumpyLoaderGPU : public CUFileLoader { public: using CUFileLoader::CUFileLoader; - void PrepareEmpty(NumpyFileWrapperGPU& tensor) override; void ReadSample(NumpyFileWrapperGPU& tensor) override; }; diff --git a/dali/operators/reader/loader/recordio_loader.h b/dali/operators/reader/loader/recordio_loader.h index bbcdef8c42..b592ec2d0e 100644 --- a/dali/operators/reader/loader/recordio_loader.h +++ b/dali/operators/reader/loader/recordio_loader.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class RecordIOLoader : public IndexedFileLoader { std::vector file_offsets; file_offsets.push_back(0); for (std::string& path : uris_) { - auto tmp = FileStream::Open(path, read_ahead_, !copy_read_data_); + auto tmp = FileStream::Open(path, {read_ahead_, !copy_read_data_, false}); file_offsets.push_back(tmp->Size() + file_offsets.back()); tmp->Close(); } @@ -110,7 +110,7 @@ class RecordIOLoader : public IndexedFileLoader { shared_ptr p = nullptr; int64 n_read = 0; - bool use_read = copy_read_data_; + bool use_read = copy_read_data_ || !current_file_->CanMemoryMap(); if (use_read) { tensor.Resize({size}); } @@ -140,8 +140,8 @@ class RecordIOLoader : public IndexedFileLoader { DALI_ENFORCE(current_file_index_ + 1 < uris_.size(), "Incomplete or corrupted record files"); // Release previously opened file - current_file_ = FileStream::Open(uris_[++current_file_index_], read_ahead_, - !copy_read_data_); + current_file_ = + FileStream::Open(uris_[++current_file_index_], {read_ahead_, !copy_read_data_, false}); next_seek_pos_ = 0; continue; } diff --git a/dali/operators/reader/loader/sequence_loader.cc b/dali/operators/reader/loader/sequence_loader.cc index 2d96dc59ea..3e23f901a3 100644 --- a/dali/operators/reader/loader/sequence_loader.cc +++ b/dali/operators/reader/loader/sequence_loader.cc @@ -137,10 +137,10 @@ void SequenceLoader::LoadFrame(const std::vector &s, Index frame_id return; } - auto frame = FileStream::Open(frame_filename, read_ahead_, !copy_read_data_); + auto frame = FileStream::Open(frame_filename, {read_ahead_, !copy_read_data_, false}); Index frame_size = frame->Size(); // Release and unmap memory previously obtained by Get call - if (copy_read_data_) { + if (copy_read_data_ || !frame->CanMemoryMap()) { if (target->shares_data()) { target->Reset(); } diff --git a/dali/operators/reader/loader/webdataset/tar_utils.cc b/dali/operators/reader/loader/webdataset/tar_utils.cc index ea8df51804..588d7896e7 100644 --- a/dali/operators/reader/loader/webdataset/tar_utils.cc +++ b/dali/operators/reader/loader/webdataset/tar_utils.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -183,8 +183,9 @@ TarArchive::EntryType TarArchive::GetFileType() const { std::shared_ptr TarArchive::ReadFile() { stream_->SeekRead(stream_->TellRead() - readoffset_); - - auto out = stream_->Get(filesize_); + std::shared_ptr out; + if (stream_->CanMemoryMap()) + out = stream_->Get(filesize_); if (out != nullptr) { readoffset_ = filesize_; } diff --git a/dali/operators/reader/loader/webdataset/tar_utils_test.cc b/dali/operators/reader/loader/webdataset/tar_utils_test.cc index 7d287d8c89..a0613276be 100644 --- a/dali/operators/reader/loader/webdataset/tar_utils_test.cc +++ b/dali/operators/reader/loader/webdataset/tar_utils_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,10 +22,11 @@ #include #include #include -#include "dali/operators/reader/loader/filesystem.h" -#include "dali/util/file.h" #include "dali/core/util.h" +#include "dali/operators/reader/loader/discover_files.h" +#include "dali/operators/reader/loader/filesystem.h" #include "dali/test/dali_test_config.h" +#include "dali/util/file.h" namespace dali { namespace detail { @@ -36,9 +37,9 @@ TEST(LibTarUtilsTestSimple, Interface) { std::string dummy_filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-1.tar")); - TarArchive dummy_archive(FileStream::Open(dummy_filepath, false, false)); + TarArchive dummy_archive(FileStream::Open(dummy_filepath)); TarArchive intermediate_archive(std::move(dummy_archive)); - TarArchive archive(FileStream::Open(filepath, false, false)); + TarArchive archive(FileStream::Open(filepath)); archive = std::move(intermediate_archive); ASSERT_FALSE(dummy_archive.NextFile()); @@ -69,7 +70,7 @@ TEST(LibTarUtilsTestSimple, Interface) { TEST(LibTarUtilsTestSimple, LongNameIndexing) { std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/gnu.tar")); - TarArchive archive(FileStream::Open(filepath, false, false)); + TarArchive archive(FileStream::Open(filepath)); std::string name_prefix(128, '#'); for (int idx = 0; idx < 1000; idx++) { ASSERT_EQ(archive.GetFileName(), name_prefix + to_string(idx)); @@ -88,7 +89,7 @@ TEST(LibTarUtilsTestSimple, Types) { TarArchive::ENTRY_DIR, TarArchive::ENTRY_FIFO, TarArchive::ENTRY_FILE, TarArchive::ENTRY_SYMLINK, TarArchive::ENTRY_HARDLINK}; - TarArchive archive(FileStream::Open(filepath, false, true)); + TarArchive archive(FileStream::Open(filepath, {false, true, false})); for (size_t i = 0; i < types.size(); i++) { ASSERT_EQ(archive.GetFileType(), types[i]); ASSERT_EQ(archive.GetFileName(), to_string(i) + (types[i] == TarArchive::ENTRY_DIR ? "/" : "")); @@ -103,7 +104,7 @@ TEST(LibTarUtilsTestSimple, Offset) { std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/types.tar")); - TarArchive archive(FileStream::Open(filepath, false, true)); + TarArchive archive(FileStream::Open(filepath, {false, true, false})); archive.SeekArchive(7 * T_BLOCKSIZE); ASSERT_EQ(archive.TellArchive(), 7 * T_BLOCKSIZE); for (int i = 7; i < 14; i++) { @@ -149,8 +150,8 @@ class SimpleTarTests : public ::testing::TestWithParam { protected: TarArchive archive; SimpleTarTests() - : archive(FileStream::Open(GetParam().filepath, GetParam().read_ahead, GetParam().use_mmap)) { - } + : archive(FileStream::Open(GetParam().filepath, + {GetParam().read_ahead, GetParam().use_mmap, false})) {} }; TEST_P(SimpleTarTests, Index) { @@ -253,7 +254,7 @@ class MultiTarTests : public ::testing::TestWithParam { filepath_prefix + "0.tar", filepath_prefix + "1.tar", filepath_prefix + "2.tar"}; for (int idx = 0; idx < kMultithreadedSamples; idx++) { - archives[idx] = std::make_unique(FileStream::Open(filepaths[idx], false, false)); + archives[idx] = std::make_unique(FileStream::Open(filepaths[idx])); } } }; diff --git a/dali/operators/reader/loader/webdataset_loader.cc b/dali/operators/reader/loader/webdataset_loader.cc index 0a986d4140..954ff2c87f 100644 --- a/dali/operators/reader/loader/webdataset_loader.cc +++ b/dali/operators/reader/loader/webdataset_loader.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -314,11 +314,12 @@ void WebdatasetLoader::ReadSample(vector>& sample) { continue; } // Reading Data - if (copy_read_data_) { + if (copy_read_data_ || !current_wds_shard->CanMemoryMap()) { uint8_t* shared_tensor_data = nullptr; bool shared_tensor_is_pinned = false; int device_id = CPU_ONLY_DEVICE_ID; for (auto& output : component.outputs) { + sample[output].SetMeta(meta); if (!shared_tensor_data) { if (sample[output].shares_data()) { sample[output].Reset(); @@ -380,7 +381,7 @@ void WebdatasetLoader::PrepareMetadataImpl() { // initializing all the readers wds_shards_.reserve(paths_.size()); for (auto& uri : paths_) { - wds_shards_.emplace_back(FileStream::Open(uri, read_ahead_, !copy_read_data_)); + wds_shards_.emplace_back(FileStream::Open(uri, {read_ahead_, !copy_read_data_, false})); } // preparing the map from extensions to outputs diff --git a/dali/operators/reader/numpy_reader_gpu_op.cc b/dali/operators/reader/numpy_reader_gpu_op.cc index 245a300a7f..0a541ff499 100644 --- a/dali/operators/reader/numpy_reader_gpu_op.cc +++ b/dali/operators/reader/numpy_reader_gpu_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ NumpyReaderGPU::NumpyReaderGPU(const OpSpec& spec) // init loader bool shuffle_after_epoch = spec.GetArgument("shuffle_after_epoch"); - loader_ = InitLoader(spec, std::vector(), shuffle_after_epoch); + loader_ = InitLoader(spec, shuffle_after_epoch); this->SetInitialSnapshot(); kmgr_transpose_.Resize(1); diff --git a/dali/util/CMakeLists.txt b/dali/util/CMakeLists.txt index 986e6431b9..aa677f3691 100644 --- a/dali/util/CMakeLists.txt +++ b/dali/util/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,8 @@ set(DALI_INST_HDRS ${DALI_INST_HDRS} "${CMAKE_CURRENT_SOURCE_DIR}/random_crop_generator.h" "${CMAKE_CURRENT_SOURCE_DIR}/thread_safe_queue.h" "${CMAKE_CURRENT_SOURCE_DIR}/numpy.h" - "${CMAKE_CURRENT_SOURCE_DIR}/user_stream.h") + "${CMAKE_CURRENT_SOURCE_DIR}/user_stream.h" + "${CMAKE_CURRENT_SOURCE_DIR}/uri.h") set(DALI_SRCS ${DALI_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/file.cc" @@ -34,7 +35,8 @@ set(DALI_SRCS ${DALI_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/ocv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/random_crop_generator.cc" "${CMAKE_CURRENT_SOURCE_DIR}/user_stream.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/numpy.cc") + "${CMAKE_CURRENT_SOURCE_DIR}/numpy.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/uri.cc") if (BUILD_CUFILE) set(DALI_INST_HDRS ${DALI_INST_HDRS} @@ -57,7 +59,8 @@ endif() set(DALI_TEST_SRCS ${DALI_TEST_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/random_crop_generator_test.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/numpy_test.cc") + "${CMAKE_CURRENT_SOURCE_DIR}/numpy_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/uri_test.cc") # transform a list of paths into a list of include directives DETERMINE_GCC_SYSTEM_INCLUDE_DIRS("c++" "${CMAKE_CXX_COMPILER}" "${CMAKE_CXX_FLAGS}" INFERED_COMPILER_INCLUDE) diff --git a/dali/util/cufile.cc b/dali/util/cufile.cc index 43cb45d33e..91afbe6539 100644 --- a/dali/util/cufile.cc +++ b/dali/util/cufile.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,8 +19,7 @@ namespace dali { -std::unique_ptr CUFileStream::Open(const std::string& uri, bool read_ahead, - bool use_mmap) { +std::unique_ptr CUFileStream::Open(const std::string& uri, FileStream::Options opts) { std::string processed_uri; const char prefix[] = "file://"; @@ -30,7 +29,7 @@ std::unique_ptr CUFileStream::Open(const std::string& uri, bool re processed_uri = uri; } - DALI_ENFORCE(!use_mmap, "mmap not implemented with cuFile yet."); + DALI_ENFORCE(!opts.use_mmap, "mmap not implemented with cuFile yet."); return std::make_unique(processed_uri); } diff --git a/dali/util/cufile.h b/dali/util/cufile.h index ecc6fac081..3758d8c441 100644 --- a/dali/util/cufile.h +++ b/dali/util/cufile.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class DLL_PUBLIC CUFileStream : public FileStream { } }; - static std::unique_ptr Open(const std::string& uri, bool read_ahead, bool use_mmap); + static std::unique_ptr Open(const std::string& uri, FileStream::Options opts); /** * @brief Reads `n_bytes` to the buffer at position `offset` * diff --git a/dali/util/file.cc b/dali/util/file.cc index 3e0d31d773..dc05615f51 100644 --- a/dali/util/file.cc +++ b/dali/util/file.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,24 +16,29 @@ #include "dali/util/file.h" #include "dali/util/mmaped_file.h" -#include "dali/util/std_file.h" #include "dali/util/odirect_file.h" +#include "dali/util/std_file.h" +#include "dali/util/uri.h" namespace dali { -std::unique_ptr FileStream::Open(const std::string& uri, bool read_ahead, - bool use_mmap, bool use_odirect) { - std::string processed_uri; +std::unique_ptr FileStream::Open(const std::string& uri, FileStream::Options opts, + std::optional size) { + bool is_s3 = uri.rfind("s3://", 0) == 0; + if (is_s3) { + throw std::runtime_error("This version of DALI was not built with AWS S3 storage support."); + } + std::string processed_uri; if (uri.find("file://") == 0) { processed_uri = uri.substr(std::string("file://").size()); } else { processed_uri = uri; } - if (use_mmap) { - return std::unique_ptr(new MmapedFileStream(processed_uri, read_ahead)); - } else if (use_odirect) { + if (opts.use_mmap) { + return std::unique_ptr(new MmapedFileStream(processed_uri, opts.read_ahead)); + } else if (opts.use_odirect) { return std::unique_ptr(new ODirectFileStream(processed_uri)); } else { return std::unique_ptr(new StdFileStream(processed_uri)); diff --git a/dali/util/file.h b/dali/util/file.h index 85e3c23d65..bf1c7defc0 100644 --- a/dali/util/file.h +++ b/dali/util/file.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,10 +18,11 @@ #include #include #include - +#include #include "dali/core/api_helper.h" #include "dali/core/common.h" #include "dali/core/stream.h" +#include "dali/core/format.h" namespace dali { @@ -69,11 +70,32 @@ class DLL_PUBLIC FileStream : public InputStream { private: unsigned int reserved; }; - static std::unique_ptr Open(const std::string &uri, bool read_ahead, bool use_mmap, - bool use_odirect = false); + + struct Options { + bool read_ahead; + bool use_mmap; + bool use_odirect; + }; + + /** + * @brief Opens file stream + * + * @param uri URI to open + * @param opts options + * @param size If provided, we can defer the actual reading of the stream until it needs to be + * read (e.g. especially useful for remote storage) + * @return std::unique_ptr + */ + static std::unique_ptr Open(const std::string &uri, + Options opts = {false, false, false}, + std::optional size = std::nullopt); virtual void Close() = 0; - virtual shared_ptr Get(size_t n_bytes) = 0; + virtual bool CanMemoryMap() { return false; } + virtual shared_ptr Get(size_t n_bytes) { + throw std::logic_error( + make_string("memory mapping is not supported for this stream type. uri=", path_)); + } virtual ~FileStream() {} protected: diff --git a/dali/util/fits_test.cc b/dali/util/fits_test.cc index ed80acec72..5794025672 100644 --- a/dali/util/fits_test.cc +++ b/dali/util/fits_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,11 +39,11 @@ struct test_sample { std::string ref_tile_sizes_path) : path(img_path), ref_undecoded_data( - ReadVector(FileStream::Open(ref_data_path, false, false).get())), + ReadVector(FileStream::Open(ref_data_path).get())), ref_offset_sizes( - ReadVector(FileStream::Open(ref_offset_sizes_path, false, false).get())), + ReadVector(FileStream::Open(ref_offset_sizes_path).get())), ref_tile_sizes( - ReadVector(FileStream::Open(ref_tile_sizes_path, false, false).get())) {} + ReadVector(FileStream::Open(ref_tile_sizes_path).get())) {} std::string path; vector ref_undecoded_data; diff --git a/dali/util/mmaped_file.h b/dali/util/mmaped_file.h index 4b1aef407b..34c5a06277 100644 --- a/dali/util/mmaped_file.h +++ b/dali/util/mmaped_file.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,7 +28,10 @@ class MmapedFileStream : public FileStream { public: explicit MmapedFileStream(const std::string& path, bool read_ahead); void Close() override; + + bool CanMemoryMap() override { return true; } shared_ptr Get(size_t n_bytes) override; + static bool ReserveFileMappings(unsigned int num); static void FreeFileMappings(unsigned int num); size_t Read(void *buffer, size_t n_bytes) override; diff --git a/dali/util/odirect_file.cc b/dali/util/odirect_file.cc index 10701d7245..7a4c4acc15 100644 --- a/dali/util/odirect_file.cc +++ b/dali/util/odirect_file.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -110,12 +110,6 @@ size_t ODirectFileStream::Read(void *buffer, size_t n_bytes) { return n_read; } -shared_ptr ODirectFileStream::Get(size_t /*n_bytes*/) { - // this unction should return a pointer inside mmaped file - // it doesn't make sense in case of StdFileStream - return {}; -} - size_t ODirectFileStream::Size() const { struct stat sb; if (stat(path_.c_str(), &sb) == -1) { diff --git a/dali/util/odirect_file.h b/dali/util/odirect_file.h index 072b8ad53e..49814faf45 100644 --- a/dali/util/odirect_file.h +++ b/dali/util/odirect_file.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ class DLL_PUBLIC ODirectFileStream : public FileStream { public: explicit ODirectFileStream(const std::string& path); void Close() override; - shared_ptr Get(size_t n_bytes) override; size_t Read(void * buffer, size_t n_bytes) override; size_t ReadAt(void * buffer, size_t n_bytes, off_t offset); static size_t GetAlignment(); diff --git a/dali/util/std_cufile.cc b/dali/util/std_cufile.cc index e8d3337662..aed4beeaae 100644 --- a/dali/util/std_cufile.cc +++ b/dali/util/std_cufile.cc @@ -182,13 +182,6 @@ size_t StdCUFileStream::Read(void *cpu_buffer, size_t n_bytes) { return n_bytes; } -// disable this function -shared_ptr StdCUFileStream::Get(size_t n_bytes) { - // this function should return a pointer inside mmaped file - // it doesn't make sense in case of StdCUFileStream - return {}; -} - size_t StdCUFileStream::Size() const { return length_; } diff --git a/dali/util/std_cufile.h b/dali/util/std_cufile.h index 6d57e55810..24fc93b4fd 100644 --- a/dali/util/std_cufile.h +++ b/dali/util/std_cufile.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ class StdCUFileStream : public CUFileStream { public: explicit StdCUFileStream(const std::string& path); void Close() override; - shared_ptr Get(size_t n_bytes) override; size_t ReadAtGPU(void *gpu_buffer, size_t n_bytes, ptrdiff_t buffer_offset, int64 file_offset) override; size_t ReadGPU(void *buffer, size_t n_bytes, ptrdiff_t offset = 0) override; diff --git a/dali/util/std_file.cc b/dali/util/std_file.cc index 890f7d22e8..e39b808c66 100644 --- a/dali/util/std_file.cc +++ b/dali/util/std_file.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -54,12 +54,6 @@ size_t StdFileStream::Read(void *buffer, size_t n_bytes) { return n_read; } -shared_ptr StdFileStream::Get(size_t /*n_bytes*/) { - // this unction should return a pointer inside mmaped file - // it doesn't make sense in case of StdFileStream - return {}; -} - size_t StdFileStream::Size() const { struct stat sb; if (stat(path_.c_str(), &sb) == -1) { diff --git a/dali/util/std_file.h b/dali/util/std_file.h index 802bc33842..fb3e0e5e75 100644 --- a/dali/util/std_file.h +++ b/dali/util/std_file.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ class StdFileStream : public FileStream { public: explicit StdFileStream(const std::string& path); void Close() override; - shared_ptr Get(size_t n_bytes) override; size_t Read(void * buffer, size_t n_bytes) override; void SeekRead(ptrdiff_t pos, int whence = SEEK_SET) override; ptrdiff_t TellRead() const override; diff --git a/dali/util/uri.cc b/dali/util/uri.cc new file mode 100644 index 0000000000..70df1f8921 --- /dev/null +++ b/dali/util/uri.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/util/uri.h" +#include +#include + +namespace dali { + +inline bool allowed_scheme_char(char c) { + return std::isalnum(c) || c == '.' || c == '+' || c == '-'; +} + +inline bool allowed_char(char c) { + // See https://en.wikipedia.org/wiki/Uniform_Resource_Identifier + // gen-delims: :, /, ?, #, [, ], and @ + // sub-delims: !, $, &, ', (, ), *, +, ,, ;, and = + // unreserved characters (uppercase and lowercase letters, decimal digits, -, ., _, and ~) + // the character % + static const std::string gen_delims = ":/?#[]@"; + static const std::string sub_delims = "!$&'()*+,;="; + static const std::string unreserved = "-._~"; + return (std::isalnum(c) + || gen_delims.find(c) != std::string::npos + || sub_delims.find(c) != std::string::npos + || unreserved.find(c) != std::string::npos); +} + +std::string display_char(char c) { + switch (c) { + case '\a': + return "\\a"; + case '\b': + return "\\b"; + case '\t': + return "\\t"; + case '\n': + return "\\n"; + case '\v': + return "\\v"; + case '\f': + return "\\f"; + case '\r': + return "\\r"; + case '\?': + return "\\\?"; + default: + return {c}; + } +} + +URI URI::Parse(std::string uri) { + // See https://en.wikipedia.org/wiki/Uniform_Resource_Identifier + URI parsed; + parsed.uri_ = std::move(uri); + parsed.valid_ = true; + size_t len = parsed.uri_.length(); + const char* p_start = parsed.uri_.data(); + const char* p_end = parsed.uri_.data() + len; + const char* p = p_start; + + // Scheme + parsed.scheme_start_ = p - p_start; + + if (!std::isalpha(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "First character should be a letter"; + return parsed; + } + while (*p != '\0' && *p != ':') { + if (!allowed_scheme_char(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "Invalid character found (" + display_char(*p) + ") in scheme"; + return parsed; + } + p++; + } + parsed.scheme_end_ = p - p_start; + if (parsed.scheme_end_ <= parsed.scheme_start_) { + parsed.valid_ = false; + parsed.err_msg_ = "Empty scheme"; + return parsed; + } + + if (*p != ':') { + parsed.valid_ = false; + parsed.err_msg_ = "Expected a colon after the URI scheme"; + return parsed; + } + p++; + + // Authority + if (*p == '/' && *(p + 1) == '/') { + p += 2; + parsed.authority_start_ = p - p_start; + while (*p != '\0' && *p != '/') { + if (!allowed_char(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "Invalid character found (" + display_char(*p) + ") in authority"; + return parsed; + } + p++; + } + parsed.authority_end_ = p - p_start; + } + + if (*p == '\0') + return parsed; + + // Path + parsed.path_start_ = p - p_start; + while (*p != '\0' && *p != '?') { + if (!allowed_char(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "Invalid character found (" + display_char(*p) + ") in path"; + return parsed; + } + p++; + } + parsed.path_end_ = p - p_start; + if (*p == '\0') + return parsed; + + // Query + p++; + parsed.query_start_ = p - p_start; + while (*p != '\0' && *p != '#') { + if (!allowed_char(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "Invalid character found (" + display_char(*p) + ") in query"; + return parsed; + } + p++; + } + parsed.query_end_ = p - p_start; + if (*p == '\0') + return parsed; + + // Fragment + p++; + parsed.fragment_start_ = p - p_start; + while (*p != '\0') { + if (!allowed_char(*p)) { + parsed.valid_ = false; + parsed.err_msg_ = "Invalid character found (" + display_char(*p) + ") in fragment"; + return parsed; + } + p++; + } + parsed.fragment_end_ = p - p_start; + return parsed; +} + +} // namespace dali diff --git a/dali/util/uri.h b/dali/util/uri.h new file mode 100644 index 0000000000..0db4d21527 --- /dev/null +++ b/dali/util/uri.h @@ -0,0 +1,91 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_UTIL_URI_H_ +#define DALI_UTIL_URI_H_ + +#include +#include +#include +#include +#include +#include "dali/core/api_helper.h" + +namespace dali { + +class URI { + private: + std::string uri_; // the original URI string + std::ptrdiff_t scheme_start_ = 0, scheme_end_ = 0; + std::ptrdiff_t authority_start_ = 0, authority_end_ = 0; + std::ptrdiff_t path_start_ = 0, path_end_ = 0; + std::ptrdiff_t query_start_ = 0, query_end_ = 0; + std::ptrdiff_t fragment_start_ = 0, fragment_end_ = 0; + bool valid_ = false; + std::string err_msg_; + + void enforce_valid() const { + if (!valid_) + throw std::runtime_error(uri_ + " is not a valid URI: " + err_msg_); + } + + std::string_view uri_part(ptrdiff_t start, ptrdiff_t end) const { + enforce_valid(); + assert(end >= start); + return std::string_view{uri_.c_str() + start, static_cast(end - start)}; + } + + public: + static DLL_PUBLIC URI Parse(std::string uri); + + bool valid() const { + return valid_; + } + + std::string_view scheme() const { + return uri_part(scheme_start_, scheme_end_); + } + + std::string_view authority() const { + return uri_part(authority_start_, authority_end_); + } + + std::string_view scheme_authority() const { + return uri_part(scheme_start_, authority_end_); + } + + std::string_view path() const { + return uri_part(path_start_, path_end_); + } + + std::string_view scheme_authority_path() const { + return uri_part(scheme_start_, path_end_); + } + + std::string_view query() const { + return uri_part(query_start_, query_end_); + } + + std::string_view path_and_query() const { + return uri_part(path_start_, std::max(path_end_, query_end_)); + } + + std::string_view fragment() const { + return uri_part(fragment_start_, fragment_end_); + } +}; + +} // namespace dali + +#endif // DALI_UTIL_URI_H_ diff --git a/dali/util/uri_test.cc b/dali/util/uri_test.cc new file mode 100644 index 0000000000..5c71bd7d87 --- /dev/null +++ b/dali/util/uri_test.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/util/uri.h" +#include + +namespace dali { + +TEST(URI, Parse_1) { + auto uri = URI::Parse( + "https://john.doe@www.example.com:123/forum/questions/?tag=networking&order=newest#top"); + EXPECT_EQ("https", uri.scheme()); + EXPECT_EQ("john.doe@www.example.com:123", uri.authority()); + EXPECT_EQ("/forum/questions/", uri.path()); + EXPECT_EQ("tag=networking&order=newest", uri.query()); + EXPECT_EQ("/forum/questions/?tag=networking&order=newest", uri.path_and_query()); + EXPECT_EQ("top", uri.fragment()); +} + +TEST(URI, Parse_2) { + auto uri = URI::Parse( + "ldap://[2001:db8::7]/c=GB?objectClass?one"); + EXPECT_EQ("ldap", uri.scheme()); + EXPECT_EQ("[2001:db8::7]", uri.authority()); + EXPECT_EQ("/c=GB", uri.path()); + EXPECT_EQ("objectClass?one", uri.query()); + EXPECT_EQ("/c=GB?objectClass?one", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_3) { + auto uri = URI::Parse( + "mailto:John.Doe@example.com"); + EXPECT_EQ("mailto", uri.scheme()); + EXPECT_EQ("", uri.authority()); + EXPECT_EQ("John.Doe@example.com", uri.path()); + EXPECT_EQ("", uri.query()); + EXPECT_EQ("John.Doe@example.com", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_4) { + auto uri = URI::Parse( + "news:comp.infosystems.www.servers.unix"); + EXPECT_EQ("news", uri.scheme()); + EXPECT_EQ("", uri.authority()); + EXPECT_EQ("comp.infosystems.www.servers.unix", uri.path()); + EXPECT_EQ("", uri.query()); + EXPECT_EQ("comp.infosystems.www.servers.unix", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_5) { + auto uri = URI::Parse( + "tel:+1-816-555-1212"); + EXPECT_EQ("tel", uri.scheme()); + EXPECT_EQ("", uri.authority()); + EXPECT_EQ("+1-816-555-1212", uri.path()); + EXPECT_EQ("", uri.query()); + EXPECT_EQ("+1-816-555-1212", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_6) { + auto uri = URI::Parse( + "telnet://192.0.2.16:80/"); + EXPECT_EQ("telnet", uri.scheme()); + EXPECT_EQ("192.0.2.16:80", uri.authority()); + EXPECT_EQ("/", uri.path()); + EXPECT_EQ("", uri.query()); + EXPECT_EQ("/", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_7) { + auto uri = URI::Parse( + "urn:oasis:names:specification:docbook:dtd:xml:4.1.2"); + EXPECT_EQ("urn", uri.scheme()); + EXPECT_EQ("", uri.authority()); + EXPECT_EQ("oasis:names:specification:docbook:dtd:xml:4.1.2", uri.path()); + EXPECT_EQ("", uri.query()); + EXPECT_EQ("oasis:names:specification:docbook:dtd:xml:4.1.2", uri.path_and_query()); + EXPECT_EQ("", uri.fragment()); +} + +TEST(URI, Parse_Error1) { + auto uri = URI::Parse( + "telnet://192. 0.2.16:80/"); + EXPECT_FALSE(uri.valid()); +} + +TEST(URI, Parse_Error2) { + auto uri = URI::Parse( + "telnet://192.\n0.2.16:80/"); + EXPECT_FALSE(uri.valid()); +} + +TEST(URI, Parse_Error3) { + auto uri = URI::Parse("noscheme"); + EXPECT_FALSE(uri.valid()); +} + +} // namespace dali