Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Official implementation for the paper ``A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values´´

NotificationsYou must be signed in to change notification settings

cosbidev/OSTransformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This document describes the implementation of``A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values´´ in Pytorch.The proposed approach is an architecture specifically designed for survival analysis, with a focus on addressing missing values inclinical data without the need for any imputation strategy.

Our approach involves adapting the transformer's encoder architecture to tabular data, via a novel positional encodingfor tabular features, and utilizing padding to mask any missing features within the attention module, enabling the modelto ignore them effectively.

Usage

Here we provide a brief guide on how to use the code, comprising the label encoding, the model with the losses employedduring training and the final metric to compute the performance.

The user can use a dataset of clinical features, but as an example we generate random data.

### DATAn_samples=100n_features=37data=np.random.rand(n_samples,n_features)

To simulate a real case scenario, then we introduce some missing values in the data.

data[np.random.choice((0,1), (n_samples,n_features),p=(0.8,0.2))==1]=np.nan# Introduce missing values

We also generate some random label to be used for the survival analysis.In particular, we generate a label that is a tuple with the event and its respective time.

### LABELSevents= ("censored","uncensored")num_events=len(events)-1# The first event is the censored one, so we do not consider itmax_time=72# Maximum time to consider for the survival analysismax_survival=100# Use this to generate labels, but then the analysis will consider only the time to max_time, setting those patients who survived longer than max_time to "censored"labels=np.hstack( (np.random.choice(events, (n_samples,1)),np.random.rand(n_samples,1)*max_survival ),dtype=object )

Given the labels, we can encode them in a format suitable for the survival analysis.More specifically, we encode the events with numbers starting from 0, which is related to thecensored event, and we floor the time of the event, also setting those patients who survived longer thanmax_time tomax_time.In particular, in the latter case, we set the event to 0, which is thecensored event.

survival_label_function=np.vectorize(lambdalabel:label_to_survival(label,events,max_time ),signature="(n)->(m)")survival_labels=survival_label_function(labels)

Afterward, we report some pieces of code to include in a training script and in the evaluation to test the model.We first define the parameters for the shared net, which is the transformer's encoder,and then the parameters for the cause specific subnets, which are the MLPs that will output the risk probabilities for each event in consideration, to instantiate the model.

### MODEL## OSTransformer (shared net)emb_dim=n_features+1n_heads=emb_dim//2shared_net_params=dict(emb_dim=emb_dim,num_heads=n_heads,output_size=emb_dim)## CustomMLP (CS subnets)hidden_sizes= [400,200]cs_subnet_params=dict(hidden_sizes=hidden_sizes)model=SurvivalWrapper(num_events=num_events,max_time=max_time,shared_net_params=shared_net_params,cs_subnets_params=cs_subnet_params)

Now we feed the data to the model and obtain the predictions, which are the risk probabilities for each event in consideration.Note that the predictions for thecensored event are not considered since the event is not observed, and the time is not relevant for the analysis.

### OUTPUTS## Forward passoutputs=model(torch.from_numpy(data).float())## Predictionspredictions=survival_prediction(outputs)

We can now compute the losses for the survival analysis, which are the survival log-likelihood loss and the survival ranking loss.

### Survival Lossesloss=0criterion1=SurvivalLogLikelihoodLoss(num_events=num_events,max_time=max_time)loss+=criterion1(outputs,torch.from_numpy(survival_labels).float().unsqueeze(dim=1))criterion2=SurvivalRankingLoss(num_events=num_events,max_time=max_time)loss+=criterion2(outputs,torch.from_numpy(survival_labels).float().unsqueeze(dim=1))

Finally, we can compute the performance of the model using the Ct-index, a time-dependent variant of the C-index.We perform the cumulative sum of the risk probabilities to obtain the cumulative incidence function (CIF) and then compute the Ct-index.

### Performanceoutputs=torch.cumsum(outputs,dim=-1)# Compute the cumulative incidence function (CIF) cumulative summing the output probabilitiesperformance=Ct_index(survival_labels,outputs.detach().numpy(),num_events)

Contact

For any questions, please contactcamillomaria.caruso@unicampus.it andvalerio.guarrasi@unicampus.it.


Citation

@article{CARUSO2024108308,title ={A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values},journal ={Computer Methods and Programs in Biomedicine},volume ={254},pages ={108308},year ={2024},issn ={0169-2607},doi ={https://doi.org/10.1016/j.cmpb.2024.108308},url ={https://www.sciencedirect.com/science/article/pii/S016926072400302X},author ={Camillo Maria Caruso and Valerio Guarrasi and Sara Ramella and Paolo Soda},keywords ={Survival analysis, Missing data, Precision medicine, Oncology},}

About

Official implementation for the paper ``A Deep Learning Approach for Overall Survival Prediction in Lung Cancer with Missing Values´´

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp