From 9e2ac746a877e2b2ed6ff2ba54a04c1a22dadb84 Mon Sep 17 00:00:00 2001 From: Oleg <97077423+RobotSail@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:55:29 -0400 Subject: [PATCH] Provide safeguards during training (#168) * fix: add safeguards during data processing Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> * fix: add a safeguard for max_batch_len & max_seq_len in training We currently have certain values that need to be validated against others, but no logic to ensure that this works adequately. This commit provides a pre-training check that errors out if the value of max_batch_len is smaller than max_seq_len, since this breaks our ability to generate training batches Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> * fix: add fallback logic to use the distributed sampler When we use the multipack sampler, it requires a certain shape of the dataset relative to the GPUs to be able to sufficiently distribute all of the samples across different nodes. When this happens, the train loaderbecomes empty which prevents us from being able to train. This commit resolves that issue by falling back to the distributed sampler when the multipack fails. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> --------- Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/training/data_process.py | 17 ++++++++++++++++- src/instructlab/training/main_ds.py | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index bca8875f..739fd93a 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -204,7 +204,18 @@ def main(args: DataProcessArgs): {"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>"]} ) - data = load_dataset("json", data_files=args.data_path, split="train") + try: + data = load_dataset("json", data_files=args.data_path, split="train") + except: + # pylint: disable=raise-missing-from,broad-exception-raised + raise Exception( + "Malformed or missing data, please ensure that your dataset is not empty and correctly formatted" + ) + + if data.num_rows == 0: + raise ValueError( + "The provided dataset is empty, please make sure that your dataset contains samples and try again." + ) print(f"\033[92mtokenizing the dataset with {args.model_path} tokenizer...\033[0m") data_with_input_ids = data.map( @@ -230,6 +241,10 @@ def main(args: DataProcessArgs): f"\033[36mat {args.max_seq_len} max sequence length, the number of samples to be dropped is {num_dropped_samples}\033[0m" ) print(f"\033[36m({((num_dropped_samples / len(lens)) * 100):.2f}% of total)\033[0m") + if num_dropped_samples == len(data): + raise RuntimeError( + f"Dataset does not contain any samples containing less than {args.max_seq_len=} tokens.\nPlease consider increasing your `max_seq_len` value, or adding more samples." + ) lowest_10_percent = np.quantile(lens, (0 + np.arange(11)) / 100.0) for i, q in enumerate(lowest_10_percent): diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0c5aa004..55aa1d67 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -555,6 +555,25 @@ def main(args): sampler=args.sampler, seed=args.seed, ) + if len(train_loader) == 0: + # this happens sometimes when we have more GPUs than data to process. In this case + # we should either alert the user to switch samplers, or do it automatically and + # warn them about it happening + print( + "\033[93mThe dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!\033[0m" + ) + args.sampler = "distributed" + train_loader = setup_dataloader( + dataset, + tokenizer.pad_token_id, + num_workers=8, + is_granite=args.is_granite, + max_batch_len=args.max_batch_len, + packing_max_batch_len=packing_max_batch_len, + samples_per_gpu=args.samples_per_gpu, + sampler=args.sampler, + seed=args.seed, + ) if args.local_rank == 0: metric_logger.log_sync( @@ -585,6 +604,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """ + # early validation logic here + if train_args.max_batch_len < train_args.max_seq_len: + raise ValueError( + f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" + ) # process the training data if not os.path.exists(train_args.data_output_dir):