diff --git a/src/sparseml/transformers/finetune/data/data_args.py b/src/sparseml/transformers/finetune/data/data_args.py index c332ac65bb7..9517a19e4de 100644 --- a/src/sparseml/transformers/finetune/data/data_args.py +++ b/src/sparseml/transformers/finetune/data/data_args.py @@ -118,6 +118,12 @@ class DataTrainingArguments(CustomDataTrainingArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) + shuffle_calibration_samples: Optional[bool] = field( + default=True, + metadata={ + "help": "whether to shuffle the dataset before selecting calibration data" + }, + ) streaming: Optional[bool] = field( default=False, metadata={"help": "True to stream data from a cloud dataset"}, diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py index 243f4085023..8fa8eb9bca3 100644 --- a/src/sparseml/transformers/finetune/data/data_helpers.py +++ b/src/sparseml/transformers/finetune/data/data_helpers.py @@ -18,7 +18,7 @@ import torch from datasets import Dataset, load_dataset -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -36,6 +36,7 @@ def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, + do_shuffle: bool = True, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, ) -> List[torch.Tensor]: @@ -45,6 +46,8 @@ def format_calibration_data( :param tokenized_dataset: dataset to convert to dataloader :param num_calibration_samples: number of data samples to convert + :param do_shuffle: whether to shuffle the dataset before selecting calibration + samples, true by default :param collate_fn: optional custom collate function, or use default :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors @@ -58,17 +61,20 @@ def format_calibration_data( f"the provided dataset only has {safe_calibration_samples}. " ) - shuffled_calibration = tokenized_dataset.shuffle() - shuffled_calibration = shuffled_calibration.select(range(safe_calibration_samples)) + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) dataloader_params = { "batch_size": 1, - "sampler": RandomSampler(shuffled_calibration), + "sampler": RandomSampler(tokenized_calibration) + if do_shuffle + else SequentialSampler(tokenized_calibration), "collate_fn": collate_fn, "pin_memory": True, } - calib_dataloader = DataLoader(shuffled_calibration, **dataloader_params) + calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params) if accelerator: calib_dataloader = accelerator.prepare(calib_dataloader) diff --git a/src/sparseml/transformers/finetune/runner.py b/src/sparseml/transformers/finetune/runner.py index e970e3b7264..df1aa0ca967 100644 --- a/src/sparseml/transformers/finetune/runner.py +++ b/src/sparseml/transformers/finetune/runner.py @@ -19,7 +19,6 @@ from typing import List, Optional import torch -from torch.nn import Module from torch.utils.data import Dataset from transformers import AutoTokenizer @@ -72,7 +71,6 @@ def __init__( data_args: "DataTrainingArguments", model_args: "ModelArguments", training_args: "TrainingArguments", - model: Module, ): self._data_args = data_args self._model_args = model_args @@ -121,9 +119,15 @@ def _get_split_name(inp_str): tokenizer=tokenizer, ) - raw_dataset = dataset_manager.get_raw_dataset(self._model_args.cache_dir) - tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset) - tokenized_datasets[split_name] = tokenized_dataset + dataset = self._data_args.dataset + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + # dataset is already tokenized + tokenized_datasets[split_name] = dataset + else: + # dataset needs to be tokenized + raw_dataset = dataset_manager.get_raw_dataset() + tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset) + tokenized_datasets[split_name] = tokenized_dataset self.datasets = make_dataset_splits( tokenized_datasets, @@ -154,6 +158,7 @@ def one_shot(self, stage: Optional[str] = None): calib_data = format_calibration_data( tokenized_dataset=self.get_dataset_split("calibration"), num_calibration_samples=self._data_args.num_calibration_samples, + do_shuffle=self._data_args.shuffle_calibration_samples, accelerator=self.trainer.accelerator, ) diff --git a/src/sparseml/transformers/finetune/text_generation.py b/src/sparseml/transformers/finetune/text_generation.py index 6005c26f034..a25778aa5fa 100644 --- a/src/sparseml/transformers/finetune/text_generation.py +++ b/src/sparseml/transformers/finetune/text_generation.py @@ -319,10 +319,7 @@ def main( # Load datasets stage_runner = StageRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args, - model=model, + model_args=model_args, data_args=data_args, training_args=training_args ) stage_runner.populate_datasets(tokenizer=tokenizer) train_dataset = stage_runner.get_dataset_split("train") diff --git a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py index 6493689416f..cd2c230b581 100644 --- a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py +++ b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py @@ -14,10 +14,12 @@ # limitations under the License. import pytest -from datasets import IterableDataset +import torch +from datasets import IterableDataset, load_dataset from sparseml.transformers.finetune.data import TextGenerationDataset from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from sparseml.transformers.finetune.data.data_helpers import format_calibration_data from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.runner import StageRunner from sparseml.transformers.finetune.training_args import TrainingArguments @@ -229,13 +231,54 @@ def test_split_loading(split_def, tiny_llama_tokenizer): training_args = TrainingArguments(do_train=True, output_dir="dummy") model_args = ModelArguments(model=None) stage_runner = StageRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args, - model=None, + model_args=model_args, data_args=data_args, training_args=training_args ) stage_runner.populate_datasets(tokenizer=tiny_llama_tokenizer) train_dataset = stage_runner.get_dataset_split("train") assert train_dataset is not None assert isinstance(train_dataset[0], dict) + + +def test_load_tokenized_data(tiny_llama_tokenizer): + dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] + NUM_CALIB_SAMPS = 256 + MAX_SEQ_LEN = 512 + dataset = dataset.shuffle(seed=42).select(range(NUM_CALIB_SAMPS)) + + def preprocess(sample): + concat_text = "INPUT: " + sample.get("input", "") + concat_text += "INSTRUCTIONS: " + sample.get("instruction", "") + concat_text += "OUTPUT: " + sample.get("output", "") + + return tiny_llama_tokenizer( + concat_text, padding=False, max_length=MAX_SEQ_LEN, truncation=True + ) + + tokenized_dataset = dataset.map( + preprocess, remove_columns=["input", "output", "instruction", "data_source"] + ) + stage_runner = StageRunner( + model_args=None, + data_args=DataTrainingArguments( + dataset=tokenized_dataset, shuffle_calibration_samples=False + ), + training_args=TrainingArguments(do_oneshot=True), + ) + stage_runner.populate_datasets(tokenizer=None) + calib_dataset = stage_runner.get_dataset_split("calibration") + assert len(calib_dataset) == NUM_CALIB_SAMPS + data_cols = calib_dataset.column_names + assert len(data_cols) == 2 + assert "input_ids" in data_cols and "attention_mask" in data_cols + + # confirm turning shuffle off works + calib_dataloader = format_calibration_data( + tokenized_dataset=calib_dataset, + num_calibration_samples=NUM_CALIB_SAMPS, + do_shuffle=stage_runner._data_args.shuffle_calibration_samples, + ) + assert len(calib_dataloader) == NUM_CALIB_SAMPS + dataloader_sample = next(iter(calib_dataloader))["input_ids"] + diff = dataloader_sample - torch.Tensor(calib_dataset[0]["input_ids"]) + assert torch.sum(diff) == 0