Create a training script Stay organized with collections Save and categorize content based on your preferences.
runmethod to run the script.In this topic, you create the training script, then specify command argumentsfor your training script.
Create a training script
In this section, you create a training script. This script is a new file in yournotebook environment namedtask.py. Later in this tutorial, you pass thisscript to theaiplatform.CustomTrainingJob constructor. When the script runs, it does the following:
Loads the data in the BigQuery dataset you created.
Uses theTensorFlow Keras API tobuild, compile, and train your model.
Specifies the number of epochs and the batch size to use when the Keras
Model.fitmethod is invoked.Specifies where to save model artifacts using the
AIP_MODEL_DIRenvironmentvariable.AIP_MODEL_DIRis set by Vertex AI and contains the URI of adirectory for saving model artifacts. For more information, seeEnvironmentvariables for special Cloud Storagedirectories.Exports a TensorFlow
SavedModeltothe model directory. For more information, seeUsing theSavedModelformaton the TensorFlow website.
To create your training script, run the following code in your notebook:
%%writefiletask.pyimportargparseimportnumpyasnpimportosimportpandasaspdimporttensorflowastffromgoogle.cloudimportbigqueryfromgoogle.cloudimportstorage# Read environmental variablestraining_data_uri=os.getenv("AIP_TRAINING_DATA_URI")validation_data_uri=os.getenv("AIP_VALIDATION_DATA_URI")test_data_uri=os.getenv("AIP_TEST_DATA_URI")# Read argsparser=argparse.ArgumentParser()parser.add_argument('--label_column',required=True,type=str)parser.add_argument('--epochs',default=10,type=int)parser.add_argument('--batch_size',default=10,type=int)args=parser.parse_args()# Set up training variablesLABEL_COLUMN=args.label_column# See https://cloud.google.com/vertex-ai/docs/workbench/managed/executor#explicit-project-selection for issues regarding permissions.PROJECT_NUMBER=os.environ["CLOUD_ML_PROJECT_ID"]bq_client=bigquery.Client(project=PROJECT_NUMBER)# Download a tabledefdownload_table(bq_table_uri:str):# Remove bq:// prefix if presentprefix="bq://"ifbq_table_uri.startswith(prefix):bq_table_uri=bq_table_uri[len(prefix):]# Download the BigQuery table as a dataframe# This requires the "BigQuery Read Session User" role on the custom training service account.table=bq_client.get_table(bq_table_uri)returnbq_client.list_rows(table).to_dataframe()# Download dataset splitsdf_train=download_table(training_data_uri)df_validation=download_table(validation_data_uri)df_test=download_table(test_data_uri)defconvert_dataframe_to_dataset(df_train:pd.DataFrame,df_validation:pd.DataFrame,):df_train_x,df_train_y=df_train,df_train.pop(LABEL_COLUMN)df_validation_x,df_validation_y=df_validation,df_validation.pop(LABEL_COLUMN)y_train=tf.convert_to_tensor(np.asarray(df_train_y).astype("float32"))y_validation=tf.convert_to_tensor(np.asarray(df_validation_y).astype("float32"))# Convert to numpy representationx_train=tf.convert_to_tensor(np.asarray(df_train_x).astype("float32"))x_test=tf.convert_to_tensor(np.asarray(df_validation_x).astype("float32"))# Convert to one-hot representationnum_species=len(df_train_y.unique())y_train=tf.keras.utils.to_categorical(y_train,num_classes=num_species)y_validation=tf.keras.utils.to_categorical(y_validation,num_classes=num_species)dataset_train=tf.data.Dataset.from_tensor_slices((x_train,y_train))dataset_validation=tf.data.Dataset.from_tensor_slices((x_test,y_validation))return(dataset_train,dataset_validation)# Create datasetsdataset_train,dataset_validation=convert_dataframe_to_dataset(df_train,df_validation)# Shuffle train setdataset_train=dataset_train.shuffle(len(df_train))defcreate_model(num_features):# Create modelDense=tf.keras.layers.Densemodel=tf.keras.Sequential([Dense(100,activation=tf.nn.relu,kernel_initializer="uniform",input_dim=num_features,),Dense(75,activation=tf.nn.relu),Dense(50,activation=tf.nn.relu),Dense(25,activation=tf.nn.relu),Dense(3,activation=tf.nn.softmax),])# Compile Keras modeloptimizer=tf.keras.optimizers.RMSprop(lr=0.001)model.compile(loss="categorical_crossentropy",metrics=["accuracy"],optimizer=optimizer)returnmodel# Create the modelmodel=create_model(num_features=dataset_train._flat_shapes[0].dims[0].value)# Set up datasetsdataset_train=dataset_train.batch(args.batch_size)dataset_validation=dataset_validation.batch(args.batch_size)# Train the modelmodel.fit(dataset_train,epochs=args.epochs,validation_data=dataset_validation)tf.saved_model.save(model,os.getenv("AIP_MODEL_DIR"))After you create the script, it appears in the root folder of your notebook:
Define arguments for your training script
You pass the following command-line arguments to your training script:
label_column- This identifies the column in your data that contains whatyou want to predict. In this case, that column isspecies. You defined thisin a variable namedLABEL_COLUMNwhen you processed your data. For moreinformation, seeDownload, preprocess, and split the data.epochs- This is the number of epochs used when you train your model. Anepoch is an iteration over the data when training your model. This tutorialuses 20 epochs.batch_size- This is the number of samples that are processed before yourmodel updates. This tutorial uses a batch size of 10.
To define the arguments that are passed to your script, run the following code:
JOB_NAME = "custom_job_unique"EPOCHS = 20BATCH_SIZE = 10CMDARGS = [ "--label_column=" + LABEL_COLUMN, "--epochs=" + str(EPOCHS), "--batch_size=" + str(BATCH_SIZE),]Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2026-02-19 UTC.