Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

剪枝分割网络报错-bisenetv2 #7

Open
GeneralJing opened this issue Oct 9, 2021 · 17 comments
Open

剪枝分割网络报错-bisenetv2 #7

GeneralJing opened this issue Oct 9, 2021 · 17 comments
Assignees

Comments

@GeneralJing
Copy link

GeneralJing commented Oct 9, 2021

File "examples/torchpruner/prune_by_class_bisenetv2.py", line 39, in <module>
    model, context = torchpruner.set_cut(model, result)
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/model_pruner.py", line 71, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 188, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 60, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/prune_function.py", line 42, in set_cut_tensor
IndexError: index 78 is out of bounds for dimension 0 with size 78
@gdh1995 gdh1995 self-assigned this Oct 9, 2021
@GeneralJing
Copy link
Author

有解决的消息了吗?是跟分组卷积有关吗?

@gdh1995
Copy link
Collaborator

gdh1995 commented Oct 12, 2021

呃这两天有别的事情,还没看多久,晚上我再自己测测

@GeneralJing
Copy link
Author

嗯 好的 辛苦大佬,有消息及时回复下,自己手动剪,感觉比较步骤比较繁琐

@gdh1995
Copy link
Collaborator

gdh1995 commented Oct 12, 2021

能贴你的 prune_by_class_bisenetv2.py 吗?我创建 lib/models/bisenetv2.py#BiSeNetV2(19)graph.build_graph 没跑过去。bisnet 是刚从 https://github.com/CoinCheung/BiSeNet 找的。

@GeneralJing
Copy link
Author

GeneralJing commented Oct 12, 2021

import sys

sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np
from bisenetv2 import BiSeNetV2

#以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层

#加载模型
model = torchvision.models.vgg11_bn()
print('model-origin:', model)
#jzy
#model = BiSeNetV2(n_classes=9, aux_mode='pred')
model = BiSeNetV2(n_classes=9)
model.load_state_dict(torch.load('/home/zxz/torch-model-compression/examples/torchpruner/model_final.pth', map_location='cuda'), strict=False)
print('model-origin:', model)

# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(8, 3, 320, 640),))

# 遍历所有的Module
for key in graph.modules:
    module = graph.modules[key]
    # 如果该module对应了BN层
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        # 获取该对象
        nn_object = module.nn_object
        # 排序,取前20%小的权重值对应的index
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
        print('index:', index)
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)

# 新的model即为剪枝后的模型
print('model-pruned:', model)

这个代码。

@GeneralJing
Copy link
Author

comment框是markdown格式的,我的注释在这很奇怪

@gdh1995
Copy link
Collaborator

gdh1995 commented Oct 12, 2021

markdown的多行代码块语法是 三个 ` 连着表示开头和结尾,比如

# // 开头
# ``` [ + 空格 + 语言名]
# 具体内容
# // 结尾
# ``` 

@GeneralJing
Copy link
Author

GeneralJing commented Oct 12, 2021

嗯嗯 有空看看mk语法,大佬看问题 嘿嘿

@gdh1995
Copy link
Collaborator

gdh1995 commented Oct 12, 2021

你用的 bisenetv2.py 是哪个文件?BiSeNet-master/lib/models/bisenetv2.py 还是 BiSeNet-master/old/bisenetv2/bisenetv2.py ?

另外我用官方仓库的模型(old/README.md的百度网盘文件 model_final.pth)好像加载不了,大概是参数不一致……厚颜来要你的原始模型了,能发的话,网盘或者发到 [email protected] 都行;不方便的话,我就再看看。我现在因为懒得自己在coco上训 练,缺少实际可用的参数张量,卡在前边某一个分组卷积的 cut_analysis 步骤了Orz

@GeneralJing
Copy link
Author

GeneralJing commented Oct 12, 2021

用的是BiSeNet-master/lib/models/bisenetv2.py这个,模型的话在公司,可以给你发一个训练几个epoch的版本,因为是其他人训练的,也不方便给最新的。得明天给你了。

@gdh1995
Copy link
Collaborator

gdh1995 commented Oct 12, 2021

嗯谢谢!

@jzy-hxf
Copy link

jzy-hxf commented Nov 18, 2021

Traceback (most recent call last):
File "G:/py/torch-model-compression-main/examples/torchpruner/prune_and_recovery.py", line 16, in
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))
File "G:\py\torch-model-compression-main\torchpruner\graph.py", line 458, in build_graph
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(
TypeError: _model_to_graph() got an unexpected keyword argument '_retain_param_name'

好几个都是 加载模型 时出现上述错误。我是刚入门的小白,请问这个怎么处理,程序跑不通,不知道怎么改这个。
model = torchvision.models.resnet50()

@gdh1995
Copy link
Collaborator

gdh1995 commented Dec 7, 2021

@jzy-hxf 抱歉之前没注意消息。具体到你这个问题,是torch版本比较新(1.11还是多少以上)造成的。你把这个项目里出现的 _retain_param_name 都去掉就行了,不影响结果。

gdh1995 added a commit that referenced this issue Dec 7, 2021
@gdh1995
Copy link
Collaborator

gdh1995 commented Dec 7, 2021

@GeneralJing 抱歉我好久没注意这个。我确认了几遍代码,应该是 examples/torchpruner/prune_by_class.py 写的有问题,每次执行 torchpruner.set_cutgraph 的部分信息会过时,所以需要重新创建 graphgraph.modules 是稳定的,可以预先算好 keys

for key in list(graph.modules):
    # ...
    model, context = torchpruner.set_cut(model, result)
    graph = torchpruner.ONNXGraph(model)  # 本行可以省略
    graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

我一会去改一下示例代码。

@GeneralJing
Copy link
Author

“下个周末看一下”,重新定义了下个周末哈哈哈哈,好的,多谢。

gdh1995 added a commit that referenced this issue Dec 7, 2021
@Durobert
Copy link

参考你的例子,对BiSeNetv2模型剪枝,报如下错误,这个怎么解决,能看一下吗

index: [34 53 31 46 12 49 36 60 18 68 39 57 67 13 14]
key: self.segment.S3.0.dwconv1.1
Traceback (most recent call last):
  File "model_pruning.py", line 33, in <module>
    result = module.cut_analysis("weight", index=index, dim=0)
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 333, in cut_analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 246, in cut_analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 269, in cut_analysis_with_mask
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/operator/onnx_operator.py", line 489, in analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/mask_utils.py", line 276, in indexs
RuntimeError: All the data is masked

@Durobert
Copy link

我的代码如下

import imp
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
import torch
import torchpruner
import numpy as np
from models.bisenetv2 import BiSeNetV2


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载模型
model = BiSeNetV2(n_classes=4)
checkpoint = torch.load('../results/05-05_01-59/checkpoint_best.pkl', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))

for key in list(graph.modules):
    module = graph.modules[key]
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        nn_object = module.nn_object
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
        print('index:', index)
        print('key:', key)
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)
        graph = torchpruner.ONNXGraph(model)
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))

print('model-pruned:', model)
torch.save(model, '../results/05-05_01-59/checkpoint_best_pruning.pkl')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants