Movatterモバイル変換


[0]ホーム

URL:


Skip to main content

This browser is no longer supported.

Upgrade to Microsoft Edge to take advantage of the latest features, security updates, and technical support.

Download Microsoft EdgeMore info about Internet Explorer and Microsoft Edge
Table of contentsExit focus mode

Retrain a model

  • 2024-12-19
Feedback

In this article

Learn how to retrain a machine learning model in ML.NET.

The world and its data change constantly. As such, models need to change and update as well. ML.NET provides functionality for retraining models using learned model parameters as a starting point to continually build on previous experience rather than starting from scratch every time.

The following algorithms are retrainable in ML.NET:

Load pretrained model

First, load the pretrained model into your application. To learn more about loading training pipelines and models, seeSave and load a trained model.

// Create MLContextMLContext mlContext = new MLContext();// Define DataViewSchema of data prep pipeline and trained modelDataViewSchema dataPrepPipelineSchema, modelSchema;// Load data preparation pipelineITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);// Load trained modelITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);

Extract pretrained model parameters

Once the model is loaded, extract the learned model parameters by accessing theModel property of the pretrained model. The pretrained model was trained using the linear regression modelOnlineGradientDescentTrainer, which creates aRegressionPredictionTransformer that outputsLinearRegressionModelParameters. These model parameters contain the learned bias and weights or coefficients of the model. These values are used as a starting point for the new retrained model.

// Extract trained model parametersLinearRegressionModelParameters originalModelParameters =    ((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;

Note

The model parameters output depend on the algorithm used. For exampleOnlineGradientDescentTrainer usesLinearRegressionModelParameters, whileLbfgsMaximumEntropyMulticlassTrainer outputsMaximumEntropyModelParameters. When extracting model parameters, cast to the appropriate type.

Retrain a model

The process for retraining a model is no different than that of training a model. The only difference is that you pass an additional argument to theFit(IDataView, LinearModelParameters) method: the original learned model parameters.Fit() uses them as a starting point in the retraining process.

// New DataHousingData[] housingData = new HousingData[]{    new HousingData    {        Size = 850f,        HistoricalPrices = new float[] { 150000f,175000f,210000f },        CurrentPrice = 205000f    },    new HousingData    {        Size = 900f,        HistoricalPrices = new float[] { 155000f, 190000f, 220000f },        CurrentPrice = 210000f    },    new HousingData    {        Size = 550f,        HistoricalPrices = new float[] { 99000f, 98000f, 130000f },        CurrentPrice = 180000f    }};//Load New DataIDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);// Preprocess DataIDataView transformedNewData = dataPrepPipeline.Transform(newData);// Retrain modelRegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =    mlContext.Regression.Trainers.OnlineGradientDescent()        .Fit(transformedNewData, originalModelParameters);

At this point, you can save your retrained model and use it in your application. For more information, seeSave and load a trained model andMake predictions with a trained model.

Compare model parameters

How do you know if retraining actually happened? One way is to compare whether the retrained model's parameters are different than those of the original model. The following code sample compares the original against the retrained model weights and outputs them to the console.

// Extract Model Parameters of re-trained modelLinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;// Inspect Change in Weightsvar weightDiffs =    originalModelParameters.Weights.Zip(        retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();Console.WriteLine("Original | Retrained | Difference");for(int i=0;i < weightDiffs.Count();i++){    Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");}

The following table shows what the output might look like.

OriginalRetrainedDifference
33039.8656293.76-23253.9
29099.1449586.03-20486.89
28938.3848609.23-19670.85
30484.0253745.43-23261.41
Collaborate with us on GitHub
The source for this content can be found on GitHub, where you can also create and review issues and pull requests. For more information, seeour contributor guide.

Feedback

Was this page helpful?

YesNo

In this article

Was this page helpful?

YesNo