- Notifications
You must be signed in to change notification settings - Fork2
cosbidev/OSTransformer
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This document describes the implementation of``A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values´´ in Pytorch.The proposed approach is an architecture specifically designed for survival analysis, with a focus on addressing missing values inclinical data without the need for any imputation strategy.
Our approach involves adapting the transformer's encoder architecture to tabular data, via a novel positional encodingfor tabular features, and utilizing padding to mask any missing features within the attention module, enabling the modelto ignore them effectively.
Here we provide a brief guide on how to use the code, comprising the label encoding, the model with the losses employedduring training and the final metric to compute the performance.
The user can use a dataset of clinical features, but as an example we generate random data.
### DATAn_samples=100n_features=37data=np.random.rand(n_samples,n_features)
To simulate a real case scenario, then we introduce some missing values in the data.
data[np.random.choice((0,1), (n_samples,n_features),p=(0.8,0.2))==1]=np.nan# Introduce missing values
We also generate some random label to be used for the survival analysis.In particular, we generate a label that is a tuple with the event and its respective time.
### LABELSevents= ("censored","uncensored")num_events=len(events)-1# The first event is the censored one, so we do not consider itmax_time=72# Maximum time to consider for the survival analysismax_survival=100# Use this to generate labels, but then the analysis will consider only the time to max_time, setting those patients who survived longer than max_time to "censored"labels=np.hstack( (np.random.choice(events, (n_samples,1)),np.random.rand(n_samples,1)*max_survival ),dtype=object )
Given the labels, we can encode them in a format suitable for the survival analysis.More specifically, we encode the events with numbers starting from 0, which is related to thecensored
event, and we floor the time of the event, also setting those patients who survived longer thanmax_time
tomax_time
.In particular, in the latter case, we set the event to 0, which is thecensored
event.
survival_label_function=np.vectorize(lambdalabel:label_to_survival(label,events,max_time ),signature="(n)->(m)")survival_labels=survival_label_function(labels)
Afterward, we report some pieces of code to include in a training script and in the evaluation to test the model.We first define the parameters for the shared net, which is the transformer's encoder,and then the parameters for the cause specific subnets, which are the MLPs that will output the risk probabilities for each event in consideration, to instantiate the model.
### MODEL## OSTransformer (shared net)emb_dim=n_features+1n_heads=emb_dim//2shared_net_params=dict(emb_dim=emb_dim,num_heads=n_heads,output_size=emb_dim)## CustomMLP (CS subnets)hidden_sizes= [400,200]cs_subnet_params=dict(hidden_sizes=hidden_sizes)model=SurvivalWrapper(num_events=num_events,max_time=max_time,shared_net_params=shared_net_params,cs_subnets_params=cs_subnet_params)
Now we feed the data to the model and obtain the predictions, which are the risk probabilities for each event in consideration.Note that the predictions for thecensored
event are not considered since the event is not observed, and the time is not relevant for the analysis.
### OUTPUTS## Forward passoutputs=model(torch.from_numpy(data).float())## Predictionspredictions=survival_prediction(outputs)
We can now compute the losses for the survival analysis, which are the survival log-likelihood loss and the survival ranking loss.
### Survival Lossesloss=0criterion1=SurvivalLogLikelihoodLoss(num_events=num_events,max_time=max_time)loss+=criterion1(outputs,torch.from_numpy(survival_labels).float().unsqueeze(dim=1))criterion2=SurvivalRankingLoss(num_events=num_events,max_time=max_time)loss+=criterion2(outputs,torch.from_numpy(survival_labels).float().unsqueeze(dim=1))
Finally, we can compute the performance of the model using the Ct-index, a time-dependent variant of the C-index.We perform the cumulative sum of the risk probabilities to obtain the cumulative incidence function (CIF) and then compute the Ct-index.
### Performanceoutputs=torch.cumsum(outputs,dim=-1)# Compute the cumulative incidence function (CIF) cumulative summing the output probabilitiesperformance=Ct_index(survival_labels,outputs.detach().numpy(),num_events)
For any questions, please contactcamillomaria.caruso@unicampus.it andvalerio.guarrasi@unicampus.it.
@article{CARUSO2024108308,title ={A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values},journal ={Computer Methods and Programs in Biomedicine},volume ={254},pages ={108308},year ={2024},issn ={0169-2607},doi ={https://doi.org/10.1016/j.cmpb.2024.108308},url ={https://www.sciencedirect.com/science/article/pii/S016926072400302X},author ={Camillo Maria Caruso and Valerio Guarrasi and Sara Ramella and Paolo Soda},keywords ={Survival analysis, Missing data, Precision medicine, Oncology},}