决策树分类算法的实现及其应用ID3算法

实验二  决策树分类算法的实现及其应用

【实验目的】

  1. 掌握决策树分类算法ID3的概念,理解算法的步骤。
  2. 加深对ID3算法的理解,逐步培养解决实际问题的能力。

【实验性质】

设计型实验

【实验内容】

 使用ID3算法来实现决策树分类

【实验环境】

Python 2

【实验结果】

决策树分类算法的实现及其应用ID3算法

【实验步骤】

程序清单3-1 计算给定数据集的香农熵

程序代码:

import operator

from math import log

def calcShannonEnt(dataSet):

    numEntries=len(dataSet)

    labelCounts={}

    for featVec in dataSet:

        currentLabel=featVec[-1]

        if currentLabel not in labelCounts.keys():

            labelCounts[currentLabel]=0

            labelCounts[currentLabel]+=1

        shannonEnt=0.0

        for key in labelCounts:

            prob=float(labelCounts[key]) / numEntries

            shannonEnt -= prob*log(prob,2)

        return shannonEnt

def createDataSet():

    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]

    labels =['no surfacing','flippers']

    return dataSet,labels

 

运行结果;

决策树分类算法的实现及其应用ID3算法

 

程序清单3-2 按照给定特征划分数据集

程序代码:

 

def splitDataSet(dataSet,axis,value):

    retDataSet=[]

    for featVec in dataSet:

        if featVec[axis] == value:

            reducedFeatVec=featVec[:axis]

            reducedFeatVec.extend(featVec[axis+1:])

            retDataSet.append(reducedFeatVec)

    return retDataSet

 

运行结果:

决策树分类算法的实现及其应用ID3算法

 

程序清单3-3 选择最好的数据集划分方式

程序代码:

def chooseBestFeatureToSplit(dataSet):

  numFeatures = len(dataSet[0])-1

  baseEntropy = calcShannonEnt(dataSet)

  bestInfoGain = 0.0;bestFeature = -1

  for i in range(numFeatures):

    featList = [example[i] for example in dataSet]

    uniqueVals = set(featList)

    newEntropy = 0.0

    for value in uniqueVals:

      subDataSet = splitDataSet(dataSet,i,value)

      prob = len(subDataSet)/float(len(dataSet))

      newEntropy += prob * calcShannonEnt(subDataSet)

    infoGain = baseEntropy - newEntropy

    if (infoGain > bestInfoGain):

      bestinfoGain = infoGain

      bestFeature = i

  return bestFeature

 

运行结果:

决策树分类算法的实现及其应用ID3算法

 

3.1.3递归构建决策树

程序代码:

def majortyCnt(classList):

  classCount = {}

  for vote in classList:

    if vote not in classCount.keys():

      classCount[vote] = 0

    classCount[vote] += 1

  sortedClassCount = sorted(classCount.items(),

                            key = operator.itemgetter(1),reverse = True)

  return sortedClassCount[0][0]

 

程序清单3-4 创建树的函数代码

程序代码:

 

def createTree(dataSet,labels):

  classList = [example[-1] for example in dataSet]

  if classList.count(classList[0]) == len(classList):

    return classList[0]

  if len(dataSet[0]) == 1:

    return majorityCnt(classList)

  bestFeat = chooseBestFeatureToSplit(dataSet)

  bestFeatLable = labels[bestFeat]

  myTree = {bestFeatLable:{}}

  del(labels[bestFeat])

  featValues = [example[bestFeat] for example in dataSet]

  uniqueVals = set(featValues)

  for value in uniqueVals:

    subLabels = labels[:]

    myTree[bestFeatLable][value] = createTree(splitDataSet\

                                              (dataSet,bestFeat,value),subLabels)

  return myTree

 

运行结果:

决策树分类算法的实现及其应用ID3算法

 

3.2 在python中使用Matplotlib注解绘制树形图

3.2.1Matplotlib注解

程序清单3-5 使用文本注解绘制树节点

程序代码;

import matplotlib.pyplot as plt

 

decisionNode = dict(boxstyle = "sawtooth",fc = "0.8")

leafNode = dict(boxstyle = "round4",fc = "0.8")

arrow_args = dict(arrowstyle = "<-")

def plotNode (nodeTxt,centerPt,parentPt,nodeType):

  createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',

                          xytext = centerPt,textcoords = 'axes fraction',

                          va = "center",ha = "center",bbox = nodeType,

                          arrowprops = arrow_args)

def createPlot():

  fig = plt.figure(1,facecolor = 'white')

  fig.clf()

  createPlot.ax1 = plt.subplot(111,frameon = False)

  plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)

  plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)

  plt.show()

 

运行结果:

决策树分类算法的实现及其应用ID3算法

 

程序清单3-6 获取叶节点的数目和树的层次

程序代码:

 

def getNumLeafs(myTree):

  numLeafs = 0

  firstStr = myTree.keys()[0]

  secondDict = myTree[firstStr]

  for key in secondDict.keys():

    if type(secondDict[key]).__name__=='dict':

      numLeafs += getNumLeafs(secondDict[key])

    else:

      numLeafs += 1

  return numLeafs

def getTreeDepth(myTree):

  maxDepth = 0

  firstStr = myTree.keys()[0]

  secondDict = myTree[firstStr]

  for key in secondDict.keys():

    if type(secondDict[key]).__name__=='dict':

      thisDepth = 1 + getTreeDepth(secondDict[key])

    else:

      thisDepth = 1

    if thisDepth > maxDepth:

      maxDepth = thisDepth

  return maxDepth

def retrieveTree(i):

  listOfTree = [{'no surfacing':{0:'no',1:{'flippers':\

                                           {0:'no',1:'yes'}}}},

                {'no surfacing':{0:'no',1:{'flippers':\

                                           {0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]

  return listOfTree[i]

运行结果:

决策树分类算法的实现及其应用ID3算法

 

程序清单3-7 plotTree函数

程序代码:

def plotMidText(cntrPt,parentPt,txtString):

    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]

    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]

    createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):

    numLeafs=getNumLeafs(myTree)

    depth=getTreeDepth(myTree)

    firstStr=myTree.keys()[0]

    cntrPt=(plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\

            plotTree.yOff)

    plotMidText(cntrPt,parentPt,nodeTxt)

    plotNode(firstStr,cntrPt,parentPt,decisionNode)

    secondDict=myTree[firstStr]

    plotTree.yOff=plotTree.yOff - 1.0/plotTree.totalD

    for key in secondDict.keys():

        if type(secondDict[key]).__name__=='dict':

            plotTree(secondDict[key],cntrPt,str(key))

        else:

            plotTree.xOff=plotTree.xOff + 1.0/plotTree.totalW

            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),

                     cntrPt,leafNode)

            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))

    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

def createPlot(inTree):

    fig=plt.figure(1,facecolor='white')

    fig.clf()

    axprops=dict(xticks=[],yticks=[])

    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)

    plotTree.totalW=float(getNumLeafs(inTree))

    plotTree.totalD=float(getTreeDepth(inTree))

    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;

    plotTree(inTree,(0.5,1.0),'')

    plt.show()

运行结果:

决策树分类算法的实现及其应用ID3算法

 

最终结果:

 

决策树分类算法的实现及其应用ID3算法

我想能看到这里的同学,无外乎两种人:来拷贝代码的人 和 来拷贝代码的人。

 

但,在拷贝走的时候,你要想清楚一件事,把代码拷走之后有个蛋用,搞明白对你来说才是最重要的。

 

好了,就酱紫。

 

 

 

老铁,这要是都不赞,说不过去吧!!!

 

 

最后对自己说:

你现在所遭遇的每一个不幸,都来自一个不肯努力的曾经。