Reputation: 189
I would like to make a Kaplan Meier plot with multiple groups. The code below can show lines for two groups in one plot but the amount of groups is binary. I would like to use a for loop that runs over a list containing all the groups but without a fixed length and > 2. How can I achieve this with lifelines?
from lifelines import KaplanMeierFitter
from lifelines.datasets import load_waltons
waltons = load_waltons()
ix = waltons['group'] == 'control'
ax = plt.subplot(111)
kmf_control = KaplanMeierFitter()
ax = kmf_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'],label='control').plot_survival_function(ax=ax)
kmf_exp = KaplanMeierFitter()
ax = kmf_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp').plot_survival_function(ax=ax)
from lifelines.plotting import add_at_risk_counts
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.tight_layout()
Thank you in advance.
Upvotes: 0
Views: 341
Reputation: 45
The key instead of each curve having unique name is to make each KM curve an element of a list by appending it and access it using the for loop index.
Note that add_at_risk_counts uses "*list_of_fits", taken from from this example. It uses "enumerate" to drive the iteration.
If you know the number of data sets in a list of dataframes:
from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times
from lifelines.plotting import add_at_risk_counts
from lifelines.datasets import load_waltons
# import pandas as pd
import matplotlib.pyplot as plt
waltons = load_waltons()
ix = waltons['group'] == 'control'
SD = 2 # Number of data sets
CI = False # True or False to show confidence intervals
SC = True # True or False to show censor tick marks
kmf_ = []
data_ = []
data_.append(waltons.loc[ix])
data_.append(waltons.loc[~ix])
Set_ = ['control', 'exp']
ax = plt.subplot(111)
for i in range(0, SD): # index for loop
# print("set : ", i)
kmf_.append(KaplanMeierFitter())
kmf_[i].fit_right_censoring(data_[i]['T'], data_[i]['E'],
label=Set_[i])
ax = kmf_[i].plot_survival_function(ax=ax, ci_show=CI, show_censors=SC)
add_at_risk_counts(*kmf_, ax=ax, labels=[Set_[l] for l in range(0, SD)])
plt.tight_layout()
Or if number of datasets is unknown but has a key/level you can use groupby and increment an index on each iteration.
from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times
from lifelines.plotting import add_at_risk_counts
from lifelines.datasets import load_waltons
# import pandas as pd
import matplotlib.pyplot as plt
waltons = load_waltons()
CI = False # True or False to show confidence intervals
SC = True # True or False to show censor tick marks
# print(waltons)
ax = plt.subplot(111)
kmf_ = []
i = 0 # initialise index
for name, grouped_df in waltons.groupby('group'):
kmf_.append(KaplanMeierFitter())
kmf_[i].fit(grouped_df["T"], grouped_df["E"], label=name)
kmf_[i].plot_survival_function(ax=ax, ci_show=CI, show_censors=SC)
print(kmf_[i])
i = i + 1 # increment index
add_at_risk_counts(*kmf_,
ax=ax,
rows_to_show=["At risk"])
plt.tight_layout()
Upvotes: 0