Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some decision tree branching may be unnecessary #226

Open
kevinrobinson opened this issue Aug 19, 2019 · 3 comments
Open

Some decision tree branching may be unnecessary #226

kevinrobinson opened this issue Aug 19, 2019 · 3 comments

Comments

@kevinrobinson
Copy link

@dalelane #221 is amazing, awesome work! 👍 Thanks as always for sharing such great work in the open! ❤️

I made some limited test data just to check this out, and noticed that the explanation had some parts of the tree that didn't seem necessary since they all result in the same classification. Here's an example, with the parts highlighted that seem unnecessary:

image

Maybe the way the tree is constructed can lead to this? I'm not sure. I tried to see where the tree is constructed, guessing that this was done by a service but only got as far as

export function getModelVisualisation(project: Objects.Project): Promise<NumbersModelDescriptionResponse> {
and wasn't quite sure from there.

I'm not sure if this is an actual issue, as in something isn't working as it should, or just something that might be worth adding to your great explanation up top.

Also, http://www.r2d3.us/visual-intro-to-machine-learning-part-1/ has some awesome visuals of decision trees that might be a good fit as a "learn more" link in the explanation up to as well.

@dalelane
Copy link
Member

I'll have a look, thanks!

In the meantime, the code where this is done lives here:

# building decision tree classifier
vec = DictVectorizer(sparse=False)
dt = tree.DecisionTreeClassifier()
dt.fit(vec.fit_transform(examples), labels)
# creating decision tree visualization
dot_data = tree.export_graphviz(dt,
feature_names=vec.feature_names_,
class_names=dt.classes_,
impurity=False,
filled=True,
rounded=True)
graph = graph_from_dot_data(dot_data)
graph.set_size('"70"')

@kevinrobinson
Copy link
Author

@dalelane awesome, thanks!

Yeah, it seems like the scikit learn algorithm is an approximation since for larger real-world data sets it can be intractable to compute the optimal representation link:

The problem of learning an optimal decision tree is known to be NP-complete under several aspects of optimality and even for simple concepts. Consequently, practical decision-tree learning algorithms are based on heuristic algorithms such as the greedy algorithm where locally optimal decisions are made at each node. Such algorithms cannot guarantee to return the globally optimal decision tree. This can be mitigated by training multiple trees in an ensemble learner, where the features and samples are randomly sampled with replacement.

I suppose to simplify things for folks it'd be possible to write a function that does that simplification on the tree that's returned, since for ML for kids models the trees wouldn't be expensive to prune after the fact. sci-kit learn exposes the internals of the decision tree structure (link) but the graphviz function takes a dt so that might involve hacking rather than just walking and modifying the tree. Alternately it might be worth just adding a line of explanation alerting learners and teachers to this, and calling it out as an example of the kind of approximation that is common in ML.

Relatedly, it looks like subsequent runs are non-deterministic link:

The features are always randomly permuted at each split. Therefore, the best found split may vary, even with the same training data and max_features=n_features, if the improvement of the criterion is identical for several splits enumerated during the search of the best split. To obtain a deterministic behaviour during fitting, random_state has to be fixed.

This might be good to either add a line about in the explanation, or to avoid this to simplify the experience and fix a random seed so that folks aren't confused if they get different trees from different training runs on the same data set (eg, within a class).

@dalelane
Copy link
Member

Related issues:
scikit-learn/scikit-learn#10810
scikit-learn/scikit-learn#6557
scikit-learn/scikit-learn#4630

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants