Last Updated on July 14, 2022 by Jay
Sometimes we might want to plot a decision tree in Python to understand how the algorithm splits the data. The decision tree is probably one of the most “easy to understand” machine learning algorithms because we can see how exactly decisions are being made.
This tutorial focuses on how to plot a decision tree in Python. If you want to learn more about the decision tree algorithm, check this tutorial here.
Library & Dataset
Below are the libraries we need to install for this tutorial. We can use pip to install all three at once:
- sklearn – a popular machine learning library for Python
- matplotlib – chart library
- graphviz – another charting library for plotting the decision tree
pip install sklearn matplotlib graphivz
The Iris Flower Dataset
The Iris flower dataset is a popular dataset for studying machine learning. The dataset was introduced by a British statistician and biologist called Ronald Fisher in 1936. (source: Wikipedia)
The dataset contains 3 different types of Iris flowers’ (Setosa, Versicolor, and Virginica) petal and sepal length and width. There are 50 samples for each type of Iris.
The sklearn library includes a few toy datasets for people to play around with, and Iris is one of them. We can import the Iris dataset as follows:
from sklearn.datasets import load_iris
iris = load_iris()
iris.keys()
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])
The load_iris() above actually returns a dictionary that contains several relevant information about the Iris flower dataset:
- data: the data itself – i.e. the 4 features
- target: the labeling for each sample (0 – setosa, 1 – versicolor, 2 – virginica)
- target_names: the actual flower names
- feature_names: the names of the four features, (‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’)
To access each item in the iris dataset (dictionary), we can use either indexing or the “dot” notation.
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
iris['feature_names']
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
As shown below, the dataset contains 4 features and all data is numeric values. By learning the patterns presented in the dataset, we hope to predict the Iris type when given the petal and sepal length and width. We will use a Decision Tree Classifier model here.
We’ll assign variables X to the features and y to the target. Then split the data into a training dataset and a test dataset. The random_state = 0 will make the model results re-producible, meaning that running the code on your own computer will produce the same results we are showing here.
from sklearn.model_selection import train_test_split
X = iris['data']
y = iris['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)
The Decision Tree Classifier
A classifier is a type of machine learning algorithm used to assign class labels to input data. For example, if we input the four features into the classifier, then it will return one of the three Iris types to us.
The sklearn library makes it really easy to create a decision tree classifier. The fit() method is the “training” part, essentially using the features and target variables to build a decision tree and learn from the data patterns.
from sklearn.tree import DecisionTreeClassifier
tree_clf = DecisionTreeClassifier(random_state = 0)
tree_clf.fit(X_train, y_train)
Now we have a decision tree classifier model, there are a few ways to visualize it.
Simple Visualization Using sklearn
The sklearn library provides a super simple visualization of the decision tree. We can call the export_text() method in the sklearn.tree module. This is a bare minimum and not that human-friendly to look at! Let’s make it a little bit easier to read.
from sklearn import tree
print(tree.export_text(tree_clf))
|--- feature_3 <= 0.80
| |--- class: 0
|--- feature_3 > 0.80
| |--- feature_2 <= 4.95
| | |--- feature_3 <= 1.65
| | | |--- class: 1
| | |--- feature_3 > 1.65
| | | |--- feature_1 <= 3.10
| | | | |--- class: 2
| | | |--- feature_1 > 3.10
| | | | |--- class: 1
| |--- feature_2 > 4.95
| | |--- feature_3 <= 1.75
| | | |--- feature_3 <= 1.65
| | | | |--- class: 2
| | | |--- feature_3 > 1.65
| | | | |--- class: 1
| | |--- feature_3 > 1.75
| | | |--- class: 2
Plot A Decision Tree Using Matplotlib
We are going to use some help from the matplotlib library. The sklearn.tree module has a plot_tree method which actually uses matplotlib under the hood for plotting a decision tree.
from sklearn import tree
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10,10))
tree.plot_tree(tree_clf, feature_names = iris['feature_names'], class_names = iris['target_names'], filled=True)
The below is much nicer to look at. There’s already a tree-looking diagram with some useful data inside each node.
How to Interpret the Decision Tree
Let’s start from the root:
- The first line “petal width (cm) <= 0.8” is the decision rule applied to the node. Note that the new node on the left-hand side represents samples meeting the deicion rule from the parent node.
- gini: we will talk about this in another tutorial
- samples: there are 112 data records in the node
- value: data distribution – there are 37 setosa, 34 versicolor, and 41 virginica
- class: shows the majority class of the samples in the node
After the first split based on petal width <= 0.8 cm, all samples meeting the criteria are placed on the left (orange node), which are all setosa examples. This is called a “pure” when a node contains all the same target values. Those with petal width > 0.8 are put into the node on the right for further splits.
We also used the argument filled=True to color each node. The color intensity indicates the strengths of the majority count for a given class. For example, the first (root) node has a faint purple color with the class = virginica. Whereas at the bottom, both the two virginica nodes are in dark purple meaning the node has a lot of virginica samples inside those node.
Plot Feature Importance
The model feature importance tells us which feature is most important when making these decision splits. We can see the importance ranking by calling the .feature_importances_ attribute. Note the order of these factors match the order of the feature_names. In our example, it appears the petal width is the most important decision for splitting.
tree_clf.feature_importances_
array([0. , 0.02014872, 0.39927524, 0.58057605])
iris['feature_names']
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
We can use matplotlib horizontal bar chart to plot the feature importance to make it more visually pleasing.
fig, ax = plt.subplots(figsize=(10,10))
plt.barh(range(len(iris['feature_names'])), tree_clf.feature_importances_)
plt.xlabel('feature importance')
plt.ylabel('feature name')
plt.yticks(range(4), iris['feature_names'])