Reputation: 36624
I have a DataFrame like this that I'm trying to describe with a Sankey diagram:
import pandas as pd
pd.DataFrame({
'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
'sex': ['male', 'female', 'female', 'male', 'male'],
'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
'count': [8, 10, 11, 14, 6]
})
animal sex status count
0 dog male wild 8
1 cat female domesticated 10
2 cat female domesticated 11
3 dog male wild 14
4 cat male domesticated 6
I'm trying to follow the steps in the documentation but I can't make it work - I can't understand what branches where. Here's the example code:
import plotly.graph_objects as go
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 15,
thickness = 20,
line = dict(color = "black", width = 0.5),
label = ["A1", "A2", "B1", "B2", "C1", "C2"],
color = "blue"
),
link = dict(
source = [0, 1, 0, 2, 3, 3],
target = [2, 3, 3, 4, 4, 5],
value = [8, 4, 2, 8, 4, 2]
))])
fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()
Here's what I'm trying to achieve:
Upvotes: 4
Views: 17200
Reputation: 5709
I find parallel-categories (either px.parallel_categories
or go.Parcats
) to be easier to manipulate than go.Sankey
, for results very similar.
This example would be:
import pandas as pd
import plotly.graph_objects as go
df = pd.DataFrame({
'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
'sex': ['male', 'female', 'female', 'male', 'male'],
'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
'count': [8, 10, 11, 14, 6]
})
fig = go.Figure(go.Parcats(
dimensions=[
{'label': 'animal', 'values': df['animal']},
{'label': 'sex', 'values': df['sex']},
{'label': 'status', 'values': df['status']},
],
counts=df['count'],
))
fig.show()
Or if your df contains individual elements (before aggregation with count
), this could even be:
import plotly.express as px
px.parallel_categories(df, dimensions=['animal', 'sex', 'status'])
Upvotes: 1
Reputation: 2826
You can create with Plotly a Sankey diagram in the following way:
import pandas as pd
import plotly.graph_objects as go
label_list = ['cat', 'dog', 'domesticated', 'female', 'male', 'wild']
# cat: 0, dog: 1, domesticated: 2, female: 3, male: 4, wild: 5
source = [0, 0, 1, 3, 4, 4]
target = [3, 4, 4, 2, 2, 5]
count = [21, 6, 22, 21, 6, 22]
fig = go.Figure(data=[go.Sankey(
node = {"label": label_list},
link = {"source": source, "target": target, "value": count}
)])
fig.show()
How does it work: The lists
source
, target
and count
have all length 6 and the Sankey diagram has 6 arrows. The elements of source
and target
are the indexes of label_list
. So the the first element of source is 0 which means "cat". The first element of target is 3 which means "female". The first element of count is 21. Therefore, the first arrow of the diagram goes from cat to female and has size 21. Correspondingly, the second elements of the lists source, target and count define the second arrow, etc.
Possibly you want to create a bigger Sankey diagram as in this example. Defining the source, target and count list manually then becomes very tedious. So here's a code which creates these lists from a dataframe of your format.
import pandas as pd
import numpy as np
df = pd.DataFrame({
'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
'sex': ['male', 'female', 'female', 'male', 'male'],
'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
'count': [8, 10, 11, 14, 6]
})
categories = ['animal', 'sex', 'status']
newDf = pd.DataFrame()
for i in range(len(categories)-1):
tempDf = df[[categories[i],categories[i+1],'count']]
tempDf.columns = ['source','target','count']
newDf = pd.concat([newDf,tempDf])
newDf = newDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
label_list = list(np.unique(df[categories].values))
source = newDf['source'].apply(lambda x: label_list.index(x))
target = newDf['target'].apply(lambda x: label_list.index(x))
count = newDf['count']
Upvotes: 7