Train a model with TabNet Stay organized with collections Save and categorize content based on your preferences.
To see an example of how to train a model with TabNet, run the "Tabular Workflows: TabNet Pipeline" notebook in one of the following environments:
Open in Colab |Open in Colab Enterprise |Openin Vertex AI Workbench |View on GitHub
This page shows you how to train a classification or regression model from atabular dataset with the Tabular Workflow for TabNet.
Two versions of the Tabular Workflow for TabNet are available:
- HyperparameterTuningJob searches for the best set of hyperparameter values to use for model training.
- CustomJob lets you specify the hyperparameter values to use for model training. If you know exactly which hyperparameter values you need, specify them instead of searching for them and save on training resources.
To learn about the service accounts this workflow uses, seeService accounts for Tabular Workflows.
Workflow APIs
This workflow uses the following APIs:
- Vertex AI
- Dataflow
- Compute Engine
- Cloud Storage
Train a model with HyperparameterTuningJob
The following sample code demonstrates how to run a HyperparameterTuningJob pipeline:
pipeline_job = aiplatform.PipelineJob( ... template_path=template_path, parameter_values=parameter_values, ...)pipeline_job.run(service_account=SERVICE_ACCOUNT)The optionalservice_account parameter inpipeline_job.run() lets you set theVertex AI Pipelines service account to an account of your choice.
The pipeline and the parameter values are defined by the following function.The training data can be either a CSV file in Cloud Storage or a table inBigQuery.
template_path, parameter_values = automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)The following is a subset ofget_tabnet_hyperparameter_tuning_job_pipeline_and_parameters parameters:
| Parameter name | Type | Definition |
|---|---|---|
data_source_csv_filenames | String | A URI for a CSV stored in Cloud Storage. |
data_source_bigquery_table_path | String | A URI for a BigQuery table. |
dataflow_service_account | String | (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account. |
study_spec_parameters_override | List[Dict[String, Any]] | (Optional) An override for tuning hyperparameters. This parameter can be empty or contain one or more of the possible hyperparameters. If a hyperparameter value is not set, Vertex AI uses the default tuning range for the hyperparameter. |
To configure the hyperparameters using thestudy_spec_parameters_override parameter,use Vertex AI's helper functionget_tabnet_study_spec_parameters_override.The function has the following inputs:
dataset_size_bucket: A bucket for the dataset size- 'small': < 1M rows
- 'medium': 1M - 100M rows
- 'large': > 100M rows
training_budget_bucket: A bucket for the training budget- 'small': < $600
- 'medium': $600 - $2400
- 'large': > $2400
prediction_type: The desired inference type
Theget_tabnet_study_spec_parameters_override function returns a list ofhyperparameters and ranges.
The following is an example of how to use theget_tabnet_study_spec_parameters_override function:
study_spec_parameters_override=automl_tabular_utils.get_tabnet_study_spec_parameters_override(dataset_size_bucket="small",prediction_type="classification",training_budget_bucket="small",)Train a model with CustomJob
The following sample code demonstrates how to run a CustomJob pipeline:
pipeline_job = aiplatform.PipelineJob( ... template_path=template_path, parameter_values=parameter_values, ...)pipeline_job.run(service_account=SERVICE_ACCOUNT)The optionalservice_account parameter inpipeline_job.run() lets you set theVertex AI Pipelines service account to an account of your choice.
The pipeline and the parameter values are defined by the following function.The training data can be either a CSV file in Cloud Storage or a table inBigQuery.
template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)The following is a subset ofget_tabnet_trainer_pipeline_and_parameters parameters:
| Parameter name | Type | Definition |
|---|---|---|
data_source_csv_filenames | String | A URI for a CSV stored in Cloud Storage. |
data_source_bigquery_table_path | String | A URI for a BigQuery table. |
dataflow_service_account | String | (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account. |
What's next
Once you're ready to make inferences with your classification or regressionmodel, you have two options:
- Make online (real-time) inferences using your model
- Get batch inferences directly from your model.
- Learn aboutpricing for model training.
Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-12-15 UTC.
Open in Colab
Open in Colab Enterprise
Openin Vertex AI Workbench
View on GitHub