机器学习实战决策树画图理解

第二章决策树用matplotlib画图的理解

作为一个小白呢,确实对于我们来说第二章画图部分有很大的难度,个人也是花了很多时间在****上反正是各种找,最后基本上弄明白了,就是想给同样是小白的人节约一点时间,但是深入理解的时间不能少哦。

决策树matplotlib画图代码

结果图

图片
机器学习实战决策树画图理解
机器学习实战决策树画图理解机器学习实战决策树画图理解

具体代码

# -*- coding: utf-8 -*-
"""
Created on Mon Apr  1 18:56:26 2019

@author: 风飘  小谭谭
"""

import matplotlib.pyplot as plt


#这里是对绘制是图形属性的一些定义
#boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细  
decisionNode = dict(boxstyle = 'sawtooth',fc = '0.8') #定义decision节点的属性
leafNode = dict(boxstyle='round4',fc='0.8')           #定义leaf节点的属性  
arrow_args = dict(arrowstyle='<-')                    #定义箭头方向 与常规的方向相反   

#声明绘制一个节点的函数
'''
annotate是关于一个数据点的文本 相当于注释的作用 
nodeTxt:即为文本框或锯齿形里面的文本内容
centerPt:即为子节点的坐标
parentPt:即为父节点的坐标
nodeType:是判断节点还是叶子节点
'''
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords='axes fraction',xytext = centerPt,
                            textcoords='axes fraction',va = 'center',ha='center',bbox = nodeType,arrowprops=arrow_args)
'''
#声明绘制图像函数,调用绘制节点函数
    
def createPlot():
    fig = plt.figure(1,facecolor='white')  #新建绘画窗口  窗口名为figure1  背景颜色为白色
    fig.clf()           #清空绘图区
     #创建了属性ax1  functionname.attribute的形式是在定义函数的属性,且该属性必须初始化,否则不能进行其他操作。
    createPlot.ax1 = plt.subplot(111,frameon=False)        #创建11列新的绘图区 且图绘于第一个区域 frameon表示是否绘制坐标轴矩形 True代表绘制 False代表不绘制
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)#调用画节点的函数
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode )
    plt.show()   #画图
'''
#--------鉴于python3与python2的不同  python3 dict_keys支持iterable,而不支持indexable
#故先将myTree.keys()返回得到的dict_keys对象转化为列表
#否则会报TypeError: 'dict_keys' object does not support indexing的错误

#获得叶节点的数目
#是一个累加的过程
def getNumLeafs(myTree):
    numLeafs = 0                #声明叶节点的个数

    #python3
    firstSides = list(myTree.keys())    #得到树所有键的列表
    firstStr = firstSides[0]            #取第一个键
    #pyton2:firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]       #得到第一个键所对应的的值

    for key in secondDict.keys():              #循环遍历secondDict的键
        if type(secondDict[key]).__name__ == 'dict': #判断该键对应的值是否是字典类型
            numLeafs += getNumLeafs(secondDict[key]) #若是则使用递归进行计算
        else:
            numLeafs += 1                           #不是则代表当前就是叶子节点,进行加1即可

    return numLeafs                                 #返回叶子结点数目


#获得叶节点的深度
#因为某一层不一定是最深的,所以引入thisDepth
#是一个求最值的过程
def getTreeDepth(myTree):
    maxDepth = 0                        #声明最大深度并赋值为0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]

    for key in secondDict:
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth =1 + getTreeDepth(secondDict[key]) 
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth

#预先存储树的信息
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
#plotTree函数

#在父子节点间填充文本
'''
cntrPt:子节点位置坐标
parentPt:父节点位置坐标
txtString:文本信息即为图中的0,1
'''
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]      #文本填充的x坐标
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]       #文本填充的y坐标
    createPlot.ax1.text(xMid,yMid,txtString)#在(xMid,yMid)位置填充txtString文本



#画树的函数
'''
myTree: 要进行绘制的树
parentPt:父节点位置坐标
nodeTxt:文本内容

plotTree.totalW: 整棵树的叶子节点数(常量)
plotTree.totalD : 整棵树的深度(常量)


'''
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)   #求得myTree的叶子的个数  注意这可不是我们之前所说的那颗最大的树 谁调用它谁是myTree
    depth = getTreeDepth(myTree)     #求得myTree的深度 
    #python3.6的原因,与书中有两行不一样
    #-----
    firstSides = list(myTree.keys())  #即为['no surfacing']
    firstStr = firstSides[0]        #得到第一个键 也就是第一个判断节点 即myTree的根节点即为'no surfacing'
    #-----

    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2/plotTree.totalW,plotTree.yOff) #计算子节点的坐标       
    plotMidText(cntrPt,parentPt,nodeTxt)  #对判断节点进行的绘制其与其父节点之间的文本信息   此处第一个节点与父节点重合(0.5,1.0)的设置 所以会没有效果 也恰好符合题意
    plotNode(firstStr,cntrPt,parentPt,decisionNode)     #绘制子节点
    secondDict = myTree[firstStr]                       #得到该节点以下的子树
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #深度改变 纵坐标进行减一层

    #循环遍历各个子树的键值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict': #如果该子树下边仍为一颗树(即字典类型)
            plotTree(secondDict[key],cntrPt,str(key))#进行递归绘制
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)#绘制叶子节点 (plotTree.xOff,plotTree.yOff)代表叶子节点(子节点)坐标,cntrPt代表判断节点父节点坐标
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))#在父子节点之间填充文本信息

    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #循环结束  



#声明绘制图像函数,调用绘制节点函数
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')  #新建绘画窗口  窗口名为figure1  背景颜色为白色
    fig.clf()           #清空绘图区
    axprops = dict(xticks=[],yticks=[]) #定义横纵坐标轴
     #创建了属性ax1  functionname.attribute的形式是在定义函数的属性,且该属性必须初始化,否则不能进行其他操作。
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)        #创建11列新的绘图区 且图绘于第一个区域 frameon表示不绘制坐标轴矩形 定义坐标轴为二维坐标轴
    plotTree.totalW = float(getNumLeafs(inTree))  #计算树的叶子数即为3
    plotTree.totalD = float(getTreeDepth(inTree)) #计算树的深度即为2

    plotTree.xOff = -0.5/plotTree.totalW    #赋值给绘制叶子节点的变量为-0.5/plotTree.totalW 
    plotTree.yOff = 1.0                     #赋值给绘制节点的初始值为1.0 

    plotTree(inTree,(0.5,1.0),'')              #调用函数plotTree 且开始父节点的位置为(0.5,1.0) 

    plt.show()   #画图

前面代码的注释的相对较为详细,其中对于难的代码的简单解释
1.cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2/plotTree.totalW,plotTree.yOff)
其实plotTree.xOff第一次为-1/6,即为向右平移1/6,plotTree.yOff第一次为1,经过上面这个cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2/plotTree.totalW,plotTree.yOff)
代码可得cntrPt =(-1/6+(1+3)*1/2 *1/3,1)即为(1/2,1)对照即可得出。
2.parentPt的坐标已经在代码 plotTree(inTree,(0.5,1.0),’’)中得出为(0.5,1.0)后面根据代码以此类推
整个过程相当于是已知决策树来画,就是指已经知道树节点和叶子节点的数目,根据它们的数目来划分整个画布,并且在x轴以1/叶子节点数目来表示每个叶子之间的x轴距离,为什么plotTree.xOff = -0.5/plotTree.totalW 其中的负号指向右平移,半个叶子节点的距离而叶子之间的距离计算也是以半个叶子节点的距离为单位,看第一个叶子结点的距离与第二个叶子节点的距离相差几个,即第二个就向右平移几个,在叶子节点中间的根节点,即在它们的中间。对于y轴,是根据循环的层数来决定,在同一个循环里即为同一层。

以上解释有一点混乱,建议结合图和代码一步一步把坐标算出来基本上就理解了。谢谢!