Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5e3f77d

Browse files
TST: Add edge case tests for _roc_collinear_free_mask_xp (n<=2) and minimal roc_curve points to cover drop_intermediate fixes
1 parent1cd9a68 commit5e3f77d

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

‎sklearn/metrics/tests/test_ranking.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
top_k_accuracy_score,
2626
)
2727
fromsklearn.metrics._rankingimport_dcg_sample_scores,_ndcg_sample_scores
28+
fromsklearn.metrics._rankingimport_roc_collinear_free_mask_xp
2829
fromsklearn.model_selectionimporttrain_test_split
2930
fromsklearn.preprocessingimportlabel_binarize
3031
fromsklearn.random_projectionimport_sparse_random_matrix
@@ -2268,3 +2269,46 @@ def test_roc_curve_with_probablity_estimates(global_random_seed):
22682269
y_score=rng.rand(10)
22692270
_,_,thresholds=roc_curve(y_true,y_score)
22702271
assertnp.isinf(thresholds[0])
2272+
2273+
2274+
deftest_roc_collinear_free_mask_xp_small_arrays():
2275+
"""Test the helper function _roc_collinear_free_mask_xp with small arrays.
2276+
2277+
This test covers the edge cases where the number of points is <= 2,
2278+
which triggers an early return in the function.
2279+
"""
2280+
2281+
# Case 1: Empty array (n=0)
2282+
fps=np.array([],dtype=float)
2283+
tps=np.array([],dtype=float)
2284+
mask=_roc_collinear_free_mask_xp(fps,tps,xp=np,device="cpu")
2285+
assertmask.shape== (0,)
2286+
2287+
# Case 2: Single point (n=1)
2288+
fps=np.array([0.0])
2289+
tps=np.array([0.0])
2290+
mask=_roc_collinear_free_mask_xp(fps,tps,xp=np,device="cpu")
2291+
assert_array_equal(mask, [0])
2292+
2293+
# Case 3: Two points (n=2)
2294+
fps=np.array([0.0,1.0])
2295+
tps=np.array([0.0,1.0])
2296+
mask=_roc_collinear_free_mask_xp(fps,tps,xp=np,device="cpu")
2297+
assert_array_equal(mask, [0,1])
2298+
2299+
2300+
deftest_roc_curve_minimal_points():
2301+
"""Test roc_curve with minimal points (no intermediates to drop).
2302+
2303+
This case triggers the early return in _roc_collinear_free_mask_xp
2304+
when called internally by roc_curve.
2305+
"""
2306+
# Case where we only have 2 points in the curve after prepending
2307+
y_true= [0,1]
2308+
y_score= [0,0]# All scores equal
2309+
fpr,tpr,thresholds=roc_curve(y_true,y_score,drop_intermediate=True)
2310+
2311+
# Expected: (0,0), (1,1) points with thresholds [inf, 0]
2312+
assert_array_equal(fpr, [0.0,1.0])
2313+
assert_array_equal(tpr, [0.0,1.0])
2314+
assert_array_equal(thresholds, [np.inf,0.0])

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp