决策树(二):CART回归树与Python代码

上一篇介绍了决策树的基本概念,特征划分标准及ID3、C4.5和CART分类树的算法,本文着重对CART回归树的内容进行补充。

本文概览
首先介绍CART回归树的算法,然后是创建CART回归树的主要步骤,最后是实现该过程的Python代码。

一、CART回归树算法

CART回归树处理的是回归问题,数据集的标签不再是离散的类别值,而是一系列的连续值的集合。
CART回归树不同于线性回归模型,不是通过拟合所有的样本点来得到一个最终模型进行预测,它是一类基于局部的回归算法,通过采用一种二分递归分割的技术将数据集切分成多份,每份子数据集中的标签值分布得比较集中(比如以数据集的方差作为数据分布比较集中的指标),然后采用该数据集的平均值作为其预测值。这样,CART回归树算法也可以较好地拟合非线性数据。

假如数据集的标签(目标值)的集合呈现如下非线性目标函数的值,CART回归树算法将数据集切分成很多份,即将如下函数切成一小段一小段的,对于每一小段的值是较为接近的,可以每一小段的平均值作为该小段的目标值。
决策树(二):CART回归树与Python代码

二、CART回归树生成

1. CART回归树的划分

在CART分类树中,是利用Gini指数作为划分的指标,通过样本中的特征对样本进行划分,直到所有的叶节点中的所有样本均为一个类别为止。其中,Gini指数表示的是数据的混乱程度,对于回归树,样本标签是连续数据,当数据分布比较分散时,各个数据与平均值的差的平方和较大,方差就较大;当数据分布比较集中时,各个数据与平均值的差的平方和较小。方差越大,数据的波动越大;方差越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的差的平方和作为划分回归树的指标。

假设,有m个训练样本,{(X(1),y(1)),(X(2),y(2)), …, (X(m),y(m))}, 则划分CART回归树的指标为:
ms2=i=1m(y(i)y)2 m*s^2 = \sum_{i=1}^m (y^{(i)}-\overline{y})^2

下面是用Python实现CART回归树的划分指标:

import numpy as np
def calculate_err(data):
	"""
	input: data(list)
	output: m*s^2(float)
	"""
	data = np.mat(data)
	return np.var(data[:,-1]) * data.shape[0]
有了划分的标准,那么应该如何对样本进行划分呢?与CART分类树的划分一样,遍历各特征的所有取值,尝试将样本划分到树节点的左右子树中。只是因为不同的划分标准,在选择划分特征和特征值时的比较会有差异而已。 下面是左右子树划分的代码:
def split_tree(data, fea, value):
    '''根据特征fea中的值value将数据集data划分成左右子树
    input:  data(list):数据集
            fea(int):待分割特征的索引
            value(float):待分割的特征的具体值
    output: (set1,set2)(tuple):分割后的左右子树
    '''
    set_1 = []
    set_2 = []
    for x in data:
        if x[fea] >= value:
            set_1.append(x)
        else:
            set_2.append(x)
    return (set_1, set_2)

2. CART回归树的构建

CART回归树的构建也类似于CART分类树,主要的不同有三方面:

  1. 在选择划分特征与特征值的比较时,不是计算Gini指数,而是计算被划分后两个子数据集中各样本与平均值的差的平方和,选择此值较小的情况对数据集进行划分。
  2. 针对每一个叶节点,不是取样本的类别,而是各样本的标签值的平均平均值作为预测结果。
  3. 最后,CART回归树可通过设置参数进行前剪枝操作,此次构建中有设置了min_sample和min_err来控制树的节点是否需要进一步划分。
class node:
    '''树的节点的类
    '''
    def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
        self.fea = fea  # 用于切分数据集的属性的列索引值
        self.value = value  # 设置划分的值
        self.results = results  # 存储叶节点所属的类别
        self.right = right  # 右子树
        self.left = left  # 左子树

def build_tree(data, min_sample, min_err):
    '''构建树
    input:  data(list):训练样本
    		min_sample(int): 叶子节点中最少的样本数
    		min_err(float): 最小的error
    output: node:树的根结点
    '''
    # 构建决策树,函数返回该决策树的根节点
    if len(data) <= min_sample:
        return node(results=leaf(data)
    
    # 1、初始化
    bestError = calculate_err(data)
    bestCriteria = None  # 存储最佳切分属性以及最佳切分点
    bestSets = None  # 存储切分后的两个数据集
    
     # 2、构建回归树 
    feature_num = len(data[0]) - 1  # 样本中特征的个数
    for fea in range(0, feature_num):
    	feature_values = {}
    	for sample in data:
    		feature_values[sample[fea]] = 1
    	
    	for value in feature_values.keys():
    		# 2.1 尝试划分
    		(set_1, set_2) = split_tree(data, fea, value)
    		if len(set_1) < 2 or len(set_2) < 2:
    			continue
    		# 2.2 计算划分后的error
    		nowError = calculate_err(set_1) + calculate_err(set_2)
			if nowError < bestError and len(set_1) > 0 and len(set_2) > 0:
				bestError = nowError
				bestCriteria = (fea, value)
				bestSets = (set_1, set_2)
    
    # 3、判断划分是否结束
    if bestError > min_err:
        right = build_tree(bestSets[0],min_sample, min_err)
        left = build_tree(bestSets[1],min_sample, min_err)
        return node(fea=bestCriteria[0], value=bestCriteria[1], right=right, left=left)
    else:
        return node(results=leaf(data))  


def leaf(data):
	"""
	计算叶节点的平均值
	"""
	data = np.mat(data)
	return np.mean(data[:,-1])

3. CART回归树的剪枝

在CART回归树中,当树中的节点对样本一直划分下去时,会出现最极端的情况:每一个叶子节点中仅包含一个样本,此时,叶子节点的值即为该样本的标签均值。这种情况极易对数据过拟合,为防止发生过拟合,需要对CART回归树进行剪枝,以防止生成过多的叶子节点。
在剪枝中主要分为:前剪枝和后剪枝。

  1. 前剪枝是指在生成树的过程中对树的深度进行控制,防止生成过多的叶子节点。在build_tree函数中就使用了min_sample和min_err来控制树中的节点是否需要进行更多的划分。通过不断调节这两个参数来找到合适的CART树模型。

  2. 后剪枝
    后剪枝是指将训练样本分成两个部分,一部分用来训练CART树模型,这部分数据被称为训练数据,另一部分用来对生成的树模型进行剪枝,称为验证数据。
    在后剪枝的过程中,通过验证生成的CART树模型是否在验证数据集上发生过拟合,如果出现过拟合的现象,则合并一些叶子节点来达到CART树模型的剪枝。

本文中主要使用的是前剪枝,通过调整min_sample和min_err参数的方式。

4. 数据预测

CART回归树模型构建好后,利用训练数据来训练该模型,最后训练好的回归树模型需要进行评估,了解预测值与实际值间的差距是否在接受范围内。
对CART回归树进行评估时,因需要对数据集中各样本进行预测,然后利用预测值与原始样本的标签值计算残差,所以,首先要建立predict函数。

def predict(sample,tree):
	"""对每一个样本sample进行预测
	input: sample(list)
	output: results(float)
	"""
	# 如果只是树根
	if tree.results != None:
		return tree.results
	else:
		# 有子树
		val_sample = sample(tree.fea)
		branch = None
		if val_sample >= tree.value:
			branch = tree.right
		else:
			branch = tree.left
		return predict(sample, branch)

接下来,对数据集进行预测,并计算残差,代码如下:

def evaluate_error(data, tree):
	"""评估CART回归树模型
	input: data(list)
			tree: 训练好的CART回归树模型
	output: total_error/m(float): 均方误差
	"""
	m = len(data)
	total_error = 0.0
	for i in range(m):
		sample = data[i,:-1]
		pred = predict(sample, tree)
		total_error += np.square(data[i,-1] - pred)
	return total_error / m

总结

最后总结一下,CART回归树算法是用来预测目标为连续值的算法,是一类基于局部的回归算法。CART回归树的构建是先利用类似CART分类树的方法将数据集进行划分,但划分的标准不同,本文使用的指标是各数值与平均值的差的平方和,在划分时选择使划分后的左右子树的该指标之和较小的特征与特征值,直到最后叶子节点中的样本个数达到min_sample或者各数值与平均值的差的平方和达到min_err为止,最后基于该叶子节点中的平均值作为预测值。在训练模型中,为避免过拟合,使用了参数min_sample和min_err来控制CART回归树模型的生成。