Xin lưu ý rằng các tác giả của seaborn
chỉ muốn seaborn.heatmap
làm việc với các khung dữ liệu phân loại. Nó không chung chung.
Nếu chỉ mục và cột của bạn là giá trị số và / hoặc ngày giờ, mã này sẽ phục vụ bạn tốt.
Chức năng ánh xạ nhiệt Matplotlib pcolormesh
yêu cầu các thùng thay vì các chỉ mục , vì vậy có một số mã ưa thích để tạo các thùng từ các chỉ số khung dữ liệu của bạn (ngay cả khi chỉ mục của bạn không cách đều nhau!).
Phần còn lại đơn giản là np.meshgrid
và plt.pcolormesh
.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def conv_index_to_bins(index):
"""Calculate bins to contain the index values.
The start and end bin boundaries are linearly extrapolated from
the two first and last values. The middle bin boundaries are
midpoints.
Example 1: [0, 1] -> [-0.5, 0.5, 1.5]
Example 2: [0, 1, 4] -> [-0.5, 0.5, 2.5, 5.5]
Example 3: [4, 1, 0] -> [5.5, 2.5, 0.5, -0.5]"""
assert index.is_monotonic_increasing or index.is_monotonic_decreasing
# the beginning and end values are guessed from first and last two
start = index[0] - (index[1]-index[0])/2
end = index[-1] + (index[-1]-index[-2])/2
# the middle values are the midpoints
middle = pd.DataFrame({'m1': index[:-1], 'p1': index[1:]})
middle = middle['m1'] + (middle['p1']-middle['m1'])/2
if isinstance(index, pd.DatetimeIndex):
idx = pd.DatetimeIndex(middle).union([start,end])
elif isinstance(index, (pd.Float64Index,pd.RangeIndex,pd.Int64Index)):
idx = pd.Float64Index(middle).union([start,end])
else:
print('Warning: guessing what to do with index type %s' %
type(index))
idx = pd.Float64Index(middle).union([start,end])
return idx.sort_values(ascending=index.is_monotonic_increasing)
def calc_df_mesh(df):
"""Calculate the two-dimensional bins to hold the index and
column values."""
return np.meshgrid(conv_index_to_bins(df.index),
conv_index_to_bins(df.columns))
def heatmap(df):
"""Plot a heatmap of the dataframe values using the index and
columns"""
X,Y = calc_df_mesh(df)
c = plt.pcolormesh(X, Y, df.values.T)
plt.colorbar(c)
Gọi nó bằng cách sử dụng heatmap(df)
, và xem nó bằng cách sử dụng plt.show()
.