
Deep Neural Networks for Survival Analysis Usingtorch
survdnn implements neural network-based models forright-censored survival analysis using the nativetorchbackend in R. It supports multiple loss functions including Cox partiallikelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives,as well as time-dependent extension such as Cox-Time. The packageprovides a formula interface, supports model evaluation usingtime-dependent metrics (e.g., C-index, Brier score, IBS),cross-validation, and hyperparameter tuning.
Surv() ~ . models"cox": Cox partial likelihood"cox_l2": penalized Cox"aft": Accelerated Failure Time"coxtime": deep time-dependent Cox (like DeepSurv)cv_survdnn() andtune_survdnn()predict() andplot()# Install from CRANinstall.packages("surdnn")# Install from GitHubinstall.packages("remotes")remotes::install_github("ielbadisy/survdnn")# Or clone and install locallygit clone https://github.com/ielbadisy/survdnn.gitsetwd("survdnn")devtools::install()library(survdnn)library(survival,quietly =TRUE)library(ggplot2)veteran<- survival::veteranmod<-survdnn(Surv(time, status)~ age+ karno+ celltype,data = veteran,hidden =c(32,16),epochs =100,loss ="cox",verbose =TRUE )## Epoch 50 - Loss: 3.898330## Epoch 100 - Loss: 3.834461summary(mod)## ## ── Summary of survdnn model ─────────────────────────────────────────────────────────────────────## ## Formula:## Surv(time, status) ~ age + karno + celltype## <environment: 0x57f5687daa00>## ## Model architecture:## Hidden layers: 32 : 16 ## Activation: relu ## Dropout: 0.3 ## Final loss: 3.834461 ## ## Training summary:## Epochs: 100 ## Learning rate: 1e-04 ## Loss function: cox ## ## Data summary:## Observations: 137 ## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge ## Time range: [ 1, 999 ]## Event rate: 93.4%plot(mod,group_by ="celltype",times =1:300)# Cox partial likelihoodmod1<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="cox",epochs =100 )## Epoch 50 - Loss: 3.991873## Epoch 100 - Loss: 3.937163# Accelerated Failure Timemod2<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="aft",epochs =100 )## Epoch 50 - Loss: 18.660992## Epoch 100 - Loss: 18.260056# Deep time-dependent Cox (Coxtime)mod3<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="coxtime",epochs =100 )## Epoch 50 - Loss: 4.899240## Epoch 100 - Loss: 4.835490cv_results<-cv_survdnn(Surv(time, status)~ age+ karno+ celltype,data = veteran,times =c(30,90,180),metrics =c("cindex","ibs"),folds =3,hidden =c(16,8),loss ="cox",epochs =100 )print(cv_results)grid<-list(hidden =list(c(16),c(32,16)),lr =c(1e-3),activation =c("relu"),epochs =c(100,300),loss =c("cox","aft","coxtime") )tune_res<-tune_survdnn(formula =Surv(time, status)~ age+ karno+ celltype,data = veteran,times =c(90,300),metrics ="cindex",param_grid = grid,folds =3,refit =FALSE,return ="summary" )print(tune_res)plot(mod1,group_by ="celltype",times =1:300)
plot(mod1,group_by ="celltype",times =1:300,plot_mean_only =TRUE)
help(package ="survdnn")?survdnn?tune_survdnn?cv_survdnn?plot.survdnn# Run all testsdevtools::test()By default, Torch initializes model weights and shuffles minibatcheswith random draws, so results may differ at each run.
Unlikeset.seed(), which only controls R’s RNG,{torch} uses its own RNG implemented in C++/CUDA. To ensurereproducibility, set the Torch seed before training:
torch::torch_manual_seed(123)Thesurvdnn R package is available on CRAN or at:https://github.com/ielbadisy/survdnn
Contributions, issues, and feature requests are welcome. Open anissue or submit apull request!
MIT ©Imad El Badisy