From ea41439873d600a1adef5ab819ab556a2c9bf511 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 7 Oct 2024 17:50:50 -0700 Subject: [PATCH] [Not for land] Util for saving quantized model --- torchchat.py | 5 ++++ torchchat/save_quant.py | 58 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 torchchat/save_quant.py diff --git a/torchchat.py b/torchchat.py index 35cdcabae..1b719f002 100644 --- a/torchchat.py +++ b/torchchat.py @@ -49,6 +49,7 @@ "where": "Return directory containing downloaded model artifacts", "server": "[WIP] Starts a locally hosted REST server for model interaction", "eval": "Evaluate a model via lm-eval", + "save_quant": "Quantize a model and save it to disk", } for verb, description in VERB_HELP.items(): subparser = subparsers.add_parser(verb, help=description) @@ -115,5 +116,9 @@ from torchchat.cli.download import remove_main remove_main(args) + elif args.command == "save_quant": + from torchchat.save_quant import main as save_quant_main + + save_quant_main(args) else: parser.print_help() diff --git a/torchchat/save_quant.py b/torchchat/save_quant.py new file mode 100644 index 000000000..0770fa1fa --- /dev/null +++ b/torchchat/save_quant.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn + +from torchchat.cli.builder import ( + _initialize_model, + BuilderArgs, +) + +from torchchat.utils.build_utils import set_precision + +from torchao.quantization import quantize_, int8_weight_only + +""" +Exporting Flow +""" + + +def main(args): + builder_args = BuilderArgs.from_args(args) + print(f"{builder_args=}") + + quant_format = "int8_wo" + # Quant option from cli, can be None + model = _initialize_model(builder_args, args.quantize) + if not args.quantize: + # Not using quantization option from cli; + # Use quantize_() to quantize the model instead. + print("Quantizing model using torchao quantize_") + quantize_(model, int8_weight_only()) + else: + print(f"{args.quantize=}") + + print(f"Model: {model}") + + # Save model + model_dir = os.path.dirname(builder_args.checkpoint_path) + model_dir = Path(model_dir + "-" + quant_format) + try: + os.mkdir(model_dir) + except FileExistsError: + pass + dest = model_dir / "model.pth" + state_dict = model.state_dict() + print(f"{state_dict.keys()=}") + + print(f"Saving checkpoint to {dest}. This may take a while.") + torch.save(state_dict, dest) + print("Done.")