Create a training script

To create a custom model, you need a Python training script that creates andtrains the custom model. You initialize your training job with the Pythontraining script, then invoke the training job'srunmethod to run the script.

In this topic, you create the training script, then specify command argumentsfor your training script.

Create a training script

In this section, you create a training script. This script is a new file in yournotebook environment namedtask.py. Later in this tutorial, you pass thisscript to theaiplatform.CustomTrainingJob constructor. When the script runs, it does the following:

  • Loads the data in the BigQuery dataset you created.

  • Uses theTensorFlow Keras API tobuild, compile, and train your model.

  • Specifies the number of epochs and the batch size to use when the KerasModel.fitmethod is invoked.

  • Specifies where to save model artifacts using theAIP_MODEL_DIR environmentvariable.AIP_MODEL_DIR is set by Vertex AI and contains the URI of adirectory for saving model artifacts. For more information, seeEnvironmentvariables for special Cloud Storagedirectories.

  • Exports a TensorFlowSavedModel tothe model directory. For more information, seeUsing theSavedModelformaton the TensorFlow website.

To create your training script, run the following code in your notebook:

%%writefiletask.pyimportargparseimportnumpyasnpimportosimportpandasaspdimporttensorflowastffromgoogle.cloudimportbigqueryfromgoogle.cloudimportstorage# Read environmental variablestraining_data_uri=os.getenv("AIP_TRAINING_DATA_URI")validation_data_uri=os.getenv("AIP_VALIDATION_DATA_URI")test_data_uri=os.getenv("AIP_TEST_DATA_URI")# Read argsparser=argparse.ArgumentParser()parser.add_argument('--label_column',required=True,type=str)parser.add_argument('--epochs',default=10,type=int)parser.add_argument('--batch_size',default=10,type=int)args=parser.parse_args()# Set up training variablesLABEL_COLUMN=args.label_column# See https://cloud.google.com/vertex-ai/docs/workbench/managed/executor#explicit-project-selection for issues regarding permissions.PROJECT_NUMBER=os.environ["CLOUD_ML_PROJECT_ID"]bq_client=bigquery.Client(project=PROJECT_NUMBER)# Download a tabledefdownload_table(bq_table_uri:str):# Remove bq:// prefix if presentprefix="bq://"ifbq_table_uri.startswith(prefix):bq_table_uri=bq_table_uri[len(prefix):]# Download the BigQuery table as a dataframe# This requires the "BigQuery Read Session User" role on the custom training service account.table=bq_client.get_table(bq_table_uri)returnbq_client.list_rows(table).to_dataframe()# Download dataset splitsdf_train=download_table(training_data_uri)df_validation=download_table(validation_data_uri)df_test=download_table(test_data_uri)defconvert_dataframe_to_dataset(df_train:pd.DataFrame,df_validation:pd.DataFrame,):df_train_x,df_train_y=df_train,df_train.pop(LABEL_COLUMN)df_validation_x,df_validation_y=df_validation,df_validation.pop(LABEL_COLUMN)y_train=tf.convert_to_tensor(np.asarray(df_train_y).astype("float32"))y_validation=tf.convert_to_tensor(np.asarray(df_validation_y).astype("float32"))# Convert to numpy representationx_train=tf.convert_to_tensor(np.asarray(df_train_x).astype("float32"))x_test=tf.convert_to_tensor(np.asarray(df_validation_x).astype("float32"))# Convert to one-hot representationnum_species=len(df_train_y.unique())y_train=tf.keras.utils.to_categorical(y_train,num_classes=num_species)y_validation=tf.keras.utils.to_categorical(y_validation,num_classes=num_species)dataset_train=tf.data.Dataset.from_tensor_slices((x_train,y_train))dataset_validation=tf.data.Dataset.from_tensor_slices((x_test,y_validation))return(dataset_train,dataset_validation)# Create datasetsdataset_train,dataset_validation=convert_dataframe_to_dataset(df_train,df_validation)# Shuffle train setdataset_train=dataset_train.shuffle(len(df_train))defcreate_model(num_features):# Create modelDense=tf.keras.layers.Densemodel=tf.keras.Sequential([Dense(100,activation=tf.nn.relu,kernel_initializer="uniform",input_dim=num_features,),Dense(75,activation=tf.nn.relu),Dense(50,activation=tf.nn.relu),Dense(25,activation=tf.nn.relu),Dense(3,activation=tf.nn.softmax),])# Compile Keras modeloptimizer=tf.keras.optimizers.RMSprop(lr=0.001)model.compile(loss="categorical_crossentropy",metrics=["accuracy"],optimizer=optimizer)returnmodel# Create the modelmodel=create_model(num_features=dataset_train._flat_shapes[0].dims[0].value)# Set up datasetsdataset_train=dataset_train.batch(args.batch_size)dataset_validation=dataset_validation.batch(args.batch_size)# Train the modelmodel.fit(dataset_train,epochs=args.epochs,validation_data=dataset_validation)tf.saved_model.save(model,os.getenv("AIP_MODEL_DIR"))

After you create the script, it appears in the root folder of your notebook:View training script.

Define arguments for your training script

You pass the following command-line arguments to your training script:

  • label_column - This identifies the column in your data that contains whatyou want to predict. In this case, that column isspecies. You defined thisin a variable namedLABEL_COLUMN when you processed your data. For moreinformation, seeDownload, preprocess, and split the data.

  • epochs - This is the number of epochs used when you train your model. Anepoch is an iteration over the data when training your model. This tutorialuses 20 epochs.

  • batch_size - This is the number of samples that are processed before yourmodel updates. This tutorial uses a batch size of 10.

To define the arguments that are passed to your script, run the following code:

JOB_NAME = "custom_job_unique"EPOCHS = 20BATCH_SIZE = 10CMDARGS = [    "--label_column=" + LABEL_COLUMN,    "--epochs=" + str(EPOCHS),    "--batch_size=" + str(BATCH_SIZE),]

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2026-02-19 UTC.