Faiza
Faiza

Reputation: 33

mlr3: How to extract predicted survival time? (to compare the model predictions with the real data)

I want to predict the response of predicted survival times in survival analysis. According to the book mlr3 for comparing the predictions from the model to the true data (https://mlr3book.mlr-org.com/chapters/chapter13/beyond_regression_and_classification.html#learnersurv-predictionsurv-and-predict-types).

In the below, I'm training and predicting survival from an xgboost (lrn("surv.xgboost.cox")), and I've used type = "regression" for survival time prediction.

But It seems that the argument 'type' is not a valid argument for Learner 'Surv.Xgboost.Cox'.

**How can I compare the predicted survival times from the model to the true survival times data? **

library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3verse)
task = as_task_surv(x = survival::veteran, time = 'time', event = 'status')
poe = po('encode')
task = poe$train(list(task))[[1]]
set.seed(42)
part = partition(task, ratio = 0.8)
pred_xgb = lrn("surv.xgboost.cox", type = "regression")$
  train(task, split$train)$predict(task, split$test)
#> Error: Cannot set argument 'type' for 'LearnerSurvXgboostCox' (not a constructor argument, not a parameter, not a field. Did you mean 'normalize_type' / 'process_type' / 'sample_type'?
data.frame(pred = pred_xgb$response[1:3],
           truth = pred_xgb$truth[1:3])
#> Error in eval(expr, envir, enclos): object 'pred_xgb' not found

Created on 2024-05-24 with reprex v2.1.0

Upvotes: 1

Views: 107

Answers (1)

John
John

Reputation: 609

The prediction types of the surv.xgboost.cox are documented clearly in the respective doc page.

To get the response prediction for this learner you have to use the crankcompositor, see example here - you will need to change to method = median to get true survival times though. I will soon add a new method to estimate the response as the RMST which would give more sensible survival times. I have created an issue here.

Lastly, there are measures in mlr3proba that use the response (survival time) as input, such as msr("surv.mae") and msr("surv.rmse").

BR, John.

Upvotes: 2

Related Questions