diff --git a/forestatrisk/model/model_random_forest.py b/forestatrisk/model/model_random_forest.py index ce4f721..984cc6c 100644 --- a/forestatrisk/model/model_random_forest.py +++ b/forestatrisk/model/model_random_forest.py @@ -29,19 +29,8 @@ class model_random_forest(object): """ - def __init__( - self, # Observations - formula, - data, - # NA action - NA_action="drop", - # Environment - eval_env=0, - # Number of cores - n_estimators=500, - n_jobs=1, - **kwargs - ): + def __init__(self, formula, data, na_action="drop", eval_env=0, + n_jobs=1, n_estimators=100, max_depth=None, **kwargs): """Function to fit a random forest model. The function fits a random forest model using patsy formula. @@ -54,12 +43,14 @@ def __init__( # Patsy eval_env = EvalEnvironment.capture(eval_env, reference=1) - y, x = dmatrices(formula, data, eval_env, NA_action) + y, x = dmatrices(formula, data, eval_env, na_action) self._y_design_info = y.design_info self._x_design_info = x.design_info # Create and train Random Forest - rf = RandomForestClassifier(n_estimators=n_estimators, n_jobs=n_jobs, **kwargs) + rf = RandomForestClassifier( + n_estimators=n_estimators, + max_depth=max_depth, n_jobs=n_jobs, **kwargs) rf.fit(x, y) self.rf = rf