Last Updated on July 14, 2022 by Jay
Did you know we can use the pandas Python library to create a scatter matrix plot? Yes! In addition to pandas’ powerful data-wrangling capabilities, it can do plotting too!
Library
To install pandas, type the following in a command prompt window:
pip install pandas
What is A Scatter Matrix Plot
A scatter matrix plot is literally a matrix of scatter plots! Sometimes people might call it “feature pair plot”.
Essentially we are creating a scatter plot for each feature pair for all possible pairs. This plot is helpful in showing how the features are correlated to each other or not. However, note that the scatter matrix plot doesn’t show interactions between all features – only between pairs of features.
Fruits Dataset
We’ll use a “fruits” dataset created by Dr. Ian Murray from the University of Edingurgh. Dr. Murray bought a few dozens of oranges, lemons, and apples of different varieties, and recorded their measurements in a table. The dataset was later formatted by the University of Michigan for teaching purposes.
Run the following code to load the fruits dataset into pandas.
%matplotlib notebook
import pandas as pd
fruits = pd.read_csv('https://raw.githubusercontent.com/pythoninoffice/fruit-data-with-colours/master/fruit_data_with_colours.csv')
fruits.head()
fruit_label fruit_name fruit_subtype mass width height color_score
0 1 apple granny_smith 192 8.4 7.3 0.55
1 1 apple granny_smith 180 8.0 6.8 0.59
2 1 apple granny_smith 176 7.4 7.2 0.60
3 2 mandarin mandarin 86 6.2 4.7 0.80
4 2 mandarin mandarin 84 6.0 4.6 0.79
fruits['fruit_name'].value_counts()
apple 19
orange 19
lemon 16
mandarin 5
Name: fruit_name, dtype: int64
Prepare Features and Labels
A feature usually refers to the attribute of the sample data. The fruits example has the following features: mass, width, height, color_score. We use X to represent the features dataset.
A label is literally the data label. In our example, the label is either fruit_label or fruit_name. We use y to represent the labels dataset.
X = fruits[['mass','width','height','color_score']]
y = fruits['fruit_label']
Creating a Scatter Matrix Plot Using Pandas
It’s extremely easy to create a scatter matrix plot using pandas. See below just 1 line of code:
pd.plotting.scatter_matrix(X, c = y, marker = 'o', figsize=(9,9))
The arguments are:
- X contains all the features to plot
- c = y means use different color for each label
- marker = ‘o’ draws circles for the scatter plot, use marker = ‘.’ to draw small dots
- figsize is optional, just to make our chart larger and easier to see
Let’s try to understand the above chart:
- In total there are 16 charts, as there are 4 features, 4^2 = 16 pairs.
- Charts on the diagonal are histograms of a given feature, these are not pair plots. E.g. the top left histogram shows the distribution of mass.
- Charts everywhere else are feature pair plots. Each dot represents a fruit from the fruits dataset. These charts show relationships between a pair of features. For example, the 3rd chart on the bottom shows relatinship between color_score (y-axis) and height (x-axis).
Note the 4th chart on the third row is actually the same color_score and height pair plot, just with axes reversed with the 3rd chart on the bottom row.
From this scatter matrix plot, we can see the color_score and height pair plot shows something interesting. It seems like dots with the same colors form several clusters with pretty clean boundaries. This observation will help us form a thesis on how to create our machine learning models for the problem.