Skip to content

Commit

Permalink
fixed hf pipeline script
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Sep 24, 2024
1 parent 176c55b commit 66ffb67
Showing 1 changed file with 63 additions and 27 deletions.
90 changes: 63 additions & 27 deletions pipeline/translate/translate_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,39 @@ def convert_simple_dict(d):
return {key: ast.literal_eval(value) if isinstance(value, str) and value.isdigit() else value for key, value in d.items()}

class TokenizedDataset(Dataset):
def __init__(self, tokenized_inputs):
def __init__(self, tokenized_inputs, sentence_ids):
self.tokenized_inputs = tokenized_inputs
self.sentence_ids = sentence_ids

def __len__(self):
return len(self.tokenized_inputs['input_ids'])

def __getitem__(self, idx):
return {key: val[idx] for key, val in self.tokenized_inputs.items()}
# Return both the tokenized input and the sentence ID
item = {key: val[idx] for key, val in self.tokenized_inputs.items()}
item['sentence_id'] = self.sentence_ids[idx]
return item

def merge_temp_files(output_file, num_processes, num_return_sequences):
from collections import defaultdict

all_sentences = defaultdict(list)

# Collect sentences by their ID
for rank in range(num_processes):
temp_file = f"{output_file}.rank{rank}.tmp"
with open(temp_file, 'r', encoding='utf-8') as infile:
for line in infile:
sentence_id = line.split(' ||| ')[0] # Get the sentence ID
all_sentences[sentence_id].append(line.strip())
os.remove(f"{output_file}.rank{rank}.tmp")

# Write the sorted sentences to the output file, keeping only the first 8 for each ID
with open(output_file, 'w', encoding='utf-8') as outfile:
for sentence_id, sentences in sorted(all_sentences.items(), key=lambda x: int(x[0])):
for sentence in sentences[:num_return_sequences]: # Keep only the first 8 sentences for each ID
outfile.write(f"{sentence}\n")


def main():
#os.environ['HF_HOME'] = args.modeldir
Expand All @@ -52,9 +77,6 @@ def main():
# Create a DataLoaderConfiguration object
dataloader_config = DataLoaderConfiguration(split_batches=True)

# Pass the config to the Accelerator
accelerator = Accelerator(device_placement=True, dataloader_config=dataloader_config)

print(f"Translating {args.filein} from {args.src} to {args.trg} with {args.modelname}...")

print("PyTorch version:", torch.__version__)
Expand All @@ -73,7 +95,6 @@ def main():
model_class = getattr(module, class_name)

model = model_class.from_pretrained(model_name, trust_remote_code=True)
model = accelerator.prepare(model) # Prepare model for distributed inference

# Mapping target languages
src_lang = lang_tags.get(args.src, None)
Expand All @@ -96,63 +117,70 @@ def main():
# Read the input text
with open(args.filein, 'r', encoding='utf-8') as infile:
text = infile.readlines()


sentence_ids = list(range(len(text))) # Create a list of sentence IDs (0, 1, 2, ...)

# Format sentences with prompt
formatted_text = [prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=t) for t in text]

# Tokenize all the inputs at once
tokenized_inputs = tokenizer(formatted_text, return_tensors='pt', padding=True)

# Prepare dataset and dataloader
dataset = TokenizedDataset(tokenized_inputs)
batch_size = 32
# Create the dataset with tokenized inputs and sentence IDs
dataset = TokenizedDataset(tokenized_inputs, sentence_ids)

# Pass the config to the Accelerator
accelerator = Accelerator(device_placement=True, dataloader_config=dataloader_config)

# Get the rank of the process (0 to 3 for 4 GPUs)
rank = accelerator.process_index
temp_file = f"{args.fileout}.rank{rank}.tmp" # Create a temporary file for each process

batch_size = 32 # 1 batch at a time 8 sents per GPU
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataloader = accelerator.prepare(dataloader)

dataloader = accelerator.prepare(dataloader) # Prepare dataset for distributed inference
model = accelerator.prepare(model) # Prepare model for distributed inference

print("Starting translations...")

# Accumulate multiple sentences in memory and write them to the file in larger batches
buffer_size = 1000000
buffer = []

with open(args.fileout, 'w', encoding='utf-8') as outfile:
with open(temp_file, 'w', encoding='utf-8') as outfile:
start_time = time.time()
sentence_counter = 0

for batch in dataloader:

sentence_ids = batch.pop('sentence_id').tolist() # Extract sentence IDs from the batch

augmented_sentence_ids = [x for x in sentence_ids for _ in range(num_return_sequences)]

# Generate output
translated_batch = model.module.generate(
**batch,
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
**config,
)

# Decode the output
translated_batch = tokenizer.batch_decode(translated_batch, skip_special_tokens=True)

# Write each translated sentence to the buffer
for i, sentence in enumerate(translated_batch):
curr_prompt = prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=text[i])
for id, sentence in zip(augmented_sentence_ids,translated_batch):
#source_id = round(i/num_return_sequences)+sentence_counter
curr_prompt = prompt.format(src_lang=src_lang, tgt_lang=tgt_lang, source=text[id])
sentence = sentence.replace(curr_prompt, "")

# Add to buffer
buffer.append(f"{sentence_counter} ||| {sentence}\n")

# Increment sentence counter every num_return_sequences sentences
if (i + 1) % num_return_sequences == 0:
sentence_counter += 1

buffer.append(f"{id} ||| {sentence}\n")

# When buffer is full, write it to file and clear the buffer
if len(buffer) >= buffer_size:
outfile.writelines(buffer) # Write buffer to file
buffer = [] # Clear the buffer

# Print progress every 50 sentences
if sentence_counter % 50 == 0:
print(f"Translated {sentence_counter} sentences...")

# If there are any remaining sentences in the buffer, flush them to the file
if buffer:
outfile.writelines(buffer)
Expand All @@ -165,5 +193,13 @@ def main():
print(f"Translation complete. Translating {len(text)} sentences took {total_time:.2f} seconds.")
print(f"{translations_per_second:.2f} translations/second")

accelerator.wait_for_everyone()

# Ensure all processes are done, then merge
if accelerator.is_main_process: # Only do the merging from the main process
merge_temp_files(args.fileout, accelerator.num_processes, num_return_sequences)
print(f"Merged all files into {args.fileout}")


if __name__ == "__main__":
main()

0 comments on commit 66ffb67

Please sign in to comment.