Movatterモバイル変換


[0]ホーム

URL:


survdnn

Deep Neural Networks for Survival Analysis Usingtorch

License: MIT
R-CMD-check


survdnn implements neural network-based models forright-censored survival analysis using the nativetorchbackend in R. It supports multiple loss functions including Cox partiallikelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives,as well as time-dependent extension such as Cox-Time. The packageprovides a formula interface, supports model evaluation usingtime-dependent metrics (e.g., C-index, Brier score, IBS),cross-validation, and hyperparameter tuning.


Features


Installation

# Install from CRANinstall.packages("surdnn")# Install from GitHubinstall.packages("remotes")remotes::install_github("ielbadisy/survdnn")# Or clone and install locallygit clone https://github.com/ielbadisy/survdnn.gitsetwd("survdnn")devtools::install()

Quick Example

library(survdnn)library(survival,quietly =TRUE)library(ggplot2)veteran<- survival::veteranmod<-survdnn(Surv(time, status)~ age+ karno+ celltype,data = veteran,hidden =c(32,16),epochs =100,loss ="cox",verbose =TRUE  )
## Epoch 50 - Loss: 3.898330## Epoch 100 - Loss: 3.834461
summary(mod)
## ## ── Summary of survdnn model ─────────────────────────────────────────────────────────────────────## ## Formula:##   Surv(time, status) ~ age + karno + celltype## <environment: 0x57f5687daa00>## ## Model architecture:##   Hidden layers:  32 : 16 ##   Activation:  relu ##   Dropout:  0.3 ##   Final loss:  3.834461 ## ## Training summary:##   Epochs:  100 ##   Learning rate:  1e-04 ##   Loss function:  cox ## ## Data summary:##   Observations:  137 ##   Predictors:  age, karno, celltypesmallcell, celltypeadeno, celltypelarge ##   Time range: [ 1, 999 ]##   Event rate:  93.4%
plot(mod,group_by ="celltype",times =1:300)

Loss Functions

# Cox partial likelihoodmod1<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="cox",epochs =100  )
## Epoch 50 - Loss: 3.991873## Epoch 100 - Loss: 3.937163
# Accelerated Failure Timemod2<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="aft",epochs =100  )
## Epoch 50 - Loss: 18.660992## Epoch 100 - Loss: 18.260056
# Deep time-dependent Cox (Coxtime)mod3<-survdnn(Surv(time, status)~ age+ karno,data = veteran,loss ="coxtime",epochs =100  )
## Epoch 50 - Loss: 4.899240## Epoch 100 - Loss: 4.835490

Cross-Validation

cv_results<-cv_survdnn(Surv(time, status)~ age+ karno+ celltype,data = veteran,times =c(30,90,180),metrics =c("cindex","ibs"),folds =3,hidden =c(16,8),loss ="cox",epochs =100  )print(cv_results)

Hyperparameter Tuning

grid<-list(hidden     =list(c(16),c(32,16)),lr         =c(1e-3),activation =c("relu"),epochs     =c(100,300),loss       =c("cox","aft","coxtime")  )tune_res<-tune_survdnn(formula =Surv(time, status)~ age+ karno+ celltype,data = veteran,times =c(90,300),metrics ="cindex",param_grid = grid,folds =3,refit =FALSE,return ="summary"  )print(tune_res)

Plot Survival Curves

plot(mod1,group_by ="celltype",times =1:300)

plot(mod1,group_by ="celltype",times =1:300,plot_mean_only =TRUE)


Documentation

help(package ="survdnn")?survdnn?tune_survdnn?cv_survdnn?plot.survdnn

Testing

# Run all testsdevtools::test()

Reproducibility

By default, Torch initializes model weights and shuffles minibatcheswith random draws, so results may differ at each run.
Unlikeset.seed(), which only controls R’s RNG,{torch} uses its own RNG implemented in C++/CUDA. To ensurereproducibility, set the Torch seed before training:

torch::torch_manual_seed(123)

Availability

Thesurvdnn R package is available on CRAN or at:https://github.com/ielbadisy/survdnn


Contributions

Contributions, issues, and feature requests are welcome. Open anissue or submit apull request!


License

MIT ©Imad El Badisy


[8]ページ先頭

©2009-2025 Movatter.jp