Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

MAINT add parameter validation framework#955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
18 commits
Select commitHold shift + click to select a range
6ee5f39
MAINT add parameter validation framework
glemaitreDec 4, 2022
b50985b
MAINT add parameter validation for ensemble models
glemaitreDec 4, 2022
e531adc
MAINT redundant tests
glemaitreDec 4, 2022
986752b
MAINT parameter validation ADASYN
glemaitreDec 4, 2022
0280fae
MAINT parameter validation for over-sampler
glemaitreDec 4, 2022
c5a138d
MAINT all under-sampler
glemaitreDec 4, 2022
1a1d75f
iter
glemaitreDec 4, 2022
f89d69f
iter
glemaitreDec 4, 2022
7316b15
iter
glemaitreDec 4, 2022
7565072
iter
glemaitreDec 4, 2022
5c9e6b3
iter
glemaitreDec 4, 2022
af01062
iter
glemaitreDec 4, 2022
4437e80
iter
glemaitreDec 4, 2022
db75eb1
add test files
glemaitreDec 4, 2022
3fc78b3
iter
glemaitreDec 4, 2022
663869b
iter
glemaitreDec 4, 2022
e1d4702
TST add coverage for _is_neighbors_object
glemaitreDec 4, 2022
82e6790
TST cover more lines
glemaitreDec 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletionsdoc/whats_new/v0.10.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -19,6 +19,9 @@ Compatibility
- Maintenance release for be compatible with scikit-learn >= 1.0.2.
:pr:`946`, :pr:`947`, :pr:`949` by :user:`Guillaume Lemaitre <glemaitre>`.

- Add support for automatic parameters validation as in scikit-learn >= 1.2.
:pr:`955` by :user:`Guillaume Lemaitre <glemaitre>`.

Deprecation
...........

Expand Down
77 changes: 76 additions & 1 deletionimblearn/base.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -12,6 +12,7 @@
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer


Expand DownExpand Up@@ -113,7 +114,26 @@ def _fit_resample(self, X, y):
pass


class BaseSampler(SamplerMixin):
class _ParamsValidationMixin:
"""Mixin class to validate parameters."""

def _validate_params(self):
"""Validate types and values of constructor parameters.

The expected type and values must be defined in the `_parameter_constraints`
class attribute, which is a dictionary `param_name: list of constraints`. See
the docstring of `validate_parameter_constraints` for a description of the
accepted constraints.
"""
if hasattr(self, "_parameter_constraints"):
validate_parameter_constraints(
self._parameter_constraints,
self.get_params(deep=False),
caller_name=self.__class__.__name__,
)


class BaseSampler(SamplerMixin, _ParamsValidationMixin):
"""Base class for sampling algorithms.

Warning: This class should not be used directly. Use the derive classes
Expand All@@ -130,6 +150,52 @@ def _check_X_y(self, X, y, accept_sparse=None):
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def fit(self, X, y):
"""Check inputs and statistics of the sampler.

You should use ``fit_resample`` in all cases.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Data array.

y : array-like of shape (n_samples,)
Target array.

Returns
-------
self : object
Return the instance itself.
"""
self._validate_params()
return super().fit(X, y)

def fit_resample(self, X, y):
"""Resample the dataset.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Matrix containing the data which have to be sampled.

y : array-like of shape (n_samples,)
Corresponding label for each sample in X.

Returns
-------
X_resampled : {array-like, dataframe, sparse matrix} of shape \
(n_samples_new, n_features)
The array containing the resampled data.

y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
self._validate_params()
return super().fit_resample(X, y)

def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}

Expand DownExpand Up@@ -241,6 +307,13 @@ class FunctionSampler(BaseSampler):

_sampling_type = "bypass"

_parameter_constraints: dict = {
"func": [callable, None],
"accept_sparse": ["boolean"],
"kw_args": [dict, None],
"validate": ["boolean"],
}

def __init__(self, *, func=None, accept_sparse=True, kw_args=None, validate=True):
super().__init__()
self.func = func
Expand All@@ -267,6 +340,7 @@ def fit(self, X, y):
self : object
Return the instance itself.
"""
self._validate_params()
# we need to overwrite SamplerMixin.fit to bypass the validation
if self.validate:
check_classification_targets(y)
Expand DownExpand Up@@ -298,6 +372,7 @@ def fit_resample(self, X, y):
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
self._validate_params()
arrays_transformer = ArraysTransformer(X, y)

if self.validate:
Expand Down
27 changes: 11 additions & 16 deletionsimblearn/combine/_smote_enn.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,8 @@
# Christos Aridas
# License: MIT

import numbers

from sklearn.base import clone
from sklearn.utils import check_X_y

Expand DownExpand Up@@ -102,6 +104,13 @@ class SMOTEENN(BaseSampler):

_sampling_type = "over-sampling"

_parameter_constraints: dict = {
**BaseOverSampler._parameter_constraints,
"smote": [SMOTE, None],
"enn": [EditedNearestNeighbours, None],
"n_jobs": [numbers.Integral, None],
}

def __init__(
self,
*,
Expand All@@ -121,14 +130,7 @@ def __init__(
def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"
if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = clone(self.smote)
else:
raise ValueError(
f"smote needs to be a SMOTE object."
f"Got {type(self.smote)} instead."
)
# Otherwise create a default SMOTE
self.smote_ = clone(self.smote)
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
Expand All@@ -137,14 +139,7 @@ def _validate_estimator(self):
)

if self.enn is not None:
if isinstance(self.enn, EditedNearestNeighbours):
self.enn_ = clone(self.enn)
else:
raise ValueError(
f"enn needs to be an EditedNearestNeighbours."
f" Got {type(self.enn)} instead."
)
# Otherwise create a default EditedNearestNeighbours
self.enn_ = clone(self.enn)
else:
self.enn_ = EditedNearestNeighbours(
sampling_strategy="all", n_jobs=self.n_jobs
Expand Down
27 changes: 11 additions & 16 deletionsimblearn/combine/_smote_tomek.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,6 +5,8 @@
# Christos Aridas
# License: MIT

import numbers

from sklearn.base import clone
from sklearn.utils import check_X_y

Expand DownExpand Up@@ -100,6 +102,13 @@ class SMOTETomek(BaseSampler):

_sampling_type = "over-sampling"

_parameter_constraints: dict = {
**BaseOverSampler._parameter_constraints,
"smote": [SMOTE, None],
"tomek": [TomekLinks, None],
"n_jobs": [numbers.Integral, None],
}

def __init__(
self,
*,
Expand All@@ -120,14 +129,7 @@ def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"

if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = clone(self.smote)
else:
raise ValueError(
f"smote needs to be a SMOTE object."
f"Got {type(self.smote)} instead."
)
# Otherwise create a default SMOTE
self.smote_ = clone(self.smote)
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
Expand All@@ -136,14 +138,7 @@ def _validate_estimator(self):
)

if self.tomek is not None:
if isinstance(self.tomek, TomekLinks):
self.tomek_ = clone(self.tomek)
else:
raise ValueError(
f"tomek needs to be a TomekLinks object."
f"Got {type(self.tomek)} instead."
)
# Otherwise create a default TomekLinks
self.tomek_ = clone(self.tomek)
else:
self.tomek_ = TomekLinks(sampling_strategy="all", n_jobs=self.n_jobs)

Expand Down
14 changes: 0 additions & 14 deletionsimblearn/combine/tests/test_smote_enn.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,7 +4,6 @@
# License: MIT

import numpy as np
import pytest
from sklearn.utils._testing import assert_allclose, assert_array_equal

from imblearn.combine import SMOTEENN
Expand DownExpand Up@@ -156,16 +155,3 @@ def test_parallelisation():
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.enn_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[
({"smote": "rnd"}, "smote needs to be a SMOTE"),
({"enn": "rnd"}, "enn needs to be an "),
],
)
def test_error_wrong_object(smote_params, err_msg):
smt = SMOTEENN(**smote_params)
with pytest.raises(ValueError, match=err_msg):
smt.fit_resample(X, Y)
14 changes: 0 additions & 14 deletionsimblearn/combine/tests/test_smote_tomek.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,7 +4,6 @@
# License: MIT

import numpy as np
import pytest
from sklearn.utils._testing import assert_allclose, assert_array_equal

from imblearn.combine import SMOTETomek
Expand DownExpand Up@@ -166,16 +165,3 @@ def test_parallelisation():
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.tomek_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[
({"smote": "rnd"}, "smote needs to be a SMOTE"),
({"tomek": "rnd"}, "tomek needs to be a TomekLinks"),
],
)
def test_error_wrong_object(smote_params, err_msg):
smt = SMOTETomek(**smote_params)
with pytest.raises(ValueError, match=err_msg):
smt.fit_resample(X, Y)
40 changes: 27 additions & 13 deletionsimblearn/ensemble/_bagging.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@
# Christos Aridas
# License: MIT

import copy
import inspect
import numbers
import warnings
Expand All@@ -18,21 +19,23 @@
from sklearn.utils.fixes import delayed
from sklearn.utils.validation import check_is_fitted

from ..base import _ParamsValidationMixin
from ..pipeline import Pipeline
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._available_if import available_if
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ._common import _estimator_has
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ._common import _bagging_parameter_constraints, _estimator_has


@Substitution(
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
n_jobs=_n_jobs_docstring,
random_state=_random_state_docstring,
)
class BalancedBaggingClassifier(BaggingClassifier):
class BalancedBaggingClassifier(BaggingClassifier, _ParamsValidationMixin):
"""A Bagging classifier with additional balancing.

This implementation of Bagging is similar to the scikit-learn
Expand DownExpand Up@@ -252,6 +255,26 @@ class BalancedBaggingClassifier(BaggingClassifier):
[ 2 225]]
"""

# make a deepcopy to not modify the original dictionary
if hasattr(BaggingClassifier, "_parameter_constraints"):
# scikit-learn >= 1.2
_parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
else:
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)

_parameter_constraints.update(
{
"sampling_strategy": [
Interval(numbers.Real, 0, 1, closed="right"),
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
dict,
callable,
],
"replacement": ["boolean"],
"sampler": [HasMethods(["fit_resample"]), None],
}
)

def __init__(
self,
estimator=None,
Expand DownExpand Up@@ -316,17 +339,7 @@ def _validate_y(self, y):

def _validate_estimator(self, default=DecisionTreeClassifier()):
"""Check the estimator and the n_estimator attribute, set the
`base_estimator_` attribute."""
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
raise ValueError(
f"n_estimators must be an integer, " f"got {type(self.n_estimators)}."
)

if self.n_estimators <= 0:
raise ValueError(
f"n_estimators must be greater than zero, " f"got {self.n_estimators}."
)

`estimator_` attribute."""
if self.estimator is not None and (
self.base_estimator not in [None, "deprecated"]
):
Expand DownExpand Up@@ -395,6 +408,7 @@ def fit(self, X, y):
Fitted estimator.
"""
# overwrite the base class method by disallowing `sample_weight`
self._validate_params()
return super().fit(X, y)

def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp