加载中...
决策树
发表于:2021-11-08 | 分类: 机器学习课程(魏)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier

mnist=loadmat('mnist-original.mat')
x,y=mnist["data"],mnist["label"]
x=x.T
y=y[0]
some_digit=x[68888]
x_train=x[:60000]
y_train=y[:60000]
model=DecisionTreeClassifier(max_depth=10)
model.fit(x_train,y_train)
print(model.predict([some_digit]))

[8.]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score as cr#调用交叉检验函数
from sklearn.tree import DecisionTreeClassifier
n_samples = 600
x,y=make_moons(n_samples=n_samples, noise=.1,random_state=8)
model=DecisionTreeClassifier(criterion='gini',max_depth=15)
model.fit(x,y)
#clf.predict([x])
xmin,xmax=x[:,0].min()-1,x[:,0].max()+1
ymin,ymax=x[:,1].min()-1,x[:,1].max()+1

xx,yy=np.meshgrid(np.arange(xmin,xmax,0.02),np.arange(ymin,ymax,0.02))

xf=np.c_[xx.ravel(),yy.ravel()];
z=model.predict(xf)
z=z.reshape(xx.shape)
plt.pcolormesh(xx,yy,z,cmap=plt.cm.Pastel1)
plt.scatter(x[:,0],x[:,1],c=y)

print(cr(model,x,y,cv=5,scoring="accuracy"))

plt.show()

[0.99166667 1.         0.98333333 0.98333333 0.98333333]


c:\users\administrator\appdata\local\programs\python\python37\lib\site-packages\ipykernel_launcher.py:19: MatplotlibDeprecationWarning: shading='flat' when X and Y have the same dimensions as C is deprecated since 3.3.  Either specify the corners of the quadrilaterals with X and Y, or pass shading='auto', 'nearest' or 'gouraud', or set rcParams['pcolor.shading'].  This will become an error two minor releases later.

png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz

iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=3)
tree_clf.fit(X, y)

export_graphviz(
tree_clf,
out_file="tree.dot",
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)

上一篇:
主成分分析
下一篇:
拟合
本文目录
本文目录