Skip to content

Commit

Permalink
Add separate methods to return nodes/leaves/depth
Browse files Browse the repository at this point in the history
  • Loading branch information
rmontanana committed Nov 27, 2023
1 parent f9b83ad commit 52d1095
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
12 changes: 12 additions & 0 deletions odte/Odte.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ def predict_proba(self, X: np.ndarray) -> np.ndarray:
for i in range(n_samples):
result[i, predictions[i]] += 1
return result / self.n_estimators

def get_nodes(self) ->int:
check_is_fitted(self, "estimators_")
return self.nodes_

def get_leaves(self) ->int:
check_is_fitted(self, "estimators_")
return self.leaves_

def get_depth(self) ->int:
check_is_fitted(self, "estimators_")
return self.depth_

def nodes_leaves(self) -> Tuple[float, float]:
check_is_fitted(self, "estimators_")
Expand Down
2 changes: 1 addition & 1 deletion odte/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.4"
__version__ = "0.3.5"
15 changes: 15 additions & 0 deletions odte/tests/Odte_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def test_nodes_leaves_not_fitted(self):
)
with self.assertRaises(NotFittedError):
tclf.nodes_leaves()
with self.assertRaises(NotFittedError):
tclf.get_nodes()
with self.assertRaises(NotFittedError):
tclf.get_leaves()
with self.assertRaises(NotFittedError):
tclf.get_depth()

def test_nodes_leaves_depth(self):
tclf = Odte(
Expand All @@ -209,11 +215,16 @@ def test_nodes_leaves_depth(self):
tclf_p.fit(X, y)
for clf in [tclf, tclf_p]:
self.assertAlmostEqual(5.8, clf.depth_)
self.assertAlmostEqual(5.8, clf.get_depth())
self.assertAlmostEqual(9.4, clf.leaves_)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, clf.nodes_)
self.assertAlmostEqual(17.8, clf.get_nodes())
nodes, leaves = clf.nodes_leaves()
self.assertAlmostEqual(9.4, leaves)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, nodes)
self.assertAlmostEqual(17.8, clf.get_nodes())

def test_nodes_leaves_SVC(self):
tclf = Odte(
Expand All @@ -224,10 +235,14 @@ def test_nodes_leaves_SVC(self):
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(0.0, tclf.leaves_)
self.assertAlmostEqual(0.0, tclf.get_leaves())
self.assertAlmostEqual(0.0, tclf.nodes_)
self.assertAlmostEqual(0.0, tclf.get_nodes())
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(0.0, leaves)
self.assertAlmostEqual(0.0, tclf.get_leaves())
self.assertAlmostEqual(0.0, nodes)
self.assertAlmostEqual(0.0, tclf.get_nodes())

def test_estimator_hyperparams(self):
data = [
Expand Down

0 comments on commit 52d1095

Please sign in to comment.