Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the doc-string for RFE #21

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions src/models/rfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,11 @@

# Common keyword constructor for both model types
"""
RecursiveFeatureElimination(model, n_features, step)
RecursiveFeatureElimination(model; n_features=0, step=1)

This model implements a recursive feature elimination algorithm for feature selection.
It recursively removes features, training a base model on the remaining features and
evaluating their importance until the desired number of features is selected.

Construct an instance with default hyper-parameters using the syntax
`rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
hyper-parameter defaults.

# Training data
In MLJ or MLJBase, bind an instance `rfe_model` to data with
Expand Down Expand Up @@ -92,12 +88,11 @@
# Operations

- `transform(mach, X)`: transform the input table `X` into a new table containing only
columns corresponding to features gotten from the RFE algorithm.
columns corresponding to features accepted by the RFE algorithm.

- `predict(mach, X)`: transform the input table `X` into a new table same as in

- `transform(mach, X)` above and predict using the fitted base model on the
transformed table.
`transform(mach, X)` above and predict using the fitted base model on the transformed
table.

# Fitted parameters
The fields of `fitted_params(mach)` are:
Expand All @@ -108,37 +103,43 @@
# Report
The fields of `report(mach)` are:
- `scores`: dictionary of scores for each feature in the training dataset.
The model deems highly scored variables more significant.
The model deems highly scored variables more significant.

- `model_report`: report for the fitted base model.


# Examples

The following example assumes you have MLJDecisionTreeInterface in the active package
ennvironment.

```
using FeatureSelection, MLJ, StableRNGs
using MLJ

RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree

# Creates a dataset where the target only depends on the first 5 columns of the input table.
A = rand(rng, 50, 10);
A = rand(50, 10);
y = 10 .* sin.(
pi .* A[:, 1] .* A[:, 2]
) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5];
X = MLJ.table(A);

# fit a rfe model
# fit a rfe model:
rf = RandomForestRegressor()
selector = RecursiveFeatureElimination(model = rf)
selector = RecursiveFeatureElimination(rf, n_features=2)
mach = machine(selector, X, y)
fit!(mach)

# view the feature importances
feature_importances(mach)

# predict using the base model
Xnew = MLJ.table(rand(rng, 50, 10));
# predict using the base model trained on the reduced feature set:
Xnew = MLJ.table(rand(50, 10));
predict(mach, Xnew)

# transform data with all features to the reduced feature set:
transform(mach, Xnew)
```
"""
function RecursiveFeatureElimination(
Expand Down Expand Up @@ -173,7 +174,7 @@
# This branch is hit just incase there are any models that supports_class_weights
# feature importance that aren't `<:Probabilistic` or `<:Deterministic`
# which is rare.
throw(ERR_MODEL_TYPE)
throw(ERR_MODEL_TYPE)

Check warning on line 177 in src/models/rfe.jl

View check run for this annotation

Codecov / codecov/patch

src/models/rfe.jl#L177

Added line #L177 was not covered by tests
end
message = MMI.clean!(selector)
isempty(message) || @warn(message)
Expand Down Expand Up @@ -214,19 +215,19 @@
"""
score_features!(scores_dict, features, importances, n_features_to_score)

Internal method that updates the `scores_dict` by increasing the score for each feature based on their
Internal method that updates the `scores_dict` by increasing the score for each feature based on their
importance and store the features in the `features` array.

# Arguments
- `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
- `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
the values are their corresponding scores.
- `features::Vector{Symbol}`: An array to store the top features based on importance.
- `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
contains a feature and its importance score.
- `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
contains a feature and its importance score.
- `n_features_to_score::Int`: The number of top features to score and store.

# Notes
Ensure that `n_features_to_score` is less than or equal to the minimum of the
Ensure that `n_features_to_score` is less than or equal to the minimum of the
lengths of `features` and `importances`.

# Example
Expand All @@ -244,7 +245,7 @@
function score_features!(scores_dict, features, importances, n_features_to_score)
for i in Base.OneTo(n_features_to_score)
ftr = first(importances[i])
features[i] = ftr
features[i] = ftr
scores_dict[ftr] += 1
end
end
Expand Down Expand Up @@ -273,7 +274,7 @@
"n_features > number of features in training data, "*
"hence no feature will be eliminated."
)
end
end
end

_step = selector.step
Expand All @@ -296,17 +297,17 @@
verbosity > 0 && @info("Fitting estimator with $(n_features_to_keep) features.")
data = MMI.reformat(model, MMI.selectcols(X, features_left), args...)
fitresult, _, report = MMI.fit(model, verbosity - 1, data...)
# Note that the MLJ feature importance API does not impose any restrictions on the
# ordering of `feature => score` pairs in the `importances` vector.
# Note that the MLJ feature importance API does not impose any restrictions on the
# ordering of `feature => score` pairs in the `importances` vector.
# Therefore, the order of `feature => score` pairs in the `importances` vector
# might differ from the order of features in the `features` vector, which is
# might differ from the order of features in the `features` vector, which is
# extracted from the feature matrix `X` above. Hence the need for a dictionary
# implementation.
importances = MMI.feature_importances(
selector.model,
fitresult,
report
)
)

# Eliminate the worse features and increase score of remaining features
sort!(importances, by=abs_last, rev = true)
Expand Down
Loading