Skip to content

Commit

Permalink
Support graph traversal in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Nov 5, 2024
1 parent 402162d commit 3e214c0
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from io import StringIO

from nanotron.logging import log_rank
from collections import deque
from collections import defaultdict, deque


# =============================================================================
Expand All @@ -41,6 +41,7 @@ def cyclic_iter(iter):
# Constants
# =============================================================================
KMER_LENGTH = 31 # overlap
ALPHABET = {"A", "T", "C", "G"}

# =============================================================================
# Dataset class
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self,
self.create_attention_mask = create_attention_mask
self.debug = debug
self.sampling_seq_len_inflection = sampling_seq_len_inflection


self.logger = logger
self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0)
Expand All @@ -87,6 +89,7 @@ def __init__(self,
# self.logging_func(f"[PetaGraphStreamDataset] Samples per epoch: {samples_per_epoch}")
self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}")
self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}")
self.logging_func(f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: {self.sampling_seq_len_inflection}")

self.VOCAB = vocabulary
self._pad_token_id = self.VOCAB["PAD"]
Expand Down Expand Up @@ -259,6 +262,7 @@ def decompression_func(self, input_data: Tuple[str, bytes]):

return path, decompressed_data

@staticmethod
def chop_at_first_repeated_kmer(sequence: str, k: int):
"""Chop the sequence at the first repeated kmer
Expand All @@ -278,8 +282,73 @@ def chop_at_first_repeated_kmer(sequence: str, k: int):
if kmer in kmers:
return sequence[:i + k - 1]
kmers.add(kmer)

return sequence

@staticmethod
def find_overlaps_and_build_graph(sequences, k_mer: int = 31):
"""Reconstruct assembly graph"""
min_overlap = k_mer - 1
prefix_dict = defaultdict(list)

# Precompute the suffixes
for i, seq in enumerate(sequences):
prefix_dict[seq[:min_overlap]].append(i)

graph = defaultdict(list)

# Check for overlaps
for i, seq1 in tqdm(enumerate(sequences), total=len(sequences)):
seq1_suffix = seq1[-min_overlap:]
graph[i] = []
for j in prefix_dict[seq1_suffix]:
if i != j:
graph[i].append(j)

return graph

# Perform random walk on the graph
@staticmethod
def dfs_paths(graph, start, path = None, all_paths = None, depth: int = 10):
"""Perform a depth-first search on the graph"""
if path is None:
path = [start] # Initialize the path with the starting node
if all_paths is None:
all_paths = [] # Initialize the list to store all paths

# If we revisit a node in the current path, it's a cycle, so we stop
if start in path[:-1]:
all_paths.append(path[:-1])
return all_paths

# Check if the current node is a leaf (no neighbors)
if start not in graph or not graph[start]:
all_paths.append(path) # Add the current path as a complete path
return all_paths

if len(path) >= depth:
all_paths.append(path)
return all_paths
# Explore each neighbor recursively, ensuring no cycles
for neighbor in graph[start]:
PetaGraphStreamDataset.dfs_paths(graph, neighbor, path + [neighbor], all_paths)

return all_paths

@staticmethod
def random_walk_graph_sequences(graph, sequences, k_mer: int = 31) -> list[str]:
"""Perform random walk on the graph"""
random_walk_sequences = []
for node in graph:
paths = PetaGraphStreamDataset.dfs_paths(graph, node)
idx = np.random.randint(len(paths))
path = paths[idx]
seq = sequences[path[0]] + "".join([sequences[p][k_mer-1:] for p in path[1:]])
# seq = seq[:MAX_SEQ_LENGTH]
random_walk_sequences.append(seq)

return random_walk_sequences


def length_sampling_filter(self, sequence: str) -> bool:
seq_len = len(sequence)
Expand Down Expand Up @@ -307,16 +376,25 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):

sequences = []
decoded_lines = data.decode()
sequences = [(path, str(s.seq)) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")]
sequences = [str(s.seq) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")]

# make sure only ALPHABET
sequences = ["".join([c for c in s if c in ALPHABET]) for s in sequences]

# Chop sequences in preparation for graph traversal
# TODO
sequences = [self.chop_at_first_repeated_kmer(s, k=KMER_LENGTH) for s in sequences]

# Construct sequence graph and perform random walks
# TODO
sequences_arr = np.array(sequences)
sequence_graph = self.find_overlaps_and_build_graph(sequences_arr, k_mer=KMER_LENGTH)
random_walk_sequences = self.random_walk_graph_sequences(sequence_graph, sequences_arr, k_mer=KMER_LENGTH)

# Sample sequences for training
keep_sequences = list(filter(self.length_sampling_filter, sequences))
keep_sequences = list(filter(self.length_sampling_filter, random_walk_sequences))

# Test outputs
assert isinstance(keep_sequences, list)
assert isinstance(keep_sequences[0], str)

return keep_sequences

Expand Down

0 comments on commit 3e214c0

Please sign in to comment.