1. 目的
2. シンプルなヒートマップ
2.1. ライブラリのインポート
2.2. データの読み込み
2.3. 相関行列の計算
2.4. ヒートマップの作成
3. プロットの大きさを相関係数に応じて変える
3.1. heatmapzのインストール
3.2. ライブラリのインポート
3.3. データの読み込み
3.4. ヒートマップの作成
1. 目的
- Pythonで相関行列 (Correlation Matrix)を作成。
- ヒートマップ (Heat Map)で図示。
Better Heatmaps and Correlation Matrix Plots in Pythonを参考にしました。
2. シンプルなヒートマップ
まずは、基本的な相関行列のヒートマップを作成します。
2.1. ライブラリのインポート
1 2 3 4 | import seaborn as sns import pandas as pd from matplotlib import pyplot as plt from pylab import rcParams |
2.2. データの読み込み
Automobile Data Setから自動車のデータを取得します。
Drazen ZaricさんがAutomobile Data Setを使いやすいように整えてっくださっているのでそちらを使わせていただきます。
1 | data = pd.read_csv( 'https://raw.githubusercontent.com/drazenz/heatmap/master/autos.clean.csv' ) |
2.3. 相関行列の計算
相関行列の計算は、pandas
のcorr
メソッドを使うと簡単に計算できます。
1 | corr = data.corr() |
2.4. ヒートマップの作成
計算した相関行列corr
をヒートマップで図示していきます。
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 | # 図の設定 rcParams[ 'figure.figsize' ] = 7 , 7 sns. set (color_codes = True , font_scale = 1.2 ) # ヒートマップの作成 ax = sns.heatmap( corr, vmin = - 1 , vmax = 1 , center = 0 , cmap = sns.diverging_palette( 20 , 220 , n = 200 ), square = True ) ax.set_xticklabels( ax.get_xticklabels(), rotation = 45 , horizontalalignment = 'right' ) # 図の保存と図示 plt.savefig( 'simple_heatmap.png' ) plt.show() |
Out:
3. プロットの大きさを相関係数に応じて変える
シンプルな相関行列のヒートマップでは、
- 最初にどこを見るべきか
- 最も強い相関、最も弱い相関はどこか
- 値段 (price)と強く相関している変数ベスト3はどれか
がわかりません。
そこで、各プロット(マス目)ごとに相関係数に応じてプロットの大きさを変えるようにして表示します。
こちらは、heatmapz パッケージを使うと簡単にできます。
3.1. heatmapzのインストール
1 | pip3 install heatmapz |
3.2. ライブラリのインポート
1 2 3 4 5 6 | from heatmap import corrplot import numpy as np import seaborn as sns import pandas as pd from matplotlib import pyplot as plt from pylab import rcParams |
3.3. データの読み込み
先程使用したデータと同じデータを使います。
1 | data = pd.read_csv( 'https://raw.githubusercontent.com/drazenz/heatmap/master/autos.clean.csv' ) |
3.4. ヒートマップの作成
01 02 03 04 05 06 07 08 09 10 11 | # 図の設定 rcParams[ 'figure.figsize' ] = 7 , 7 sns. set (color_codes = True , font_scale = 1.2 ) # ヒートマップの作成 plt.figure(figsize = ( 8 , 8 )) corrplot(data.corr(), size_scale = 300 ) # 図の保存と表示 plt.savefig( 'advanced_heatmap.png' ) plt.show() |
Out: