1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| data=make_blobs(n_samples=1000,n_features=2,centers=5,random_state=1) x,y=data plt.figure() plt.scatter(x[:,0],x[:,1],c=y,cmap=plt.cm.spring,edgecolor='k') model=KNeighborsClassifier() model.fit(x,y)
xmin,xmax=x[:,0].min()-1,x[:,0].max()+1 ymin,ymax=x[:,1].min()-1,x[:,1].max()+1 x1,y1=np.meshgrid(np.arange(xmin,xmax,0.02),np.arange(ymin,ymax,0.02))
z=model.predict(np.c_[x1.ravel(),y1.ravel()]) z=z.reshape(x1.shape) plt.figure() plt.pcolormesh(x1,y1,z,cmap=plt.cm.Pastel1) plt.scatter(x[:,0],x[:,1],c=y,cmap=plt.cm.spring,edgecolor='k') plt.show()
|