Pandas Plotting: Scatter Matrix

Sharing is caring!

Last Updated on July 14, 2022 by Jay

Did you know we can use the pandas Python library to create a scatter matrix plot? Yes! In addition to pandas’ powerful data-wrangling capabilities, it can do plotting too!


To install pandas, type the following in a command prompt window:

pip install pandas

What is A Scatter Matrix Plot

A scatter matrix plot is literally a matrix of scatter plots! Sometimes people might call it “feature pair plot”.

Essentially we are creating a scatter plot for each feature pair for all possible pairs. This plot is helpful in showing how the features are correlated to each other or not. However, note that the scatter matrix plot doesn’t show interactions between all features – only between pairs of features.

Fruits Dataset

We’ll use a “fruits” dataset created by Dr. Ian Murray from the University of Edingurgh. Dr. Murray bought a few dozens of oranges, lemons, and apples of different varieties, and recorded their measurements in a table. The dataset was later formatted by the University of Michigan for teaching purposes.

Run the following code to load the fruits dataset into pandas.

%matplotlib notebook
import pandas as pd

fruits = pd.read_csv('')

   fruit_label fruit_name fruit_subtype  mass  width  height  color_score
0            1      apple  granny_smith   192    8.4     7.3         0.55
1            1      apple  granny_smith   180    8.0     6.8         0.59
2            1      apple  granny_smith   176    7.4     7.2         0.60
3            2   mandarin      mandarin    86    6.2     4.7         0.80
4            2   mandarin      mandarin    84    6.0     4.6         0.79

apple       19
orange      19
lemon       16
mandarin     5
Name: fruit_name, dtype: int64

Prepare Features and Labels

A feature usually refers to the attribute of the sample data. The fruits example has the following features: mass, width, height, color_score. We use X to represent the features dataset.

A label is literally the data label. In our example, the label is either fruit_label or fruit_name. We use y to represent the labels dataset.

X = fruits[['mass','width','height','color_score']]
y = fruits['fruit_label']

Creating a Scatter Matrix Plot Using Pandas

It’s extremely easy to create a scatter matrix plot using pandas. See below just 1 line of code:

pd.plotting.scatter_matrix(X, c = y, marker = 'o', figsize=(9,9))

The arguments are:

  • X contains all the features to plot
  • c = y means use different color for each label
  • marker = ‘o’ draws circles for the scatter plot, use marker = ‘.’ to draw small dots
  • figsize is optional, just to make our chart larger and easier to see
pandas scatter matrix plot
pandas scatter matrix plot

Let’s try to understand the above chart:

  1. In total there are 16 charts, as there are 4 features, 4^2 = 16 pairs.
  2. Charts on the diagonal are histograms of a given feature, these are not pair plots. E.g. the top left histogram shows the distribution of mass.
  3. Charts everywhere else are feature pair plots. Each dot represents a fruit from the fruits dataset. These charts show relationships between a pair of features. For example, the 3rd chart on the bottom shows relatinship between color_score (y-axis) and height (x-axis).

Note the 4th chart on the third row is actually the same color_score and height pair plot, just with axes reversed with the 3rd chart on the bottom row.

From this scatter matrix plot, we can see the color_score and height pair plot shows something interesting. It seems like dots with the same colors form several clusters with pretty clean boundaries. This observation will help us form a thesis on how to create our machine learning models for the problem.

Additional Resources

How to Do Train Test Split in Sklearn

Least Squares Linear Regression With Python Example

Leave a Reply

Your email address will not be published. Required fields are marked *