Train a model with TabNet

This page shows you how to train a classification or regression model from atabular dataset with the Tabular Workflow for TabNet.

Two versions of the Tabular Workflow for TabNet are available:

  • HyperparameterTuningJob searches for the best set of hyperparameter values to use for model training.
  • CustomJob lets you specify the hyperparameter values to use for model training. If you know exactly which hyperparameter values you need, specify them instead of searching for them and save on training resources.

To learn about the service accounts this workflow uses, seeService accounts for Tabular Workflows.

Workflow APIs

This workflow uses the following APIs:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Train a model with HyperparameterTuningJob

The following sample code demonstrates how to run a HyperparameterTuningJob pipeline:

pipeline_job = aiplatform.PipelineJob(    ...    template_path=template_path,    parameter_values=parameter_values,    ...)pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optionalservice_account parameter inpipeline_job.run() lets you set theVertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function.The training data can be either a CSV file in Cloud Storage or a table inBigQuery.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

The following is a subset ofget_tabnet_hyperparameter_tuning_job_pipeline_and_parameters parameters:

Parameter nameTypeDefinition
data_source_csv_filenamesStringA URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_pathStringA URI for a BigQuery table.
dataflow_service_accountString(Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.
study_spec_parameters_overrideList[Dict[String, Any]](Optional) An override for tuning hyperparameters. This parameter can be empty or contain one or more of the possible hyperparameters. If a hyperparameter value is not set, Vertex AI uses the default tuning range for the hyperparameter.

To configure the hyperparameters using thestudy_spec_parameters_override parameter,use Vertex AI's helper functionget_tabnet_study_spec_parameters_override.The function has the following inputs:

  • dataset_size_bucket: A bucket for the dataset size
    • 'small': < 1M rows
    • 'medium': 1M - 100M rows
    • 'large': > 100M rows
  • training_budget_bucket: A bucket for the training budget
    • 'small': < $600
    • 'medium': $600 - $2400
    • 'large': > $2400
  • prediction_type: The desired inference type

Theget_tabnet_study_spec_parameters_override function returns a list ofhyperparameters and ranges.

The following is an example of how to use theget_tabnet_study_spec_parameters_override function:

study_spec_parameters_override=automl_tabular_utils.get_tabnet_study_spec_parameters_override(dataset_size_bucket="small",prediction_type="classification",training_budget_bucket="small",)

Train a model with CustomJob

The following sample code demonstrates how to run a CustomJob pipeline:

pipeline_job = aiplatform.PipelineJob(    ...    template_path=template_path,    parameter_values=parameter_values,    ...)pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optionalservice_account parameter inpipeline_job.run() lets you set theVertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function.The training data can be either a CSV file in Cloud Storage or a table inBigQuery.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

The following is a subset ofget_tabnet_trainer_pipeline_and_parameters parameters:

Parameter nameTypeDefinition
data_source_csv_filenamesStringA URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_pathStringA URI for a BigQuery table.
dataflow_service_accountString(Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.

What's next

Once you're ready to make inferences with your classification or regressionmodel, you have two options:

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-12-15 UTC.