From 036d1ba2a721e5299325841049580646a660a024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 27 Nov 2023 10:02:14 +0100 Subject: [PATCH] Add separate methods to return nodes/leaves/depth --- stree/Strees.py | 37 +++++++++++++++++++++++++++++++++++++ stree/_version.py | 2 +- stree/tests/Stree_test.py | 9 +++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/stree/Strees.py b/stree/Strees.py index d443ab8..fd618c4 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -484,6 +484,43 @@ def predict(self, X: np.array) -> np.array: X = self.check_predict(X) return self.classes_[np.argmax(self.__predict_class(X), axis=1)] + def get_nodes(self) -> int: + """Return the number of nodes in the tree + + Returns + ------- + int + number of nodes + """ + nodes = 0 + for _ in self: + nodes += 1 + return nodes + + def get_leaves(self) -> int: + """Return the number of leaves in the tree + + Returns + ------- + int + number of leaves + """ + leaves = 0 + for node in self: + if node.is_leaf(): + leaves += 1 + return leaves + + def get_depth(self) -> int: + """Return the depth of the tree + + Returns + ------- + int + depth of the tree + """ + return self.depth_ + def nodes_leaves(self) -> tuple: """Compute the number of nodes and leaves in the built tree diff --git a/stree/_version.py b/stree/_version.py index 9c73af2..f708a9b 100644 --- a/stree/_version.py +++ b/stree/_version.py @@ -1 +1 @@ -__version__ = "1.3.1" +__version__ = "1.3.2" diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 67f618e..9556264 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -239,6 +239,7 @@ def test_check_max_depth(self): ) tcl.fit(*load_dataset(self._random_state)) self.assertEqual(depth, tcl.depth_) + self.assertEqual(depth, tcl.get_depth()) def test_unfitted_tree_is_iterable(self): tcl = Stree() @@ -640,10 +641,12 @@ def test_depth(self): clf = Stree(random_state=self._random_state) clf.fit(X, y) self.assertEqual(6, clf.depth_) + self.assertEqual(6, clf.get_depth()) X, y = load_wine(return_X_y=True) clf = Stree(random_state=self._random_state) clf.fit(X, y) self.assertEqual(4, clf.depth_) + self.assertEqual(4, clf.get_depth()) def test_nodes_leaves(self): """Check number of nodes and leaves.""" @@ -657,13 +660,17 @@ def test_nodes_leaves(self): clf.fit(X, y) nodes, leaves = clf.nodes_leaves() self.assertEqual(31, nodes) + self.assertEqual(31, clf.get_nodes()) self.assertEqual(16, leaves) + self.assertEqual(16, clf.get_leaves()) X, y = load_wine(return_X_y=True) clf = Stree(random_state=self._random_state) clf.fit(X, y) nodes, leaves = clf.nodes_leaves() self.assertEqual(11, nodes) + self.assertEqual(11, clf.get_nodes()) self.assertEqual(6, leaves) + self.assertEqual(6, clf.get_leaves()) def test_nodes_leaves_artificial(self): """Check leaves of artificial dataset.""" @@ -682,7 +689,9 @@ def test_nodes_leaves_artificial(self): clf.tree_ = n1 nodes, leaves = clf.nodes_leaves() self.assertEqual(6, nodes) + self.assertEqual(6, clf.get_nodes()) self.assertEqual(2, leaves) + self.assertEqual(2, clf.get_leaves()) def test_bogus_multiclass_strategy(self): """Check invalid multiclass strategy."""