-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
剪枝分割网络报错-bisenetv2 #7
Comments
有解决的消息了吗?是跟分组卷积有关吗? |
呃这两天有别的事情,还没看多久,晚上我再自己测测 |
嗯 好的 辛苦大佬,有消息及时回复下,自己手动剪,感觉比较步骤比较繁琐 |
能贴你的 prune_by_class_bisenetv2.py 吗?我创建 |
import sys
sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np
from bisenetv2 import BiSeNetV2
#以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层
#加载模型
model = torchvision.models.vgg11_bn()
print('model-origin:', model)
#jzy
#model = BiSeNetV2(n_classes=9, aux_mode='pred')
model = BiSeNetV2(n_classes=9)
model.load_state_dict(torch.load('/home/zxz/torch-model-compression/examples/torchpruner/model_final.pth', map_location='cuda'), strict=False)
print('model-origin:', model)
# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(8, 3, 320, 640),))
# 遍历所有的Module
for key in graph.modules:
module = graph.modules[key]
# 如果该module对应了BN层
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
# 获取该对象
nn_object = module.nn_object
# 排序,取前20%小的权重值对应的index
weight = nn_object.weight.detach().cpu().numpy()
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
print('index:', index)
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
# 新的model即为剪枝后的模型
print('model-pruned:', model) 这个代码。 |
comment框是markdown格式的,我的注释在这很奇怪 |
markdown的多行代码块语法是 三个
|
嗯嗯 有空看看mk语法,大佬看问题 嘿嘿 |
你用的 bisenetv2.py 是哪个文件?BiSeNet-master/lib/models/bisenetv2.py 还是 BiSeNet-master/old/bisenetv2/bisenetv2.py ? 另外我用官方仓库的模型(old/README.md的百度网盘文件 model_final.pth)好像加载不了,大概是参数不一致……厚颜来要你的原始模型了,能发的话,网盘或者发到 [email protected] 都行;不方便的话,我就再看看。我现在因为懒得自己在coco上训 练,缺少实际可用的参数张量,卡在前边某一个分组卷积的 |
用的是BiSeNet-master/lib/models/bisenetv2.py这个,模型的话在公司,可以给你发一个训练几个epoch的版本,因为是其他人训练的,也不方便给最新的。得明天给你了。 |
嗯谢谢! |
Traceback (most recent call last): 好几个都是 加载模型 时出现上述错误。我是刚入门的小白,请问这个怎么处理,程序跑不通,不知道怎么改这个。 |
@jzy-hxf 抱歉之前没注意消息。具体到你这个问题,是torch版本比较新(1.11还是多少以上)造成的。你把这个项目里出现的 |
@GeneralJing 抱歉我好久没注意这个。我确认了几遍代码,应该是 for key in list(graph.modules):
# ...
model, context = torchpruner.set_cut(model, result)
graph = torchpruner.ONNXGraph(model) # 本行可以省略
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),)) 我一会去改一下示例代码。 |
“下个周末看一下”,重新定义了下个周末哈哈哈哈,好的,多谢。 |
参考你的例子,对BiSeNetv2模型剪枝,报如下错误,这个怎么解决,能看一下吗 index: [34 53 31 46 12 49 36 60 18 68 39 57 67 13 14]
key: self.segment.S3.0.dwconv1.1
Traceback (most recent call last):
File "model_pruning.py", line 33, in <module>
result = module.cut_analysis("weight", index=index, dim=0)
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 333, in cut_analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 246, in cut_analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 269, in cut_analysis_with_mask
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/operator/onnx_operator.py", line 489, in analysis
File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/mask_utils.py", line 276, in indexs
RuntimeError: All the data is masked |
我的代码如下 import imp
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
import torch
import torchpruner
import numpy as np
from models.bisenetv2 import BiSeNetV2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
model = BiSeNetV2(n_classes=4)
checkpoint = torch.load('../results/05-05_01-59/checkpoint_best.pkl', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))
for key in list(graph.modules):
module = graph.modules[key]
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
nn_object = module.nn_object
weight = nn_object.weight.detach().cpu().numpy()
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
print('index:', index)
print('key:', key)
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))
print('model-pruned:', model)
torch.save(model, '../results/05-05_01-59/checkpoint_best_pruning.pkl') |
The text was updated successfully, but these errors were encountered: