决策树算法 MATLAB 简单实现

决策树算法

前言

最近在数据挖掘与机器学习的课程上刚刚学到了决策树算法,于是,想自己用 MATLAB 简单实现一下。虽然拿其中最简单算法的进行实现,但是,从构思–编写–初步完成,也花费了不少时间,毕竟只有动手编写,才能真正体会到算法的内涵。

1 算法流程

通过阅读机器学习的书籍首先了解决策树算法的基本思想:通过递归的方式构建一棵树,子树是通过选取某一属性,按照其属性值进行划分产生的。其算法伪代码如下:

决策树算法 MATLAB 简单实现

2 程序设计

程序设计必须对算法的每个细节都要搞清楚,有时可能要实现一个健全完善的算法很困难,我们可以对算法进行简化,忽略复杂的情况,比如,在上面的构建决策树算法的步骤中,子树的划分可能有多个输出,连续属性和无序离散属性的划分的方法也有所不同,如果都要将这些考虑进去程序的设计难度会很大。作为初学者,可以对问题进行简化:

  • 假设无序离散属性都只是二元属性,属性值用0或1表示
  • 类别只有两类,用0或1表示
  • 每个节点只有两个输出

在明确了细节之后,还需考虑另外一个问题:数据结构。在程序中用什么数据结构来描述所构建的“树”?这一步很关键,因为在对训练集之外的记录进行测试的时候要用到该数据结构。

由于自己实现决策树算法的目的只是加深对算法的理解,并不是实际开发,因此,只是将“树”的结构和参数打印出来。

function build_tree(x, y, L, level, parent_y, sig, p_value)
% 自编的用于构建决策树的简单程序,适用于属性为二元属性,二分类情况。(也可对程序进行修改以适用连续属性)。
% 输入:
% x:数值矩阵,样本属性记录(每一行为一个样本)
% y:数值向量,样本对应的labels
% 其它参数调用时可以忽略,在递归时起作用。
% 输出:打印决策树。
    if nargin == 2
       level = 0; 
       parent_y = -1;
       L = 1:size(x, 2);
       sig = -1;
       p_value = [];
%        bin_f = zeros(size(x, 2), 1);
%        for k=1:size(x, 2)
%            if length(unique(x(:,k))) == 2
%               bin_f(k) = 1; 
%            end
%        end
    end
    class = [0, 1];
    [r, label] = is_leaf(x, y, parent_y); % 判断是否是叶子节点
    if r   
        if sig ==-1
            disp([repmat('     ', 1, level), 'leaf (', num2str(label), ')']);
        elseif sig ==0
            disp([repmat('     ', 1, level), '<', num2str(p_value),' leaf (', num2str(label), ')']);
        else
            disp([repmat('     ', 1, level), '>', num2str(p_value),' leaf (', num2str(label), ')']);
        end
    else
        [ind, value, i_] = find_best_test(x, y, L); % 找出最佳的测试值
%         
%         if ind ==1
%            keyboard; 
%         end
        
        [x1, y1, x2, y2] = split_(x, y, i_, value); % 实施划分
        if sig ==-1
            disp([repmat('     ', 1, level), 'node (', num2str(ind), ', ', num2str(value), ')']);
        elseif sig ==0
            disp([repmat('     ', 1, level), '<', num2str(p_value),' node (', num2str(ind), ', ', num2str(value), ')']);
        else
            disp([repmat('     ', 1, level), '>', num2str(p_value),' node (', num2str(ind), ', ', num2str(value), ')']);
        end
%         if bin_f(i_) == 1
            x1(:,i_) = []; 
            x2(:,i_) = [];
            L(:,i_) = [];
%             bin_f(i_) = [];
%         end
        build_tree(x1, y1, L, level+1, y, 0, value); % 地柜调用
        build_tree(x2, y2, L, level+1, y, 1, value);
    end

    function [ind, value, i_] = find_best_test(xx, yy, LL) % 子函数:找出最佳测试值(可以对连续属性适用)
        imp_min = inf;
        i_ = 1;
        ind = LL(i_);
        for i=1:size(xx,2);
            if length(unique(xx(:,i))) ==1
                continue;
            end
%            [xx_sorted, ii] = sortrows(xx, i); 
%            yy_sorted = yy(ii, :);
           vv = unique(xx(:,i));
           imp_min_i = inf;
           best_point = mean([vv(1), vv(2)]);
           value = best_point;
           for j = 1:length(vv)-1
               point = mean([vv(j), vv(j+1)]);               
               [xx1, yy1, xx2, yy2] = split_(xx, yy, i, point);
               imp = calc_imp(yy1, yy2);
               if imp<imp_min_i
                   best_point = point;
                   imp_min_i = imp;
               end
           end
           if imp_min_i < imp_min
              value = best_point;
              imp_min = imp_min_i;
              i_ = i;
              ind = LL(i_);
           end
        end
    end
    
    function imp = calc_imp(y1, y2) % 子函数:计算熵
        p11 = sum(y1==class(1))/length(y1);
        p12 = sum(y1==class(2))/length(y1);
        p21 = sum(y2==class(1))/length(y2);
        p22 = sum(y2==class(2))/length(y2);
        if p11==0
            t11 = 0;
        else
           t11 = p11*log2(p11); 
        end
        if p12==0
            t12 = 0;
        else
           t12 = p12*log2(p12); 
        end
        if p21==0
            t21 = 0;
        else
           t21 = p21*log2(p21); 
        end
        if p22==0
            t22 = 0;
        else
           t22 = p22*log2(p22); 
        end
        
        imp = -t11-t12-t21-t22;
    end

    function [x1, y1, x2, y2] = split_(x, y, i, point) % 子函数:实施划分
       index = (x(:,i)<point);
       x1 = x(index,:);
       y1 = y(index,:);
       x2 = x(~index,:);
       y2 = y(~index,:);
    end
    
    function [r, label] = is_leaf(xx, yy, parent_yy) % 子函数:判断是否是叶子节点
        if isempty(xx)
            r = true;
            label = mode(parent_yy);
        elseif length(unique(yy)) == 1
            r = true;
            label = unique(yy);
        else
            t = xx - repmat(xx(1,:),size(xx, 1), 1);
            if all(all(t ==0))
                r = true;
                label = mode(yy);
            else
                r = false;
                label = [];
            end
        end
    end
end

利用MATLAB提供的数据集进行测试,并与 MATLAB 自身提供的决策树分类的函数进行对比。

clc
clear all
load ionosphere % contains X and Y variables
x = X(:,1:3);
ind = x(:,3)>0;
x(ind,3) = 1;
x(~ind,3) = 0;

y = zeros(size(Y));
y(ismember(Y, 'b')) = 1;

ctree = fitctree(x, y);
view(ctree,'mode','graph') % graphic description
% [label, score] = predict(ctree, X(5,:))

build_tree(x, y);

自编程序运行结果

含义说明:

node(属性序号, 划分点)

leaf(类别)

决策树算法 MATLAB 简单实现

MATLAB提供的函数的运行结果

决策树算法 MATLAB 简单实现

结果与MATLAB中自己实现的函数运行结果相同。

3 MATLAB 中的调用

自己对算法的实现的目的主要还是用于加深对算法的理解,但是在实际应用时,还得借助成熟的机器学习工具包,比如MATLAB或Python提供的机器学习工具包。下面介绍一下MATLAB中决策树算法的相关函数的调用方法。

tree = fitctree(x,y) 
tree = fitctree(x,y,Name,Value)

根据给定的记录的属性x,对应类别y,构造决策树(二叉树)。要求x为数值矩阵,y为数值向量或cell数组。name-value pair 为可选参数,用于指定算法的参数(划分准则,叶子节点最少记录值等)。x, y 每一行为一个样本。

返回tree为决策树的数据结构。

利用tree进行分类:

label = predict(tree, x)

4 Python 中的调用

scikit-learn 库提供了决策树分类和回归的方法.

训练

>>> from sklearn import tree
>>> X = [[0, 0], [1, 1]]
>>> Y = [0, 1]
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, Y)

分类

>>> clf.predict([[2., 2.]])
array([1])

DecisionTreeClassifier能够进行二元分类(标签为[- 1,1])和多类分类(标签为[0,…,K-1])。