Reputation: 1307
I have a data frame that contains multiple variables where each variable is logically connected to a factor level of an additional group variable. I would like to plot a histogram of each variable in such a way that it is possible to show a grid of multiple histograms 'group-wise'.
Here's an example data frame df_melt
(the variables var_1
,var_2
,var_3
,var_4
are logically connected to the factor level 'foo'
, the variables var_5
,var_6
,var_7
belong to factor level 'bar'
):
import numpy as np
import pandas as pd
# simulate data and create plot-ready dataframe
np.random.seed(42)
var_values = np.random.randint(low=1,high=100,size=(100,7))
var_names = ['var_1','var_2','var_3','var_4','var_5','var_6','var_7']
group_names = ['foo','foo','foo','foo','bar','bar','bar']
df = pd.DataFrame(var_values,columns=var_names)
multi_index = pd.MultiIndex.from_arrays([df.columns,group_names],names=['variable','group'])
df.columns = multi_index
df_melt = pd.melt(df)
The output should look like this:
These stackoverflow posts might help to provide an answer, but I was not able to come up with a solution on my own:
Plotting a grouped pandas data in plotly
Plotly equivalent for pd.DataFrame.hist
Upvotes: 1
Views: 1882
Reputation: 956
Best I came up with is the following. Sadly, this is not in the nicely plotted format that you wanted, but I think/hope you can start with this.
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# simulate data and create plot-ready dataframe
np.random.seed(42)
var_values = np.random.randint(low=1,high=100,size=(100,7))
var_names = ['var_1','var_2','var_3','var_4','var_5','var_6','var_7']
group_names = ['foo','foo','foo','foo','bar','bar','bar']
df = pd.DataFrame(var_values,columns=var_names)
multi_index = pd.MultiIndex.from_arrays([df.columns,group_names],names=['variable','group'])
df.columns = multi_index
df_melt = pd.melt(df)
uniq_cols = set(group_names)
for col in uniq_cols:
rows = df_melt[df_melt['group']==col]['variable'].unique()
# print(list(rows))
num_vars = len(rows)
fig = make_subplots(rows=1, cols=len(rows), column_titles=list(rows))
for i, row in enumerate(rows):
fig.add_trace(go.Histogram(x=df_melt[(df_melt['group']==col) & (df_melt['variable']==row)]['value']),
row=1, col=i+1)
fig.show()
Upvotes: 0