前言
决策树,顾名思义,就是用来产生决策的树,我们通过数据的属性特征来构造这棵树。相比于K近邻算法,决策树的主要优势在于数据形式非常容易理解。决策树的一个很重要的任务就是理解数据中所蕴含的知识信息,从数据中提取出一系列规则的过程,也就是决策树构建的过程、机器学习的过程。
决策树算法是随机森林等集成学习算法的基础,最基本的决策树没有反馈,没有修正,就是简单的输入训练集,得出一个可以应对其他不同数据集的分类方法。目前常用的决策树算法有ID3算法、改进的C4.5算法和CART算法。在本篇blog中,我着重介绍决策树的思想和实现过程,以及实现过程中的一些思考。代码学习自《机器学习实战》,代码很规范很高效,我感觉值得深入挖掘其思想。
想必大家都知道周志华老师的《机器学习》一书,俗称西瓜书,书中在讲决策树的时候,用的也是如何买瓜的数据,如下图所示。树中的内部节点表示某个属性,节点引出的分支表示此属性的所有可能的值,叶子节点表示最终的判断结果也就是分类后的类型。
那么在考虑一棵决策树的时候,我认为有以下两点需要注意一下:
递归返回
因为决策树的生成是一个递归的过程,那么终止递归的条件需要思考:
- 当前结点包含的样本全属于同一类别,无需划分;
- 当前属性集为空,或者所有样本在所有属性集上取值相同,无法划分;
- 当前结点包含的样本集合为空,不能划分。
划分选择
决策树学习的关键是如何选择最优化分属性。一般而言,随着划分过程的不断进行,我们希望决策树的分支节点所包含的样本尽量属于同一类别,即结点的“纯度”越来越高。度量节点“纯度”的方法有多种,我们先介绍信息增益。
信息增益
信息增益是信息熵的有效减少量,那信息熵又是什么呢?信息熵定义为信息的期望值,那信息又是怎么定义的呢?让我们一步一步来解释。
如果待分类的事务可能划分在多个分类之中,则符号的信息定义为:
其中 是选择该分类的概率。
信息熵(information entropy)是度量样本集合纯度最常用的一种指标,若$ x_i $ 构成样本集合D,那么,D的信息熵定义如下,Ent值越小,纯度越高。计算信息熵时约定:若$ p=0 $,则$ plog_2p = 0$。
我们要计算出当前属性集合中每一个属性的信息增益,选出最大的那一个作为当前的结点,然后再对该结点进行划分。
ID3算法就是利用信息增益进行属性集合的划分,C4.5算法则是使用了信息增益率,CART算法使用基尼指数来选择划分属性。
python代码实现
声明测试数据
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
导入相应包
from math import log
import operator
计算给定数据集的信息熵
def compute_Shang(dataSet):
num = len(dataSet) # 由于代码中多次用到该值,为提高效率,显式地声明一个变量保存实例总数
labelCounts = {} # 定义一个字典
for featVec in dataSet:
currentLabel = featVec[-1] # 该条数据的标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 在label统计数组中加一条总数值为0的记录
labelCounts[currentLabel] += 1 # 统计各个标签的总次数
Shang = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/num
Shang -= prob * log(prob, 2)
return Shang
按照给定特征划分数据集
跟每个介绍决策树的章节一样,我们需要选定某一个特征,然后对其他剩下的信息,根据label进行划分,从而计算熵,在下面这个函数中,我们只做了把该特征拿出来,返回其他的信息。
def splitDataSet(dataSet, axis, value):
"""
para:带划分的数据集、划分数据集的特征、需要返回的特征值
Note that:python在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的
整个生命周期。为了消除这个不良影响,我们需要在函数的开始声明一个新列表对象。因为该函数代码在
同一个数据集上被调用多次,为了不修改原始数据集,创建一个新的列表对象。
"""
retDataSet = [] #
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:]) # 将要寻找的索引栏空出来,输出其他特征及标签
retDataSet.append(reduceFeatVec) # [1,2,3],[4,5,6] extend:[1,2,3,4,5,6] append:[1,2,3,[4,5,6]]
return retDataSet
选择最好的数据集划分方式
在这个函数中,我们算出在一个层次中的最好的数据集划分方式,也就是找出最合适的特征。set(featList)
选择出所有的不同特征,循环遍历,计算在这个特征充当划分结点时,整体的信息熵,最后比较出最合适的特征并返回。
def chooseBestFeatureToSplit(dataSet):
num_Features = len(dataSet[0]) - 1 # 定义特征的数量
base_Entropy = compute_Shang(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(num_Features): # 迭代所有的特征
featList = [example[i] for example in dataSet] # 这个特征下所有的样例 [1,1,1,1,0,0] [1,1,1,0,1,1]
uniqueVals = set(featList) # set去掉重复
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value) # i + value : 00 01 10 11
# print(subDataSet)
prob = len(subDataSet)/float(len(dataSet)) # 对应公式
newEntropy += prob * compute_Shang(subDataSet) # 对应公式
infoGain = base_Entropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
投票表决
在训练过程中我们有时会遇到这样的情况:如果数据集处理完了所有的属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点。这是,我们采用“多数表决”的方法决定该叶子节点的分类。
def majorityCnt(classList):
classCount = {}
for vote in classList:
# classList存储了每个类标签出现的频率
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# python3中,要将iteritems变为items
return sortedClassCount[0][0]
创建树
def createTree(dataSet, labels):
"""
Note that:del是python内置的关键字(比如import, return等都是python的关键字),并不是python的内置函数
(内置函数有range(), sorted()等等),del的作用是删除一个对象,不仅可以删除list中的某一个元素,
也可以删除一个list,一个变量,或者类的实例
"""
classList = [example[-1] for example in dataSet]
# 递归停止的第一个条件:所有类标签完全相同
if classList.count(classList[0]) == len(classList):
return classList[0]
# 递归停止的第二个条件是:使用完了所有特征,仍然不能把所有数据集划分成仅包含唯一类别的分组
if len(dataSet[0]) == 1:
return majorityCnt(classList) # 返回出现次数最多的类别
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
测试案例
mydata, labels = trees.createDataSet()
print(trees.createTree(mydata, labels))
结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
Matplotlib绘制决策树
matplotlib并没有提供决策树绘制接口,所以需要自己来实现,因为决策树的主要优势是直观,而且易于理解,如果不能把它直观的显示出来,就无法发挥其优势。而且我们需要写一个通用的代码,能一直为不同决策树提供接口的代码。
首先我们先来了解一下matplotlib的注解工具annotation
,它可以在数据图形上添加文本注解。
import matplotlib.pyplot as plt
# 定义文本框和箭头格式,dict用来创建空字典
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
绘制树的主函数
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.axl = plt.subplot(111, frameon=False)
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
运行示例:createPlot()
结果如下:
构造注解树
初始化一个树结构
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
获取叶节点数目和树的深度
我们需要知道有多少叶节点,以便确定x轴的长度。需要知道有多少层,以便确定y轴长度。 用递归思想来得到叶节点个数和树的深度,递归停止的条件是输入到函数中的参数不再是dict类型,意味着不再拥有子节点,即为叶节点。所以一旦到达叶节点,则从递归调用中返回执行else语句,num加一。同理树的深度也是这样。
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0] # py2: myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
plotTree函数
# 在父节点和子节点之间添加文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':# test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else: # it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# if you do get a dictonary you know it's a tree, and the first element will be another dict
使用函数得到最终效果如下:
决策树用于示例测试
测试集模块
其实上述分析与代码构建就是决策树源码的一部分构建过程,下面我们来构建测试集的测试代码。其实测试的过程很简单,就是将待测试的数据,沿着建好的树遍历,直到到达叶节点,返回求证。具体的代码解释见注释。
# 测试 使用特征标签特征列表
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0] # 取出当前树(json)中的第一个key
secondDict = inputTree[firstStr] # 第一个key对应的值(key-value)
featIndex = featLabels.index(firstStr) # 第一个key是第几个特征?返回index
key = testVec[featIndex] # 取出待分类行数据中的该特征的具体值
valueOfFeat = secondDict[key] # 这一层的value,也就是准备下一层的key
if isinstance(valueOfFeat, dict): # 比较testVec中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
测试demo
myDat, labels = trees.createDataSet()
myTree = treePlotter.retrieveTree(0)
print(trees.classify(myTree, labels, [1,0]))
综合案例
本数据集包含90个数据(训练集),分为2类,每类45个数据,每个数据4个属性:
- Sepal.Length(花萼长度),单位是cm;
- Sepal.Width(花萼宽度),单位是cm;
- Petal.Length(花瓣长度),单位是cm;
- Petal.Width(花瓣宽度),单位是cm;
分类种类: Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),部分数据如下:
注意:需要重新定义一下labels,因为前面的labels在经过createTree函数时,del了递归中的所有最优label,剩下的就是没有用到的标签,所以需要重新定义一下。
import trees
import treePlotter
a = []
train_data = []
count = 0
# 切分训练集
with open(r"C:\\Users\\Administrator\\Desktop\\第一次作业 (1)\\第一次作业\\Iris.txt", "r") as f:
for line in f.readlines():
count = count + 1
if 41 < count < 52:
pass
else:
line = line.strip('\n')
train_data.append(line)
train_data[-1] = train_data[-1].split(",")
f.close()
# 总的数据集
with open(r"C:\\Users\\Administrator\\Desktop\\第一次作业 (1)\\第一次作业\\Iris.txt", "r") as f:
for line in f.readlines():
line = line.strip('\n')
a.append(line)
a[-1] = a[-1].split(",")
print("train_data:")
print(train_data)
# 切分测试集
test_data = a[41:51]
# 定义标签
labels = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width']
myTree = trees.createTree(train_data, labels)
print(labels)
# 输出树的结构
print(myTree)
# 树结构反馈在图表中
treePlotter.createPlot(myTree)
# 测试如下
# classify(inputTree, featLabels, testVec)
test_data_feat = []
test_data_label = []
for i in range(len(test_data)):
test_data_feat.append(test_data[i][0:4])
test_data_label.append(test_data[i][-1])
test_labels = []
# 需要重新定义一下labels,因为前面的labels在经过createTree函数时,del了递归中的所有最优label,剩下的就是没有用到的标签,所以需要重新定义一下。
labels = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width']
for i in range(len(test_data)):
test_labels.append(trees.classify(myTree, labels, test_data_feat[i]))
count2 = 0
for i in range(len(test_data)):
if test_labels[i] != test_data_label[i]:
count2 = count2 + 1
print("Accuracy:")
print((len(test_data) - count2) / len(test_data))