Model interpretation (or explainability) recently gained attention in the field of machine learning because ML models are very accurate, but often times it’s relevant to also know what is happening behind the scenes when model makes predictions. This topic actually involves both statistics and machine learning, so lets start with a theoretical discussion and then jump to a practical application.
Classical statistics can be categorized intodescriptiveandinferential. Descriptive statistics deals with exploring the observed sample data while statistical inference is concerned with making propositions about a population using data drawn from the population with some form of sampling.
Machine learning is sometimes called applied or glorified statistics and there are many more opinions. There is some overlap between statistics and machine learning, but distinctions between these fields are mostly in application. In 2001 paperStatistical Modeling: The Two Cultures Leo Breiman separated the field of statistical modeling into two communities:
Machine learning cares about model performance, but explainability falls in learning information category (first community — statistics). Then maybe there is something to be learned from statisticians? R. A. Fisher specified three main aspects to consider for a valid inference:
.fit()
method of a model.In statistics the goal of modeling is approximating and understanding the data-generating process. It’s clear that if different ML algorithms (SVM, k-NN, XGBoost, etc.) would be fitted on the same data set there would also be differences in how they explain the data. None of them are regarded as the “true model” which generated the data, but rather an abstraction of empirical data, which can help answer certain questions. Moreover:
“The model does not represent a belief about or a commitment to the data generation process. Its purpose is purely functional. No ML practitioner would be prepared to testify to the “validity” of a model; this has no meaning in Machine Learning, since the model is really only instrumental to its performance.”[4]
To sum up, machine learning cares about performance and many times feature relationships with the target are considered to be a blackbox. But as soon as model explainability becomes important, then more care should be taken (put your statistician’s hat). For instance, some statistical models assume feature independence, while a Data Scientist may deem this to be unimportant. If there are 10 importantXpredictors which have a high degree of multicollinearity (lets’ say Pearson correlation ρ>0.9), then what may happen is that they will divide the importance of the same underlying driver and it will be more complicated to discern inner workings of such model than of the one having less features.
InterpretML is a new framework backed by Microsoft with a simple idea to bring existing model interpretation frameworks in one place and package them in a practical way. It works with both types of models: glassbox (interpretable — linear models, decision trees, etc.) and blackbox (non interpretable — gradient boosting, SVM, etc.).
Another interesting thing which InterpretML brings is an implementation of a glassbox model — Explainable Boosting Machine (EBM). Authors claim that it is designed not only to be interpretable, but also to have a comparable accuracy to such popular algorithms like Random Forest, XGBoost, etc. You can learn more about its mechanics in InterpretML repository and their research paper [5, 6].
For the following experimentPredict Ad Clicksdata set from hackerearth was selected. To download the data visit thelink (login required) and clonethis repository to access the code.
For the sake of experiment, a smaller sample was selected — 100,000. This data set is convenient to experiment on because it has only a handful (5 categorical and 4 numerical) predictors to perform a binary classification. More details:
ID
is completely unique and object type, it won't be used;datetime
contains detailed time and can be preprocessed in different ways to maximize its utility;siteid
,offerid
,category
,merchant
will be used as numeric columns (even though they actually represent high cardinality categories);countrycode
,browserid
,devid
are categorical and therefore will be preprocessed and converted to numeric.click
is the target variable which will be classified. 1 refers to instances where it was clicked.Data sample is randomly split into training (70%) and holdout (30%) sets. The random split maintains an approximately the same proportion of target variable in both sets (96% — 0 and 4% — 1).
In this step lets just briefly check feature variation against target variable. Popular visualization libraries like matplotlib and seaborn could be used but that would be too mainstream for a blog post. Lets try something different. For example, Altair visualization library has somewhat intimidating syntax, but produces good-looking plots.
For numerical features (again, they are actually just high cardinality categorical features) density plots should do the trick. With visualization libraries it’s sometimes simpler to write a loop than to make a facet plot:
Altair density plots produce quite choppy distributions (other libraries smooth them out more by default). Distributions are more or less aligned across numerical features, but there is a notable bump incategory
between 35,000 and 50,000 values:
Categorical features with low cardinality can be conveyed with a heatmap. Since the targetclick
is imbalanced (there are significantly less clicks than no clicks) the key here is to displayrelative frequency by group(0 and 1 click categories sum to 1 respectively if there are no missing values):
In this case, categorical features seem to be really helpful in separating clicks vs no clicks. For instance, all these factors appear to contribute to more clicks:
countrycode
;browserid
;devid
.Note that this data set is most likely not fully cleaned because same browsers are named differently as separate categories (IE, Internet Explorer,etc.), but here this will not be accounted for.
Typically model pipeline is more complex than just calling.fit()
and.predict()
methods. Therefore, lets construct a simple, but yet realistic pipeline. It’s important to look ahead while doing that. If any methods in model interpretation framework do not support missing values or categorical features, then this should be addressedbefore the model is built.For instance, InterpretML’sExplainableBoostingClassifier
will raise the following error if any of the features contain missing values:
ValueError: Missing values are currently not supported.
Pipeline will consist of the following steps:
datetime
feature to numeric (day,hour,etc.) so that a model can make the full use of the date variable. For that, a customDateTransformer()
will be used (inspired by thisblog post andStackOverflow answer).OrdinalEncoder()
was selected from category_encoderslibrary because scikit-learnversion doesn’t support missing values or out of sample encoding.SimpleImputer(fill_value=0)
should do the trick by replacing missing values with 0 (again, let the algorithm do all the thinking).ExplainableBoostingClassifier
and the otherLGBMClassifier
.After building two pipelines it’s important to check whether the models are useful at prediction. The goal is to build explainable models which are not fitted to noise (overfit data) and are able to generalize out of sample. Models which approximate data well, should also provide relevant insight into underlying structure of data.
Without getting into accuracy metric details, one handy method can be used to validate model — scikit-learn’sclassification_report()
. In this case, it was used for both models on training (70k) and holdout (30k) sets and compressed into a single data frame:
Training set metrics are displayed here just to see how well did the model learn training data. LightGBM can easily overfit smaller data sets and here it shows higher metric values on the training set than Explainable Boosting Machine.
On the other hand, both models show very similar accuracy on the test set. This is an imbalanced classification task and the majority class (0) are nearly perfectly predicted while minority class (1) is predicted reasonably well.
In this section InterpretML and SHAP Python libraries will be tested on previously created pipelines. The unfortunate part is that model interpretation frameworks refuse to be fully compatible with pipelines and require workarounds.
Lets start with InterpretML, it has a usefulClassHistogram()
, which enables doing some EDA on data. There is a caveat. In order to use this, the data set can’t contain missing values, which means it has to undergo pipeline steps. Therefore, lets create a training set which is preprocessed and then visualize it.
from interpret import show
from interpret.data import ClassHistogramX_t_prep = pd.DataFrame(data=pipeline_ebm[0:3].transform(X_t), columns=feature_names)
hist = ClassHistogram().explain_data(X_t_prep, y_t, name = 'Train Data')
show(hist)
This creates a dashboard which displays Plotly interactive histograms colored byclick
counts. The supplied preprocessed data set contains only numeric features, for examplebrowserid
now is in a numeric representation from 1 to 12. Two ofbrowserid
categories show a noticeably high percentage of clicks (but with this encoding it’s unclear which ones):
Next thing to check is global explanations with both InterpretML and SHAP. Since pipeline is not supported directly, the estimator has to be extracted in each case:
# InterpretML
ebm_global = pipeline_ebm['model'].explain_global()
show(ebm_global)
# SHAP
explainer = shap.TreeExplainer(pipeline_lgb['model'])
shap_values = explainer.shap_values(X_t_prep)
shap.summary_plot(shap_values, X_t_prep, plot_type="bar", plot_size=(10,5))
Feature importance summary shows that two categorical featurescountrycode
andbrowserid
are very important. Their ability to separate clicks was already seen in EDA section. There is definitely a slight disagreement between both estimators in feature importance. What they both agree is thatmonth
is the worst performing feature without any importance. There is a good reason for that — it has no variance (all clicks happened in same month).
Now lets look at single feature influence on the target variable. Typically clear relationships can be seen for the most important features. To view this, InterpretML requires to call the global dashboard again, while SHAP hasdependence_plot
method:
#InterpretML
show(ebm_global)
#SHAP
shap.dependence_plot(ind="countrycode", shap_values=shap_values[0], features=X_t_prep, interaction_index=None)
In this particular case, InterpretML has a more appealing visualization, but they both tell the same story (just in opposite directions). SHAP dependence plot tells that 1 and 5 country codes strongly influence theclick
prediction downwards to 0. InterpretML also shows that the same country codes have the strongest influence.
And finally, for local explanation comparison, one random observation was selected.
ind = [69950]
# InterpretML
ebm_local = pipeline_ebm['model'].explain_local(X_t_prep.iloc[ind], y_t.iloc[ind], name='Local')
show(ebm_local)
# SHAP
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values[0][ind,:], X_t_prep.iloc[ind])
EBM predicted 0.35 and LightGBM 0.88, while true value was 1 — clicked. Both SHAP and InterpertML plots display thatcountrycode
was the main driver in their respective explanations.day
= 17 has some effect in the decision, but this data sample covers only one month and in a more realistic application constructing a variable something likedayofweek
should be more useful.
The downside of SHAP’s so called “force plot” is that feature names which had the smallest impact are not visible.
In this blog post it was briefly discussed what it takes to build a pipeline which is explainable. To summarize:
month
really didn’t make any decisions in LightGBM or Explaining Boosting Machine (EBM) models. In this particular case this feature could be dropped but it would probably become an interesting feature in the real world once there are more months of data.ClassHistogram
, which can be useful to conduct EDA, but you may need to sample data. Since it's interactive and runs under Plotly, it may quickly become something that takes long to load as data grows.Currently InterpretML is in Alpha release and already shows some compelling features. Some functionality is missing, but this shouldn’t discourage to start testing it. In the roadmap there is planned support for missing values, improvements for categorical encoding, R language interface and more.
Thanks for reading, do you have examples/suggestions for improvements (code can be accessedhere)? Share them in the comments section, thanks! Also, I am happy to connect onLinkedIn!
Senior Data Scientist at Beyond Analysis. If you like my content consider following me on Medium and connecting on LinkedIn:linkedin.com/in/marius-vadeika/