Scatter Matrix With Plotly

Sharing is caring!

Last Updated on July 14, 2022 by Jay

Did you know that the plotly Python library can create a scatter matrix plot as well? A scatter matrix, or a features pair plot is a useful visualization tool we can create to help spot correlations in the dataset. Most people probably learned about the scatter matrix from matplotlib or pandas. However, I wasn’t satisfied with the default looks that matplotlib offers. Plotly offers an easy API, and the charts are interactive and modern looking at the same time.

Library

To install plotly, type the following into a command prompt window:

pip install plotly

The plotly library is available in many different programming languages including Python, R, Julia, JS, etc. The original library is written in JavaScript. What we’ll be using is the Python version (wrapper) of it.

Plotly is powerful, easy to use, and free!

The library has three main tools: plotly express, plotly graphic_objects, and dash. We won’t talk about dash since that’s a web framework.

Dataset

We’ll use a “fruits” dataset created by Dr. Ian Murray from the University of Edingurgh. Dr. Murray bought a few dozens of oranges, lemons, and apples of different varieties, and recorded their measurements in a table. The dataset was later formatted by the University of Michigan for teaching purposes.

Copy and run the following code to load the fruits dataset into pandas.

%matplotlib notebook
import pandas as pd

fruits = pd.read_csv('https://raw.githubusercontent.com/pythoninoffice/fruit-data-with-colours/master/fruit_data_with_colours.csv')

fruits.head()
   fruit_label fruit_name fruit_subtype  mass  width  height  color_score
0            1      apple  granny_smith   192    8.4     7.3         0.55
1            1      apple  granny_smith   180    8.0     6.8         0.59
2            1      apple  granny_smith   176    7.4     7.2         0.60
3            2   mandarin      mandarin    86    6.2     4.7         0.80
4            2   mandarin      mandarin    84    6.0     4.6         0.79

fruits['fruit_name'].unique()
array(['apple', 'mandarin', 'orange', 'lemon'], dtype=object)

Create A Scatter Matrix Plot Using Plotly Express

As the name suggests, the plotly express module provides a quick and easy way to create a plot, usually, it just takes 1 line of code!

The API is also easy to understand. We can pass an entire pandas dataframe into plotly.express, then use the provided arguments on those column names directly to control how the graph looks like.

For a scatter matrix created by matplotlib, it requires a little bit of tweaking to show the legends. With plotly, all we need to do is just to add an argument or two:

  • color=’fruit_name’: will display data points in different colors by fruit_name.
  • symbol=’fruit_name’: will display data points in different symbols /shape by fruit name. This is optional if you just need different colors.
import plotly.express as px

px.scatter_matrix(fruits, 
                  dimensions=['mass','width','height','color_score'],
                  color = 'fruit_name',
                  symbol = 'fruit_name',
                  width = 1000,
                  height = 1000)
A scatter matrix by plotly.express

Create A Scatter Matrix Plot Using Plotly Graph_Objects

The plotly graph_objects module offers full functionalities and allows for fine-tuning almost every aspect of the graph. For quick plotting, go with plotly express. If you need more fine-tuning, go with plotly graphc_objects.

The API for graph_object is of course more complex than plotly express. However, graph_objects allows us to customize the chart to the teeth.

We’ll use the graph_objects.Splom() function to create a scatter matrix in three simple steps:

  1. Initiatlize/create a go.Figure object
  2. Add trace (data) to the Figure object
  3. Update the layout (cosmetics) of the Figure
import plotly.graph_objects as go

fig = go.Figure()
    
fig.add_trace(go.Splom(dimensions=[dict(label='mass', values = fruits['mass']),
                                   dict(label='width', values = fruits['width']),
                                   dict(label='height', values = fruits['height']),
                                   dict(label='color_score', values = fruits['color_score'])],
                       text = fruits['fruit_name'],
                       marker = dict(color = fruits['fruit_label'],
                                     opacity = 0.8,
                                     line_color = 'white',line_width=1),
                       showupperhalf = False
                       ),
             )
fig.update_layout(width = 1000, height = 1000)
fig.show()

Several notable arguments above:

  • text = fruits[‘fruit_name’] will display fruits name when mouse over the data points
  • marker dictionary controls how the dots appear. We even can add a thin white outline around each dots
  • showupperhalf = False turns off the upper half charts, which can help reduce some distractions since the upper and lower halves actually show the same feature pair plots
A Scatter matrix plot by plotly graph_objects

Additional Resources

Pandas Plotting: Scatter Matrix

Python Data Visualization & Exploration With Plotly

Leave a Reply

Your email address will not be published. Required fields are marked *