基于 sklearn 包自带的 iris 数据集,了解一下分类树的各种参数设置以及代表的意义。
iris 数据集介绍iris 数据集包含 150 个样本,对应数据集的每行数据,每行数据包含每个样本的四个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)和样本的类别信息,所以 iris 数据集是一个 150 行 5 列的二维表。
iris 数据集总共有三类:Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),以及 Iris Virginica(维吉尼亚鸢尾),每类 50 个数据。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的,具体地看后续实验分析。
下边这张图片是在网上找的数据集示例,单位是 cm.
对其有了大概的认识,下边就进行深入探究吧。
sklearn.tree.DecisionTreeClassifier 函数参数
该函数包含很多参数,具体如下:
下面一一解释。
参数 criterion 表示选择特征的准则,默认是 \'gini\',也就是基尼系数了,sklearn 库是使用的改良后的 CART 算法。当然你也可以设置成 \'entropy\',即信息增益,具体使用哪个,那就实验看看模型效果呗。
下面我们使用默认的 \'gini\' 准则来生成决策树:
# -*- coding: utf-8 -*- """ Created on Wed Apr 18 11:33:09 2018 @author: zhoukui """ from sklearn.datasets import load_iris from sklearn import tree from sklearn.externals.six import StringIO import pydotplus \'\'\' StringIO 经常被用来作字符串的缓存,它的部分接口跟文件一样,可以 认为是作为"内存文件对象",简而言之,就是为了方便 \'\'\' dot_data = StringIO() iris = load_iris() clf = tree.DecisionTreeClassifier() # 如果改用信息增益准则,就在括号中添加: criterion=\'entropy\' clf = clf.fit(iris.data, iris.target) # print(clf.max_features_) # 输出拟合树的属性 dot_data = tree.export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True,rounded=True, impurity=False) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("iris.pdf")可视化如下:
图 1:采用 \'gini\' 标准生成的决策树
由图可知,第一个选择的特征是花瓣宽度(petal width),只要其值小于 0.8,就把所有的 setosa 类共 50 个全都找出来了,其它的分类可以自己继续分析。
如果你选择 \'entropy\' 准则,就会得到下面这样的树:
图 2:采用 \'entropy\' 标准生成的决策树
由图可知,此时第一个选择的特征改为了花瓣长度(petal length),其值小于 2.45,也能够把 50 个 setosa 类全部分出来。仔细对比一下,这两种准则生成的树几乎一样,看来对于这个 iris 数据集对于准则的选择并不敏感。
splitter该参数是设置划分点的选择标准,比方说图 1 中第一个节点的 0.8 的选择,默认是 \'best\',表示在所有特征中找最好的切分点。也可以选择 \'random\',表示随机地在部分特征中选择最好的切分点(数据量大的时候),所以默认的 \'best\' 适合样本量不大的时候,而如果样本数据量非常大,可尝试使用 \'random\' . 在本例中,如果选择后者,会产生一个更加复杂的树,显然没有 \'best\' 好。
我们知道决策树的一个大缺点就是过拟合,max_depth 参数就是设置决策树的最大层数的,默认值是 None. 如果树形比较简单的话一般选择默认就行了,如果树形太过于复杂,可以设置一下最大层数。比方说对于 iris 数据,我们设置 max_depth = 3,就会得到一个深度为 3 的树,看下图它的层数没有算上根节点。anyway,知道什么意思就行了。实践中如果模型样本量、特征非常多的情况下,推荐限制这个最大深度,具体的取值取决于数据的分布,常用的可以取值10-100之间。
图 3:限制 `max_depth = 3` 生成的决策树 ### min_samples_split