Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork26k
MNT Add_check_sample_weights
to classification metrics#31701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Conversation
github-actionsbot commentedJul 4, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
@@ -596,7 +596,7 @@ def test_multilabel_confusion_matrix_errors(): | |||
# Bad sample_weight | |||
withpytest.raises(ValueError,match="inconsistent numbers of samples"): | |||
multilabel_confusion_matrix(y_true,y_pred,sample_weight=[1,2]) | |||
withpytest.raises(ValueError,match="should be a 1darray"): | |||
withpytest.raises(ValueError,match="Sample weights must be 1Darray or scalar"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Note that this is because this error is now being raised by_check_sample_weight
:
scikit-learn/sklearn/utils/validation.py
Lines 2207 to 2208 in9489ee6
ifsample_weight.ndim!=1: | |
raiseValueError("Sample weights must be 1D array or scalar") |
Instead ofcolumn_or_1d
, which is called after_check_sample_weight
:
sample_weight=column_or_1d(sample_weight,device=device_) |
This is actually a fix because the old error message was: "y should be a 1d array, got an array of shape (3, 3) instead." - which is mis-leading as it was actuallysample_weight
For reference here is some discussion in the original PR adding_check_sample_weight
:https://github.com/scikit-learn/scikit-learn/pull/14307/files#r302938269
I'm not 100% about the addition of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Adding sample weight checking to_check_targets
looks good because it allows to remove a redundantcheck_consistent_length
call. The name_check_targets
feels a bit off now though.
Thecheck_sample_weight
added in the PR are only used for the checks but not for the conversions of generations thatcheck_sample_weight
is able to do. It would be interesting to check if it's enough. Passing an int assample_weight
for instance. We might have to make_check_targets
return the validated sample weights like it does fory_true
andy_pred
.
sklearn/metrics/_classification.py Outdated
else: | ||
sample_weight=np.asarray(sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
We could replace these lines by a call to_check_sample_weight
, no ?
lucyleeow commentedJul 7, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
That is a very good point 🤔 I don't know (we kept
🤦 I meant to return Let me check what can be deleted with that. |
Let's keep the name as is then :) |
xp,_,device=get_namespace_and_device(y_true,y_pred,sample_weight) | ||
ifsample_weightisNone: | ||
weight_average=1.0 | ||
else: | ||
sample_weight=xp.asarray(sample_weight,device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Althoughcheck_array
inside_check_sample_weights
does not specifydevice
, because we usexp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
I think sample weight should be on the right device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I thought that if the 3 arrays are not on the same device an error is raised, so indeed if we get there, they should be on the same device.
lucyleeow commentedJul 9, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
I've hit a snag. scikit-learn/sklearn/utils/validation.py Lines 2188 to 2189 in0872e9a
This was not a problem for regression targets because we mostly wanted it to be a float (i.e., we use For classification metrics, it is fine for Looking at the uses of I am wondering if we could remove the It is also confusing that if you specifically specify an int dtype, The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in WDYT? |
Actually I see two cases where Ithink it is possible that it is passed an int dtype:
and
|
It's not straightforward. IIRC this upcast was added to avoid converting float sample weights to integers, possibly loosing precision in the process.
+1 to try that approach |
if sample_weight is not None: | ||
sample_weight = column_or_1d(sample_weight, device=device_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
We keep that for the device, right ?
It feels a bit weird. Do you thinkcheck_targets
should be responsible for this ?
ifnotallow_non_float: | ||
ifdtypeisnotNoneanddtypenotinfloat_dtypes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Should be a single if condition
Something that is not clear to me is when |
Reference Issues/PRs
Follow up to#30886
What does this implement/fix? Explain your changes.
check_consistent_length
ony_true
,y_prob
andsample_weights
in_check_targets
- this avoids the secondcheck_consistent_length
, which means that all length checks occur at the start and you know who is raising errors (note this is not about avoiding the double checking, as they are not expensive checks)_check_sample_weight
to_check_targets
,_validate_multiclass_probabilistic_prediction
and_validate_binary_probabilistic_prediction
- I am not 100% sure on this. Currently this check is only being done ind2_log_loss_score
. This check does the following:ComplexWarning
to error, though not sure if this warning is only raised for numpy arrays or other array API arrays)array.ndim
not greater than 3This seems like reasonable checks to have. The onlypotential downside is that these checks would take a bit more time, but I don't think this is really a problem.
cc@ogrisel
Any other comments?