1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| import numpy as np from sklearn.datasets import make_moons import matplotlib.pyplot as plt from sklearn.model_selection import cross_val_score as cr from sklearn.tree import DecisionTreeClassifier n_samples = 600 x,y=make_moons(n_samples=n_samples, noise=.1,random_state=8) model=DecisionTreeClassifier(criterion='gini',max_depth=15) model.fit(x,y)
xmin,xmax=x[:,0].min()-1,x[:,0].max()+1 ymin,ymax=x[:,1].min()-1,x[:,1].max()+1
xx,yy=np.meshgrid(np.arange(xmin,xmax,0.02),np.arange(ymin,ymax,0.02))
xf=np.c_[xx.ravel(),yy.ravel()]; z=model.predict(xf) z=z.reshape(xx.shape) plt.pcolormesh(xx,yy,z,cmap=plt.cm.Pastel1) plt.scatter(x[:,0],x[:,1],c=y)
print(cr(model,x,y,cv=5,scoring="accuracy"))
plt.show()
|