Skip to content

Commit

Permalink
Add separate methods to return nodes/leaves/depth
Browse files Browse the repository at this point in the history
  • Loading branch information
rmontanana committed Nov 27, 2023
1 parent 4de7497 commit 036d1ba
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
37 changes: 37 additions & 0 deletions stree/Strees.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,43 @@ def predict(self, X: np.array) -> np.array:
X = self.check_predict(X)
return self.classes_[np.argmax(self.__predict_class(X), axis=1)]

def get_nodes(self) -> int:
"""Return the number of nodes in the tree
Returns
-------
int
number of nodes
"""
nodes = 0
for _ in self:
nodes += 1
return nodes

def get_leaves(self) -> int:
"""Return the number of leaves in the tree
Returns
-------
int
number of leaves
"""
leaves = 0
for node in self:
if node.is_leaf():
leaves += 1
return leaves

def get_depth(self) -> int:
"""Return the depth of the tree
Returns
-------
int
depth of the tree
"""
return self.depth_

def nodes_leaves(self) -> tuple:
"""Compute the number of nodes and leaves in the built tree
Expand Down
2 changes: 1 addition & 1 deletion stree/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.1"
__version__ = "1.3.2"
9 changes: 9 additions & 0 deletions stree/tests/Stree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def test_check_max_depth(self):
)
tcl.fit(*load_dataset(self._random_state))
self.assertEqual(depth, tcl.depth_)
self.assertEqual(depth, tcl.get_depth())

def test_unfitted_tree_is_iterable(self):
tcl = Stree()
Expand Down Expand Up @@ -640,10 +641,12 @@ def test_depth(self):
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
self.assertEqual(6, clf.depth_)
self.assertEqual(6, clf.get_depth())
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
self.assertEqual(4, clf.depth_)
self.assertEqual(4, clf.get_depth())

def test_nodes_leaves(self):
"""Check number of nodes and leaves."""
Expand All @@ -657,13 +660,17 @@ def test_nodes_leaves(self):
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(31, nodes)
self.assertEqual(31, clf.get_nodes())
self.assertEqual(16, leaves)
self.assertEqual(16, clf.get_leaves())
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(11, nodes)
self.assertEqual(11, clf.get_nodes())
self.assertEqual(6, leaves)
self.assertEqual(6, clf.get_leaves())

def test_nodes_leaves_artificial(self):
"""Check leaves of artificial dataset."""
Expand All @@ -682,7 +689,9 @@ def test_nodes_leaves_artificial(self):
clf.tree_ = n1
nodes, leaves = clf.nodes_leaves()
self.assertEqual(6, nodes)
self.assertEqual(6, clf.get_nodes())
self.assertEqual(2, leaves)
self.assertEqual(2, clf.get_leaves())

def test_bogus_multiclass_strategy(self):
"""Check invalid multiclass strategy."""
Expand Down

0 comments on commit 036d1ba

Please sign in to comment.