Skip to content

Commit

Permalink
Add lock
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Nov 6, 2024
1 parent 3b46e69 commit e91c17d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion petagraph/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_dataloader_from_data_stage(

# Set or read from config dataloader workers
num_dl_workers = data.num_loading_workers
assert num_dl_workers == 0, "num_dl_workers must be 0 for the current implementation for robust data loading under streaming from AWS"
# assert num_dl_workers == 0, "num_dl_workers must be 0 for the current implementation for robust data loading under streaming from AWS"
log_rank(f"Using {num_dl_workers} dataloader workers", logger=logger, level=logging.INFO, rank=0)

# Set logging directories
Expand Down
12 changes: 10 additions & 2 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from typing import Dict, Optional, Tuple
import json
import multiprocessing as mp

# import zstd
import zstandard
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self,
self.log_directory = log_directory
self.num_consumed_sequences = 0
self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"

self.consumed_files_lock = mp.Lock()

# Save the vocabulary as json on head node
if self.rank == 0:
Expand Down Expand Up @@ -192,7 +193,7 @@ def __init__(self,
# sequences_unbatched = sequences_unbatched.prefetch(self.prefetch_sequences)

self.logging_func(f"Prefetching and shuffling {self.prefetch_sequences} unbatched sequences")
sequences_unbatched = Shuffler(sequences_unbatched, buffer_size=self.prefetch_sequences)
sequences_unbatched = Shuffler(sequences_unbatched, buffer_size=self.prefetch_sequences).prefetch(16_000)

# sequences_crop = Mapper(sequences_unbatched, self.crop_maxlen)
# sequences_tokenized = Mapper(sequences_crop, self.tokenize_and_pad)
Expand Down Expand Up @@ -396,6 +397,9 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
keep_sequences = [(path, s) for s in filter(self.length_sampling_filter, random_walk_sequences)]

# Test outputs
if len(keep_sequences) == 0:
return [[]]

assert isinstance(keep_sequences, list)
assert isinstance(keep_sequences[0], tuple) and len(keep_sequences[0]) == 2
assert isinstance(keep_sequences[0][0], str) and isinstance(keep_sequences[0][1], str)
Expand Down Expand Up @@ -447,8 +451,12 @@ def generate(self):
# Log the consumed files
if self.log_directory is not None:
if source_path not in self.consumed_files:

self.consumed_files_lock.acquire()
with open(self.consumed_files_path, "a") as f:
f.write(f"{self.current_epoch}_{source_path}\n")
self.consumed_files_lock.release()

self.consumed_files.add(source_path)
if len(self.consumed_files) == self.num_files:
self.current_epoch += 1
Expand Down

0 comments on commit e91c17d

Please sign in to comment.