How to A Plot Decision Tree in Python Matplotlib

Sharing is caring!

Last Updated on July 14, 2022 by Jay

Sometimes we might want to plot a decision tree in Python to understand how the algorithm splits the data. The decision tree is probably one of the most “easy to understand” machine learning algorithms because we can see how exactly decisions are being made.

Plot A Decision Tree In Python
Plot A Decision Tree In Python

This tutorial focuses on how to plot a decision tree in Python. If you want to learn more about the decision tree algorithm, check this tutorial here.

Library & Dataset

Below are the libraries we need to install for this tutorial. We can use pip to install all three at once:

  • sklearn – a popular machine learning library for Python
  • matplotlib – chart library
  • graphviz – another charting library for plotting the decision tree
pip install sklearn matplotlib graphivz

The Iris Flower Dataset

The Iris flower dataset is a popular dataset for studying machine learning. The dataset was introduced by a British statistician and biologist called Ronald Fisher in 1936. (source: Wikipedia)

The dataset contains 3 different types of Iris flowers’ (Setosa, Versicolor, and Virginica) petal and sepal length and width. There are 50 samples for each type of Iris.

Iris Flowers
Iris Flowers

The sklearn library includes a few toy datasets for people to play around with, and Iris is one of them. We can import the Iris dataset as follows:

from sklearn.datasets import load_iris

iris = load_iris()
iris.keys()
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])

The load_iris() above actually returns a dictionary that contains several relevant information about the Iris flower dataset:

  • data: the data itself – i.e. the 4 features
  • target: the labeling for each sample (0 – setosa, 1 – versicolor, 2 – virginica)
  • target_names: the actual flower names
  • feature_names: the names of the four features, (‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’)

To access each item in the iris dataset (dictionary), we can use either indexing or the “dot” notation.

iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

iris['feature_names']
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

As shown below, the dataset contains 4 features and all data is numeric values. By learning the patterns presented in the dataset, we hope to predict the Iris type when given the petal and sepal length and width. We will use a Decision Tree Classifier model here.

The Iris Dataset
The Iris Dataset

We’ll assign variables X to the features and y to the target. Then split the data into a training dataset and a test dataset. The random_state = 0 will make the model results re-producible, meaning that running the code on your own computer will produce the same results we are showing here.

from sklearn.model_selection import train_test_split

X = iris['data']
y = iris['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)

The Decision Tree Classifier

A classifier is a type of machine learning algorithm used to assign class labels to input data. For example, if we input the four features into the classifier, then it will return one of the three Iris types to us.

The sklearn library makes it really easy to create a decision tree classifier. The fit() method is the “training” part, essentially using the features and target variables to build a decision tree and learn from the data patterns.

from sklearn.tree import DecisionTreeClassifier

tree_clf = DecisionTreeClassifier(random_state = 0)
tree_clf.fit(X_train, y_train)

Now we have a decision tree classifier model, there are a few ways to visualize it.

Simple Visualization Using sklearn

The sklearn library provides a super simple visualization of the decision tree. We can call the export_text() method in the sklearn.tree module. This is a bare minimum and not that human-friendly to look at! Let’s make it a little bit easier to read.

from sklearn import tree
print(tree.export_text(tree_clf))

|--- feature_3 <= 0.80
|   |--- class: 0
|--- feature_3 >  0.80
|   |--- feature_2 <= 4.95
|   |   |--- feature_3 <= 1.65
|   |   |   |--- class: 1
|   |   |--- feature_3 >  1.65
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |--- feature_2 >  4.95
|   |   |--- feature_3 <= 1.75
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 1
|   |   |--- feature_3 >  1.75
|   |   |   |--- class: 2

Plot A Decision Tree Using Matplotlib

We are going to use some help from the matplotlib library. The sklearn.tree module has a plot_tree method which actually uses matplotlib under the hood for plotting a decision tree.

from sklearn import tree
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10,10))
tree.plot_tree(tree_clf, feature_names = iris['feature_names'], class_names = iris['target_names'], filled=True)

The below is much nicer to look at. There’s already a tree-looking diagram with some useful data inside each node.

Plot A Decision Tree In Python
Plot A Decision Tree In Python

How to Interpret the Decision Tree

Let’s start from the root:

  1. The first line “petal width (cm) <= 0.8” is the decision rule applied to the node. Note that the new node on the left-hand side represents samples meeting the deicion rule from the parent node.
  2. gini: we will talk about this in another tutorial
  3. samples: there are 112 data records in the node
  4. value: data distribution – there are 37 setosa, 34 versicolor, and 41 virginica
  5. class: shows the majority class of the samples in the node
Interpreting The Decision Tree
Interpreting The Decision Tree

After the first split based on petal width <= 0.8 cm, all samples meeting the criteria are placed on the left (orange node), which are all setosa examples. This is called a “pure” when a node contains all the same target values. Those with petal width > 0.8 are put into the node on the right for further splits.

We also used the argument filled=True to color each node. The color intensity indicates the strengths of the majority count for a given class. For example, the first (root) node has a faint purple color with the class = virginica. Whereas at the bottom, both the two virginica nodes are in dark purple meaning the node has a lot of virginica samples inside those node.

Plot Feature Importance

The model feature importance tells us which feature is most important when making these decision splits. We can see the importance ranking by calling the .feature_importances_ attribute. Note the order of these factors match the order of the feature_names. In our example, it appears the petal width is the most important decision for splitting.

tree_clf.feature_importances_
array([0.        , 0.02014872, 0.39927524, 0.58057605])

iris['feature_names']
['sepal length (cm)',  'sepal width (cm)',  'petal length (cm)',  'petal width (cm)']

We can use matplotlib horizontal bar chart to plot the feature importance to make it more visually pleasing.

fig, ax = plt.subplots(figsize=(10,10))

plt.barh(range(len(iris['feature_names'])), tree_clf.feature_importances_)
plt.xlabel('feature importance')
plt.ylabel('feature name')
plt.yticks(range(4), iris['feature_names'])
Plotting Feature Importance of A Decision Tree Classifier Model
Plotting Feature Importance of A Decision Tree Classifier Model

Additional Resources

How To Make Waterfall Chart In Python Matplotlib

How to Make a WordCloud in Python

Leave a Reply

Your email address will not be published. Required fields are marked *