00__00__00
00__00__00

Reputation: 5367

setting legend only for one of the marginal plots in seaborn

I am creating a JointGrid plot using seaborn.

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
mydataset=pd.DataFrame(data=np.random.rand(50,2),columns=['a','b'])
g = sns.JointGrid(x=mydataset['a'], y=mydataset['b'])
g=g.plot_marginals(sns.distplot,color='black',kde=True,hist=False,rug=True,bins=20,label='X')
g=g.plot_joint(plt.scatter,label='X')        


legend_properties = {'weight':'bold','size':8}
legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')


legendSide=g.ax_marg_x.legend(prop=legend_properties,loc='upper right')

I get this:

enter image description here

I would like to get rid of the legend within the vertical marginal plot (the one on the right side) but keep the one for the horizontal margin. how to achieve that?


EDIT: The solution from @ImportanceOfBeingErnest works fine for one plot. However, if I repeat it in a for loops something unexpected happens. I still get a legend in the upper plot and that is unexpected. How to get rid of it?

The following code:

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
mydataset=pd.DataFrame(data=np.random.rand(50,2),columns=['a','b'])
g = sns.JointGrid(x=mydataset['a'], y=mydataset['b'])
LABEL_LIST=['x','Y','Z']
for n in range(0,3):

    g=g.plot_marginals(sns.distplot,color='black',kde=True,hist=False,rug=True,bins=20,label=LABEL_LIST[n])
    g=g.plot_joint(plt.scatter,label=LABEL_LIST[n])        


    legend_properties = {'weight':'bold','size':8}
    legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')


    legendSide=g.ax_marg_y.legend(labels=[LABEL_LIST[n]],prop=legend_properties,loc='upper right')

gives:

enter image description here

which is almost perfect, byt I need to get rid of the last legend entry in the plo on the right.

Upvotes: 0

Views: 2560

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339765

You may decide not to give any label to the marginals, but instead add the label when creating the legend inside the top marginal axes.

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

mydataset=pd.DataFrame(data=np.random.rand(50,2),columns=['a','b'])
g = sns.JointGrid(x=mydataset['a'], y=mydataset['b'])
g=g.plot_marginals(sns.distplot,color='black',
                   kde=True,hist=False,rug=True,bins=20)
g=g.plot_joint(plt.scatter,label='X')        


legend_properties = {'weight':'bold','size':8}
legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')


legendSide=g.ax_marg_x.legend(labels=["x"], 
                              prop=legend_properties,loc='upper right')

plt.show()

enter image description here

The solution is the same for a plot in a loop.

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

mydataset=pd.DataFrame(data=np.random.rand(50,2),columns=['a','b'])
g = sns.JointGrid(x=mydataset['a'], y=mydataset['b'])
LABEL_LIST=['x','Y','Z']
for n in range(0,3):
    g=g.plot_marginals(sns.distplot,color='black',kde=True,hist=False,rug=True,bins=20)
    g=g.plot_joint(plt.scatter,label=LABEL_LIST[n])        

legend_properties = {'weight':'bold','size':8}
legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')
legendSide=g.ax_marg_x.legend(labels=LABEL_LIST,prop=legend_properties,loc='upper right')

plt.show()

enter image description here

Upvotes: 1

Related Questions