Last Updated on July 14, 2022 by Jay
Did you know that the plotly Python library can create a scatter matrix plot as well? A scatter matrix, or a features pair plot is a useful visualization tool we can create to help spot correlations in the dataset. Most people probably learned about the scatter matrix from matplotlib or pandas. However, I wasn’t satisfied with the default looks that matplotlib offers. Plotly offers an easy API, and the charts are interactive and modern looking at the same time.
To install plotly, type the following into a command prompt window:
pip install plotly
Plotly is powerful, easy to use, and free!
The library has three main tools: plotly express, plotly graphic_objects, and dash. We won’t talk about dash since that’s a web framework.
We’ll use a “fruits” dataset created by Dr. Ian Murray from the University of Edingurgh. Dr. Murray bought a few dozens of oranges, lemons, and apples of different varieties, and recorded their measurements in a table. The dataset was later formatted by the University of Michigan for teaching purposes.
Copy and run the following code to load the fruits dataset into pandas.
%matplotlib notebook import pandas as pd fruits = pd.read_csv('https://raw.githubusercontent.com/pythoninoffice/fruit-data-with-colours/master/fruit_data_with_colours.csv') fruits.head() fruit_label fruit_name fruit_subtype mass width height color_score 0 1 apple granny_smith 192 8.4 7.3 0.55 1 1 apple granny_smith 180 8.0 6.8 0.59 2 1 apple granny_smith 176 7.4 7.2 0.60 3 2 mandarin mandarin 86 6.2 4.7 0.80 4 2 mandarin mandarin 84 6.0 4.6 0.79 fruits['fruit_name'].unique() array(['apple', 'mandarin', 'orange', 'lemon'], dtype=object)
Create A Scatter Matrix Plot Using Plotly Express
As the name suggests, the plotly express module provides a quick and easy way to create a plot, usually, it just takes 1 line of code!
The API is also easy to understand. We can pass an entire pandas dataframe into plotly.express, then use the provided arguments on those column names directly to control how the graph looks like.
For a scatter matrix created by matplotlib, it requires a little bit of tweaking to show the legends. With plotly, all we need to do is just to add an argument or two:
- color=’fruit_name’: will display data points in different colors by fruit_name.
- symbol=’fruit_name’: will display data points in different symbols /shape by fruit name. This is optional if you just need different colors.
import plotly.express as px px.scatter_matrix(fruits, dimensions=['mass','width','height','color_score'], color = 'fruit_name', symbol = 'fruit_name', width = 1000, height = 1000)
Create A Scatter Matrix Plot Using Plotly Graph_Objects
The plotly graph_objects module offers full functionalities and allows for fine-tuning almost every aspect of the graph. For quick plotting, go with plotly express. If you need more fine-tuning, go with plotly graphc_objects.
The API for graph_object is of course more complex than plotly express. However, graph_objects allows us to customize the chart to the teeth.
We’ll use the graph_objects.Splom() function to create a scatter matrix in three simple steps:
- Initiatlize/create a go.Figure object
- Add trace (data) to the Figure object
- Update the layout (cosmetics) of the Figure
import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Splom(dimensions=[dict(label='mass', values = fruits['mass']), dict(label='width', values = fruits['width']), dict(label='height', values = fruits['height']), dict(label='color_score', values = fruits['color_score'])], text = fruits['fruit_name'], marker = dict(color = fruits['fruit_label'], opacity = 0.8, line_color = 'white',line_width=1), showupperhalf = False ), ) fig.update_layout(width = 1000, height = 1000) fig.show()
Several notable arguments above:
- text = fruits[‘fruit_name’] will display fruits name when mouse over the data points
- marker dictionary controls how the dots appear. We even can add a thin white outline around each dots
- showupperhalf = False turns off the upper half charts, which can help reduce some distractions since the upper and lower halves actually show the same feature pair plots
Pandas Plotting: Scatter Matrix