Apache Beam RunInference for scikit-learn Stay organized with collections Save and categorize content based on your preferences.
Run in Google Colab | View source on GitHub |
This notebook demonstrates the use of the RunInference transform forscikit-learn, also called sklearn.Apache BeamRunInference has implementations of theModelHandler class prebuilt for scikit-learn. For more information about using RunInference, seeGet started with AI/ML pipelines in the Apache Beam documentation.
You can choose the appropriate model handler based on your input data type:
With RunInference, these model handlers manage batching, vectorization, and prediction optimization for your scikit-learn pipeline or model.
This notebook demonstrates the following common RunInference patterns:
- Generate predictions.
- Postprocess results after RunInference.
- Run inference with multiple models in the same pipeline.
The linear regression models used in these samples are trained on data that correspondes to the 5 and 10 times tables; that is,y = 5x andy = 10x respectively.
Before you begin
Complete the following setup steps:
- Install dependencies for Apache Beam.
- Authenticate with Google Cloud.
- Specify your project and bucket. You use the project and bucket to save and load models.
pipinstallgoogle-api-core--quietpipinstallgoogle-cloud-pubsubgoogle-cloud-bigquery-storage--quietpipinstallapache-beam[gcp,dataframe]--quiet
About scikit-learn versions
scikit-learn is a build-dependency of Apache Beam. If you need to install a different version of sklearn , use%pip install scikit-learn==<version>
fromgoogle.colabimportauthauth.authenticate_user()importpicklefromsklearnimportlinear_modelfromtypingimportTupleimportnumpyasnpimportapache_beamasbeamfromapache_beam.ml.inference.sklearn_inferenceimportModelFileTypefromapache_beam.ml.inference.sklearn_inferenceimportSklearnModelHandlerNumpyfromapache_beam.ml.inference.baseimportKeyedModelHandlerfromapache_beam.ml.inference.baseimportPredictionResultfromapache_beam.ml.inference.baseimportRunInferencefromapache_beam.options.pipeline_optionsimportPipelineOptions# NOTE: If an error occurs, restart your runtime.importos# Constantsproject="<PROJECT_ID>"# @param {type:'string'}bucket="<BUCKET_NAME>"# @param {type:'string'}# To avoid warnings, set the project.os.environ['GOOGLE_CLOUD_PROJECT']=projectCreate the data and the scikit-learn model
This section demonstrates the following steps:
- Create the data to train the scikit-learn linear regression model.
- Train the linear regression model.
- Save the scikit-learn model using
pickle.
In this example, you create two models, one with the 5 times model and a second with the 10 times model.
# Input data to train the sklearn model for the 5 times table.x=np.arange(0,100,dtype=np.float32).reshape(-1,1)y=(x*5).reshape(-1,1)deftrain_and_save_model(x,y,model_file_name):regression=linear_model.LinearRegression()regression.fit(x,y)withopen(model_file_name,'wb')asf:pickle.dump(regression,f)five_times_model_filename='sklearn_5x_model.pkl'train_and_save_model(x,y,five_times_model_filename)# Change y to be 10 times, and output a 10 times table.ten_times_model_filename='sklearn_10x_model.pkl'train_and_save_model(x,y,ten_times_model_filename)y=(x*10).reshape(-1,1)train_and_save_model(x,y,'sklearn_10x_model.pkl')Create a scikit-learn RunInference pipeline
This section demonstrates how to do the following:
- Define a scikit-learn model handler that accepts an
array_likeobject as input. - Read the data from BigQuery.
- Use the scikit-learn trained model and the scikit-learn RunInference transform on unkeyed data.
%pipinstall--upgradegoogle-cloud-bigquery--quietgcloudconfigsetproject$projectUpdated property [core/project].
# Populated BigQuery tablefromgoogle.cloudimportbigqueryclient=bigquery.Client(project=project)# Make sure the dataset_id is unique in your project.dataset_id='{project}.maths'.format(project=project)dataset=bigquery.Dataset(dataset_id)# Modify the location based on your project configuration.dataset.location='US'dataset=client.create_dataset(dataset,exists_ok=True)# Table name in the BigQuery dataset.table_name='maths_problems_1'query=""" CREATE OR REPLACE TABLE{project}.maths.{table} ( key STRING OPTIONS(description="A unique key for the maths problem"), value FLOAT64 OPTIONS(description="Our maths problem" ) ); INSERT INTO maths.{table} VALUES ("first_example", 105.00), ("second_example", 108.00), ("third_example", 1000.00), ("fourth_example", 1013.00)""".format(project=project,table=table_name)create_job=client.query(query)create_job.result()<google.cloud.bigquery.table._EmptyRowIterator at 0x7f97abb4e850>
sklearn_model_handler=SklearnModelHandlerNumpy(model_uri=five_times_model_filename)pipeline_options=PipelineOptions().from_dictionary({'temp_location':f'gs://{bucket}/tmp'})# Define the BigQuery table specification.table_name='maths_problems_1'table_spec=f'{project}:maths.{table_name}'withbeam.Pipeline(options=pipeline_options)asp:(p|"ReadFromBQ" >>beam.io.ReadFromBigQuery(table=table_spec)|"ExtractInputs" >>beam.Map(lambdax:[x['value']])|"RunInferenceSklearn" >>RunInference(model_handler=sklearn_model_handler)|beam.Map(print))PredictionResult(example=[1000.0], inference=array([5000.]))PredictionResult(example=[1013.0], inference=array([5065.]))PredictionResult(example=[108.0], inference=array([540.]))PredictionResult(example=[105.0], inference=array([525.]))
Use sklearn RunInference on keyed inputs
This section demonstrates how to do the following:
- Wrap the
SklearnModelHandlerNumpyobject aroundKeyedModelHandlerto handle keyed data. - Read the data from BigQuery.
- Use the sklearn trained model and the sklearn RunInference transform on a keyed data.
sklearn_model_handler=SklearnModelHandlerNumpy(model_uri=five_times_model_filename)keyed_sklearn_model_handler=KeyedModelHandler(sklearn_model_handler)pipeline_options=PipelineOptions().from_dictionary({'temp_location':f'gs://{bucket}/tmp'})withbeam.Pipeline(options=pipeline_options)asp:(p|"ReadFromBQ" >>beam.io.ReadFromBigQuery(table=table_spec)|"ExtractInputs" >>beam.Map(lambdax:(x['key'],[x['value']]))|"RunInferenceSklearn" >>RunInference(model_handler=keyed_sklearn_model_handler)|beam.Map(print))('third_example', PredictionResult(example=[1000.0], inference=array([5000.])))('fourth_example', PredictionResult(example=[1013.0], inference=array([5065.])))('second_example', PredictionResult(example=[108.0], inference=array([540.])))('first_example', PredictionResult(example=[105.0], inference=array([525.])))Run multiple models
This code creates a pipeline that takes two RunInference transforms with different models and then combines the output.
fromtypingimportTupledefformat_output(run_inference_output)->str:"""Takes input from RunInference for scikit-learn and extracts the output."""key,prediction_result=run_inference_outputexample=prediction_result.example[0]prediction=prediction_result.inference[0]returnf"key ={key}, example ={example} -> predictions{prediction}"five_times_model_handler=KeyedModelHandler(SklearnModelHandlerNumpy(model_uri=five_times_model_filename))ten_times_model_handler=KeyedModelHandler(SklearnModelHandlerNumpy(model_uri=ten_times_model_filename))pipeline_options=PipelineOptions().from_dictionary({'temp_location':f'gs://{bucket}/tmp'})withbeam.Pipeline(options=pipeline_options)asp:inputs=(p|"ReadFromBQ" >>beam.io.ReadFromBigQuery(table=table_spec))five_times=(inputs|"Extract For 5" >>beam.Map(lambdax:('{}{}'.format(x['key'],'* 5'),[x['value']]))|"5 times" >>RunInference(model_handler=five_times_model_handler))ten_times=(inputs|"Extract For 10" >>beam.Map(lambdax:('{}{}'.format(x['key'],'* 10'),[x['value']]))|"10 times" >>RunInference(model_handler=ten_times_model_handler))_=((five_times,ten_times)|"Flattened" >>beam.Flatten()|"format output" >>beam.Map(format_output)|"Print" >>beam.Map(print))key = third_example * 10, example = 1000.0 -> predictions 10000.0key = fourth_example * 10, example = 1013.0 -> predictions 10130.0key = second_example * 10, example = 108.0 -> predictions 1080.0key = first_example * 10, example = 105.0 -> predictions 1050.0key = third_example * 5, example = 1000.0 -> predictions 5000.0key = fourth_example * 5, example = 1013.0 -> predictions 5065.0key = second_example * 5, example = 108.0 -> predictions 540.0key = first_example * 5, example = 105.0 -> predictions 525.0
Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-10-22 UTC.
Run in Google Colab
View source on GitHub