Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit99d2a64

Browse files
author
Jesal Patel
committed
added plotting functions
1 parent8439c90 commit99d2a64

27 files changed

+2059
-0
lines changed

‎plotting_functions/make_blobs.py‎

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
importnumbers
2+
importnumpyasnp
3+
4+
fromsklearn.utilsimportcheck_array,check_random_state
5+
fromsklearn.utilsimportshuffleasshuffle_
6+
fromsklearn.utils.deprecationimportdeprecated
7+
8+
9+
@deprecated("Please import make_blobs directly from scikit-learn")
10+
defmake_blobs(n_samples=100,n_features=2,centers=2,cluster_std=1.0,
11+
center_box=(-10.0,10.0),shuffle=True,random_state=None):
12+
"""Generate isotropic Gaussian blobs for clustering.
13+
14+
Read more in the :ref:`User Guide <sample_generators>`.
15+
16+
Parameters
17+
----------
18+
n_samples : int, or tuple, optional (default=100)
19+
The total number of points equally divided among clusters.
20+
21+
n_features : int, optional (default=2)
22+
The number of features for each sample.
23+
24+
centers : int or array of shape [n_centers, n_features], optional
25+
(default=3)
26+
The number of centers to generate, or the fixed center locations.
27+
28+
cluster_std: float or sequence of floats, optional (default=1.0)
29+
The standard deviation of the clusters.
30+
31+
center_box: pair of floats (min, max), optional (default=(-10.0, 10.0))
32+
The bounding box for each cluster center when centers are
33+
generated at random.
34+
35+
shuffle : boolean, optional (default=True)
36+
Shuffle the samples.
37+
38+
random_state : int, RandomState instance or None, optional (default=None)
39+
If int, random_state is the seed used by the random number generator;
40+
If RandomState instance, random_state is the random number generator;
41+
If None, the random number generator is the RandomState instance used
42+
by `np.random`.
43+
44+
Returns
45+
-------
46+
X : array of shape [n_samples, n_features]
47+
The generated samples.
48+
49+
y : array of shape [n_samples]
50+
The integer labels for cluster membership of each sample.
51+
52+
Examples
53+
--------
54+
>>> from sklearn.datasets.samples_generator import make_blobs
55+
>>> X, y = make_blobs(n_samples=10, centers=3, n_features=2,
56+
... random_state=0)
57+
>>> print(X.shape)
58+
(10, 2)
59+
>>> y
60+
array([0, 0, 1, 0, 2, 2, 2, 1, 1, 0])
61+
62+
See also
63+
--------
64+
make_classification: a more intricate variant
65+
"""
66+
generator=check_random_state(random_state)
67+
68+
ifisinstance(centers,numbers.Integral):
69+
centers=generator.uniform(center_box[0],center_box[1],
70+
size=(centers,n_features))
71+
else:
72+
centers=check_array(centers)
73+
n_features=centers.shape[1]
74+
75+
ifisinstance(cluster_std,numbers.Real):
76+
cluster_std=np.ones(len(centers))*cluster_std
77+
78+
X= []
79+
y= []
80+
81+
n_centers=centers.shape[0]
82+
ifisinstance(n_samples,numbers.Integral):
83+
n_samples_per_center= [int(n_samples//n_centers)]*n_centers
84+
foriinrange(n_samples%n_centers):
85+
n_samples_per_center[i]+=1
86+
else:
87+
n_samples_per_center=n_samples
88+
89+
fori, (n,std)inenumerate(zip(n_samples_per_center,cluster_std)):
90+
X.append(centers[i]+generator.normal(scale=std,
91+
size=(n,n_features)))
92+
y+= [i]*n
93+
94+
X=np.concatenate(X)
95+
y=np.array(y)
96+
97+
ifshuffle:
98+
X,y=shuffle_(X,y,random_state=generator)
99+
100+
returnX,y
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
importnumpyasnp
2+
importmatplotlib.pyplotasplt
3+
from .plot_helpersimportcm2,cm3,discrete_scatter
4+
5+
6+
defplot_2d_classification(classifier,X,fill=False,ax=None,eps=None,
7+
alpha=1,cm=cm3):
8+
# multiclass
9+
ifepsisNone:
10+
eps=X.std()/2.
11+
12+
ifaxisNone:
13+
ax=plt.gca()
14+
15+
x_min,x_max=X[:,0].min()-eps,X[:,0].max()+eps
16+
y_min,y_max=X[:,1].min()-eps,X[:,1].max()+eps
17+
xx=np.linspace(x_min,x_max,1000)
18+
yy=np.linspace(y_min,y_max,1000)
19+
20+
X1,X2=np.meshgrid(xx,yy)
21+
X_grid=np.c_[X1.ravel(),X2.ravel()]
22+
decision_values=classifier.predict(X_grid)
23+
ax.imshow(decision_values.reshape(X1.shape),extent=(x_min,x_max,
24+
y_min,y_max),
25+
aspect='auto',origin='lower',alpha=alpha,cmap=cm)
26+
ax.set_xlim(x_min,x_max)
27+
ax.set_ylim(y_min,y_max)
28+
ax.set_xticks(())
29+
ax.set_yticks(())
30+
31+
32+
defplot_2d_scores(classifier,X,ax=None,eps=None,alpha=1,cm="viridis",
33+
function=None):
34+
# binary with fill
35+
ifepsisNone:
36+
eps=X.std()/2.
37+
38+
ifaxisNone:
39+
ax=plt.gca()
40+
41+
x_min,x_max=X[:,0].min()-eps,X[:,0].max()+eps
42+
y_min,y_max=X[:,1].min()-eps,X[:,1].max()+eps
43+
xx=np.linspace(x_min,x_max,100)
44+
yy=np.linspace(y_min,y_max,100)
45+
46+
X1,X2=np.meshgrid(xx,yy)
47+
X_grid=np.c_[X1.ravel(),X2.ravel()]
48+
iffunctionisNone:
49+
function=getattr(classifier,"decision_function",
50+
getattr(classifier,"predict_proba"))
51+
else:
52+
function=getattr(classifier,function)
53+
decision_values=function(X_grid)
54+
ifdecision_values.ndim>1anddecision_values.shape[1]>1:
55+
# predict_proba
56+
decision_values=decision_values[:,1]
57+
grr=ax.imshow(decision_values.reshape(X1.shape),
58+
extent=(x_min,x_max,y_min,y_max),aspect='auto',
59+
origin='lower',alpha=alpha,cmap=cm)
60+
61+
ax.set_xlim(x_min,x_max)
62+
ax.set_ylim(y_min,y_max)
63+
ax.set_xticks(())
64+
ax.set_yticks(())
65+
returngrr
66+
67+
68+
defplot_2d_separator(classifier,X,fill=False,ax=None,eps=None,alpha=1,
69+
cm=cm2,linewidth=None,threshold=None,
70+
linestyle="solid"):
71+
# binary?
72+
ifepsisNone:
73+
eps=X.std()/2.
74+
75+
ifaxisNone:
76+
ax=plt.gca()
77+
78+
x_min,x_max=X[:,0].min()-eps,X[:,0].max()+eps
79+
y_min,y_max=X[:,1].min()-eps,X[:,1].max()+eps
80+
xx=np.linspace(x_min,x_max,1000)
81+
yy=np.linspace(y_min,y_max,1000)
82+
83+
X1,X2=np.meshgrid(xx,yy)
84+
X_grid=np.c_[X1.ravel(),X2.ravel()]
85+
try:
86+
decision_values=classifier.decision_function(X_grid)
87+
levels= [0]ifthresholdisNoneelse [threshold]
88+
fill_levels= [decision_values.min()]+levels+ [
89+
decision_values.max()]
90+
exceptAttributeError:
91+
# no decision_function
92+
decision_values=classifier.predict_proba(X_grid)[:,1]
93+
levels= [.5]ifthresholdisNoneelse [threshold]
94+
fill_levels= [0]+levels+ [1]
95+
iffill:
96+
ax.contourf(X1,X2,decision_values.reshape(X1.shape),
97+
levels=fill_levels,alpha=alpha,cmap=cm)
98+
else:
99+
ax.contour(X1,X2,decision_values.reshape(X1.shape),levels=levels,
100+
colors="black",alpha=alpha,linewidths=linewidth,
101+
linestyles=linestyle,zorder=5)
102+
103+
ax.set_xlim(x_min,x_max)
104+
ax.set_ylim(y_min,y_max)
105+
ax.set_xticks(())
106+
ax.set_yticks(())
107+
108+
109+
if__name__=='__main__':
110+
fromsklearn.datasetsimportmake_blobs
111+
fromsklearn.linear_modelimportLogisticRegression
112+
X,y=make_blobs(centers=2,random_state=42)
113+
clf=LogisticRegression().fit(X,y)
114+
plot_2d_separator(clf,X,fill=True)
115+
discrete_scatter(X[:,0],X[:,1],y)
116+
plt.show()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
importmatplotlib.pyplotasplt
2+
importnumpyasnp
3+
fromsklearn.datasetsimportmake_blobs
4+
fromsklearn.clusterimportAgglomerativeClustering
5+
fromsklearn.neighborsimportKernelDensity
6+
7+
8+
defplot_agglomerative_algorithm():
9+
# generate synthetic two-dimensional data
10+
X,y=make_blobs(random_state=0,n_samples=12)
11+
12+
agg=AgglomerativeClustering(n_clusters=X.shape[0],compute_full_tree=True).fit(X)
13+
14+
fig,axes=plt.subplots(X.shape[0]//5,5,subplot_kw={'xticks': (),
15+
'yticks': ()},
16+
figsize=(20,8))
17+
18+
eps=X.std()/2
19+
20+
x_min,x_max=X[:,0].min()-eps,X[:,0].max()+eps
21+
y_min,y_max=X[:,1].min()-eps,X[:,1].max()+eps
22+
23+
xx,yy=np.meshgrid(np.linspace(x_min,x_max,100),np.linspace(y_min,y_max,100))
24+
gridpoints=np.c_[xx.ravel().reshape(-1,1),yy.ravel().reshape(-1,1)]
25+
26+
fori,axinenumerate(axes.ravel()):
27+
ax.set_xlim(x_min,x_max)
28+
ax.set_ylim(y_min,y_max)
29+
agg.n_clusters=X.shape[0]-i
30+
agg.fit(X)
31+
ax.set_title("Step %d"%i)
32+
ax.scatter(X[:,0],X[:,1],s=60,c='grey')
33+
bins=np.bincount(agg.labels_)
34+
forclusterinrange(agg.n_clusters):
35+
ifbins[cluster]>1:
36+
points=X[agg.labels_==cluster]
37+
other_points=X[agg.labels_!=cluster]
38+
39+
kde=KernelDensity(bandwidth=.5).fit(points)
40+
scores=kde.score_samples(gridpoints)
41+
score_inside=np.min(kde.score_samples(points))
42+
score_outside=np.max(kde.score_samples(other_points))
43+
levels=.8*score_inside+.2*score_outside
44+
ax.contour(xx,yy,scores.reshape(100,100),levels=[levels],
45+
colors='k',linestyles='solid',linewidths=2)
46+
47+
axes[0,0].set_title("Initialization")
48+
49+
50+
defplot_agglomerative():
51+
X,y=make_blobs(random_state=0,n_samples=12)
52+
agg=AgglomerativeClustering(n_clusters=3)
53+
54+
eps=X.std()/2.
55+
56+
x_min,x_max=X[:,0].min()-eps,X[:,0].max()+eps
57+
y_min,y_max=X[:,1].min()-eps,X[:,1].max()+eps
58+
59+
xx,yy=np.meshgrid(np.linspace(x_min,x_max,100),np.linspace(y_min,y_max,100))
60+
gridpoints=np.c_[xx.ravel().reshape(-1,1),yy.ravel().reshape(-1,1)]
61+
62+
ax=plt.gca()
63+
fori,xinenumerate(X):
64+
ax.text(x[0]+.1,x[1],"%d"%i,horizontalalignment='left',verticalalignment='center')
65+
66+
ax.scatter(X[:,0],X[:,1],s=60,c='grey')
67+
ax.set_xticks(())
68+
ax.set_yticks(())
69+
70+
foriinrange(11):
71+
agg.n_clusters=X.shape[0]-i
72+
agg.fit(X)
73+
74+
bins=np.bincount(agg.labels_)
75+
forclusterinrange(agg.n_clusters):
76+
ifbins[cluster]>1:
77+
points=X[agg.labels_==cluster]
78+
other_points=X[agg.labels_!=cluster]
79+
80+
kde=KernelDensity(bandwidth=.5).fit(points)
81+
scores=kde.score_samples(gridpoints)
82+
score_inside=np.min(kde.score_samples(points))
83+
score_outside=np.max(kde.score_samples(other_points))
84+
levels=.8*score_inside+.2*score_outside
85+
ax.contour(xx,yy,scores.reshape(100,100),levels=[levels],
86+
colors='k',linestyles='solid',linewidths=1)
87+
88+
ax.set_xlim(x_min,x_max)
89+
ax.set_ylim(y_min,y_max)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
fromimageioimportimread
2+
importmatplotlib.pyplotasplt
3+
4+
5+
defplot_animal_tree(ax=None):
6+
importgraphviz
7+
ifaxisNone:
8+
ax=plt.gca()
9+
mygraph=graphviz.Digraph(node_attr={'shape':'box'},
10+
edge_attr={'labeldistance':"10.5"},
11+
format="png")
12+
mygraph.node("0","Has feathers?")
13+
mygraph.node("1","Can fly?")
14+
mygraph.node("2","Has fins?")
15+
mygraph.node("3","Hawk")
16+
mygraph.node("4","Penguin")
17+
mygraph.node("5","Dolphin")
18+
mygraph.node("6","Bear")
19+
mygraph.edge("0","1",label="True")
20+
mygraph.edge("0","2",label="False")
21+
mygraph.edge("1","3",label="True")
22+
mygraph.edge("1","4",label="False")
23+
mygraph.edge("2","5",label="True")
24+
mygraph.edge("2","6",label="False")
25+
mygraph.render("tmp")
26+
ax.imshow(imread("tmp.png"))
27+
ax.set_axis_off()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp