Skip to content

Commit

Permalink
remove unnecessary imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Aug 26, 2024
1 parent d93d3c4 commit c024661
Showing 1 changed file with 6 additions and 119 deletions.
125 changes: 6 additions & 119 deletions test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,134 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ctypes
import errno
import faulthandler
import functools
import gc
import itertools

import math
import operator
import os
import signal
import sys
import tempfile
import time
import unittest
import warnings

import torch
import torch.utils.data.datapipes as dp
from torch import multiprocessing as mp
from torch._utils import ExceptionWrapper
from torch.testing._internal.common_device_type import instantiate_device_type_tests

from torch.testing._internal.common_utils import (
IS_CI,
IS_JETSON,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
load_tests,
NO_MULTIPROCESSING_SPAWN,
parametrize,
run_tests,
skipIfNoDill,
skipIfRocm,
slowTest,
TEST_CUDA,
TEST_NUMPY,
TEST_WITH_ASAN,
TEST_WITH_TSAN,
TestCase,
)

from torch.utils.data import (
_utils,
ChainDataset,
ConcatDataset,
Dataset,
IterableDataset,
IterDataPipe,
StackDataset,
Subset,
TensorDataset,
)
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.datapipes.iter import IterableWrapper
from torch.utils.data.dataset import random_split
from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TestCase

from torch.utils.data import Dataset

from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler


try:
import psutil

HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
err_msg = (
"psutil not found. Some critical data loader tests relying on it "
"(e.g., TestDataLoader.test_proper_exit) will not run."
)
if IS_CI:
raise ImportError(err_msg) from None
else:
warnings.warn(err_msg)


try:
import numpy as np

HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")

# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests

if TEST_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")

if not NO_MULTIPROCESSING_SPAWN:
# We want to use `spawn` if able because some of our tests check that the
# data loader terminiates gracefully. To prevent hanging in the testing
# process, such data loaders are run in a separate subprocess.
#
# We also want to test the `pin_memory=True` configuration, thus `spawn` is
# required to launch such processes and they initialize the CUDA context.
#
# Mixing different start method is a recipe for disaster (e.g., using a fork
# `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
# to avoid bugs.
#
# Get a multiprocessing context because some test / third party library will
# set start_method when imported, and setting again triggers `RuntimeError`.
mp = mp.get_context(method="spawn")


# 60s of timeout?
# Yes, in environments where physical CPU resources are shared, e.g., CI, the
# time for a inter-process communication can be highly varying. With 15~17s of
# timeout, we have observed flakiness in some CI builds (see
# pytorch/pytorch#14501, pytorch/pytorch#16608). We follow the CPython
# multiprocessing setup and set the timeout to 60s here:
#
# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
JOIN_TIMEOUT = 60.0 # seconds


supported_multiprocessing_contexts = [None] + list(torch.multiprocessing.get_all_start_methods())


# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
def _clone_collate(b):
return [x.clone() for x in b]


class MockDataset(Dataset):
def __init__(self, size):
self.size = size
Expand Down Expand Up @@ -175,7 +62,7 @@ def test_dataloader_state_dict(self):
dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=sampler)
# Partial iteration over the DataLoader
iter_count = 5
for i, data in enumerate(dataloader):
for i, _ in enumerate(dataloader):
if i == iter_count - 1:
break
state_dict = dataloader.state_dict()
Expand Down Expand Up @@ -290,7 +177,7 @@ def test_data_distribution_across_replicas(self):
all_data = []
for rank in range(num_replicas):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=num_replicas, rank=rank, shuffle=False)
dataloader = torch.utils.data.DataLoader(self.dataset, sampler=sampler)
dataloader = StatefulDataLoader(self.dataset, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend([int(x.item()) for x in batch])
Expand Down

0 comments on commit c024661

Please sign in to comment.