Skip to content

Commit

Permalink
Modifying random forest to accept number of trees and max depth
Browse files Browse the repository at this point in the history
  • Loading branch information
ghislainv committed Jun 4, 2024
1 parent b73df4c commit b73dbed
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions forestatrisk/model/model_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,8 @@ class model_random_forest(object):
"""

def __init__(
self, # Observations
formula,
data,
# NA action
NA_action="drop",
# Environment
eval_env=0,
# Number of cores
n_estimators=500,
n_jobs=1,
**kwargs
):
def __init__(self, formula, data, na_action="drop", eval_env=0,
n_jobs=1, n_estimators=100, max_depth=None, **kwargs):
"""Function to fit a random forest model.
The function fits a random forest model using patsy formula.
Expand All @@ -54,12 +43,14 @@ def __init__(

# Patsy
eval_env = EvalEnvironment.capture(eval_env, reference=1)
y, x = dmatrices(formula, data, eval_env, NA_action)
y, x = dmatrices(formula, data, eval_env, na_action)
self._y_design_info = y.design_info
self._x_design_info = x.design_info

# Create and train Random Forest
rf = RandomForestClassifier(n_estimators=n_estimators, n_jobs=n_jobs, **kwargs)
rf = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth, n_jobs=n_jobs, **kwargs)
rf.fit(x, y)
self.rf = rf

Expand Down

0 comments on commit b73dbed

Please sign in to comment.