Skip to content

Commit

Permalink
Tweak logged progress to support continuation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Aug 1, 2024
1 parent a9f30c2 commit fa7e26c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion petagraph/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_dataloader_from_data_stage(
num_dl_workers = data.num_loading_workers
log_rank(f"Using {num_dl_workers} dataloader workers", logger=logger, level=logging.INFO, rank=0)

# Set loggin directories
# Set logging directories
logging_directory = Path(trainer.config.checkpoints.checkpoints_path)
consumed_files_directory = logging_directory / "consumed_files"
with main_rank_first(trainer.parallel_context.world_pg):
Expand Down
41 changes: 38 additions & 3 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self,
log_directory: Path = None,
rank: int = 0,
packed: bool = False,
restart_consumed_files: list[str] = None,
restart_epoch: int = 0,
):

self.samples_per_epoch = samples_per_epoch
Expand All @@ -67,7 +69,7 @@ def __init__(self,
self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0)
self.logging_func("=====================================")
self.logging_func(f"[PetaGraphStreamDataset] Creating PetaGraphStreamDataset with maxlen {maxlen}")
self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}")
# self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}")
self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}")
self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}")

Expand All @@ -77,16 +79,40 @@ def __init__(self,
self._bos_token_id = self.VOCAB["BOS"]
self._unk_token_id = self.VOCAB["UNK"]

self.num_files = len(url_list)
self.current_epoch = 0

# TODO: Take list of already consumed lists and remove them from the
# url list, to continue training from the last checkpoint properly
if restart_consumed_files is not None:

# All files in restart_consumed_files should be present in the url_list
for f in restart_consumed_files:
assert f in url_list, f"File {f} from restart not found in the url_list"

# Remove those files from the url list and append them to the end
# of the url list
for f in restart_consumed_files:
url_list.remove(f)
url_list.append(f)

# Add the consumed files to the consumed files set
self.consumed_files = set(restart_consumed_files)

# Set the current epoch to the restart epoch
self.current_epoch = restart_epoch


if from_cloud:
# In order to make sure data are shuffled and sharded in the
# distributed environment, `shuffle` and `sharding_filter`
# are required. For detail, please check our tutorial in:
# https://pytorch.org/data/main/tutorial.html#working-with-dataloader
dp_s3_urls = IterableWrapper(url_list) # .list_files_by_s3()
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter().cycle()

# Sharding filter sets each n-th element to be processed by the current worker
# in case of multiple workers. Should maintain the same order of elements.
sharded_s3_urls = dp_s3_urls.sharding_filter().cycle()

# opened_files = S3FileLoader(sharded_s3_urls)
opened_files = FSSpecFileOpener(sharded_s3_urls, mode="rb")
Expand Down Expand Up @@ -148,6 +174,11 @@ def __init__(self,

self.logging_func("=====================================")


@staticmethod
def load_restart_consumed_files(restart_file: Path):
raise NotImplementedError("Loading restart files not implemented yet")

def decompression_func(self, input_data):
path, data = input_data
# if self.debug:
Expand Down Expand Up @@ -244,8 +275,12 @@ def generate(self):
if source_path not in self.consumed_files:
out_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"
with open(out_path, "a") as f:
f.write(f"{source_path}\n")
f.write(f"{self.current_epoch}_{source_path}\n")
self.consumed_files.add(source_path)
if len(self.consumed_files) == self.num_files:
self.current_epoch += 1
self.logging_func(f"Epoch {self.current_epoch} completed")
self.consumed_files = set()

except StopIteration:
self.logger.warning(f"Reached end of dataset")
Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ def train_step_logs(
num_consumed_files = len(current_dataset.consumed_files)
log_entries.append(LogItem("num_consumed_files", num_consumed_files, "human_format"))

if hasattr(current_dataset, "current_epoch"):
log_entries.append(LogItem("current_epoch", current_dataset.current_epoch, "human_format"))

if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))

Expand Down

0 comments on commit fa7e26c

Please sign in to comment.