diff --git a/examples/model_compression/ofa/export_model.py b/examples/model_compression/ofa/export_model.py index f6a5831b17d12..ec60a22bcc372 100644 --- a/examples/model_compression/ofa/export_model.py +++ b/examples/model_compression/ofa/export_model.py @@ -15,6 +15,7 @@ import argparse import logging import os +import math import random import time import json @@ -85,6 +86,11 @@ def parse_args(): type=float, default=1.0, help="width mult you want to export") + parser.add_argument( + '--depth_mult', + type=float, + default=1.0, + help="depth mult you want to export") args = parser.parse_args() return args @@ -106,6 +112,18 @@ def do_train(args): model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_path = os.path.join(args.model_name_or_path, 'model_config.json') cfg_dict = dict(json.loads(open(config_path).read())) + + if args.depth_mult < 1.0: + depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * args.depth_mult) + cfg_dict["init_args"][0]['num_hidden_layers'] = depth + kept_layers_index = {} + for idx, i in enumerate(range(1, depth+1)): + kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1 + + os.rename(config_path, config_path+'_bak') + with open(config_path, "w", encoding="utf-8") as f: + f.write(json.dumps(cfg_dict, ensure_ascii=False)) + num_labels = cfg_dict['num_classes'] model = model_class.from_pretrained( @@ -114,6 +132,8 @@ def do_train(args): origin_model = model_class.from_pretrained( args.model_name_or_path, num_classes=num_labels) + os.rename(config_path+'_bak', config_path) + sp_config = supernet(expand_ratio=[1.0, args.width_mult]) model = Convert(sp_config).convert(model) @@ -121,7 +141,15 @@ def do_train(args): sd = paddle.load( os.path.join(args.model_name_or_path, 'model_state.pdparams')) - ofa_model.model.set_state_dict(sd) + + for name, params in ofa_model.model.named_parameters(): + if 'encoder' not in name: + params.set_value(sd[name]) + else: + idx = int(name.strip().split('.')[3]) + mapping_name = name.replace('.'+str(idx)+'.', '.'+str(kept_layers_index[idx])+'.') + params.set_value(sd[mapping_name]) + best_config = utils.dynabert_config(ofa_model, args.width_mult) ofa_model.export( best_config,