Skip to content

Commit

Permalink
Provide safeguards during training (#168)
Browse files Browse the repository at this point in the history
* fix: add safeguards during data processing

Signed-off-by: Oleg S <[email protected]>

* fix: add a safeguard for max_batch_len & max_seq_len in training

We currently have certain values that need to be validated against others, but no logic to ensure that this
works adequately. This commit provides a pre-training check that errors out if the value of max_batch_len
is smaller than max_seq_len, since this breaks our ability to generate training batches

Signed-off-by: Oleg S <[email protected]>

* fix:  add fallback logic to use the distributed sampler

When we use the multipack sampler, it requires a certain shape of the dataset relative to the
GPUs to be able to sufficiently distribute all of the samples across different nodes.
When this happens, the train loaderbecomes empty which prevents us from being able to train.

This commit resolves that issue by falling back to the distributed sampler when the multipack
fails.

Signed-off-by: Oleg S <[email protected]>

---------

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail authored Aug 14, 2024
1 parent b1a54c1 commit 9e2ac74
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,18 @@ def main(args: DataProcessArgs):
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>"]}
)

data = load_dataset("json", data_files=args.data_path, split="train")
try:
data = load_dataset("json", data_files=args.data_path, split="train")
except:
# pylint: disable=raise-missing-from,broad-exception-raised
raise Exception(
"Malformed or missing data, please ensure that your dataset is not empty and correctly formatted"
)

if data.num_rows == 0:
raise ValueError(
"The provided dataset is empty, please make sure that your dataset contains samples and try again."
)

print(f"\033[92mtokenizing the dataset with {args.model_path} tokenizer...\033[0m")
data_with_input_ids = data.map(
Expand All @@ -230,6 +241,10 @@ def main(args: DataProcessArgs):
f"\033[36mat {args.max_seq_len} max sequence length, the number of samples to be dropped is {num_dropped_samples}\033[0m"
)
print(f"\033[36m({((num_dropped_samples / len(lens)) * 100):.2f}% of total)\033[0m")
if num_dropped_samples == len(data):
raise RuntimeError(
f"Dataset does not contain any samples containing less than {args.max_seq_len=} tokens.\nPlease consider increasing your `max_seq_len` value, or adding more samples."
)

lowest_10_percent = np.quantile(lens, (0 + np.arange(11)) / 100.0)
for i, q in enumerate(lowest_10_percent):
Expand Down
24 changes: 24 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,25 @@ def main(args):
sampler=args.sampler,
seed=args.seed,
)
if len(train_loader) == 0:
# this happens sometimes when we have more GPUs than data to process. In this case
# we should either alert the user to switch samplers, or do it automatically and
# warn them about it happening
print(
"\033[93mThe dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!\033[0m"
)
args.sampler = "distributed"
train_loader = setup_dataloader(
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
sampler=args.sampler,
seed=args.seed,
)

if args.local_rank == 0:
metric_logger.log_sync(
Expand Down Expand Up @@ -585,6 +604,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

# process the training data
if not os.path.exists(train_args.data_output_dir):
Expand Down

0 comments on commit 9e2ac74

Please sign in to comment.