Skip to content

Commit

Permalink
include gradient accumulation steps in wps
Browse files Browse the repository at this point in the history
  • Loading branch information
philippguevorguian committed Dec 12, 2023
1 parent 649fc83 commit c1c0ce1
Showing 1 changed file with 70 additions and 31 deletions.
101 changes: 70 additions & 31 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

logger = get_logger(__name__)


def calc_hash_for_binary_file(path):
with open(path, "rb") as _file:
file_content = _file.read()
Expand Down Expand Up @@ -111,7 +112,12 @@ def on_step_begin(self, args, state, control, model, **kwargs):
if self._start_time is not None:
batch_size = args.per_device_train_batch_size
# Calculate tokens in batch
num_words = batch_size * self._block_size * args.world_size
num_words = (
batch_size
* self._block_size
* args.world_size
* args.gradient_accumulation_steps
)
# Calculate time taken for this step
elapsed_time = time.time() - self._start_time
# Calculate words per second
Expand Down Expand Up @@ -152,7 +158,7 @@ def __init__(self, accelerator, model_config, use_flash_attn=False):
self.model_config = model_config
self.train_config = model_train_configs[model_config]
self.use_flash_attn = use_flash_attn

def on_save(self, args, state, control, model, **kwargs):
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -176,14 +182,16 @@ def on_save(self, args, state, control, model, **kwargs):
del inp["token_type_ids"]
inp = {k: inp[k].unsqueeze(0).to(model.device) for k in inp.keys()}
batches.append(inp)
if i == 20: break
if i == 20:
break

checkpoint_dir = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)

# accelerator = Accelerator()
# saved_model = load_model(f"facebook/galactica-{self.model_config}", use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16)
# saved_model = load_model(f"facebook/galactica-{self.model_config}",
# use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16)
# saved_model.resize_token_embeddings(
# self.train_config["vocab_size"] + len(chemlactica_special_tokens)
# )
Expand All @@ -194,25 +202,32 @@ def on_save(self, args, state, control, model, **kwargs):
# config = AutoConfig.from_pretrained("facebook/galactica-125m")
# config.vocab_size = 50028
# state_dict = torch.load(f"{checkpoint_dir}/pytorch_model.bin")
# saved_model = CustomOPTForCausalLM.from_pretrained(None, config=config, state_dict=state_dict, use_flash_attn=self.use_flash_attn, torch_dtype=torch.bfloat16).to(model.device)

contexts = [
"[CLOGP 100][START_SMILES]",
"[SAS 1][START_SMILES]",
"[WEIGHT 41.123][START_SMILES]",
"random input",
]
# saved_model = CustomOPTForCausalLM.from_pretrained(
# None, config=config, state_dict=state_dict,
# use_flash_attn=self.use_flash_attn,
# torch_dtype=torch.bfloat16).to(model.device)

# contexts = [
# "[CLOGP 100][START_SMILES]",
# "[SAS 1][START_SMILES]",
# "[WEIGHT 41.123][START_SMILES]",
# "random input",
# ]

model.eval()
model_logits = []
model_gen_toks = {}
# model_gen_toks = {}
with torch.no_grad():
for i, batch in enumerate(batches):
model_logits.append(model(**batch))

if torch.distributed.get_rank() == 0:
print(f"Loading from checkpoint: {checkpoint_dir} (process {torch.distributed.get_rank()})")
saved_model = load_model(checkpoint_dir, use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16)
print(
f"Loading from checkpoint: {checkpoint_dir} (process {torch.distributed.get_rank()})" # noqa
)
saved_model = load_model(
checkpoint_dir, use_flash_attn=self.use_flash_attn, dtype=torch.bfloat16
)
saved_model.to(model.device)

saved_model.eval()
Expand All @@ -224,13 +239,21 @@ def on_save(self, args, state, control, model, **kwargs):

logits_diff = torch.abs(out.logits - saved_md_out.logits)
if logits_diff.max() != 0:
print(f"MISMATCH: logits difference {i} min {logits_diff.min()}, max {logits_diff.max()}, mean {logits_diff.mean()}, median {logits_diff.median()}")
print(
f"MISMATCH: logits difference {i} min {logits_diff.min()}, max {logits_diff.max()}, mean {logits_diff.mean()}, median {logits_diff.median()}" # noqa
)
loss_diff = torch.abs(out.loss - saved_md_out.loss)
if loss_diff != 0:
print(f"MISMATCH: loss difference {loss_diff}")
different_tokens_count = torch.sum(out.logits.softmax(-1).argmax(-1) != saved_md_out.logits.softmax(-1).argmax(-1))
different_tokens_count = torch.sum(
out.logits.softmax(-1).argmax(-1)
!= saved_md_out.logits.softmax(-1).argmax(-1)
)
if different_tokens_count != 0:
print("MISMATCH: different token count", different_tokens_count.item())
print(
"MISMATCH: different token count",
different_tokens_count.item(),
)

# for cont in contexts:
# max_length = 400
Expand All @@ -245,44 +268,60 @@ def on_save(self, args, state, control, model, **kwargs):
# saved_md_generated_toks = saved_md_generated_toks.squeeze()
# maximum = max(len(generated_toks), len(saved_md_generated_toks))
# print(len(saved_md_generated_toks), len(generated_toks), maximum)
# generated_toks = F.pad(generated_toks, pad=(0, maximum - len(generated_toks)), mode='constant', value=0)
# saved_md_generated_toks = F.pad(saved_md_generated_toks, pad=(0, maximum - len(saved_md_generated_toks)), mode='constant', value=0)
# generated_toks = F.pad(generated_toks,
# pad=(0, maximum - len(generated_toks)),
# mode='constant', value=0)
# saved_md_generated_toks = F.pad(saved_md_generated_toks,
# pad=(0, maximum - len(saved_md_generated_toks)),
# mode='constant', value=0)
# print(generated_toks.shape, saved_md_generated_toks.shape)
# diff_gen_tokens = torch.sum(generated_toks.squeeze() != saved_md_generated_toks.squeeze())
# print(f"Checking diff generated tokens (max_length={max_length}) '{cont}': count {diff_gen_tokens}")


torch.distributed.barrier()


# return the usual dataloader, no batches skipped
accelerate.skip_first_batches = lambda dataloader, num_batches=0: dataloader


class JsonlDatasetResumeCallback(TrainerCallback):
def __init__(self, shared_jsonl_files):
self.shared_jsonl_files = shared_jsonl_files

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if args.resume_from_checkpoint: # resume training
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if args.resume_from_checkpoint: # resume training
print("Resuming from saved jsonl states.")
with open(os.path.join(args.resume_from_checkpoint, "jsonl_states.json"), "r") as file:
with open(
os.path.join(args.resume_from_checkpoint, "jsonl_states.json"), "r"
) as file:
jsonl_states = json.load(file)

assert not self.shared_jsonl_files
for name, state in jsonl_states.items():
print(f"loadeding state {name}: {state}")
self.shared_jsonl_files[name] = state


def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
assert self.shared_jsonl_files
jsonl_states = {key: value for key, value in self.shared_jsonl_files.items()}
print(jsonl_states)

checkpoint_dir = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)
print(f"Saving jsonl states")
print("Saving jsonl states")
for name, state in jsonl_states.items():
print(name, state)
with open(os.path.join(checkpoint_dir, "jsonl_states.json"), "w") as file:
json.dump(jsonl_states, file, indent=4)
json.dump(jsonl_states, file, indent=4)

0 comments on commit c1c0ce1

Please sign in to comment.