Update ML models in running pipelines Stay organized with collections Save and categorize content based on your preferences.
Run in Google Colab | View source on GitHub |
This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in aModelHandler configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as theWatchFilePattern, or by configuring a custom side inputPCollection that defines the logic for the model update.
The pipeline in this notebook uses a RunInferencePTransform with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side inputPCollection that emitsModelMetadata.For more information about side inputs, see theSide inputs section in the Apache Beam Programming Guide.
This example usesWatchFilePattern as a side input.WatchFilePattern is used to watch for file updates that match thefile_pattern based on timestamps. It emits the latestModelMetadata, which is used in the RunInferencePTransform to automatically update the ML model without stopping the Apache Beam pipeline.
Before you begin
Install the dependencies required to run this notebook.
To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later.
pipinstallapache_beam[interactive,gcp]>=2.46.0tensorflow==2.15.0tensorflow_hub==0.16.1keras==2.15.0Pillow==11.0.0--quiet# Imports required for the notebook.importloggingimporttimeimportosfromtypingimportIterablefromtypingimportTupleimportapache_beamasbeamfromapache_beam.ml.inference.baseimportPredictionResultfromapache_beam.ml.inference.baseimportRunInferencefromapache_beam.ml.inference.tensorflow_inferenceimportTFModelHandlerTensorfromapache_beam.ml.inference.utilsimportWatchFilePatternfromapache_beam.options.pipeline_optionsimportGoogleCloudOptionsfromapache_beam.options.pipeline_optionsimportPipelineOptionsfromapache_beam.options.pipeline_optionsimportSetupOptionsfromapache_beam.options.pipeline_optionsimportStandardOptionsfromapache_beam.options.pipeline_optionsimportWorkerOptionsfromapache_beam.transforms.periodicsequenceimportPeriodicImpulseimportnumpyfromPILimportImageimporttensorflowastf# Authenticate to your Google Cloud account.defauth_to_colab():fromgoogle.colabimportauthauth.authenticate_user()auth_to_colab()Configure the runner
This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:
- Ensure that you have all the required permissions to run the pipeline on Dataflow.
- Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.
In the following code, replaceBUCKET_NAME with the the name of your Cloud Storage bucket.
options=PipelineOptions()options.view_as(StandardOptions).streaming=True# Replace with your bucket name.BUCKET_NAME='<BUCKET_NAME>'# @param {type:'string'}os.environ['BUCKET_NAME']=BUCKET_NAME# Provide required pipeline options for the Dataflow Runner.options.view_as(StandardOptions).runner="DataflowRunner"# Set the project to the default project in your current Google Cloud environment.PROJECT_NAME='<PROJECT_NAME>'# @param {type:'string'}options.view_as(GoogleCloudOptions).project=PROJECT_NAME# Set the Google Cloud region that you want to run Dataflow in.options.view_as(GoogleCloudOptions).region='us-central1'# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.dataflow_gcs_location="gs://%s/dataflow"%BUCKET_NAME# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.options.view_as(GoogleCloudOptions).staging_location='%s/staging'%dataflow_gcs_location# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.options.view_as(GoogleCloudOptions).staging_location='%s/staging'%dataflow_gcs_location# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.options.view_as(GoogleCloudOptions).temp_location='%s/temp'%dataflow_gcs_locationoptions.view_as(SetupOptions).save_main_session=True# Launching Dataflow with only one worker might result in processing delays due to# initial input processing. This could further postpone the side input model updates.# To expedite the model update process, it's recommended to set num_workers>1.# https://github.com/apache/beam/issues/28776options.view_as(WorkerOptions).num_workers=5Install thetensorflow andtensorflow_hub dependencies on Dataflow. Use therequirements_file pipeline option to pass these dependencies.
# In a requirements file, define the dependencies required for the pipeline.!printf'tensorflow==2.15.0\ntensorflow_hub==0.16.1\nkeras==2.15.0\nPillow==11.0.0' >./requirements.txt# Install the pipeline dependencies on Dataflow.options.view_as(SetupOptions).requirements_file='./requirements.txt'Use the TensorFlow model handler
This example usesTFModelHandlerTensor as the model handler and theresnet_101 model trained onImageNet.
For the Dataflow runner, you need to store the model in a remote location that the Apache Beam pipeline can access. For this example, download theResNet101 model, and upload it to the Google Cloud Storage bucket.
model=tf.keras.applications.resnet.ResNet101()model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`!gsutilcpresnet101_weights_tf_dim_ordering_tf_kernels.kerasgs://${BUCKET_NAME}/dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.kerasmodel_handler=TFModelHandlerTensor(model_uri=dataflow_gcs_location+"/resnet101_weights_tf_dim_ordering_tf_kernels.keras")Preprocess images
Usepreprocess_image to run the inference, read the image, and convert the image to a TensorFlow tensor.
defpreprocess_image(image_name,image_dir):img=tf.keras.utils.get_file(image_name,image_dir+image_name)img=Image.open(img).resize((224,224))img=numpy.array(img)/255.0img_tensor=tf.cast(tf.convert_to_tensor(img[...]),dtype=tf.float32)returnimg_tensorclassPostProcessor(beam.DoFn):"""Process the PredictionResult to get the predicted label. Returns predicted label. """defprocess(self,element:PredictionResult)->Iterable[Tuple[str,str]]:predicted_class=numpy.argmax(element.inference,axis=-1)labels_path=tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'# pylint: disable=line-too-long)imagenet_labels=numpy.array(open(labels_path).read().splitlines())predicted_class_name=imagenet_labels[predicted_class]yieldpredicted_class_name.title(),element.model_id# Define the pipeline object.pipeline=beam.Pipeline(options=options)Next, review the pipeline steps and examine the code.
Pipeline steps
Create a
PeriodicImpulsetransform, which emits output everynseconds. ThePeriodicImpulsetransform generates an infinite sequence of elements with a given runtime interval.In this example,
PeriodicImpulsemimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, usePeriodicImpulseto output elements atmintervals.To learn more aboutPeriodicImpulse, see thePeriodicImpulsecode.
start_timestamp=time.time()# start timestamp of the periodic impulseend_timestamp=start_timestamp+60*20# end timestamp of the periodic impulse (will run for 20 minutes).main_input_fire_interval=60# interval in seconds at which the main input PCollection is emitted.side_input_fire_interval=60# interval in seconds at which the side input PCollection is emitted.periodic_impulse=(pipeline|"MainInputPcoll" >>PeriodicImpulse(start_timestamp=start_timestamp,stop_timestamp=end_timestamp,fire_interval=main_input_fire_interval))To read and preprocess the images, use the
Note: The image used for prediction is licensed in CC-BY. The creator is listed in theLICENSE.txt file.preprocess_imagefunction. This example usesCat-with-beanie.jpgfor all inferences.
image_data=(periodic_impulse|beam.Map(lambdax:"Cat-with-beanie.jpg")|"ReadImage" >>beam.Map(lambdaimage_name:preprocess_image(image_name=image_name,image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))- Pass the images to the RunInference
PTransform. RunInference takesmodel_handlerandmodel_metadata_pcollas input parameters.model_metadata_pcollis a side inputPCollectionto the RunInferencePTransform. This side input updates themodel_uriin themodel_handlerwhile the Apache Beam pipeline runs.- Use
WatchFilePatternas side input to watch afile_patternmatching.kerasfiles. In this case, thefile_patternis'gs://BUCKET_NAME/dataflow/*keras'.
# The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.file_pattern=dataflow_gcs_location+'/*.keras'side_input_pcoll=(pipeline|"WatchFilePattern" >>WatchFilePattern(file_pattern=file_pattern,interval=side_input_fire_interval,stop_timestamp=end_timestamp))inferences=(image_data|"ApplyWindowing" >>beam.WindowInto(beam.window.FixedWindows(10))|"RunInference" >>RunInference(model_handler=model_handler,model_metadata_pcoll=side_input_pcoll))- Post-process the
PredictionResultobject.When the inference is complete, RunInference outputs aPredictionResultobject that contains the fieldsexample,inference, andmodel_id. Themodel_idfield identifies the model used to run the inference. ThePostProcessorreturns the predicted label and the model ID used to run the inference on the predicted label.
post_processor=(inferences|"PostProcessResults" >>beam.ParDo(PostProcessor())|"LogResults" >>beam.Map(logging.info))Watch for the model update
After the pipeline starts processing data, when you see output emitted from the RunInferencePTransform, upload aresnet152 model saved in the.keras format to a Google Cloud Storage bucket location that matches thefile_pattern you defined earlier.
model=tf.keras.applications.resnet.ResNet152()model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')!gsutilcpresnet152_weights_tf_dim_ordering_tf_kernels.kerasgs://${BUCKET_NAME}/resnet152_weights_tf_dim_ordering_tf_kernels.kerasRun the pipeline
Use the following code to run the pipeline.
# Run the pipeline.result=pipeline.run().wait_until_finish()Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-10-22 UTC.
Run in Google Colab
View source on GitHub