diff --git a/README.md b/README.md index 5678c80..8b03b66 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Lastly, if you want to study the effect of multitask prompted training (a.k.a. i - T-Zero ++: https://huggingface.co/bigscience/T0pp - T-Zero Single Prompt: https://huggingface.co/bigscience/T0_single_prompt - T-Zero Original Task Only: https://huggingface.co/bigscience/T0_original_task_only -- T-Zero 3B: https://huggingface.co/bigscience/T0_3B +- T-Zero 3B: https://huggingface.co/bigscience/T0_3Bx§ ## Citation diff --git a/evaluation/run_eval.py b/evaluation/run_eval.py index 660e594..c8a7676 100644 --- a/evaluation/run_eval.py +++ b/evaluation/run_eval.py @@ -51,7 +51,6 @@ def parse_args(): parser.add_argument( "--dataset_name", type=str, - default=None, help="The name of the dataset to use (via the datasets library).", required=True, ) @@ -61,12 +60,17 @@ def parse_args(): default=None, help="The configuration name of the dataset to use (via the datasets library).", ) + parser.add_argument( + "--template_config_name", + type=str, + default=None, + help="The name of the dataset_config_name of the template we want to use, example: use XNLI En prompts for XNLI Fr", + ) parser.add_argument( "--template_name", type=str, default=None, - help="The template/prompt name", - required=True, + help="The template/prompt name. If None, we run all templates.", ) parser.add_argument( "--max_length", @@ -128,115 +132,40 @@ def parse_args(): action="store_true", help="Activate debug mode and run training only with a subset of data.", ) - parser.add_argument( - "--parallelize", - action="store_true", - help=( - "If passed, will call `model.parallelize` which splits the model on all GPUs available when applicable (model parallelism). " - "Note that this feature is still experimental in HF Transformers." - ), - ) - args = parser.parse_args() - - return args + args = parser.parse_args() -def main(): - args = parse_args() + # TODO @thomasw21 hack! + if args.dataset_config_name == "None": + args.dataset_config_name = None + if args.template_config_name == "None": + args.template_config_name = None - # Initialize the accelerator. We will let the accelerator handle device placement for us. - accelerator = Accelerator() - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state) - - # Setup logging, we only want one process per machine to log things on the screen. - # accelerator.is_local_main_process is only True for one process per machine. - logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() + return args +def run_template(template_name, prompts, model, tokenizer, raw_datasets, accelerator: Accelerator, args): # Handle the output directory creation - if accelerator.is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # In distributed evaluation, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - if args.dataset_name == "anli": - raw_datasets = load_dataset(args.dataset_name, split=args.dataset_config_name) - else: - raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name, split="validation") - #TODO(Victor): enable loading pre-processed dataset from https://huggingface.co/datasets/bigscience/P3 - - # Trim a number of evaluation examples - if args.debug: - raw_datasets = raw_datasets.select(range(min(len(raw_datasets),100))) - - column_names = raw_datasets.column_names - - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - raise ValueError( - "Either `args.config_name` or `args.model_name_or_path` should be provided." - ) - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." + result_dir = None + if args.output_dir is not None and accelerator.is_main_process: + paths = [ + args.dataset_name, + args.dataset_config_name, + template_name, + ] + result_dir = os.path.join( + args.output_dir, + *[path.replace(" ", "_").replace("/", "_") for path in paths if path is not None] ) + os.makedirs(result_dir, exist_ok=True) - if tokenizer.pad_token is None: - for token in [tokenizer.eos_token, tokenizer.bos_token, tokenizer.sep_token]: - if token is not None: - tokenizer.pad_token = token - if tokenizer.pad_token is None: - raise ValueError("Please define a pad token id.") - + template = prompts[template_name] - model = ModelBase.from_config( - config=config, - model_name_or_path=args.model_name_or_path, - parallelize=args.parallelize - ) # Preprocessing the datasets. # First we tokenize all the texts. padding = "max_length" if args.pad_to_max_length else False - - # Get the prompt to apply and the possible targets. - # TODO(Victor): If pulling from pre-processed data, remove this logic. - prompts = DatasetTemplates( - f"{args.dataset_name}" - if args.dataset_config_name is None - else f"{args.dataset_name}/{args.dataset_config_name}" - ) - template = prompts[args.template_name] - + column_names = raw_datasets.column_names def preprocess_function(examples): bs = len(examples[column_names[0]]) @@ -265,8 +194,9 @@ def preprocess_function(examples): tokenized_targets = [ tokenizer( ans_choi, - padding=True, - max_length=args.target_max_length, + # padding is on the right here. + padding=False, + max_length=args.max_length, truncation=True, ) for ans_choi in answer_choices_texts @@ -319,17 +249,16 @@ def preprocess_function(examples): eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) - - # Use the device given by the `accelerator` object. - if not args.parallelize: - model.to(accelerator.device) - # Prepare everything with our `accelerator`. eval_dataloader = accelerator.prepare(eval_dataloader) - # Metrics - metric = load_metric("accuracy") + metric = load_metric( + "accuracy", + process_id=accelerator.process_index, + num_process=accelerator.num_processes, + experiment_id=f"{args.dataset_name}_{args.dataset_config_name}_{args.template_name}" + ) # Eval! total_batch_size = args.per_device_eval_batch_size * accelerator.num_processes @@ -359,14 +288,119 @@ def preprocess_function(examples): results = { "dataset_name": args.dataset_name, "dataset_config_name": args.dataset_config_name, - "template_name": args.template_name, - "evaluation": eval_metric + "template_name": template_name, + "evaluation": eval_metric, + "arguments": str(args) } if accelerator.is_main_process: - if args.output_dir is not None: - with open(os.path.join(args.output_dir, "results.json"), "w") as f: - json.dump(results, f, indent=4) + if result_dir is not None: + with open(os.path.join(result_dir, "results.json"), "w") as f: + json.dump(results, f, indent=2) + +def main(): + args = parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us. + accelerator = Accelerator() + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + accelerator.wait_for_everyone() + + # In distributed evaluation, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + # Downloading and loading a dataset from the hub. + if args.dataset_name == "anli": + raw_datasets = load_dataset(args.dataset_name, split=args.dataset_config_name) + else: + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name, split="validation") + #TODO(Victor): enable loading pre-processed dataset from https://huggingface.co/datasets/bigscience/P3 + + # Trim a number of evaluation examples + if args.debug: + raw_datasets = raw_datasets.select(range(min(len(raw_datasets),100))) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + raise ValueError( + "Either `args.config_name` or `args.model_name_or_path` should be provided." + ) + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, padding_side="left") + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, padding_side="left") + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if tokenizer.pad_token is None: + for token in [tokenizer.eos_token, tokenizer.bos_token, tokenizer.sep_token]: + if token is not None: + tokenizer.pad_token = token + if tokenizer.pad_token is None: + raise ValueError("Please define a pad token id.") + + + model = ModelBase.from_config( + config=config, + model_name_or_path=args.model_name_or_path + ) + model = accelerator.prepare_model(model) + # Get the prompt to apply and the possible targets. + # TODO(Victor): If pulling from pre-processed data, remove this logic. + + if (args.dataset_config_name is None and args.template_config_name is None) or args.dataset_name == "anli": + prompt_dataset_name = f"{args.dataset_name}" + elif args.template_config_name is not None: + prompt_dataset_name = f"{args.dataset_name}/{args.template_config_name}" + else: + prompt_dataset_name = f"{args.dataset_name}/{args.dataset_config_name}" + + prompts = DatasetTemplates( + prompt_dataset_name + ) + + if args.template_name is not None: + template_names = [args.template_name] + else: + template_names = prompts.all_template_names + + for template_name in template_names: + run_template( + template_name=template_name, + prompts=prompts, + model=model, + tokenizer=tokenizer, + raw_datasets=raw_datasets, + accelerator=accelerator, + args=args + ) if __name__ == "__main__": main() diff --git a/t0/model.py b/t0/model.py index 664a0a8..ada0d4f 100644 --- a/t0/model.py +++ b/t0/model.py @@ -27,7 +27,7 @@ def from_config(config, **kwargs) -> "ModelBase": raise NotImplementedError class EncoderDecoderModel(ModelBase): - def __init__(self, config, model_name_or_path: Optional[str], parallelize: bool, **kwargs): + def __init__(self, config, model_name_or_path: Optional[str], **kwargs): """ Args: @@ -46,11 +46,9 @@ def __init__(self, config, model_name_or_path: Optional[str], parallelize: bool, ) else: logger.info("Training new model from scratch") - self._model = AutoModelForSeq2SeqLM.from_config(config) - - if parallelize: - assert torch.cuda.is_available(), "You need at least 1 GPU to call `parallelize` (even though if there is only 1 GPU, there won't be any model parallelism)." - self._model.parallelize() + self._model = AutoModelForSeq2SeqLM.from_config( + config, + ) def forward(self, batch) -> torch.Tensor: @@ -78,19 +76,23 @@ def __init__(self, config, model_name_or_path: Optional[str], **kwargs): ) else: logger.info("Training new model from scratch") - self._model = AutoModelForCausalLM.from_config(config) + self._model = AutoModelForCausalLM.from_config( + config, + ) def forward(self, batch): + device = batch["input_ids"].device _, prefix_length = batch["input_ids"].shape + model_inputs = { "input_ids": torch.cat([batch["input_ids"], batch["labels"]], dim=-1), "attention_mask": torch.cat([batch["attention_mask"], batch["labels_attention_mask"]], dim=-1), } # Set position ids correctly to take care of padding tokens between inputs_ids and labels - # Empty attention_mask is a forbidden value, ie full of zeros. In fact the first element should be 1 as the input - # cannot be empty - assert torch.all(model_inputs["attention_mask"][:,0] == 1), "First element in the attention mask should be 1." - position_ids = torch.cumsum(model_inputs["attention_mask"].to(torch.long), dim=-1) - 1 + position_ids = torch.maximum( + torch.cumsum(model_inputs["attention_mask"].to(torch.long), dim=-1) - 1, + torch.zeros(1, dtype=torch.long, device=device)[None, None] + ) model_inputs["position_ids"] = position_ids logits = self._model(**model_inputs).logits[:, prefix_length-1:-1]