Movatterモバイル変換


[0]ホーム

URL:


Introduction to PIE – A PartiallyInterpretable Model with Black-box Refinement

Tong Wang, Jingyi Yang, Yunyi Li and BoxiangWang

2025-01-20

Introduction toPIE

ThePIE package implements Partially InterpretableEstimators (PIE), a framework that jointly train an interpretable modeland a black-box model to achieve high predictive performance as well aspartial model transparency.

Installation

To install the development version from GitHub, run thefollowing:

# Install the R package from CRANinstall.packages("PIE")

Getting Started

This section demonstrates how to generate synthetic data for transferlearning and apply the ART framework using different models.

Generate Data

The functiondata_process() allows you to processdataset into the format that fits with PIE model, includingcross-validation dataset (such as training, validation and testing) andgroup indicators for group lasso.

library(PIE)# Load the training datadata("winequality")# Which columns are numerical?num_col<-1:11# Which columns are categorical?cat_col<-12# Which column is the response?y_col<-ncol(winequality)# Data Processingdat<-data_process(X =as.matrix(winequality[,-y_col]),y = winequality[, y_col],num_col = num_col,cat_col = cat_col,y_col = y_col)

Fitting PIE

Once the data is prepared, you can use thePIE_fit()function to train PIE model. In this example, we fit only with 5iterations using group lasso and XGBoost models.

# Fit a PIE modelfold<-1fit<-PIE_fit(X = dat$spl_train_X[[fold]],y = dat$train_y[[fold]],lasso_group = dat$lasso_group,X_orig = dat$orig_train_X[[fold]],lambda1 =0.01,lambda2 =0.01,iter =5,eta =0.05,nrounds =200)

Predicting PIE

Once your PIE model is trained, you can use thePIE_predict() function to predict on test data.

# Predictionpred<-predict(fit,X = dat$spl_validation_X[[fold]],X_orig = dat$orig_validation_X[[fold]])

Evaluate PIE

You can evaluate your PIE model’s performance withRPE(), which has formula\(RPE=\frac{\sum_i(y_i-\hat{y_i})^2}{\sum_i(y_i-\bar{y})^2}\),where\(\bar{y} = \frac{1}{n}\sum_i^ny_i\).

# Validationval_rrmse_test<-RPE(pred$total, dat$validation_y[[fold]])

[8]ページ先頭

©2009-2025 Movatter.jp