kafka
kafka

Reputation: 129

python plotly: unfixed number of traces

My code reads data from .xlsx file and it plots the Bubble diagram by using plotly. Bubble Diagram The task is easy when I do know how many traces need to be plot. However, I was thrown into bewilderment when the number of traces is unfixed since the number of rows is variable.


       1991  1992  1993  1994  1995  1996  1997
US       10    14    16    18    20    42    64
JAPAN   100    30    70    85    30    42    64
CN       50    22    30    65    70    66    60

Here is my uncompleted code:

# Version 2 could read data from .xlsx file.
import plotly as py
import plotly.graph_objs as go
import openpyxl

wb = openpyxl.load_workbook(('grape output.xlsx'))     
sheet = wb['Sheet1']       
row_max = sheet.max_row
col_max = sheet.max_column
l=[]

for row_n in range(row_max-1):
    l.append([])
    for col_n in range(col_max-1):
        l[row_n].append(sheet.cell(row=row_n+2, column=col_n+2).value)

trace0 = go.Scatter(
    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],
    y=['US', 'US', 'US', 'US', 'US', 'US', 'US'],
    mode='markers+text',
    marker=dict(
        color='rgb(150,204,90)',
        size= l[0],
        showscale = False,
        ),
    text=list(map(str, l[0])),     
    textposition='middle center',   
)

trace1 = go.Scatter(
    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],
    y=['JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN', 'JAPAN'],
    mode='markers+text',

    marker=dict(
        color='rgb(255, 130, 71)',
        size=l[1],
        showscale=False,
    ),
    text=list(map(str,l[1])),
    textposition='middle center',
)

trace2 = go.Scatter(
    x=[1991, 1992, 1993, 1994, 1995, 1996, 1997],
    y=['CN', 'CN', 'CN', 'CN', 'CN', 'CN', 'CN'],
    mode='markers+text',

    marker=dict(
        color='rgb(255, 193, 37)',
        size=l[2],
        showscale=False,
    ),
    text=list(map(str,l[2])),
    textposition='middle center',
)

layout = go.Layout(plot_bgcolor='rgb(10, 10, 10)',  
                   paper_bgcolor='rgb(20, 55, 100)',  
                   font={               
                       'size': 15,
                       'family': 'sans-serif',
                       'color': 'rgb(255, 255, 255)'  
                   },
                   width=1000,
                   height=500,
                   xaxis=dict(title='Output of grapes per year in US, JAPAN and CN', ),  
                   showlegend=False,
                   margin=dict(l=100, r=100, t=100, b=100),
                   hovermode = False,       
                   )

data = [trace0, trace1, trace2]
fig = go.Figure(data=data, layout=layout)


py.offline.init_notebook_mode()
py.offline.plot(fig, filename='basic-scatter.html')

Could you please teach me how to draw them? Thx

Upvotes: 1

Views: 398

Answers (3)

kafka
kafka

Reputation: 129

Code has been updated to Version 2 which could read data from .xlsx file and plot Bubble Diagram. The raw data named 'grape output.xlsx' has been added with new items in comparison with the previous one:

             1991  1992  1993  1994  1995  1996  1997  1998  1999
         US    10    14    16    18    20    42    64   100    50
      JAPAN   100    30    70    85    30    42    64    98    24
         CN    50    22    30    65    70    66    60    45    45
      INDIA    90    88    35    50    90    60    40    66    76
         UK    40    50    70    50    25    30    22    40    60

Here is the code:

# Version 2 
import plotly as py
import plotly.graph_objs as go
import openpyxl
import pandas as pd


wb = openpyxl.load_workbook('grape output.xlsx')
sheet = wb['Sheet1']
row_max = sheet.max_row
col_max = sheet.max_column
first_row_list = []
first_col_list = []
for col_n in range(2, col_max+1):
    first_row_list.append(sheet.cell(row=1, column=col_n).value)
for row_n in range(2,row_max+1):
    first_col_list.append(sheet.cell(row=row_n, column=1).value)

data_all = pd.read_excel('grape output.xlsx')
data = data_all.loc[:,first_row_list]

df = pd.DataFrame(data)
df.index = first_col_list
colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)','rgb(180,240,190)','rgb(255, 10, 1)',
          'rgb(25, 19, 3)','rgb(100, 100, 100)','rgb(45,24,200)','rgb(33, 58, 108)','rgb(35, 208, 232)']

data = [go.Scatter(
    x=df.columns,
    y=[country]*len(df.columns),
    mode='markers+text',
    marker=dict(
        color=colors[num],
        size= df.loc[country],
        showscale = False,
        ),
    text=list(map(str, df.loc[country])),
    textposition='middle center',
    )
    for num, country in enumerate(reversed(df.index))
]

layout = go.Layout(plot_bgcolor='rgb(10, 10, 10)',
                   paper_bgcolor='rgb(20, 55, 100)',
                   font={
                       'size': 15,
                       'family': 'sans-serif',
                       'color': 'rgb(255, 255, 255)'
                   },
                   width=1000,
                   height=800,
                   xaxis=dict(title='Output of grapes per year in US, JAPAN and CN'),
                   showlegend=False,
                   margin=dict(l=100, r=100, t=100, b=100),
                   hovermode = False,
                   )

fig = go.Figure(data=data, layout=layout)
py.offline.plot(fig, filename='basic-scatter.html')

Now the result is like this: updated bubble diagram There remains some little problems:

  1. How to get rid of the two numbers 1990 and 2000 as well as white vertical lines for 1990 and 2000?
  2. How to draw white lines for 1991, 1993, 1995, 1997,1999 and display all these years as abscissa axis?

Please make corrections for code Versinon 2 to improve it. Thank you!

Upvotes: 0

rpanai
rpanai

Reputation: 13437

Derek O.'s answer is perfect but i think there is a more flexible way to do it using plotly.express this in particular if you don't want to define the colors.

The idea is to properly transform the data.

Data

import pandas as pd
df = pd.DataFrame({1991:[10,100,50], 1992:[14,30,22], 1993:[16,70,30], 1994:[18,85,65], 1995:[20,30,70], 1996:[42,42,66], 1997:[64,64,60]})
df.index = ['US','JAPAN','CN']
df = df.T.unstack()\
      .reset_index()\
      .rename(columns={"level_0": "country",
                       "level_1": "year",
                       0: "n"})
print(df)
   country  year    n
0       US  1991   10
1       US  1992   14
2       US  1993   16
3       US  1994   18
4       US  1995   20
5       US  1996   42
6       US  1997   64
7    JAPAN  1991  100
8    JAPAN  1992   30
9    JAPAN  1993   70
10   JAPAN  1994   85
11   JAPAN  1995   30
12   JAPAN  1996   42
13   JAPAN  1997   64
14      CN  1991   50
15      CN  1992   22
16      CN  1993   30
17      CN  1994   65
18      CN  1995   70
19      CN  1996   66
20      CN  1997   60

Using plotly.express

Now that your data is in a long format you can use plotly.express as following

import plotly.express as px
fig = px.scatter(df,
                 x="year",
                 y="country",
                 size="n",
                 color="country",
                 text="n",
                 size_max=50 # you need this otherwise the bubble are too small
                )

fig.update_layout(plot_bgcolor='rgb(10, 10, 10)',  
                  paper_bgcolor='rgb(20, 55, 100)',  
                  font={'size': 15,
                        'family': 'sans-serif',
                        'color': 'rgb(255, 255, 255)'
                       },
                  width=1000,
                  height=500,
                  xaxis=dict(title='Output of grapes per year in selected countries', ),  
                  showlegend=False,
                  margin=dict(l=100, r=100, t=100, b=100),
                  hovermode = False,)
# Uncomment this if you don't wont country as yaxis title
# fig.layout.yaxis.title.text = None
fig.show()

enter image description here

Upvotes: 3

Derek O
Derek O

Reputation: 19545

I should point out that your code would be more reproducible if you attached your raw data as text or something that can be more easily copy and pasted. However, I can still answer your question and point you in the right direction regardless.

What you should do is use a loop, and start by looking at the line data = [trace0, trace1, trace2]. As you noticed, this method won't scale up if you have 100 countries instead of 3.

Instead, you can create the data as a list using a list comprehension, and updating the part of each trace that changes. trace0, trace1, trace2 aren't much different except for the country, values, and colors. To show you what I mean, I recreated your data using a DataFrame, then created individual lists containing your countries and colors.

# Version 2 could read data from .xlsx file.
import plotly as py
import plotly.graph_objs as go
import openpyxl

# wb = openpyxl.load_workbook(('grape output.xlsx'))     
# sheet = wb['Sheet1']       
# row_max = sheet.max_row
# col_max = sheet.max_column
# l=[]

# for row_n in range(row_max-1):
#     l.append([])
#     for col_n in range(col_max-1):
#         l[row_n].append(sheet.cell(row=row_n+2, column=col_n+2).value)

import pandas as pd

df = pd.DataFrame({1991:[10,100,50], 1992:[14,30,22], 1993:[16,70,30], 1994:[18,85,65], 1995:[20,30,70], 1996:[42,42,66], 1997:[64,64,60]})
df.index = ['US','JAPAN','CN']
colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)']

data = [go.Scatter(
    x=df.columns,
    y=[country]*len(df.columns),
    mode='markers+text',
    marker=dict(
        color=colors[num],
        size= df.loc[country],
        showscale = False,
        ),
    text=list(map(str, df.loc[country])),     
    textposition='middle center',   
    )
    for num, country in enumerate(df.index)
]

layout = go.Layout(plot_bgcolor='rgb(10, 10, 10)',  
                   paper_bgcolor='rgb(20, 55, 100)',  
                   font={               
                       'size': 15,
                       'family': 'sans-serif',
                       'color': 'rgb(255, 255, 255)'  
                   },
                   width=1000,
                   height=500,
                   xaxis=dict(title='Output of grapes per year in US, JAPAN and CN', ),  
                   showlegend=False,
                   margin=dict(l=100, r=100, t=100, b=100),
                   hovermode = False,       
                   )

# data = [trace0, trace1, trace2]
fig = go.Figure(data=data, layout=layout)
fig.show()

# py.offline.init_notebook_mode()
# py.offline.plot(fig, filename='basic-scatter.html')

enter image description here

If I then add a test country to the DataFrame with values for 1991-1997, I don't need to change the rest of the code and the bubble plot will update accordingly.

# I added a test country with data
df = pd.DataFrame({1991:[10,100,50,10], 1992:[14,30,22,20], 1993:[16,70,30,30], 1994:[18,85,65,40], 1995:[20,30,70,50], 1996:[42,42,66,60], 1997:[64,64,60,70]})
df.index = ['US','JAPAN','CN','TEST']
colors = ['rgb(150,204,90)','rgb(255, 130, 71)','rgb(255, 193, 37)','rgb(100, 100, 100)']

enter image description here

Upvotes: 2

Related Questions