diff --git a/petagraph/run_train.py b/petagraph/run_train.py index 2a581647..a9d3db21 100644 --- a/petagraph/run_train.py +++ b/petagraph/run_train.py @@ -203,7 +203,7 @@ def get_dataloader_from_data_stage( num_dl_workers = data.num_loading_workers log_rank(f"Using {num_dl_workers} dataloader workers", logger=logger, level=logging.INFO, rank=0) - # Set loggin directories + # Set logging directories logging_directory = Path(trainer.config.checkpoints.checkpoints_path) consumed_files_directory = logging_directory / "consumed_files" with main_rank_first(trainer.parallel_context.world_pg): diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index ab55f25f..96af2197 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -56,6 +56,8 @@ def __init__(self, log_directory: Path = None, rank: int = 0, packed: bool = False, + restart_consumed_files: list[str] = None, + restart_epoch: int = 0, ): self.samples_per_epoch = samples_per_epoch @@ -67,7 +69,7 @@ def __init__(self, self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0) self.logging_func("=====================================") self.logging_func(f"[PetaGraphStreamDataset] Creating PetaGraphStreamDataset with maxlen {maxlen}") - self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}") + # self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}") self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}") self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}") @@ -77,8 +79,29 @@ def __init__(self, self._bos_token_id = self.VOCAB["BOS"] self._unk_token_id = self.VOCAB["UNK"] + self.num_files = len(url_list) + self.current_epoch = 0 + # TODO: Take list of already consumed lists and remove them from the # url list, to continue training from the last checkpoint properly + if restart_consumed_files is not None: + + # All files in restart_consumed_files should be present in the url_list + for f in restart_consumed_files: + assert f in url_list, f"File {f} from restart not found in the url_list" + + # Remove those files from the url list and append them to the end + # of the url list + for f in restart_consumed_files: + url_list.remove(f) + url_list.append(f) + + # Add the consumed files to the consumed files set + self.consumed_files = set(restart_consumed_files) + + # Set the current epoch to the restart epoch + self.current_epoch = restart_epoch + if from_cloud: # In order to make sure data are shuffled and sharded in the @@ -86,7 +109,10 @@ def __init__(self, # are required. For detail, please check our tutorial in: # https://pytorch.org/data/main/tutorial.html#working-with-dataloader dp_s3_urls = IterableWrapper(url_list) # .list_files_by_s3() - sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter().cycle() + + # Sharding filter sets each n-th element to be processed by the current worker + # in case of multiple workers. Should maintain the same order of elements. + sharded_s3_urls = dp_s3_urls.sharding_filter().cycle() # opened_files = S3FileLoader(sharded_s3_urls) opened_files = FSSpecFileOpener(sharded_s3_urls, mode="rb") @@ -148,6 +174,11 @@ def __init__(self, self.logging_func("=====================================") + + @staticmethod + def load_restart_consumed_files(restart_file: Path): + raise NotImplementedError("Loading restart files not implemented yet") + def decompression_func(self, input_data): path, data = input_data # if self.debug: @@ -244,8 +275,12 @@ def generate(self): if source_path not in self.consumed_files: out_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt" with open(out_path, "a") as f: - f.write(f"{source_path}\n") + f.write(f"{self.current_epoch}_{source_path}\n") self.consumed_files.add(source_path) + if len(self.consumed_files) == self.num_files: + self.current_epoch += 1 + self.logging_func(f"Epoch {self.current_epoch} completed") + self.consumed_files = set() except StopIteration: self.logger.warning(f"Reached end of dataset") diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 060837f8..73988f9c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -609,6 +609,9 @@ def train_step_logs( num_consumed_files = len(current_dataset.consumed_files) log_entries.append(LogItem("num_consumed_files", num_consumed_files, "human_format")) + if hasattr(current_dataset, "current_epoch"): + log_entries.append(LogItem("current_epoch", current_dataset.current_epoch, "human_format")) + if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))