aghtaal
aghtaal

Reputation: 307

Plotly: How to show grouped text elements from a dataframe as multiline hoverinfos?

I have a dataframe like as:

df = 
                               time_id gt_class  num_missed_base  num_missed_feature  num_objects_base  num_objects_feature
   5G21A6P00L4100023:1566617404450336      CAR               11                   4                27                   30
   5G21A6P00L4100023:1566617404450336  BICYCLE                4                   6                27                   30
   5G21A6P00L4100023:1566617404450336   PERSON                2                   3                27                   30
   5G21A6P00L4100023:1566617404450336    TRUCK                1                   0                27                   30
   5G21A6P00L4100023:1566617428450689      CAR               25                  14                60                   67
   5G21A6P00L4100023:1566617428450689   PERSON                7                   6                60                   67
   5G21A6P00L4100023:1566617515950900  BICYCLE                1                   1                59                   65
   5G21A6P00L4100023:1566617515950900      CAR               20                   9                59                   65
   5G21A6P00L4100023:1566617515950900   PERSON               10                   2                59                   65
   5G21A6P00L4100037:1567169649450046      CAR                8                   0                29                   32
   5G21A6P00L4100037:1567169649450046   PERSON                1                   0                29                   32
   5G21A6P00L4100037:1567169649450046    TRUCK                1                   0                29                   32

at each time_id it shows how many objects are missed in base model num_missed_base, how many objects are missed in feature model num_missed_feature, and how many objects exist at that time in base and feature innum_objects_base, num_objects_feature

I need to draw a scatter plot using (plotly.graph_objs and FigureWidget) of time_id, such that when user hover over each point(each point represents a unique time_id) it shows the following for the time_id == 5G21A6P00L4100023:1566617404450336: enter image description here

What should be the hover_text in the code below?

import plotly.graph_objs as go
hover_text = ????
df_agg = df.groupby("time_id").sum().reset_index()
error_trace = go.Scattergl(
        x=df_agg["num_missed_base"].tolist(),
        y=df_agg["num_missed_feature"].tolist(),
        text=hover_text,
        mode="markers",
        marker=dict(cmax=50, cmin=-50, opacity=0.3),
    )

Upvotes: 1

Views: 434

Answers (2)

vestland
vestland

Reputation: 61074

A pandas professional would certainly be able to make the code snippet below a bit more elegant and efficient. But my work-arounds will do the job as well. The main challenge is to turn your source dataframe into a grouped version like this:

    time_id                             gt_class                    num_missed_base base_str    num_missed_feature  feature_str
0   5G21A6P00L4100023:1566617404450336  CAR,BICYCLE,PERSON,TRUCK    18  11,4,2,1    13  11,4,2,1
1   5G21A6P00L4100023:1566617428450689  CAR,PERSON                  32  25,7        20  25,7
2   5G21A6P00L4100023:1566617515950900  BICYCLE,CAR,PERSON          31  1,20,10     12  1,20,10
3   5G21A6P00L4100037:1567169649450046  CAR,PERSON,TRUCK            10  8,1,1       0   8,1,1

The bad news is that this is not nearly enough. The good news is that the snippet below will handle it all and give you this plot:

enter image description here

What you see here is a plot that groups the associated data for each timestamp so that you can see the sum of, for example, num_missed_feature for all classes, and the number for each underlying class in the hoverinfo. With a little further tweaking I may be able to include the sums as well. But this is all I have time for right now.

Complete code:

import pandas as pd
import re
import plotly.graph_objects as go 

smpl = {'index': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
         'columns': ['time_id',
          'gt_class',
          'num_missed_base',
          'num_missed_feature',
          'num_objects_base',
          'num_objects_feature'],
         'data': [['5G21A6P00L4100023:1566617404450336', 'CAR', 11, 4, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'BICYCLE', 4, 6, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'PERSON', 2, 3, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'TRUCK', 1, 0, 27, 30],
          ['5G21A6P00L4100023:1566617428450689', 'CAR', 25, 14, 60, 67],
          ['5G21A6P00L4100023:1566617428450689', 'PERSON', 7, 6, 60, 67],
          ['5G21A6P00L4100023:1566617515950900', 'BICYCLE', 1, 1, 59, 65],
          ['5G21A6P00L4100023:1566617515950900', 'CAR', 20, 9, 59, 65],
          ['5G21A6P00L4100023:1566617515950900', 'PERSON', 10, 2, 59, 65],
          ['5G21A6P00L4100037:1567169649450046', 'CAR', 8, 0, 29, 32],
          ['5G21A6P00L4100037:1567169649450046', 'PERSON', 1, 0, 29, 32],
          ['5G21A6P00L4100037:1567169649450046', 'TRUCK', 1, 0, 29, 32]]}

df = pd.DataFrame(index=smpl['index'], columns = smpl['columns'], data=smpl['data'])
df['base_str'] = df['num_missed_base'].astype(str)
df['feature_str'] = df['num_missed_base'].astype(str)
df2=df.groupby(['time_id'], as_index = False).agg({'gt_class': ','.join,
                                                   'num_missed_base':sum,
                                                   'base_str':','.join,
                                                   'num_missed_feature':sum,
                                                   'feature_str':','.join,})
col_elem=[]
row_elem=[]
for i in df2.index:
    gt_class = df2['gt_class'].loc[i].split(',')
    base_str = df2['base_str'].loc[i].split(',')
    for j, elem in enumerate(gt_class):
        
        new_elem = elem+": "+base_str[j]
        row_elem.append(new_elem)
        
    col_elem.append(row_elem)
    row_elem=[]

df2['hover']=col_elem
df2['hover'] = df2['hover'].astype(str)
df2['hover2'] = df2['hover'].map(lambda x: x.lstrip('[]').rstrip(']'))
#df2['hover2'].apply(lambda x: x.str.replace(',','.'))

df2['hover2']=df2['hover2'].replace("'",'', regex=True)
df2['hover2']=df2['hover2'].replace(',','<br>', regex=True)

# plotly
fig = go.Figure()
fig.add_traces(go.Scatter(x=df2['num_missed_base'], y=df2['num_missed_feature'],
                          mode='markers', marker=dict(color='red',
                                                      line=dict(color='black', width=1),
                                                      size=14),
                          #hovertext=df2["hover"],
                          hovertext=df2['hover2'],
                          hoverinfo="text",
                          
                         ))

fig.update_xaxes(showspikes=True, linecolor='black', title='Base',
                 spikecolor='black', spikethickness=0.5, spikedash='solid')
fig.update_yaxes(showspikes=True, linecolor='black', title = 'Feature',
                 spikecolor='black', spikethickness=0.5, spikedash='solid')
fig.update_layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)

fig.show()

Upvotes: 1

aghtaal
aghtaal

Reputation: 307

Based on the @vestland answer I came up with this:

smpl = {'index': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
         'columns': ['time_id',
          'gt_class',
          'num_missed_base',
          'num_missed_feature',
          'num_objects_base',
          'num_objects_feature'],
         'data': [['5G21A6P00L4100023:1566617404450336', 'CAR', 11, 4, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'BICYCLE', 4, 6, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'PERSON', 2, 3, 27, 30],
          ['5G21A6P00L4100023:1566617404450336', 'TRUCK', 1, 0, 27, 30],
          ['5G21A6P00L4100023:1566617428450689', 'CAR', 25, 14, 60, 67],
          ['5G21A6P00L4100023:1566617428450689', 'PERSON', 7, 6, 60, 67],
          ['5G21A6P00L4100023:1566617515950900', 'BICYCLE', 1, 1, 59, 65],
          ['5G21A6P00L4100023:1566617515950900', 'CAR', 20, 9, 59, 65],
          ['5G21A6P00L4100023:1566617515950900', 'PERSON', 10, 2, 59, 65],
          ['5G21A6P00L4100037:1567169649450046', 'CAR', 8, 0, 29, 32],
          ['5G21A6P00L4100037:1567169649450046', 'PERSON', 1, 0, 29, 32],
          ['5G21A6P00L4100037:1567169649450046', 'TRUCK', 1, 0, 29, 32]]}

df = pd.DataFrame(index=smpl['index'], columns = smpl['columns'], data=smpl['data'])

def func(row):
    return ','.join(row.tolist())

def multi_column1(row):
    l = []
    for n in row.index:
        x = df.loc[n, 'gt_class']
        y = df.loc[n, 'num_missed_base']
        z = df.loc[n, 'num_missed_feature']        
        w = '{} : [base = {}, feature = {}]'.format(x, y, z)
        l.append(w)
    return l
if "hover_text" not in df.columns:
    df.insert(0, "hover_text", range(len(df)))
df = df.groupby('time_id').agg({'gt_class':func, 'num_missed_base': sum, 'num_missed_feature': sum, 'hover_text': multi_column1})
df.reset_index(inplace=True)
df['hover_text'] = df['hover_text'].astype(str)
df['hover_text'] = df['hover_text'].map(lambda x: x.lstrip('[]').rstrip(']'))
df['hover_text'] = df['hover_text'].replace("'",'', regex=True)
df['hover_text'] = df['hover_text'].replace('],',']<br>', regex=True)

# plotly
fig = go.Figure()
fig.add_traces(go.Scatter(x=df['num_missed_base'], y=df['num_missed_feature'],
                          mode='markers', marker=dict(color='red',
                                                      line=dict(color='black', width=1),
                                                      size=14),
                          #hovertext=df2["hover"],
                          hovertext=df['hover_text'],
                          hoverinfo="text",
                          
                         ))

fig.update_xaxes(showspikes=True, linecolor='black', title='Base',
                 spikecolor='black', spikethickness=0.5, spikedash='solid')
fig.update_yaxes(showspikes=True, linecolor='black', title = 'Feature',
                 spikecolor='black', spikethickness=0.5, spikedash='solid')
fig.update_layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)

fig.show()

enter image description here

Upvotes: 1

Related Questions