Nicolas Gervais
Nicolas Gervais

Reputation: 36624

How do I make a simple, multi-level Sankey diagram with Plotly?

I have a DataFrame like this that I'm trying to describe with a Sankey diagram:

import pandas as pd

pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})
    animal  sex     status          count
0   dog     male    wild            8
1   cat     female  domesticated    10
2   cat     female  domesticated    11
3   dog     male    wild            14
4   cat     male    domesticated    6

I'm trying to follow the steps in the documentation but I can't make it work - I can't understand what branches where. Here's the example code:

import plotly.graph_objects as go

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A1", "A2", "B1", "B2", "C1", "C2"],
      color = "blue"
    ),
    link = dict(
      source = [0, 1, 0, 2, 3, 3], 
      target = [2, 3, 3, 4, 4, 5],
      value = [8, 4, 2, 8, 4, 2]
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

Here's what I'm trying to achieve: enter image description here

Upvotes: 4

Views: 17200

Answers (2)

Conchylicultor
Conchylicultor

Reputation: 5709

I find parallel-categories (either px.parallel_categories or go.Parcats) to be easier to manipulate than go.Sankey, for results very similar.

This example would be:

import pandas as pd
import plotly.graph_objects as go


df = pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})

fig = go.Figure(go.Parcats(
    dimensions=[
        {'label': 'animal', 'values': df['animal']},
        {'label': 'sex', 'values': df['sex']},
        {'label': 'status', 'values': df['status']},
    ],
    counts=df['count'],
))
fig.show()

Plotly

Or if your df contains individual elements (before aggregation with count), this could even be:

import plotly.express as px

px.parallel_categories(df, dimensions=['animal', 'sex', 'status'])

Upvotes: 1

Pascalco
Pascalco

Reputation: 2826

You can create with Plotly a Sankey diagram in the following way:

import pandas as pd
import plotly.graph_objects as go

label_list = ['cat', 'dog', 'domesticated', 'female', 'male', 'wild']
# cat: 0, dog: 1, domesticated: 2, female: 3, male: 4, wild: 5
source = [0, 0, 1, 3, 4, 4]
target = [3, 4, 4, 2, 2, 5]
count = [21, 6, 22, 21, 6, 22]

fig = go.Figure(data=[go.Sankey(
    node = {"label": label_list},
    link = {"source": source, "target": target, "value": count}
    )])
fig.show()

sankey diagram How does it work: The lists source, target and count have all length 6 and the Sankey diagram has 6 arrows. The elements of source and target are the indexes of label_list. So the the first element of source is 0 which means "cat". The first element of target is 3 which means "female". The first element of count is 21. Therefore, the first arrow of the diagram goes from cat to female and has size 21. Correspondingly, the second elements of the lists source, target and count define the second arrow, etc.


Possibly you want to create a bigger Sankey diagram as in this example. Defining the source, target and count list manually then becomes very tedious. So here's a code which creates these lists from a dataframe of your format.

import pandas as pd
import numpy as np

df = pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})

categories = ['animal', 'sex', 'status']

newDf = pd.DataFrame()
for i in range(len(categories)-1):
    tempDf = df[[categories[i],categories[i+1],'count']]
    tempDf.columns = ['source','target','count']
    newDf = pd.concat([newDf,tempDf])    
newDf = newDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()

label_list = list(np.unique(df[categories].values))
source = newDf['source'].apply(lambda x: label_list.index(x))
target = newDf['target'].apply(lambda x: label_list.index(x))
count = newDf['count']

Upvotes: 7

Related Questions