Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting trees with TreeRecipe.jl #56

Open
ablaom opened this issue Feb 22, 2024 · 3 comments
Open

Plotting trees with TreeRecipe.jl #56

ablaom opened this issue Feb 22, 2024 · 3 comments

Comments

@ablaom
Copy link
Member

ablaom commented Feb 22, 2024

The following example shows how to manually plot the trees learned in DecisionTree.jl:

https://github.com/JuliaAI/TreeRecipe.jl/blob/master/examples/DecisionTree_iris.jl

Currently, the way to integrate a plot recipe in MLJ.jl is not documented, but is sketched in this comment.

So, can we somehow put this together to arrange that a workflow like this generates a plot of a decision tree?

edited again (x2):

using MLJBase
using Plots                 # <---- added in edit
import MLJDecisionTreeInterface
tree = MLJDecisionTreeInterface.DecisionTreeClassifier()
X, y = @load_iris
mach = machine(tree, X, y) |> fit!
plot(mach, 0.8, 0.7; size = (1400,600)))   # <---- added in edit

Note: It used to be that you made RecipesBase.jl your dependency, to avoid a full Plots.jl dependency. But now the recipes live in Plots.jl and you are expected to make Plots.jl a weak dependency. You can see an example of this here.

@adarshpalaskar1
Copy link
Contributor

adarshpalaskar1 commented Feb 23, 2024

Hello, I went through the RecipesBase documentation and needed some help understanding the plot recipe's integration. I had some questions:

  1. Should I add the code for the recipe in the MLJDecisionTree.jl file itself or somewhere else? If yes, should I use the example code you mentioned above for plotting directly in the recipe? (I am unable to convert the code that works for the DecisionTreeClassifier model in the example for the MLJ machine).

  2. I cannot pass the Machine in the recipe argument, as it's not a part of the current dependencies(If writing the recipe in the MLJDecisionTree.jl file). What do you think should be done here?

Also, please let me know if these questions make sense or if I'm thinking in the wrong direction😅

@ablaom
Copy link
Member Author

ablaom commented Feb 26, 2024

I've looked into this a bit further. I have an idea how to do it but it's a bit involved. The first step is to replace the current fitresult output of the fit methods for DecisionTreeClassifier and DecisionTreeRegressor models with wrapped versions. We need this because we are going to overload Plots.plot(fitresult, ...) for appropriate fitresult types.

So, we create a new struct

struct DecisionTreeClassifierFitResult{T,C,I}
    tree::T
    classes_seen::C
    integers_seen::I
    features::Vector{Symbol}
end

and instead of fit(::DecisionTreeClassifier, ...) returning fitresult = (tree, classes_seen, integers_seen, features) we will return
DecisionTreeClassifierFitResult(tree, classes_seen, integers_seen, features).

We will have to modify the definition of predict(::DecisionTreeClassifier, fitresult,...), fitted_parameters(::DecisionTreeClassifier, fitresult) and feature_importances(::TreeModel, ...) accordingly, so that they first unwrap the fitresult.

We do something similar for DecisionTreeRegressor, whose fitresult has a different form.

We should be careful that none of these changes breaks anything. Since fitresult is private (public access is through the fitted_params method we are fixing) this should not be a problem.

@adarshpalaskar1 You want to have a go at a PR to do this internal wrapping?

@adarshpalaskar1
Copy link
Contributor

adarshpalaskar1 commented Mar 1, 2024

Sorry for the delayed response. I went through the implementation steps you provided, and I am eager to work on a PR.

I added the above mentioned changes for the DecisionTreeClassifier, but I am facing an issue while plotting:

In the above mentioned example, (https://github.com/JuliaAI/TreeRecipe.jl/blob/master/examples/DecisionTree_iris.jl), we have:

julia> typeof(dtree)
Node{Float64, String}

wrapped tree in example:

julia> typeof(wt)
InfoNode{Float64, String}


Where as in case of DecisionTreeInterface, we have:

julia> typeof(fitted_params(mach).raw_tree.node)
DecisionTree.Node{Float64, UInt32}

wrapped tree in fitted_params:

julia> typeof(fitted_params(mach).tree)
DecisionTree.InfoNode{Float64, UInt32}

Due to which we get:

julia> plot(fitted_params(mach).tree)
ERROR: Cannot convert DecisionTree.InfoNode{Float64, UInt32} to series data for plotting

and similarly, after adding a recipe for wrapping:

julia> plot(mach)
ERROR: Cannot convert DecisionTree.InfoNode{Float64, UInt32} to series data for plotting

How can I handle this mismatch in the datatypes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: priority low / involved
Development

No branches or pull requests

2 participants