Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

MNT Add_check_sample_weights to classification metrics#31701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
lucyleeow wants to merge13 commits intoscikit-learn:main
base:main
Choose a base branch
Loading
fromlucyleeow:class_checks

Conversation

lucyleeow
Copy link
Member

Reference Issues/PRs

Follow up to#30886

What does this implement/fix? Explain your changes.

  • performcheck_consistent_length ony_true,y_prob andsample_weights in_check_targets - this avoids the secondcheck_consistent_length, which means that all length checks occur at the start and you know who is raising errors (note this is not about avoiding the double checking, as they are not expensive checks)
  • adds_check_sample_weight to_check_targets,_validate_multiclass_probabilistic_prediction and_validate_binary_probabilistic_prediction - I am not 100% sure on this. Currently this check is only being done ind2_log_loss_score. This check does the following:
    • ensures all values are finite
    • ensure not complex data (i.e. convertsComplexWarning to error, though not sure if this warning is only raised for numpy arrays or other array API arrays)
    • ensurearray.ndim not greater than 3

This seems like reasonable checks to have. The onlypotential downside is that these checks would take a bit more time, but I don't think this is really a problem.

cc@ogrisel

Any other comments?

@lucyleeowlucyleeow marked this pull request as draftJuly 4, 2025 12:12
@github-actionsGitHub Actions
Copy link

github-actionsbot commentedJul 4, 2025
edited
Loading

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit:ced0f12. Link to the linter CI:here

@@ -596,7 +596,7 @@ def test_multilabel_confusion_matrix_errors():
# Bad sample_weight
withpytest.raises(ValueError,match="inconsistent numbers of samples"):
multilabel_confusion_matrix(y_true,y_pred,sample_weight=[1,2])
withpytest.raises(ValueError,match="should be a 1darray"):
withpytest.raises(ValueError,match="Sample weights must be 1Darray or scalar"):
Copy link
MemberAuthor

@lucyleeowlucyleeowJul 7, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Note that this is because this error is now being raised by_check_sample_weight:

ifsample_weight.ndim!=1:
raiseValueError("Sample weights must be 1D array or scalar")

Instead ofcolumn_or_1d, which is called after_check_sample_weight:

sample_weight=column_or_1d(sample_weight,device=device_)

This is actually a fix because the old error message was: "y should be a 1d array, got an array of shape (3, 3) instead." - which is mis-leading as it was actuallysample_weight

For reference here is some discussion in the original PR adding_check_sample_weight:https://github.com/scikit-learn/scikit-learn/pull/14307/files#r302938269

@lucyleeow
Copy link
MemberAuthor

I'm not 100% about the addition of_check_sample_weights so won't add tests until we decide we are happy about this change. Thanks!

@lucyleeowlucyleeow marked this pull request as ready for reviewJuly 7, 2025 04:12
Copy link
Member

@jeremiedbbjeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Adding sample weight checking to_check_targets looks good because it allows to remove a redundantcheck_consistent_length call. The name_check_targets feels a bit off now though.

Thecheck_sample_weight added in the PR are only used for the checks but not for the conversions of generations thatcheck_sample_weight is able to do. It would be interesting to check if it's enough. Passing an int assample_weight for instance. We might have to make_check_targets return the validated sample weights like it does fory_true andy_pred.

Comment on lines 491 to 492
else:
sample_weight=np.asarray(sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

We could replace these lines by a call to_check_sample_weight, no ?

@lucyleeow
Copy link
MemberAuthor

lucyleeow commentedJul 7, 2025
edited
Loading

The name _check_targets feels a bit off now though.

That is a very good point 🤔 I don't know (we kept_check_reg_targets in#30886, but maybe we could consider changing the name....?)

The check_sample_weight added in the PR are only used for the checks but not for the conversions of generations that check_sample_weight is able to do.

🤦 I meant to returnsample_weight like I did in#30886, sorry that was a brain fart. Thanks for your patience.

Let me check what can be deleted with that.

@jeremiedbb
Copy link
Member

we kept _check_reg_targets in#30886,

Let's keep the name as is then :)

lucyleeow reacted with thumbs up emoji


xp,_,device=get_namespace_and_device(y_true,y_pred,sample_weight)

ifsample_weightisNone:
weight_average=1.0
else:
sample_weight=xp.asarray(sample_weight,device=device)
Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Althoughcheck_array inside_check_sample_weights does not specifydevice, because we usexp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight) I think sample weight should be on the right device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I thought that if the 3 arrays are not on the same device an error is raised, so indeed if we get there, they should be on the same device.

@lucyleeow
Copy link
MemberAuthor

lucyleeow commentedJul 9, 2025
edited
Loading

I've hit a snag._check_sample_weight forcessample_weight to be a float, even if you specifically pass an intdtype:

ifdtypeisnotNoneanddtypenotinfloat_dtypes:
dtype=max_float_type

This was not a problem for regression targets because we mostly wanted it to be a float (i.e., we use_check_reg_targets_with_floating_dtype) and when_check_reg_targets_with_floating_dtype wasn't explicitly used, we were passingsample_weight to_averaged_weighted_percentile where we upcastsample_weight anyway.

For classification metrics, it is fine forsample_weights to be an int or bool (and indeed we check for this intest_confusion_matrix_dtype).

Looking at the uses of_check_sample_weight where we have specifically specified a dtype (seesearch), it's typically always been a float (because we did an array check/validate previously, and ensuredy orX is float32 or float64).

I am wondering if we could remove thedtype not in float_dtypes part...?

It is also confusing that if you specifically specify an int dtype,_check_sample_weight will still upcast to float64.

The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in_check_sample_weight.

WDYT?

@lucyleeow
Copy link
MemberAuthor

Actually I see two cases where Ithink it is possible that it is passed an int dtype:

sample_weight=_check_sample_weight(sample_weight,X,dtype=X.dtype)

and

sample_weight=_check_sample_weight(sample_weight,X,dtype=X.dtype)

@jeremiedbb
Copy link
Member

I am wondering if we could remove the dtype not in float_dtypes part...?

It's not straightforward. IIRC this upcast was added to avoid converting float sample weights to integers, possibly loosing precision in the process.

The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in _check_sample_weight.

+1 to try that approach

lucyleeow reacted with thumbs up emojilucyleeow reacted with eyes emoji

Comment on lines 671 to 672
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight, device=device_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

We keep that for the device, right ?
It feels a bit weird. Do you thinkcheck_targets should be responsible for this ?

Comment on lines +2198 to +2199
ifnotallow_non_float:
ifdtypeisnotNoneanddtypenotinfloat_dtypes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Should be a single if condition

@jeremiedbb
Copy link
Member

Something that is not clear to me is whensample_weights is not on the same device as y_true/y_pred, do we want to raise an error, or move it to the device. It seems that there are cases where we do one and cases where we do the other, but maybe I missed something. Maybe@lesteve has more insights ?

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@jeremiedbbjeremiedbbjeremiedbb left review comments

At least 1 approving review is required to merge this pull request.

Assignees
No one assigned
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

2 participants
@lucyleeow@jeremiedbb

[8]ページ先頭

©2009-2025 Movatter.jp