Last Updated on July 14, 2022 by Jay
We are going to make a waterfall chart in Python, specifically using the matplotlib library. A waterfall chart shows the running total and the additions and subtractions, which makes a good choice for attribution analysis.
The Concept
Matplotlib doesn’t have a magical function like ‘waterfall_chart()’ to enable us to make a waterfall chart in one line of code. However, we can customize our own waterfall chart in Python using a little trick.
The concept is simple:
- Create a normal bar chart
- Create another bar chart and layer it on top of the first bar chart, then set the color of the new bars the same as the background color to hide the bottom section of the first bar chart
Essentially, because we can’t see the 2nd set of bars, we can use them to “hide” the other set of bars! As the following illustration – there are a set of “invisible bars” on the chart but we won’t see them.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
df = pd.DataFrame({'category':['Sales','Service','Expenses','Taxes','Interest'],
'num':[100,10,-20,-30,60]})
df
category num
0 Sales 100
1 Service 10
2 Expenses -20
3 Taxes -30
4 Interest 60
The Two Bar Charts
Our task now becomes creating two bar charts. One of them should keep track of the running total. The other one is just a variation of the running total, which we’ll see in a second.
We can use the cumsum() method to calculate a running total, then shift it down by 1 row. These two new columns tot and tot1 give us the beginning and ending points for each of the waterfall bars. For example, in row 2 Expenses, the beginning point is 110, and the end point is 90.
df['tot'] = df['num'].cumsum()
df['tot1']= df['tot'].shift(1).fillna(0)
category num tot tot1
0 Sales 100 100 0.0
1 Service 10 110 100.0
2 Expenses -20 90 110.0
3 Taxes -30 60 90.0
4 Interest 60 120 60.0
Since the beginning and end points can be in either of the two new columns (depending on the sign of the values), we can create two more columns to capture the upper and lower points:
lower = df[['tot','tot1']].min(axis=1)
upper = df[['tot','tot1']].max(axis=1)
We plot the first set of bars using the upper points. Note these bars will have some color that’s different from the background color. Then we plot the 2nd set of bars using the lower points and set the color to be the same as the background color, which is white by default.
fig,ax = plt.subplots()
ax.bar(x=df['category'],height=upper,)
ax.bar(x=df['category'], height=lower,color='white')
We now get a chart that looks like the one below. Basically, the bars with “lower points” as the height are invisible due to having the same color as the background!
Add Colors To Waterfall Chart
Now we have a basic waterfall chart. Let’s add some colors to it. We’ll use green for increases, and red for decreases.
The data is readily available in the num column, let’s create a new color column to store the appropriate colors for each category.
df.loc[df['num'] >= 0, 'color'] = 'green'
df.loc[df['num'] < 0, 'color'] = 'red'
Re-plotting the bars with new colors will look like the following:
ax.bar(x=df['category'],height=upper,color = df['color'])
Add Labels To Waterfall Chart
A waterfall chart shows how much each category contributes to the total, so let’s add that information to the chart. There are several ways to do this, but I like to add the labels in the middle of each bar. So we need to first calculate the y-axis position for the middle of each bar (only on the visible sections).
The plt.text(x,y,str) is a convenient method to add text anywhere on the plot. We just need to loop through each category and add relevant information. The x=i-0.15 gives the x-axis position for roughly the middle of each bar. If your chart has a different number of bars, you might need to re-jig this value a little bit to find out the perfect spot for your chart.
mid = (lower + upper)/2
for i, v in enumerate(upper):
plt.text(x=i-.15, y=mid[i], f"{df['num'][i]:,.0f}")
Add Connectors To Waterfall Chart
Some waterfall charts have “connectors” and connect the beginning and end points from the previous bar to the next. This is really a personal preference and totally optional. This is also a little tricky to do.
The goal is to have the same y-axis position for the previous bar endpoint and the current bar beginning point.
Another thing we need to be aware of is that although the x-axis shows categories data, behind the scene there is still an integer index 0, 1, 2, 3 … on the x-axis.
So for the connectors we want some points like these:
- x = (0,1), y = 100
- x = (1,2), y = 110
- x = (2,3), y = 90
- etc…
We can achieve the above by using the following code. Note the NaN values between each connection are essential to “break” continuous lines so they look piece-wise.
connect = df['tot1'].repeat(3).shift(-1)
connect[1::3] = np.nan
0 0.0
0 NaN
0 100.0
1 100.0
1 NaN
1 110.0
2 110.0
2 NaN
2 90.0
3 90.0
3 NaN
3 60.0
4 60.0
4 NaN
4 NaN
Name: tot1, dtype: float64
plt.plot(connect.index,connect.values, 'k' )
Create A Waterfall Chart In Python Matplotlib
We can convert the above waterfall chart code into a convenient Python function so we can re-use it later. The function takes three arguments: a dataframe that contains the data, the name of the data column we want to place as the x-axis, and the name of the data column we want to use as the y-axis.
def waterfall(df, x, y):
# calculate running totals
df['tot'] = df[y].cumsum()
df['tot1']=df['tot'].shift(1).fillna(0)
# lower and upper points for the bar charts
lower = df[['tot','tot1']].min(axis=1)
upper = df[['tot','tot1']].max(axis=1)
# mid-point for label position
mid = (lower + upper)/2
# positive number shows green, negative number shows red
df.loc[df[y] >= 0, 'color'] = 'green'
df.loc[df[y] < 0, 'color'] = 'red'
# calculate connection points
connect= df['tot1'].repeat(3).shift(-1)
connect[1::3] = np.nan
fig,ax = plt.subplots()
# plot first bar with colors
bars = ax.bar(x=df[x],height=upper, color =df['color'])
# plot second bar - invisible
plt.bar(x=df[x], height=lower,color='white')
# plot connectors
plt.plot(connect.index,connect.values, 'k' )
# plot bar labels
for i, v in enumerate(upper):
plt.text(i-.15, mid[i], f"{df[y][i]:,.0f}")
waterfall(df,'category','num')