Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6343cd7

Browse files
authored
DOC Usefrom_cv_results inplot_roc_crossval.py (#31455)
1 parent4560abc commit6343cd7

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

‎examples/model_selection/plot_roc_crossval.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,46 +62,56 @@
6262
# Classification and ROC analysis
6363
# -------------------------------
6464
#
65-
# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and
66-
# plot the ROC curves fold-wise. Notice that the baseline to define the chance
65+
# Here we run :func:`~sklearn.model_selection.cross_validate` on a
66+
# :class:`~sklearn.svm.SVC` classifier, then use the computed cross-validation results
67+
# to plot the ROC curves fold-wise. Notice that the baseline to define the chance
6768
# level (dashed ROC curve) is a classifier that would always predict the most
6869
# frequent class.
6970

7071
importmatplotlib.pyplotasplt
7172

7273
fromsklearnimportsvm
7374
fromsklearn.metricsimportRocCurveDisplay,auc
74-
fromsklearn.model_selectionimportStratifiedKFold
75+
fromsklearn.model_selectionimportStratifiedKFold,cross_validate
7576

7677
n_splits=6
7778
cv=StratifiedKFold(n_splits=n_splits)
7879
classifier=svm.SVC(kernel="linear",probability=True,random_state=random_state)
80+
cv_results=cross_validate(
81+
classifier,X,y,cv=cv,return_estimator=True,return_indices=True
82+
)
83+
84+
prop_cycle=plt.rcParams["axes.prop_cycle"]
85+
colors=prop_cycle.by_key()["color"]
86+
curve_kwargs_list= [
87+
dict(alpha=0.3,lw=1,color=colors[fold%len(colors)])forfoldinrange(n_splits)
88+
]
89+
names= [f"ROC fold{idx}"foridxinrange(n_splits)]
7990

80-
tprs= []
81-
aucs= []
8291
mean_fpr=np.linspace(0,1,100)
92+
interp_tprs= []
93+
94+
_,ax=plt.subplots(figsize=(6,6))
95+
viz=RocCurveDisplay.from_cv_results(
96+
cv_results,
97+
X,
98+
y,
99+
ax=ax,
100+
name=names,
101+
curve_kwargs=curve_kwargs_list,
102+
plot_chance_level=True,
103+
)
83104

84-
fig,ax=plt.subplots(figsize=(6,6))
85-
forfold, (train,test)inenumerate(cv.split(X,y)):
86-
classifier.fit(X[train],y[train])
87-
viz=RocCurveDisplay.from_estimator(
88-
classifier,
89-
X[test],
90-
y[test],
91-
name=f"ROC fold{fold}",
92-
curve_kwargs=dict(alpha=0.3,lw=1),
93-
ax=ax,
94-
plot_chance_level=(fold==n_splits-1),
95-
)
96-
interp_tpr=np.interp(mean_fpr,viz.fpr,viz.tpr)
105+
foridxinrange(n_splits):
106+
interp_tpr=np.interp(mean_fpr,viz.fpr[idx],viz.tpr[idx])
97107
interp_tpr[0]=0.0
98-
tprs.append(interp_tpr)
99-
aucs.append(viz.roc_auc)
108+
interp_tprs.append(interp_tpr)
100109

101-
mean_tpr=np.mean(tprs,axis=0)
110+
mean_tpr=np.mean(interp_tprs,axis=0)
102111
mean_tpr[-1]=1.0
103112
mean_auc=auc(mean_fpr,mean_tpr)
104-
std_auc=np.std(aucs)
113+
std_auc=np.std(viz.roc_auc)
114+
105115
ax.plot(
106116
mean_fpr,
107117
mean_tpr,
@@ -111,7 +121,7 @@
111121
alpha=0.8,
112122
)
113123

114-
std_tpr=np.std(tprs,axis=0)
124+
std_tpr=np.std(interp_tprs,axis=0)
115125
tprs_upper=np.minimum(mean_tpr+std_tpr,1)
116126
tprs_lower=np.maximum(mean_tpr-std_tpr,0)
117127
ax.fill_between(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp