Update ML models in running pipelines

Run in Google ColabView source on GitHub

This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in aModelHandler configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as theWatchFilePattern, or by configuring a custom side inputPCollection that defines the logic for the model update.

The pipeline in this notebook uses a RunInferencePTransform with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side inputPCollection that emitsModelMetadata.For more information about side inputs, see theSide inputs section in the Apache Beam Programming Guide.

This example usesWatchFilePattern as a side input.WatchFilePattern is used to watch for file updates that match thefile_pattern based on timestamps. It emits the latestModelMetadata, which is used in the RunInferencePTransform to automatically update the ML model without stopping the Apache Beam pipeline.

Before you begin

Install the dependencies required to run this notebook.

To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later.

pipinstallapache_beam[interactive,gcp]>=2.46.0tensorflow==2.15.0tensorflow_hub==0.16.1keras==2.15.0Pillow==11.0.0--quiet
# Imports required for the notebook.importloggingimporttimeimportosfromtypingimportIterablefromtypingimportTupleimportapache_beamasbeamfromapache_beam.ml.inference.baseimportPredictionResultfromapache_beam.ml.inference.baseimportRunInferencefromapache_beam.ml.inference.tensorflow_inferenceimportTFModelHandlerTensorfromapache_beam.ml.inference.utilsimportWatchFilePatternfromapache_beam.options.pipeline_optionsimportGoogleCloudOptionsfromapache_beam.options.pipeline_optionsimportPipelineOptionsfromapache_beam.options.pipeline_optionsimportSetupOptionsfromapache_beam.options.pipeline_optionsimportStandardOptionsfromapache_beam.options.pipeline_optionsimportWorkerOptionsfromapache_beam.transforms.periodicsequenceimportPeriodicImpulseimportnumpyfromPILimportImageimporttensorflowastf
# Authenticate to your Google Cloud account.defauth_to_colab():fromgoogle.colabimportauthauth.authenticate_user()auth_to_colab()

Configure the runner

This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:

  • Ensure that you have all the required permissions to run the pipeline on Dataflow.
  • Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.

In the following code, replaceBUCKET_NAME with the the name of your Cloud Storage bucket.

options=PipelineOptions()options.view_as(StandardOptions).streaming=True# Replace with your bucket name.BUCKET_NAME='<BUCKET_NAME>'# @param {type:'string'}os.environ['BUCKET_NAME']=BUCKET_NAME# Provide required pipeline options for the Dataflow Runner.options.view_as(StandardOptions).runner="DataflowRunner"# Set the project to the default project in your current Google Cloud environment.PROJECT_NAME='<PROJECT_NAME>'# @param {type:'string'}options.view_as(GoogleCloudOptions).project=PROJECT_NAME# Set the Google Cloud region that you want to run Dataflow in.options.view_as(GoogleCloudOptions).region='us-central1'# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.dataflow_gcs_location="gs://%s/dataflow"%BUCKET_NAME# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.options.view_as(GoogleCloudOptions).staging_location='%s/staging'%dataflow_gcs_location# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.options.view_as(GoogleCloudOptions).staging_location='%s/staging'%dataflow_gcs_location# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.options.view_as(GoogleCloudOptions).temp_location='%s/temp'%dataflow_gcs_locationoptions.view_as(SetupOptions).save_main_session=True# Launching Dataflow with only one worker might result in processing delays due to# initial input processing. This could further postpone the side input model updates.# To expedite the model update process, it's recommended to set num_workers>1.# https://github.com/apache/beam/issues/28776options.view_as(WorkerOptions).num_workers=5

Install thetensorflow andtensorflow_hub dependencies on Dataflow. Use therequirements_file pipeline option to pass these dependencies.

# In a requirements file, define the dependencies required for the pipeline.!printf'tensorflow==2.15.0\ntensorflow_hub==0.16.1\nkeras==2.15.0\nPillow==11.0.0' >./requirements.txt# Install the pipeline dependencies on Dataflow.options.view_as(SetupOptions).requirements_file='./requirements.txt'

Use the TensorFlow model handler

This example usesTFModelHandlerTensor as the model handler and theresnet_101 model trained onImageNet.

For the Dataflow runner, you need to store the model in a remote location that the Apache Beam pipeline can access. For this example, download theResNet101 model, and upload it to the Google Cloud Storage bucket.

model=tf.keras.applications.resnet.ResNet101()model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`!gsutilcpresnet101_weights_tf_dim_ordering_tf_kernels.kerasgs://${BUCKET_NAME}/dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras
model_handler=TFModelHandlerTensor(model_uri=dataflow_gcs_location+"/resnet101_weights_tf_dim_ordering_tf_kernels.keras")

Preprocess images

Usepreprocess_image to run the inference, read the image, and convert the image to a TensorFlow tensor.

defpreprocess_image(image_name,image_dir):img=tf.keras.utils.get_file(image_name,image_dir+image_name)img=Image.open(img).resize((224,224))img=numpy.array(img)/255.0img_tensor=tf.cast(tf.convert_to_tensor(img[...]),dtype=tf.float32)returnimg_tensor
classPostProcessor(beam.DoFn):"""Process the PredictionResult to get the predicted label.  Returns predicted label.  """defprocess(self,element:PredictionResult)->Iterable[Tuple[str,str]]:predicted_class=numpy.argmax(element.inference,axis=-1)labels_path=tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'# pylint: disable=line-too-long)imagenet_labels=numpy.array(open(labels_path).read().splitlines())predicted_class_name=imagenet_labels[predicted_class]yieldpredicted_class_name.title(),element.model_id
# Define the pipeline object.pipeline=beam.Pipeline(options=options)

Next, review the pipeline steps and examine the code.

Pipeline steps

  1. Create aPeriodicImpulse transform, which emits output everyn seconds. ThePeriodicImpulse transform generates an infinite sequence of elements with a given runtime interval.

    In this example,PeriodicImpulse mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, usePeriodicImpulse to output elements atm intervals.To learn more aboutPeriodicImpulse, see thePeriodicImpulse code.

start_timestamp=time.time()# start timestamp of the periodic impulseend_timestamp=start_timestamp+60*20# end timestamp of the periodic impulse (will run for 20 minutes).main_input_fire_interval=60# interval in seconds at which the main input PCollection is emitted.side_input_fire_interval=60# interval in seconds at which the side input PCollection is emitted.periodic_impulse=(pipeline|"MainInputPcoll" >>PeriodicImpulse(start_timestamp=start_timestamp,stop_timestamp=end_timestamp,fire_interval=main_input_fire_interval))
  1. To read and preprocess the images, use thepreprocess_image function. This example usesCat-with-beanie.jpg for all inferences.

    Note: The image used for prediction is licensed in CC-BY. The creator is listed in theLICENSE.txt file.

download.png

image_data=(periodic_impulse|beam.Map(lambdax:"Cat-with-beanie.jpg")|"ReadImage" >>beam.Map(lambdaimage_name:preprocess_image(image_name=image_name,image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))
  1. Pass the images to the RunInferencePTransform. RunInference takesmodel_handler andmodel_metadata_pcoll as input parameters.
    • model_metadata_pcoll is a side inputPCollection to the RunInferencePTransform. This side input updates themodel_uri in themodel_handler while the Apache Beam pipeline runs.
    • UseWatchFilePattern as side input to watch afile_pattern matching.keras files. In this case, thefile_pattern is'gs://BUCKET_NAME/dataflow/*keras'.
# The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.file_pattern=dataflow_gcs_location+'/*.keras'side_input_pcoll=(pipeline|"WatchFilePattern" >>WatchFilePattern(file_pattern=file_pattern,interval=side_input_fire_interval,stop_timestamp=end_timestamp))inferences=(image_data|"ApplyWindowing" >>beam.WindowInto(beam.window.FixedWindows(10))|"RunInference" >>RunInference(model_handler=model_handler,model_metadata_pcoll=side_input_pcoll))
  1. Post-process thePredictionResult object.When the inference is complete, RunInference outputs aPredictionResult object that contains the fieldsexample,inference, andmodel_id. Themodel_id field identifies the model used to run the inference. ThePostProcessor returns the predicted label and the model ID used to run the inference on the predicted label.
post_processor=(inferences|"PostProcessResults" >>beam.ParDo(PostProcessor())|"LogResults" >>beam.Map(logging.info))

Watch for the model update

After the pipeline starts processing data, when you see output emitted from the RunInferencePTransform, upload aresnet152 model saved in the.keras format to a Google Cloud Storage bucket location that matches thefile_pattern you defined earlier.

model=tf.keras.applications.resnet.ResNet152()model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')!gsutilcpresnet152_weights_tf_dim_ordering_tf_kernels.kerasgs://${BUCKET_NAME}/resnet152_weights_tf_dim_ordering_tf_kernels.keras

Run the pipeline

Use the following code to run the pipeline.

# Run the pipeline.result=pipeline.run().wait_until_finish()

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-10-22 UTC.