Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch_lightning.Trainer.fit() error #6

Closed
sipercai opened this issue Nov 12, 2024 · 3 comments · May be fixed by pyannote/pyannote-audio#1787
Closed

pytorch_lightning.Trainer.fit() error #6

sipercai opened this issue Nov 12, 2024 · 3 comments · May be fixed by pyannote/pyannote-audio#1787

Comments

@sipercai
Copy link

When I was running the 1_training_from_scratch script, there was no problem with trainer.validate(mamba_diar). However, when the trainer.fit(mamba_diar) statement was running, I encountered the following error message, which showed that there was no key of "metadata" in the task for *[self.prepared_data["metadata"][key] for key in balance]. Have you ever encountered a similar problem? Do you have any suggestions for solving it?

Details

{
"name": "KeyError",
"message": "Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in train__iter

*[self.prepared_data["metadata"][key] for key in balance]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in
*[self.prepared_data["metadata"][key] for key in balance]
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'metadata'
",
"stack": "---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[17], line 1
----> 1 trainer.fit(mamba_diar)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
976 self._signal_connector.register_signal_handlers()
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
986 log.debug(f"{self.class.name}: trainer tearing down")

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1025, in Trainer._run_stage(self)
1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025 self.fit_loop.run()
1026 return None
1027 raise RuntimeError(f"Unexpected state {self.state}")

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
138 while not self.done:
139 try:
--> 140 self.advance(data_fetcher)
141 self.on_advance_end(data_fetcher)
142 self._restarting = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:212, in _TrainingEpochLoop.advance(self, data_fetcher)
210 else:
211 dataloader_iter = None
--> 212 batch, _, __ = next(data_fetcher)
213 # TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the
214 # fetcher state so that the batch_idx is correct after restarting
215 batch_idx = self.batch_idx + 1

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:133, in _PrefetchDataFetcher.next(self)
130 self.done = not self.batches
131 elif not self.done:
132 # this will run only when no pre-fetching was done.
--> 133 batch = super().next()
134 else:
135 # the iterator is empty
136 raise StopIteration

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:60, in _DataFetcher.next(self)
58 self._start_profiler()
59 try:
---> 60 batch = next(self.iterator)
61 except StopIteration:
62 self.done = True

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:341, in CombinedLoader.next(self)
339 def next(self) -> _ITERATOR_RETURN:
340 assert self._iterator is not None
--> 341 out = next(self._iterator)
342 if isinstance(self._iterator, _Sequential):
343 return out

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:78, in _MaxSizeCycle.next(self)
76 for i in range(n):
77 try:
---> 78 out[i] = next(self.iterators[i])
79 except StopIteration:
80 self._consumed[i] = True

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.next(self)
627 if self._sampler_iter is None:
628 # TODO(pytorch/pytorch#76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1344, in _MultiProcessingDataLoaderIter._next_data(self)
1342 else:
1343 del self._task_info[idx]
-> 1344 return self._process_data(data)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1370, in _MultiProcessingDataLoaderIter._process_data(self, data)
1368 self._try_put_index()
1369 if isinstance(data, ExceptionWrapper):
-> 1370 data.reraise()
1371 return data

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
702 except TypeError:
703 # If the exception takes multiple arguments, don't try to
704 # instantiate since we don't know how to
705 raise RuntimeError(msg) from None
--> 706 raise exception

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in train__iter

*[self.prepared_data["metadata"][key] for key in balance]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in
*[self.prepared_data["metadata"][key] for key in balance]
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'metadata'
"
}

@FrenchKrab
Copy link
Collaborator

Thanks for the issue and detailed logs ! I pushed a PR that should fix this issue if you can install pyannote from git.

You can also ignore this issue by not using balance (do not pass balance to the task/pass None), maybe this is OK for your case. The balance parameter is there to uniformly sample according to a criterion (in the case of this paper, uniformly sample from each dataset).

@FrenchKrab
Copy link
Collaborator

Actually I realized there has been some breaking changes in latest pyannote versions while, so I'm not 100% confident this PR will work right away, especially with this repository (and I dont have the time to really test things for now). But if you apply the commit from this PR to pyannote 3.1 (git cherry-pick?) it should work. Or again, if you are OK not using it, it's probably the easiest option.

Sorry for the inconvenience !

@sipercai
Copy link
Author

Thank you for taking the time to reply amidst your busy schedule.

I have tried your two suggestions. Firstly, removing balance=['database'] from the task worked. After removing it, the training was successfully carried out!
Your pull request (PR) was also effective. I don't really know how to quickly apply your PR. The steps I tried were to copy down the corresponding modifications of your task.py and mixins.py in the site-packages/pyannote/audio of anaconda3/envs in the conda environment. After running it again, the normal training could also be achieved!

Thank you very much for your reply! Best regards!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants