Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

ENH allow any dtype in input from RandomSampler#1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
glemaitre merged 7 commits intoscikit-learn-contrib:masterfromglemaitre:is/970
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletionsdoc/whats_new/v0.11.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -43,3 +43,8 @@ Enhancements
parameters. A new fitted parameter `categorical_encoder_` is exposed to access the
fitted encoder.
:pr:`1001` by :user:`Guillaume Lemaitre <glemaitre>`.

- :class:`~imblearn.under_sampling.RandomUnderSampler` and
:class:`~imblearn.over_sampling.RandomOverSampler` (when `shrinkage is not
None`) now accept any data types and will not attempt any data conversion.
:pr:`1004` by :user:`Guillaume Lemaitre <glemaitre>`.
3 changes: 1 addition & 2 deletionsexamples/api/plot_sampling_strategy_usage.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -59,10 +59,9 @@
# resampling and the number of samples in the minority class, respectively.

# %%
import numpy as np

# select only 2 classes since the ratio make sense in this case
binary_mask =np.bitwise_or(y ==0,y == 2)
binary_mask =y.isin([0,1])
binary_y = y[binary_mask]
binary_X = X[binary_mask]

Expand Down
7 changes: 5 additions & 2 deletionsimblearn/datasets/tests/test_imbalance.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -67,11 +67,14 @@ def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
],
)
def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts):
pytest.importorskip("pandas")
iris = load_iris(as_frame=True)
pd =pytest.importorskip("pandas")
iris = load_iris(as_frame=as_frame)
X, y = iris.data, iris.target
y = iris.target_names[iris.target]
if as_frame:
y = pd.Series(iris.target_names[iris.target], name="target")
X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy)
if as_frame:
assert hasattr(X_res, "loc")
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert Counter(y_res) == expected_counts
3 changes: 2 additions & 1 deletionimblearn/ensemble/tests/test_bagging.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -572,11 +572,12 @@ def roughly_balanced_bagging(X, y, replace=False):

# Roughly Balanced Bagging
rbb = BalancedBaggingClassifier(
estimator=CountDecisionTreeClassifier(),
estimator=CountDecisionTreeClassifier(random_state=0),
n_estimators=2,
sampler=FunctionSampler(
func=roughly_balanced_bagging, kw_args={"replace": replace}
),
random_state=0,
)
rbb.fit(X, y)

Expand Down
15 changes: 7 additions & 8 deletionsimblearn/over_sampling/_random_over_sampler.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,6 +15,7 @@
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval
from ..utils._validation import _check_X
from .base import BaseOverSampler


Expand DownExpand Up@@ -154,14 +155,9 @@ def __init__(

def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
reset=True,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand DownExpand Up@@ -258,4 +254,7 @@ def _more_tags(self):
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
"_xfail_checks": {
"check_complex_data": "Robust to this type of data.",
},
}
5 changes: 3 additions & 2 deletionsimblearn/over_sampling/_smote/base.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -27,6 +27,7 @@
from ...utils import Substitution, check_neighbors_object, check_target_type
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval
from ...utils._validation import _check_X
from ...utils.fixes import _mode
from ..base import BaseOverSampler

Expand DownExpand Up@@ -559,9 +560,9 @@ def _check_X_y(self, X, y):
features.
"""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
if not (hasattr(X, "__array__") or sparse.issparse(X)):
X = check_array(X, dtype=object)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _validate_estimator(self):
Expand Down
14 changes: 14 additions & 0 deletionsimblearn/over_sampling/tests/test_random_over_sampler.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@
# License: MIT

from collections import Counter
from datetime import datetime

import numpy as np
import pytest
Expand DownExpand Up@@ -273,3 +274,16 @@ def test_random_over_sampler_strings(sampling_strategy):
random_state=0,
)
RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)


def test_random_over_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
pd = pytest.importorskip("pandas")
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
y = X["label"]
ros = RandomOverSampler(random_state=0)
X_res, y_res = ros.fit_resample(X, y)

pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1]))
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -9,6 +9,7 @@

from ...utils import Substitution, check_target_type
from ...utils._docstring import _random_state_docstring
from ...utils._validation import _check_X
from ..base import BaseUnderSampler


Expand DownExpand Up@@ -97,14 +98,9 @@ def __init__(

def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
reset=True,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand DownExpand Up@@ -140,4 +136,7 @@ def _more_tags(self):
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
"_xfail_checks": {
"check_complex_data": "Robust to this type of data.",
},
}
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@
# License: MIT

from collections import Counter
from datetime import datetime

import numpy as np
import pytest
Expand DownExpand Up@@ -148,3 +149,16 @@ def test_random_under_sampler_strings(sampling_strategy):
random_state=0,
)
RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)


def test_random_under_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
pd = pytest.importorskip("pandas")
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
y = X["label"]
rus = RandomUnderSampler(random_state=0)
X_res, y_res = rus.fit_resample(X, y)

pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 1]))
26 changes: 25 additions & 1 deletionimblearn/utils/_validation.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -12,8 +12,11 @@
import numpy as np
from sklearn.base import clone
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import column_or_1d
from sklearn.utils importcheck_array,column_or_1d
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples

from .fixes import _is_pandas_df

SAMPLING_KIND = (
"over-sampling",
Expand All@@ -35,6 +38,12 @@ def __init__(self, X, y):
def transform(self, X, y):
X = self._transfrom_one(X, self.x_props)
y = self._transfrom_one(y, self.y_props)
if self.x_props["type"].lower() == "dataframe" and self.y_props[
"type"
].lower() in {"series", "dataframe"}:
# We lost the y.index during resampling. We can safely use X.index to align
# them.
y.index = X.index
return X, y

def _gets_props(self, array):
Expand DownExpand Up@@ -607,3 +616,18 @@ def inner_f(*args, **kwargs):
return f(**kwargs)

return inner_f


def _check_X(X):
"""Check X and do not check it if a dataframe."""
n_samples = _num_samples(X)
if n_samples < 1:
raise ValueError(
f"Found array with {n_samples} sample(s) while a minimum of 1 is "
"required."
)
if _is_pandas_df(X):
return X
return check_array(
X, dtype=None, accept_sparse=["csr", "csc"], force_all_finite=False
)
16 changes: 16 additions & 0 deletionsimblearn/utils/fixes.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,6 +5,7 @@
which the fix is no longer needed.
"""
import functools
import sys

import numpy as np
import scipy
Expand DownExpand Up@@ -132,3 +133,18 @@ def _is_fitted(estimator, attributes=None, all_or_any=all):

else:
from sklearn.utils.validation import _is_fitted # type: ignore[no-redef]

try:
from sklearn.utils.validation import _is_pandas_df
except ImportError:

def _is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
if hasattr(X, "columns") and hasattr(X, "iloc"):
# Likely a pandas DataFrame, we explicitly check the type to confirm.
try:
pd = sys.modules["pandas"]
except KeyError:
return False
return isinstance(X, pd.DataFrame)
return False

[8]ページ先頭

©2009-2025 Movatter.jp