Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于split_gain的问题 #12

Open
Aliang-CN opened this issue Nov 21, 2019 · 4 comments
Open

关于split_gain的问题 #12

Aliang-CN opened this issue Nov 21, 2019 · 4 comments

Comments

@Aliang-CN
Copy link

您好:
我在阅读您的代码的时候发现一个问题,self.gain = getItemByTree(self, 'split_gain'),这行代码应该是获取节点每次分裂的信息增益,但是在getItemByTree里面的getFeature里面并没有相对应的操作。
def getItemByTree(tree, item='split_feature'):
root = tree.raw['tree_structure']
split_nodes = tree.split_nodes
res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32)
if 'value' in item or 'threshold' in item or 'split_gain' in item:
res = res.astype(np.float64)
def getFeature(root, res):
if 'child' in item:
if 'split_index' in root:
node = root[item]
if 'split_index' in node:
res[root['split_index']] = node['split_index']
else:
res[root['split_index']] = node['leaf_index'] + split_nodes # need to check
else:
res[root['leaf_index'] + split_nodes] = -1
elif 'value' in item:
if 'split_index' in root:
res[root['split_index']] = root['internal_'+item]
else:
res[root['leaf_index'] + split_nodes] = root['leaf_'+item]
else:
if 'split_index' in root:
res[root['split_index']] = root[item]
else:
res[root['leaf_index'] + split_nodes] = -2
if 'left_child' in root:
getFeature(root['left_child'], res)
if 'right_child' in root:
getFeature(root['right_child'], res)
getFeature(root, res)
return res

@Aliang-CN
Copy link
Author

您好,可以更新一下最新的模型吗?

@Aliang-CN
Copy link
Author

您好!
还有一个问题想咨询一下。在SubGBDTLeaf_cls函数里面下面这段代码中的all_hav主要是记录什么信息呢?然后下面treeI[tree].gain[kdx],实际上应该是记录split _feature,我这么理解可以吗?
all_hav = {} # set([i for i in range(MAX)])
for jdx, tree in enumerate(tree_indices):
for kdx, f in enumerate(treeI[tree].feature):
if f == -2:
continue
if f not in all_hav:
all_hav[f] = 0
all_hav[f] += treeI[tree].gain[kdx]

@Aliang-CN
Copy link
Author

您好!
还有一个问题想咨询你,vectors[idx] = set(features[np.where(features>0)]) ,这行代码是过滤split_feature>0的特征, 那么split_feature=0的这个特征就会被遗留。
def EqualGroup(self, n_clusters, args):
vectors = {}
# n_feature = 256
for idx,features in enumerate(self.featurelist):
vectors[idx] = set(features[np.where(features>0)])
keys = random.sample(vectors.keys(), len(vectors))
clusterIdx = np.zeros(len(vectors))
groups = [[] for i in range(n_clusters)]
trees_per_cluster = len(vectors)//n_clusters
mod_per_cluster = len(vectors) % n_clusters
begin = 0
for idx in range(n_clusters):
for jdx in range(trees_per_cluster):
clusterIdx[keys[begin]] = idx
begin += 1
if idx < mod_per_cluster:
clusterIdx[keys[begin]] = idx
begin += 1
print([np.where(clusterIdx==i)[0].shape for i in range(n_clusters)])
return clusterIdx

@motefly
Copy link
Owner

motefly commented Nov 21, 2019

您好:
我在阅读您的代码的时候发现一个问题,self.gain = getItemByTree(self, 'split_gain'),这行代码应该是获取节点每次分裂的信息增益,但是在getItemByTree里面的getFeature里面并没有相对应的操作。
def getItemByTree(tree, item='split_feature'):
root = tree.raw['tree_structure']
split_nodes = tree.split_nodes
res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32)
if 'value' in item or 'threshold' in item or 'split_gain' in item:
res = res.astype(np.float64)
def getFeature(root, res):
if 'child' in item:
if 'split_index' in root:
node = root[item]
if 'split_index' in node:
res[root['split_index']] = node['split_index']
else:
res[root['split_index']] = node['leaf_index'] + split_nodes # need to check
else:
res[root['leaf_index'] + split_nodes] = -1
elif 'value' in item:
if 'split_index' in root:
res[root['split_index']] = root['internal_'+item]
else:
res[root['leaf_index'] + split_nodes] = root['leaf_'+item]
else:
if 'split_index' in root:
res[root['split_index']] = root[item]
else:
res[root['leaf_index'] + split_nodes] = -2
if 'left_child' in root:
getFeature(root['left_child'], res)
if 'right_child' in root:
getFeature(root['right_child'], res)
getFeature(root, res)
return res

您好,之前我误判了,https://github.com/motefly/DeepGBM/blob/master/tree_model_interpreter.py#L36 这里应该是可以拿到gain的。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants