Developing with the Plotting API#
Scikit-learn defines a simple API for creating visualizations for machinelearning. The key features of this API are to run calculations once and to havethe flexibility to adjust the visualizations after the fact. This section isintended for developers who wish to develop or maintain plotting tools. Forusage, users should refer to theUser Guide.
Plotting API Overview#
This logic is encapsulated into a display object where the computed data isstored and the plotting is done in aplot
method. The display object’s__init__
method contains only the data needed to create the visualization.Theplot
method takes in parameters that only have to do with visualization,such as a matplotlib axes. Theplot
method will store the matplotlib artistsas attributes allowing for style adjustments through the display object. TheDisplay
class should define one or both class methods:from_estimator
andfrom_predictions
. These methods allow creating theDisplay
object fromthe estimator and some data or from the true and predicted values. After theseclass methods create the display object with the computed values, then call thedisplay’s plot method. Note that theplot
method defines attributes relatedto matplotlib, such as the line artist. This allows for customizations aftercalling theplot
method.
For example, theRocCurveDisplay
defines the following methods andattributes:
classRocCurveDisplay:def__init__(self,fpr,tpr,roc_auc,estimator_name):...self.fpr=fprself.tpr=tprself.roc_auc=roc_aucself.estimator_name=estimator_name@classmethoddeffrom_estimator(cls,estimator,X,y):# get the predictionsy_pred=estimator.predict_proba(X)[:,1]returncls.from_predictions(y,y_pred,estimator.__class__.__name__)@classmethoddeffrom_predictions(cls,y,y_pred,estimator_name):# do ROC computation from y and y_predfpr,tpr,roc_auc=...viz=RocCurveDisplay(fpr,tpr,roc_auc,estimator_name)returnviz.plot()defplot(self,ax=None,name=None,**kwargs):...self.line_=...self.ax_=axself.figure_=ax.figure_
Read more inROC Curve with Visualization APIand theUser Guide.
Plotting with Multiple Axes#
Some of the plotting tools likefrom_estimator
andPartialDependenceDisplay
support plotting onmultiple axes. Two different scenarios are supported:
1. If a list of axes is passed in,plot
will check if the number of axes isconsistent with the number of axes it expects and then draws on those axes. 2.If a single axes is passed in, that axes defines a space for multiple axes tobe placed. In this case, we suggest using matplotlib’s~matplotlib.gridspec.GridSpecFromSubplotSpec
to split up the space:
importmatplotlib.pyplotaspltfrommatplotlib.gridspecimportGridSpecFromSubplotSpecfig,ax=plt.subplots()gs=GridSpecFromSubplotSpec(2,2,subplot_spec=ax.get_subplotspec())ax_top_left=fig.add_subplot(gs[0,0])ax_top_right=fig.add_subplot(gs[0,1])ax_bottom=fig.add_subplot(gs[1,:])
By default, theax
keyword inplot
isNone
. In this case, the singleaxes is created and the gridspec api is used to create the regions to plot in.
See for example,from_estimator
which plots multiple lines and contours using this API. The axes defining thebounding box are saved in abounding_ax_
attribute. The individual axescreated are stored in anaxes_
ndarray, corresponding to the axes position onthe grid. Positions that are not used are set toNone
. Furthermore, thematplotlib Artists are stored inlines_
andcontours_
where the key is theposition on the grid. When a list of axes is passed in, theaxes_
,lines_
,andcontours_
are a 1d ndarray corresponding to the list of axes passed in.