Skip to content

Commit

Permalink
Fix depth/leaves/nodes no longer return average
Browse files Browse the repository at this point in the history
  • Loading branch information
rmontanana committed Nov 27, 2023
1 parent 52d1095 commit 02e75b3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
23 changes: 13 additions & 10 deletions odte/Odte.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,20 @@ def fit(
return self

def _compute_metrics(self) -> None:
tdepth = tnodes = tleaves = 0.0
tdepth = tnodes = tleaves = 0
for estimator in self.estimators_:
if hasattr(estimator, "nodes_leaves"):
nodes, leaves = estimator.nodes_leaves()
depth = estimator.depth_
tdepth += depth
tnodes += nodes
tleaves += leaves
self.depth_ = tdepth / self.n_estimators
self.leaves_ = tleaves / self.n_estimators
self.nodes_ = tnodes / self.n_estimators
# self.depth_ = tdepth / self.n_estimators
# self.leaves_ = tleaves / self.n_estimators
# self.nodes_ = tnodes / self.n_estimators
self.depth_ = tdepth
self.leaves_ = tleaves
self.nodes_ = tnodes

def _train(
self, X: np.ndarray, y: np.ndarray, weights: np.ndarray
Expand Down Expand Up @@ -250,16 +253,16 @@ def predict_proba(self, X: np.ndarray) -> np.ndarray:
for i in range(n_samples):
result[i, predictions[i]] += 1
return result / self.n_estimators
def get_nodes(self) ->int:

def get_nodes(self) -> int:
check_is_fitted(self, "estimators_")
return self.nodes_
def get_leaves(self) ->int:

def get_leaves(self) -> int:
check_is_fitted(self, "estimators_")
return self.leaves_
def get_depth(self) ->int:

def get_depth(self) -> int:
check_is_fitted(self, "estimators_")
return self.depth_

Expand Down
2 changes: 1 addition & 1 deletion odte/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.5"
__version__ = "0.3.6"
20 changes: 10 additions & 10 deletions odte/tests/Odte_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,17 @@ def test_nodes_leaves_depth(self):
tclf.fit(X, y)
tclf_p.fit(X, y)
for clf in [tclf, tclf_p]:
self.assertAlmostEqual(5.8, clf.depth_)
self.assertAlmostEqual(5.8, clf.get_depth())
self.assertAlmostEqual(9.4, clf.leaves_)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, clf.nodes_)
self.assertAlmostEqual(17.8, clf.get_nodes())
self.assertEqual(29, clf.depth_)
self.assertEqual(29, clf.get_depth())
self.assertEqual(47, clf.leaves_)
self.assertEqual(47, clf.get_leaves())
self.assertEqual(89, clf.nodes_)
self.assertEqual(89, clf.get_nodes())
nodes, leaves = clf.nodes_leaves()
self.assertAlmostEqual(9.4, leaves)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, nodes)
self.assertAlmostEqual(17.8, clf.get_nodes())
self.assertEqual(47, leaves)
self.assertEqual(47, clf.get_leaves())
self.assertEqual(89, nodes)
self.assertEqual(89, clf.get_nodes())

def test_nodes_leaves_SVC(self):
tclf = Odte(
Expand Down

0 comments on commit 02e75b3

Please sign in to comment.