diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index d3ad74ba2..6b4319843 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -16,6 +16,10 @@ def s3_test(): from torchdata._torchdata import S3Handler +def stateful_dataloader_test(): + from torchdata.stateful_dataloader import StatefulDataLoader + + if __name__ == "__main__": r""" TorchData Smoke Test @@ -26,3 +30,5 @@ def s3_test(): options = parser.parse_args() if options.s3: s3_test() + + stateful_dataloader_test()