H.Down
H.Down

Reputation: 23

How to keep the ones in the Diagonal of sns.heatmap?

I want to plot the correlation Matrix with sns.heatmap and have some questions. This is my code:

plt.figure(figsize=(8,8)) mask =np.zeros_like(data.corr()) mask[np.triu_indices_from(mask)] = True sns.heatmap(data.corr(), mask=mask, linewidth=1, annot=True, fmt=".2f",cmap='coolwarm',vmin=-1, vmax=1) plt.show()

and this is what i get: [Correlation Matrix][1] [1]: https://i.sstatic.net/DX2oN.png \

Now i have some questions:

1) How can i keep the ones in the diagonale?

2) How can i change the position of the x-axis?

3) I want that the colorbar goes from 1 till -1, but the code is not working

I hope someone can help.

Thx

Upvotes: 1

Views: 2864

Answers (3)

Andre Goulart
Andre Goulart

Reputation: 596

mask[np.triu_indices_from(mask)] will define the triangle (including diagonal)

mask[np.eye(mask.shape[0], dtype=bool)] will define the diagonal.

If you put those together, you can control them independently. (Be aware you need to set the triangle before the diagonal).

def plot_correlation_matrix(df, remove_diagonal=True, remove_triangle=False, **kwargs):
    corr = df.corr()
    # Apply mask
    mask = np.zeros_like(corr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = remove_triangle
    mask[np.eye(mask.shape[0], dtype=bool)] = remove_diagonal
    # Plot
    # plt.figure(figsize=(8,8))
    sns.heatmap(corr, mask=mask, **kwargs)
    plt.show()

So this command will generate the matrix, removing the upper triangle, but keeping the diagonal:

plot_correlation_matrix(df[colunas_notas], remove_diagonal=False, remove_triangle=True)

Upvotes: 1

pyano
pyano

Reputation: 1978

Change of the position of the x-axis

Since I'm not experienced with seaborn I would use matplotlib to plot the heat map (here an example) an then use matplotlib's twinx() or twiny() to place the axis where you want to have it (here an example).

(I think that can be done with seaborn too - I just do not know it)

Upvotes: 0

pyano
pyano

Reputation: 1978

I think you have to check data.corr(), because your code is correct and gives the diagnoal (see below). One question is: you use np.triu but the picture you show displays np.tirl.

Here the code I've tested - the diagonal is there:

N = 5
A = np.arange(N*N).reshape(N,N)

B = np.tril(A)

mask =np.zeros_like(A)
mask[np.triu_indices_from(mask)] = True

print('A'); print(A); print()
print('tril(A)'); print(B); print()
print('mask'); print(mask); print()

gives

A
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

tril(A)
[[ 0  0  0  0  0]
 [ 5  6  0  0  0]
 [10 11 12  0  0]
 [15 16 17 18  0]
 [20 21 22 23 24]]

mask
[[1 1 1 1 1]
 [0 1 1 1 1]
 [0 0 1 1 1]
 [0 0 0 1 1]
 [0 0 0 0 1]]

edit: suplement

you could re-fine the mask, e.g.

C = A *mask
D = np.where(C > 1, 1,C)
print('D'); print(D)

gives

D
[[0 1 1 1 1]
 [0 1 1 1 1]
 [0 0 1 1 1]
 [0 0 0 1 1]
 [0 0 0 0 1]]

The first element of the diagonal of D is now a Zero since the first element of the diagonal of A is a Zero too.

edit: suplement 2

F = np.tril(A,-1)
E = np.eye(N)
G = E + F

print('F'); print(F); print()
print('E'); print(E); print()
print('G'); print(G); print()

gives

F
[[ 0  0  0  0  0]
 [ 5  0  0  0  0]
 [10 11  0  0  0]
 [15 16 17  0  0]
 [20 21 22 23  0]]

E
[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]]

G
[[ 1.  0.  0.  0.  0.]
 [ 5.  1.  0.  0.  0.]
 [10. 11.  1.  0.  0.]
 [15. 16. 17.  1.  0.]
 [20. 21. 22. 23.  1.]]

Upvotes: 1

Related Questions