From c1c0ce1633044de4ac70e64a09fc7ed8dc3b6443 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 12 Dec 2023 15:20:11 +0400 Subject: [PATCH] include gradient accumulation steps in wps --- src/callbacks.py | 101 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 31 deletions(-) diff --git a/src/callbacks.py b/src/callbacks.py index a91c633..4639ce3 100644 --- a/src/callbacks.py +++ b/src/callbacks.py @@ -24,6 +24,7 @@ logger = get_logger(__name__) + def calc_hash_for_binary_file(path): with open(path, "rb") as _file: file_content = _file.read() @@ -111,7 +112,12 @@ def on_step_begin(self, args, state, control, model, **kwargs): if self._start_time is not None: batch_size = args.per_device_train_batch_size # Calculate tokens in batch - num_words = batch_size * self._block_size * args.world_size + num_words = ( + batch_size + * self._block_size + * args.world_size + * args.gradient_accumulation_steps + ) # Calculate time taken for this step elapsed_time = time.time() - self._start_time # Calculate words per second @@ -152,7 +158,7 @@ def __init__(self, accelerator, model_config, use_flash_attn=False): self.model_config = model_config self.train_config = model_train_configs[model_config] self.use_flash_attn = use_flash_attn - + def on_save(self, args, state, control, model, **kwargs): gc.collect() torch.cuda.empty_cache() @@ -176,14 +182,16 @@ def on_save(self, args, state, control, model, **kwargs): del inp["token_type_ids"] inp = {k: inp[k].unsqueeze(0).to(model.device) for k in inp.keys()} batches.append(inp) - if i == 20: break + if i == 20: + break checkpoint_dir = os.path.join( args.output_dir, f"checkpoint-{state.global_step}" ) # accelerator = Accelerator() - # saved_model = load_model(f"facebook/galactica-{self.model_config}", use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16) + # saved_model = load_model(f"facebook/galactica-{self.model_config}", + # use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16) # saved_model.resize_token_embeddings( # self.train_config["vocab_size"] + len(chemlactica_special_tokens) # ) @@ -194,25 +202,32 @@ def on_save(self, args, state, control, model, **kwargs): # config = AutoConfig.from_pretrained("facebook/galactica-125m") # config.vocab_size = 50028 # state_dict = torch.load(f"{checkpoint_dir}/pytorch_model.bin") - # saved_model = CustomOPTForCausalLM.from_pretrained(None, config=config, state_dict=state_dict, use_flash_attn=self.use_flash_attn, torch_dtype=torch.bfloat16).to(model.device) - - contexts = [ - "[CLOGP 100][START_SMILES]", - "[SAS 1][START_SMILES]", - "[WEIGHT 41.123][START_SMILES]", - "random input", - ] + # saved_model = CustomOPTForCausalLM.from_pretrained( + # None, config=config, state_dict=state_dict, + # use_flash_attn=self.use_flash_attn, + # torch_dtype=torch.bfloat16).to(model.device) + + # contexts = [ + # "[CLOGP 100][START_SMILES]", + # "[SAS 1][START_SMILES]", + # "[WEIGHT 41.123][START_SMILES]", + # "random input", + # ] model.eval() model_logits = [] - model_gen_toks = {} + # model_gen_toks = {} with torch.no_grad(): for i, batch in enumerate(batches): model_logits.append(model(**batch)) - + if torch.distributed.get_rank() == 0: - print(f"Loading from checkpoint: {checkpoint_dir} (process {torch.distributed.get_rank()})") - saved_model = load_model(checkpoint_dir, use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16) + print( + f"Loading from checkpoint: {checkpoint_dir} (process {torch.distributed.get_rank()})" # noqa + ) + saved_model = load_model( + checkpoint_dir, use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16 + ) saved_model.to(model.device) saved_model.eval() @@ -224,13 +239,21 @@ def on_save(self, args, state, control, model, **kwargs): logits_diff = torch.abs(out.logits - saved_md_out.logits) if logits_diff.max() != 0: - print(f"MISMATCH: logits difference {i} min {logits_diff.min()}, max {logits_diff.max()}, mean {logits_diff.mean()}, median {logits_diff.median()}") + print( + f"MISMATCH: logits difference {i} min {logits_diff.min()}, max {logits_diff.max()}, mean {logits_diff.mean()}, median {logits_diff.median()}" # noqa + ) loss_diff = torch.abs(out.loss - saved_md_out.loss) if loss_diff != 0: print(f"MISMATCH: loss difference {loss_diff}") - different_tokens_count = torch.sum(out.logits.softmax(-1).argmax(-1) != saved_md_out.logits.softmax(-1).argmax(-1)) + different_tokens_count = torch.sum( + out.logits.softmax(-1).argmax(-1) + != saved_md_out.logits.softmax(-1).argmax(-1) + ) if different_tokens_count != 0: - print("MISMATCH: different token count", different_tokens_count.item()) + print( + "MISMATCH: different token count", + different_tokens_count.item(), + ) # for cont in contexts: # max_length = 400 @@ -245,26 +268,37 @@ def on_save(self, args, state, control, model, **kwargs): # saved_md_generated_toks = saved_md_generated_toks.squeeze() # maximum = max(len(generated_toks), len(saved_md_generated_toks)) # print(len(saved_md_generated_toks), len(generated_toks), maximum) - # generated_toks = F.pad(generated_toks, pad=(0, maximum - len(generated_toks)), mode='constant', value=0) - # saved_md_generated_toks = F.pad(saved_md_generated_toks, pad=(0, maximum - len(saved_md_generated_toks)), mode='constant', value=0) + # generated_toks = F.pad(generated_toks, + # pad=(0, maximum - len(generated_toks)), + # mode='constant', value=0) + # saved_md_generated_toks = F.pad(saved_md_generated_toks, + # pad=(0, maximum - len(saved_md_generated_toks)), + # mode='constant', value=0) # print(generated_toks.shape, saved_md_generated_toks.shape) - # diff_gen_tokens = torch.sum(generated_toks.squeeze() != saved_md_generated_toks.squeeze()) - # print(f"Checking diff generated tokens (max_length={max_length}) '{cont}': count {diff_gen_tokens}") - + torch.distributed.barrier() # return the usual dataloader, no batches skipped accelerate.skip_first_batches = lambda dataloader, num_batches=0: dataloader + class JsonlDatasetResumeCallback(TrainerCallback): def __init__(self, shared_jsonl_files): self.shared_jsonl_files = shared_jsonl_files - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - if args.resume_from_checkpoint: # resume training + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if args.resume_from_checkpoint: # resume training print("Resuming from saved jsonl states.") - with open(os.path.join(args.resume_from_checkpoint, "jsonl_states.json"), "r") as file: + with open( + os.path.join(args.resume_from_checkpoint, "jsonl_states.json"), "r" + ) as file: jsonl_states = json.load(file) assert not self.shared_jsonl_files @@ -272,8 +306,13 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: print(f"loadeding state {name}: {state}") self.shared_jsonl_files[name] = state - - def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): assert self.shared_jsonl_files jsonl_states = {key: value for key, value in self.shared_jsonl_files.items()} print(jsonl_states) @@ -281,8 +320,8 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer checkpoint_dir = os.path.join( args.output_dir, f"checkpoint-{state.global_step}" ) - print(f"Saving jsonl states") + print("Saving jsonl states") for name, state in jsonl_states.items(): print(name, state) with open(os.path.join(checkpoint_dir, "jsonl_states.json"), "w") as file: - json.dump(jsonl_states, file, indent=4) \ No newline at end of file + json.dump(jsonl_states, file, indent=4)