- Notifications
You must be signed in to change notification settings - Fork0
Code for "VisTabNet: Adapting Vision Transformers for Tabular Data"
License
wwydmanski/VisTabNet
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
VisTabNet is a powerful Vision Transformer-based Tabular Data Classifier that leverages the strength of transformer architectures for tabular data classification tasks.
It is a proof of concept forVisTabNet: Adapting Vision Transformers for Tabular Data publication.
- Vision Transformer architecture adapted for tabular data
- Simple and intuitive API similar to scikit-learn
- GPU acceleration support
- Automatic handling of numerical features
- Built-in evaluation metrics
- Compatible with pandas DataFrames and numpy arrays
You can install VisTabNet using pip:
pip install vistabnet
Here's a simple example to get you started:
fromvistabnetimportVisTabNetClassifierimportnumpyasnpfromsklearn.metricsimportbalanced_accuracy_scorefromsklearn.model_selectionimporttrain_test_split# Prepare your dataX_train,y_train,X_test,y_test= ...# Load your data here# Note: y should be label encoded, not one-hot encoded# Initialize the modelmodel=VisTabNetClassifier(input_features=X_train.shape[1],classes=len(np.unique(y_train)),device="cuda"# Use "cpu" if no GPU is available)# Train the modelmodel.fit(X_train,y_train,eval_X=X_test,eval_y=y_test)# Make predictionsy_pred=model.predict(X_test)# Evaluate the modelaccuracy=balanced_accuracy_score(y_test,y_pred)print(f"Balanced accuracy:{accuracy}")
You can customize the VisTabNet model by adjusting various parameters:
model=VisTabNetClassifier(input_features=X_train.shape[1],classes=len(np.unique(y_train)),hidden_dim=256,num_layers=6,num_heads=8,device="cuda")
- Python ≥ 3.9
- PyTorch ≥ 2.0
- torchvision ≥ 0.15.0
- tqdm ≥ 4.65.0
- focal-loss-torch ≥ 0.1.2
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use VisTabNet in your research, please cite:
@misc{wydmański2024vistabnetadaptingvisiontransformers,title={VisTabNet: Adapting Vision Transformers for Tabular Data},author={Witold Wydmański and Ulvi Movsum-zada and Jacek Tabor and Marek Śmieja},year={2024},eprint={2501.00057},archivePrefix={arXiv},primaryClass={cs.LG},url={https://arxiv.org/abs/2501.00057}, }
For questions and support, please open an issue in the GitHub repository.