Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State_dict on dataset seems to be called more often than expected #1268

Open
gokulavasan opened this issue Jun 10, 2024 · 2 comments
Open

State_dict on dataset seems to be called more often than expected #1268

gokulavasan opened this issue Jun 10, 2024 · 2 comments

Comments

@gokulavasan
Copy link
Contributor

🐛 Describe the bug

Consider the following code:

class DatasetStateIterable(torch.utils.data.IterableDataset, Stateful):
    def __init__(self, length):
        self.length = length

    def __iter__(self):
        return iter(list(range(self.length)))

    def state_dict(self):
		print("Calling state dict")
        return {"key": "value"}

    def load_state_dict(self, state_dict):
        pass

class TestSimple(TestCase):
    def test(self):
        dataset = DatasetStateIterable(100)
        dl = StatefulDataLoader(
            dataset=dataset,
			num_workers=1,
			snapshot_every_n_steps=10,
        )
        it = iter(dl)
		for _ in range(30):
			next(it)
        self.assertTrue(False)

Here snapshot frequency is set to every 10 steps. And the iteration is carried out for 30 steps. But here is the output on number of items (12 times) state_dict is called on the dataset

Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict

Versions

Latest git commit - 82918dd

@andrewkho
Copy link
Contributor

This is expected because we need to eagerly request state_dict from workers and have no idea if other workers are sending StopIterations, so we need to ask for more than expected

@gokulavasan
Copy link
Contributor Author

@andrewkho In the above example there is only 1 multiprocessing worker. Is that still expected?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants