Skip to content

Commit

Permalink
[Not for land] Util for saving quantized model
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Oct 8, 2024
1 parent 9fb7999 commit ea41439
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"where": "Return directory containing downloaded model artifacts",
"server": "[WIP] Starts a locally hosted REST server for model interaction",
"eval": "Evaluate a model via lm-eval",
"save_quant": "Quantize a model and save it to disk",
}
for verb, description in VERB_HELP.items():
subparser = subparsers.add_parser(verb, help=description)
Expand Down Expand Up @@ -115,5 +116,9 @@
from torchchat.cli.download import remove_main

remove_main(args)
elif args.command == "save_quant":
from torchchat.save_quant import main as save_quant_main

save_quant_main(args)
else:
parser.print_help()
58 changes: 58 additions & 0 deletions torchchat/save_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn

from torchchat.cli.builder import (
_initialize_model,
BuilderArgs,
)

from torchchat.utils.build_utils import set_precision

from torchao.quantization import quantize_, int8_weight_only

"""
Exporting Flow
"""


def main(args):
builder_args = BuilderArgs.from_args(args)
print(f"{builder_args=}")

quant_format = "int8_wo"
# Quant option from cli, can be None
model = _initialize_model(builder_args, args.quantize)
if not args.quantize:
# Not using quantization option from cli;
# Use quantize_() to quantize the model instead.
print("Quantizing model using torchao quantize_")
quantize_(model, int8_weight_only())
else:
print(f"{args.quantize=}")

print(f"Model: {model}")

# Save model
model_dir = os.path.dirname(builder_args.checkpoint_path)
model_dir = Path(model_dir + "-" + quant_format)
try:
os.mkdir(model_dir)
except FileExistsError:
pass
dest = model_dir / "model.pth"
state_dict = model.state_dict()
print(f"{state_dict.keys()=}")

print(f"Saving checkpoint to {dest}. This may take a while.")
torch.save(state_dict, dest)
print("Done.")

0 comments on commit ea41439

Please sign in to comment.