Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfe6960b

Browse files
FIX: Regression in DecisionBoundaryDisplay.from_estimator with colors and plot_method='contour' (#31553)
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parentbab34a0 commitfe6960b

File tree

3 files changed

+93
-55
lines changed

3 files changed

+93
-55
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- Fix multiple issues in the multiclass setting of:class:`inspection.DecisionBoundaryDisplay`:
2+
3+
- `contour` plotting now correctly shows the decision boundary.
4+
- `cmap` and `colors` are now properly ignored in favor of `multiclass_colors`.
5+
- Linear segmented colormaps are now fully supported.
6+
7+
By:user:`Yunjie Lin <jshn9515>`

‎sklearn/inspection/_plot/decision_boundary.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,22 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
221221
self.surface_=plot_func(self.xx0,self.xx1,self.response,**kwargs)
222222
else:# self.response.ndim == 3
223223
n_responses=self.response.shape[-1]
224-
if (
225-
isinstance(self.multiclass_colors,str)
226-
orself.multiclass_colorsisNone
224+
forkwargin ("cmap","colors"):
225+
ifkwarginkwargs:
226+
warnings.warn(
227+
f"'{kwarg}' is ignored in favor of 'multiclass_colors' "
228+
"in the multiclass case when the response method is "
229+
"'decision_function' or 'predict_proba'."
230+
)
231+
delkwargs[kwarg]
232+
233+
ifself.multiclass_colorsisNoneorisinstance(
234+
self.multiclass_colors,str
227235
):
228-
ifisinstance(self.multiclass_colors,str):
229-
cmap=self.multiclass_colors
236+
ifself.multiclass_colorsisNone:
237+
cmap="tab10"ifn_responses<=10else"gist_rainbow"
230238
else:
231-
ifn_responses<=10:
232-
cmap="tab10"
233-
else:
234-
cmap="gist_rainbow"
239+
cmap=self.multiclass_colors
235240

236241
# Special case for the tab10 and tab20 colormaps that encode a
237242
# discrete set of colors that are easily distinguishable
@@ -241,40 +246,41 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
241246
elifcmap=="tab20"andn_responses<=20:
242247
colors=plt.get_cmap("tab20",20).colors[:n_responses]
243248
else:
244-
colors=plt.get_cmap(cmap,n_responses).colors
245-
elifisinstance(self.multiclass_colors,str):
246-
colors=colors=plt.get_cmap(
247-
self.multiclass_colors,n_responses
248-
).colors
249-
else:
249+
cmap=plt.get_cmap(cmap,n_responses)
250+
ifnothasattr(cmap,"colors"):
251+
# For LinearSegmentedColormap
252+
colors=cmap(np.linspace(0,1,n_responses))
253+
else:
254+
colors=cmap.colors
255+
elifisinstance(self.multiclass_colors,list):
250256
colors= [mpl.colors.to_rgba(color)forcolorinself.multiclass_colors]
257+
else:
258+
raiseValueError("'multiclass_colors' must be a list or a str.")
251259

252260
self.multiclass_colors_=colors
253-
multiclass_cmaps= [
254-
mpl.colors.LinearSegmentedColormap.from_list(
255-
f"colormap_{class_idx}", [(1.0,1.0,1.0,1.0), (r,g,b,1.0)]
256-
)
257-
forclass_idx, (r,g,b,_)inenumerate(colors)
258-
]
259-
260-
self.surface_= []
261-
forclass_idx,cmapinenumerate(multiclass_cmaps):
262-
response=np.ma.array(
263-
self.response[:, :,class_idx],
264-
mask=~(self.response.argmax(axis=2)==class_idx),
261+
ifplot_method=="contour":
262+
# Plot only argmax map for contour
263+
class_map=self.response.argmax(axis=2)
264+
self.surface_=plot_func(
265+
self.xx0,self.xx1,class_map,colors=colors,**kwargs
265266
)
266-
# `cmap` should not be in kwargs
267-
safe_kwargs=kwargs.copy()
268-
if"cmap"insafe_kwargs:
269-
delsafe_kwargs["cmap"]
270-
warnings.warn(
271-
"Plotting max class of multiclass 'decision_function' or "
272-
"'predict_proba', thus 'multiclass_colors' used and "
273-
"'cmap' kwarg ignored."
267+
else:
268+
multiclass_cmaps= [
269+
mpl.colors.LinearSegmentedColormap.from_list(
270+
f"colormap_{class_idx}", [(1.0,1.0,1.0,1.0), (r,g,b,1.0)]
271+
)
272+
forclass_idx, (r,g,b,_)inenumerate(colors)
273+
]
274+
275+
self.surface_= []
276+
forclass_idx,cmapinenumerate(multiclass_cmaps):
277+
response=np.ma.array(
278+
self.response[:, :,class_idx],
279+
mask=~(self.response.argmax(axis=2)==class_idx),
280+
)
281+
self.surface_.append(
282+
plot_func(self.xx0,self.xx1,response,cmap=cmap,**kwargs)
274283
)
275-
self.surface_.append(
276-
plot_func(self.xx0,self.xx1,response,cmap=cmap,**safe_kwargs)
277-
)
278284

279285
ifxlabelisnotNoneornotax.get_xlabel():
280286
xlabel=self.xlabelifxlabelisNoneelsexlabel

‎sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
169169
@pytest.mark.parametrize(
170170
"kwargs, error_msg",
171171
[
172+
(
173+
{"multiclass_colors": {"dict":"not_list"}},
174+
"'multiclass_colors' must be a list or a str.",
175+
),
172176
({"multiclass_colors":"not_cmap"},"it must be a valid Matplotlib colormap"),
173177
({"multiclass_colors": ["red","green"]},"it must be of the same length"),
174178
(
@@ -617,6 +621,7 @@ def test_multiclass_plot_max_class(pyplot, response_method):
617621
"multiclass_colors",
618622
[
619623
"plasma",
624+
"Blues",
620625
["red","green","blue"],
621626
],
622627
)
@@ -642,31 +647,51 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors):
642647

643648
ifmulticlass_colors=="plasma":
644649
colors=mpl.pyplot.get_cmap(multiclass_colors,len(clf.classes_)).colors
650+
elifmulticlass_colors=="Blues":
651+
cmap=mpl.pyplot.get_cmap(multiclass_colors,len(clf.classes_))
652+
colors=cmap(np.linspace(0,1,len(clf.classes_)))
645653
else:
646654
colors= [mpl.colors.to_rgba(color)forcolorinmulticlass_colors]
647655

648-
cmaps= [
649-
mpl.colors.LinearSegmentedColormap.from_list(
650-
f"colormap_{class_idx}", [(1.0,1.0,1.0,1.0), (r,g,b,1.0)]
651-
)
652-
forclass_idx, (r,g,b,_)inenumerate(colors)
653-
]
654-
655-
foridx,quadinenumerate(disp.surface_):
656-
assertquad.cmap==cmaps[idx]
656+
ifplot_method!="contour":
657+
cmaps= [
658+
mpl.colors.LinearSegmentedColormap.from_list(
659+
f"colormap_{class_idx}", [(1.0,1.0,1.0,1.0), (r,g,b,1.0)]
660+
)
661+
forclass_idx, (r,g,b,_)inenumerate(colors)
662+
]
663+
foridx,quadinenumerate(disp.surface_):
664+
assertquad.cmap==cmaps[idx]
665+
else:
666+
assert_allclose(disp.surface_.colors,colors)
657667

658668

659-
deftest_multiclass_plot_max_class_cmap_kwarg(pyplot):
660-
"""Check`cmap` kwarg ignored when using plotting max multiclass class."""
669+
deftest_cmap_and_colors_logic(pyplot):
670+
"""Checkthe handling logic for `cmap` and `colors`."""
661671
X,y=load_iris_2d_scaled()
662672
clf=LogisticRegression().fit(X,y)
663673

664-
msg= (
665-
"Plotting max class of multiclass 'decision_function' or 'predict_proba', "
666-
"thus 'multiclass_colors' used and 'cmap' kwarg ignored."
667-
)
668-
withpytest.warns(UserWarning,match=msg):
669-
DecisionBoundaryDisplay.from_estimator(clf,X,cmap="viridis")
674+
withpytest.warns(
675+
UserWarning,
676+
match="'cmap' is ignored in favor of 'multiclass_colors'",
677+
):
678+
DecisionBoundaryDisplay.from_estimator(
679+
clf,
680+
X,
681+
multiclass_colors="plasma",
682+
cmap="Blues",
683+
)
684+
685+
withpytest.warns(
686+
UserWarning,
687+
match="'colors' is ignored in favor of 'multiclass_colors'",
688+
):
689+
DecisionBoundaryDisplay.from_estimator(
690+
clf,
691+
X,
692+
multiclass_colors="plasma",
693+
colors="blue",
694+
)
670695

671696

672697
deftest_subclass_named_constructors_return_type_is_subclass(pyplot):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp