700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 【机器学习实战】决策树算法:预测隐形眼镜类型

【机器学习实战】决策树算法:预测隐形眼镜类型

时间:2024-01-11 14:27:51

相关推荐

【机器学习实战】决策树算法:预测隐形眼镜类型

【机器学习实战】决策树算法:预测隐形眼镜类型

0.收集数据

这里采用的数据集是《机器学习实战》提供的lenses.txt文件,该文件内容如下:

youngmyopenoreducedno lensesyoungmyopenonormalsoftyoungmyopeyesreducedno lensesyoungmyopeyesnormalhardyounghypernoreducedno lensesyounghypernonormalsoftyounghyperyesreducedno lensesyounghyperyesnormalhardpremyopenoreducedno lensespremyopenonormalsoftpremyopeyesreducedno lensespremyopeyesnormalhardprehypernoreducedno lensesprehypernonormalsoftprehyperyesreducedno lensesprehyperyesnormalno lensespresbyopicmyopenoreducedno lensespresbyopicmyopenonormalno lensespresbyopicmyopeyesreducedno lensespresbyopicmyopeyesnormalhardpresbyopichypernoreducedno lensespresbyopichypernonormalsoftpresbyopichyperyesreducedno lensespresbyopichyperyesnormalno lenses

每列数据类型分别是 age、prescript、astigmatic、tearRateage、prescript、astigmatic、tearRateage、prescript、astigmatic、tearRate ,而最后一列的类型是隐形眼镜的类型。

1.准备数据:解析tab键分隔的数据行

首先由于我们的数据文件是以 TabTabTab 分割开各列之间的数据的,所以我们首先需要获取被分隔的数据行。

代码如下,其中 strip()strip()strip() 表示删除掉数据中的换行符,则split('\t')是数据中遇到'\t'(既 TabTabTab) 就隔开。

fr = open('lenses.txt') # 打开数据集文件lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 解析tab键分割的数据行

由于 lenses.txtlenses.txtlenses.txt 文件中并没有对每列数据进行命名,这里我将每列数据的名称准备在 lensesLabelslensesLabelslensesLabels 变量中。

lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

数据都准备好了,接下来就可以开始我们的决策树构造了。

2.决策树的构造

决策树算法(DecisionTreeDecision TreeDecisionTree):决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题。

适用数据类型:数值型和标称型。

2.1 信息增益

划分数据集的大原则是:将无序的数据变得更加有序。在划分数据集之前之后信息发生的变化称为信息增益,这里我们采用香农熵来计算信息的增益。

如果待分类的事务可能划分在多个分类中,则符号 xix_ixi​ 的信息定义为:l(xi)=−log2p(xi)l(x_i)=-log_2p(x_i)l(xi​)=−log2​p(xi​)

其中 p(xi)p(x_i)p(xi​) 是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到(其中 nnn 是分类的数目):

H=−∑i=1np(xi)log2p(xi)H=-\sum^{n}_{i=1}p(x_i)log_2p(x_i)H=−∑i=1n​p(xi​)log2​p(xi​)

from math import log#计算给定数据集的香农熵def calcShannonEnt(dataSet):numEntries = len(dataSet) # 获取数据集中实例的总数labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1] # featVec[-1]是指获取最后一个数值if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0 # 新添加的值,所以计数为 0labelCounts[currentLabel] += 1shannonEnt = 0.0 # shannonEnt用于记录计算的香农熵for key in labelCounts:prob = float(labelCounts[key])/numEntries # 计算P(xi)的概率shannonEnt -= prob * log(prob, 2) # 计算香农熵return shannonEnt

由于熵越高,则混合的数据也越多,因此我们可以通过计算香农熵来划分数据集。

2.2 划分数据集

首先先把当作特征值的属性进行抽取。

# 输入参数分别是:待划分的数据集、划分数据集的特征,需要返回的特征的值def splitDataSet(dataSet, axis, value):retDataSet = [] # 创建新的list对象for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis] # 获取关键特征前面的属性reducedFeatVec.extend(featVec[axis + 1 :]) # 填加关键特征后面的属性retDataSet.append(reducedFeatVec) # 以上步骤相当于对特征值进行抽取return retDataSet # 返回抽取特征后的数据集

然后再依次计算以不同属性值为特征值时的香农熵,判断以何种属性为特征值时是最优的数据划分。

# 选择最好的数据集划分方式def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1 #获取每个数据集拥有几个特征(排除最后一个)beseEntropy = calcShannonEnt(dataSet) # 计算以最后一个数值为特征的香农熵bestInfoGain = 0.0;bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]# 将dataSet中的数据先按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中uniqueVals = set(featList) # set() 函数创建一个无序不重复元素集newEntropy = 0.0for value in uniqueVals: # 计算每种划分方式的信息熵subDataSet = splitDataSet(dataSet, i, value) # 按照给定特征划分数据集prob = len(subDataSet) / float(len(dataSet)) # 计算当前结果的可能性newEntropy += prob * calcShannonEnt(subDataSet) # 不同可能性的香农熵的和infoGain = beseEntropy - newEntropyif(infoGain > bestInfoGain): # 判断是否是当前最小香农熵,计算出最好的信息增益bestInfoGain = infoGainbestFeature = ireturn bestFeature

到这里,我们已经可以计算当前数据的最好划分方式了,但决策树不是只划分一次就好了,而是层层递进的划分下去,因此接下来就开始实现递归构建决策树。

2.3 递归构建决策树

工作原理:得到原始数据,然后基于最好的属性值划分数据集,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,再这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。

递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

首先使用分类名称的列表,然后创建值为 classListclassListclassList 中唯一值的数据字典,字典对象存储了 classListclassListclassList 中每个类标签出现的频率,最后利用 operatoroperatoroperator 操作键值排序字典,并返回出现次数最多的分类名称。

import operatordef majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedClassCount # 返回出现次数最多的分类名称

接着就可以创建树了,其中变量 myTreemyTreemyTree 包含了很多代表树结构信息的嵌套字典,至此我们已经正确的构建好了树。

# 创建树的函数代码,两个输入参数:数据集和标签列表def creatTree(dataSet, labels):classList = [example[-1] for example in dataSet]# 将dataSet中的数据先按行依次放入example中,然后取得example中的example[-1]元素,放入列表classList中if classList.count(classList[0]) == len(classList): # 类别完全相同则停止继续划分return classList[0]if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的类别return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最好的数据集划分方式bestFeatLabel = labels[bestFeat] # 获取属性文字标签myTree = {bestFeatLabel : {}}# 得到列表包含的所有属性值del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree

3.在Python中使用Matplotlib注解绘制树形图

由于这里使用的主要是 MatplotlibMatplotlibMatplotlib 绘图的知识,与机器学习关系不大,故这里不对代码进行详细讲解。

import matplotlib.pyplot as pltimport matplotlib# 定义文本框和箭头格式decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")leafNode = dict(boxstyle = "round4", fc = "0.8")arrow_args = dict(arrowstyle = "<-")# 绘制带箭头的注解def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',xytext = centerPt, textcoords = 'axes fraction',va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)# 获取叶节点的数目和树的层数def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else: numLeafs += 1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else: thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepth# plotTree函数# 在父子节点间填充文本信息def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# 计算宽与高def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#标记子节点属性值plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodesplotTree(secondDict[key], cntrPt, str(key)) # recursionelse: # it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 这个是真正的绘制,上边是逻辑的绘制def createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticksplotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalW;plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.axis('off') # 去掉坐标轴plt.show()

4.使用算法

主函数代码:

if __name__ == "__main__":fr = open('lenses.txt') # 打开数据集文件lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 解析tab键分割的数据行lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = creatTree(lenses, lensesLabels)createPlot(lensesTree)

运行过后就可以得到我们的结果,如下图片:

5.总结

这个算法的思想本质其实并不复杂,但我在阅读代码的过程中却是困难重重😭,在针对最后的绘图部分代码时,对代码的理解也并不是很好,看来这就是基础不牢,地动山摇(╯‵□′)╯︵┻━┻

但总的来说,这个算法也算是理解了,那么这一章就这么结束了吧 !

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。