Published on

How to create Heatmaps using matplotlib.pyplot

Figure 1: Nature scenery

In this tutorial we will go over how to create heat maps such as this one using Pandas and matplotlib.pyplot:

Figure 2: Example of a heatmap

What is matplotlib.pyplot?

matplotlib.pyplot is a submodule of Matplotlib, the popular plotting library using Python.

The pyplot submodule is intended for interactive plots and plot generation.

Creating heat maps using the imshow function:

The imshow() can be used to create heat maps. Let's go over a basic example on how to do so now:

First, we need to create the data and the labels that are needed to be displayed using the heatmap:

import numpy as np
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt

nhl_teams = ["Bruins", "Maple Leafs", "Lightning", "Panthers",
              "Sabres", "Senators", "Red Wings"]
nhl_team_stats = ["2022", "2021", "2020", "2019", "2018", "2017", "2016"]

nhl_games_won = np.array([[82, 63, 83, 92, 70, 45, 64],
                    [86, 48, 72, 67, 46, 42, 71],
                    [76, 89, 45, 43, 51, 38, 53],
                    [54, 56, 78, 76, 72, 80, 65],
                    [67, 49, 91, 56, 68, 40, 87],
                    [45, 70, 53, 86, 59, 63, 97],
                    [97, 67, 62, 90, 67, 78, 39]])

Then we will create the figure and subplots needed to display the heatmap. Also, the imshow() function will be used to display the nhl_games_won numpy array as a heat map:

fig, ax = plt.subplots()
im = ax.imshow(nhl_games_won)

Next, we add the ticks and labels for our heatmap:

ax.set_xticks(np.arange(len(nhl_teams)), labels=nhl_teams)
ax.set_yticks(np.arange(len(nhl_team_stats)), labels=nhl_team_stats)

In order to improve the readability of the x-axis we will rotate the x-axis tick labels via the following code:

plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

Now, in order to add text annotations that show the values contained in the nhl_games_won numpy array we use a double for loop and the ax.text() method:

for i in range(len(nhl_teams)):
    for j in range(len(nhl_team_stats)):
        text = ax.text(j, i, nhl_games_won[i, j],
                       ha="center", va="center", color="w")

And finally, we set the title for the heatmap, adjust the padding between and around the subplot via fig.tight_layout(pad=0.5) and call plt.show() to display the generated figure:

ax.set_title("NHL Games Won By Year")
fig.tight_layout(pad=0.5)
plt.show()

And here is the final result, a heatmap generated using matplotlib.pyplot:

Figure 3: A heatmap generated using matplotlib.pyplot

Here is the final version of the above code example on GitHub

Here are some definitions of the concepts that we covered above, in case you are unfamiliar with them:

  • Figure: A figure in matplotlib is the most basic foundation for plotting data using matplotlib
  • Subplot: A smaller chart that is nested with a matplotlib figure
  • Ticks: A series of values on either the x or y axis to show the coordinates on the graph

Conclusion

Well that's it for this post! Thanks for following along in this article and if you have any questions or concerns please feel free to post a comment in this post and I will get back to you when I find the time.

If you found this article helpful please share it and make sure to follow me on Twitter and GitHub, connect with me on LinkedIn and subscribe to my YouTube channel.