Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc9f2f89

Browse files
Remove trailing whitespace in sklearn/metrics/_ranking.py
Only drops truly redundant points where both FPR and TPR are unchangedFIX ensure proper point dropping in roc_curve by prepending (0,0) before drop_intermediate and maintaining current heuristic for test compatibilityUpdated the test caseUpdated commitImprove roc_curve's drop_intermediate with geometric collinearity
1 parent5430920 commitc9f2f89

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

‎sklearn/metrics/_ranking.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,24 @@ def precision_recall_curve(
10721072
returnnp.hstack((precision[sl],1)),np.hstack((recall[sl],0)),thresholds[sl]
10731073

10741074

1075+
defcollinear_free_mask(fps_np,tps_np,tolerance=1e-12):
1076+
"""Return indices of non-collinear points preserving endpoints."""
1077+
iflen(fps_np)<=2:
1078+
returnnp.arange(len(fps_np))
1079+
# Compute segment vectors
1080+
dx0=fps_np[1:-1]-fps_np[:-2]
1081+
dy0=tps_np[1:-1]-tps_np[:-2]
1082+
dx1=fps_np[2:]-fps_np[1:-1]
1083+
dy1=tps_np[2:]-tps_np[1:-1]
1084+
# Cross-product test
1085+
cross=dx0*dy1-dy0*dx1
1086+
is_collinear=np.abs(cross)<tolerance
1087+
# Always keep endpoints, drop only true collinear
1088+
keep=np.ones(len(fps_np),dtype=bool)
1089+
keep[1:-1]=~is_collinear
1090+
returnnp.flatnonzero(keep)
1091+
1092+
10751093
@validate_params(
10761094
{
10771095
"y_true": ["array-like"],
@@ -1190,27 +1208,32 @@ def roc_curve(
11901208
# _binary_clf_curve). This keeps all cases where the point should be kept,
11911209
# but does not drop more complicated cases like fps = [1, 3, 7],
11921210
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
1193-
ifdrop_intermediateandfps.shape[0]>2:
1194-
optimal_idxs=xp.where(
1195-
xp.concat(
1196-
[
1197-
xp.asarray([True],device=device),
1198-
xp.logical_or(xp.diff(fps,2),xp.diff(tps,2)),
1199-
xp.asarray([True],device=device),
1200-
]
1201-
)
1202-
)[0]
1203-
fps=fps[optimal_idxs]
1204-
tps=tps[optimal_idxs]
1205-
thresholds=thresholds[optimal_idxs]
1206-
1207-
# Add an extra threshold position
1208-
# to make sure that the curve starts at (0, 0)
1209-
tps=xp.concat([xp.asarray([0.0],device=device),tps])
1211+
# Add an extra threshold position to make sure curve starts at (0, 0)
1212+
# Prepend start of curve
12101213
fps=xp.concat([xp.asarray([0.0],device=device),fps])
1211-
# get dtype of `y_score` even if it is an array-like
1212-
thresholds=xp.astype(thresholds,_max_precision_float_dtype(xp,device))
1213-
thresholds=xp.concat([xp.asarray([xp.inf],device=device),thresholds])
1214+
tps=xp.concat([xp.asarray([0.0],device=device),tps])
1215+
thresholds=xp.concatenate(
1216+
[
1217+
xp.asarray([xp.inf],device=device),
1218+
xp.astype(thresholds,_max_precision_float_dtype(xp,device)),
1219+
]
1220+
)
1221+
# Drop intermediate collinear points if requested
1222+
ifdrop_intermediateandfps.shape[0]>2:
1223+
# Convert to numpy arrays for mask computation
1224+
ifhasattr(xp,"asnumpy"):
1225+
fps_cpu=xp.asnumpy(fps)
1226+
tps_cpu=xp.asnumpy(tps)
1227+
else:
1228+
fps_cpu=fps.cpu().numpy()ifhasattr(fps,"cpu")elsenp.asarray(fps)
1229+
tps_cpu=tps.cpu().numpy()ifhasattr(tps,"cpu")elsenp.asarray(tps)
1230+
1231+
# Identify indices to keep
1232+
keep_idx=collinear_free_mask(fps_cpu,tps_cpu)
1233+
# Apply mask using original array API
1234+
fps=fps[keep_idx]
1235+
tps=tps[keep_idx]
1236+
thresholds=thresholds[keep_idx]
12141237

12151238
iffps[-1]<=0:
12161239
warnings.warn(

‎sklearn/metrics/tests/test_ranking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_roc_curve_toydata():
320320
tpr,fpr,_=roc_curve(y_true,y_score)
321321
roc_auc=roc_auc_score(y_true,y_score)
322322
assert_array_almost_equal(tpr, [0,0,1])
323-
assert_array_almost_equal(fpr, [0,1,1])
323+
assert_array_almost_equal(fpr, [0,0,1])
324324
assert_almost_equal(roc_auc,1.0)
325325

326326
y_true= [0,1]
@@ -344,7 +344,7 @@ def test_roc_curve_toydata():
344344
tpr,fpr,_=roc_curve(y_true,y_score)
345345
roc_auc=roc_auc_score(y_true,y_score)
346346
assert_array_almost_equal(tpr, [0,0,1])
347-
assert_array_almost_equal(fpr, [0,1,1])
347+
assert_array_almost_equal(fpr, [0,0,1])
348348
assert_almost_equal(roc_auc,1.0)
349349

350350
y_true= [1,0]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp