【Python】相関行列 (Correlation Matrix)とヒートマップ (Heat Map)の作り方


1. 目的
2. シンプルなヒートマップ
2.1. ライブラリのインポート
2.2. データの読み込み
2.3. 相関行列の計算
2.4. ヒートマップの作成
3. プロットの大きさを相関係数に応じて変える
3.1. heatmapzのインストール
3.2. ライブラリのインポート
3.3. データの読み込み
3.4. ヒートマップの作成


1. 目的

  • Pythonで相関行列 (Correlation Matrix)を作成。
  • ヒートマップ (Heat Map)で図示。

Better Heatmaps and Correlation Matrix Plots in Pythonを参考にしました。

2. シンプルなヒートマップ

まずは、基本的な相関行列のヒートマップを作成します。

2.1. ライブラリのインポート

1
2
3
4
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
from pylab import rcParams

2.2. データの読み込み

Automobile Data Setから自動車のデータを取得します。

Drazen ZaricさんがAutomobile Data Setを使いやすいように整えてっくださっているのでそちらを使わせていただきます。

1
data = pd.read_csv('https://raw.githubusercontent.com/drazenz/heatmap/master/autos.clean.csv')

2.3. 相関行列の計算

相関行列の計算は、pandascorrメソッドを使うと簡単に計算できます。

1
corr = data.corr()

2.4. ヒートマップの作成

計算した相関行列corrをヒートマップで図示していきます。

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
# 図の設定
rcParams['figure.figsize'] = 7,7
sns.set(color_codes=True, font_scale=1.2)
 
# ヒートマップの作成
ax = sns.heatmap(
    corr,
    vmin=-1, vmax=1, center=0,
    cmap=sns.diverging_palette(20, 220, n=200),
    square=True
)
ax.set_xticklabels(
    ax.get_xticklabels(),
    rotation=45,
    horizontalalignment='right'
)
 
# 図の保存と図示
plt.savefig('simple_heatmap.png')
plt.show()

Out:

3. プロットの大きさを相関係数に応じて変える

シンプルな相関行列のヒートマップでは、

  • 最初にどこを見るべきか
  • 最も強い相関、最も弱い相関はどこか
  • 値段 (price)と強く相関している変数ベスト3はどれか
    がわかりません。

そこで、各プロット(マス目)ごとに相関係数に応じてプロットの大きさを変えるようにして表示します。
こちらは、heatmapz パッケージを使うと簡単にできます。

3.1. heatmapzのインストール

1
pip3 install heatmapz

3.2. ライブラリのインポート

1
2
3
4
5
6
from heatmap import corrplot
import numpy as np
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
from pylab import rcParams

3.3. データの読み込み

先程使用したデータと同じデータを使います。

1
data = pd.read_csv('https://raw.githubusercontent.com/drazenz/heatmap/master/autos.clean.csv')

3.4. ヒートマップの作成

01
02
03
04
05
06
07
08
09
10
11
# 図の設定
rcParams['figure.figsize'] = 7, 7
sns.set(color_codes=True, font_scale=1.2)
 
# ヒートマップの作成
plt.figure(figsize=(8, 8))
corrplot(data.corr(), size_scale=300)
 
# 図の保存と表示
plt.savefig('advanced_heatmap.png')
plt.show()

Out:

コメントを残す

This site uses Akismet to reduce spam. Learn how your comment data is processed.