jar
jar

Reputation: 2908

Is it possible to sort the columns of an Altair grouped bar chart based on the value of one of the categories?

I have the following chart -
enter image description here
I'd like to be able to sort the columns (NOT the individual bars of a single group - I know how to do that already), i.e order the 3 sub-chart - if you will - based on the value of any category(a,b or c) I choose.

I tried using alt.SortField and alt.EncodeSortField, they move around the charts a bit, but don't actually work if you change the category to see if they actually work.

Code -

import altair as alt
import pandas as pd

dummy = pd.DataFrame({'place':['Asia', 'Antarctica','Africa', 'Antarctica', 'Asia', 'Africa', 'Africa','Antarctica', 'Asia'],'category':['a','a','a','b','b','b','c','c','c'],'value':[5,2,3,4,3,5,6,9,5]})
alt.Chart(dummy).mark_bar().encode(
    x=alt.X('category'),
    y='value',
    column=alt.Column('place:N', sort=alt.SortField(field='value', order='descending')),
    color='category',
)

I know that alt.Column('place:N', sort=alt.SortField(field='value', order='descending')), doesn't seem correct, since I am not targeting any category, so I tried x=alt.X('category', sort=alt.SortField(field='c', order='descending')), too, but it doesn't work either.

Expected Output (assuming descending order)-

Upvotes: 4

Views: 5198

Answers (1)

jakevdp
jakevdp

Reputation: 86300

This is a bit involved, but you can do this with a series of transforms:

  • a Calculate Transform to select the value you want to sort on
  • a Join-Aggregate Transform with argmax to join the desired values to each group
  • another calculate transform to pull-out the specific field within this result that you would like to sort by

It looks like this, first sorting by "c":

import altair as alt
import pandas as pd

dummy = pd.DataFrame({'place':['Asia', 'Antarctica','Africa', 'Antarctica', 'Asia', 'Africa', 'Africa','Antarctica', 'Asia'],'category':['a','a','a','b','b','b','c','c','c'],'value':[5,2,3,4,3,5,6,9,5]})
alt.Chart(dummy).transform_calculate(
    key="datum.category == 'c'"
).transform_joinaggregate(
    sort_key="argmax(key)", groupby=['place']
).transform_calculate(
    sort_val='datum.sort_key.value'  
).mark_bar().encode(
    x=alt.X('category'),
    y='value',
    column=alt.Column('place:N', sort=alt.SortField("sort_val", order="descending")),
    color='category',
)

enter image description here

Then sorting by "a":

alt.Chart(dummy).transform_calculate(
    key="datum.category == 'a'"
).transform_joinaggregate(
    sort_key="argmax(key)", groupby=['place']
).transform_calculate(
    sort_val='datum.sort_key.value'  
).mark_bar().encode(
    x=alt.X('category'),
    y='value',
    column=alt.Column('place:N', sort=alt.SortField("sort_val", order="descending")),
    color='category',
)

enter image description here

Upvotes: 3

Related Questions