决策树算法
采用ID3算法来实现决策树分类,下面通过举例天气等因素对销量影响的因子。
使用ID3算法建立决策树的MATLAB代码如下:
ID3_decision_tree.m
%% 使用ID3决策树算法预测销量高低
clear ;
%% 数据预处理
disp('正在进行数据预处理...');
[matrix,attributes_label,attributes] = id3_preprocess();
%% 构造ID3决策树,其中id3()为自定义函数
disp('数据预处理完成,正在进行构造树...');
tree = id3(matrix,attributes_label,attributes);
%% 打印并画决策树
[nodeids,nodevalues] = print_tree(tree);
tree_plot(nodeids,nodevalues);
disp('ID3算法构建决策树完成!');
Id3_preprocess.m
function [ matrix,attributes,activeAttributes ] = id3_preprocess( )
%% ID3算法数据预处理,把字符串转换为0,1编码
% 输出参数:
% matrix: 转换后的0,1矩阵;
% attributes: 属性和Label;
% activeAttributes : 属性向量,全1;
%% 读取数据
txt = { '序号' '天气' '是否周末' '是否有促销' '销量'
'' '坏' '是' '是' '高'
'' '坏' '是' '是' '高'
'' '坏' '是' '是' '高'
'' '坏' '否' '是' '高'
'' '坏' '是' '是' '高'
'' '坏' '否' '是' '高'
'' '坏' '是' '否' '高'
'' '好' '是' '是' '高'
'' '好' '是' '否' '高'
'' '好' '是' '是' '高'
'' '好' '是' '是' '高'
'' '好' '是' '是' '高'
'' '好' '是' '是' '高'
'' '坏' '是' '是' '低'
'' '好' '否' '是' '高'
'' '好' '否' '是' '高'
'' '好' '否' '是' '高'
'' '好' '否' '是' '高'
'' '好' '否' '否' '高'
'' '坏' '否' '否' '低'
'' '坏' '否' '是' '低'
'' '坏' '否' '是' '低'
'' '坏' '否' '是' '低'
'' '坏' '否' '否' '低'
'' '坏' '是' '否' '低'
'' '好' '否' '是' '低'
'' '好' '否' '是' '低'
'' '坏' '否' '否' '低'
'' '坏' '否' '否' '低'
'' '好' '否' '否' '低'
'' '坏' '是' '否' '低'
'' '好' '否' '是' '低'
'' '好' '否' '否' '低'
'' '好' '否' '否' '低' }
attributes=txt(1,2:end);
activeAttributes = ones(1,length(attributes)-1);
data = txt(2:end,2:end);
%% 针对每列数据进行转换
[rows,cols] = size(data);
matrix = zeros(rows,cols);
for j=1:cols
matrix(:,j) = cellfun(@trans2onezero,data(:,j));
end
end
function flag = trans2onezero(data)
if strcmp(data,'坏') ||strcmp(data,'否')...
||strcmp(data,'低')
flag =0;
return ;
end
flag =1;
End
Id3.m
function [ tree ] = id3( examples, attributes, activeAttributes )
%% ID3 算法 ,构建ID3决策树
...参考:https://github.com/gwheaton/ID3-Decision-Tree
% 输入参数:
% example: 输入0、1矩阵;
% attributes: 属性值,含有Label;
% activeAttributes: 活跃的属性值;-1,1向量,1表示活跃;
% 输出参数:
% tree:构建的决策树;
%% 提供的数据为空,则报异常
if (isempty(examples));
error('必须提供数据!');
end
% 常量
numberAttributes = length(activeAttributes);
numberExamples = length(examples(:,1));
% 创建树节点
tree = struct('value', 'null', 'left', 'null', 'right', 'null');
% 如果最后一列全部为1,则返回“true”
lastColumnSum = sum(examples(:, numberAttributes + 1));
if (lastColumnSum == numberExamples);
tree.value = 'true';
return
end
% 如果最后一列全部为0,则返回“false”
if (lastColumnSum == 0);
tree.value = 'false';
return
end
% 如果活跃的属性为空,则返回label最多的属性值
if (sum(activeAttributes) == 0);
if (lastColumnSum >= numberExamples / 2);
tree.value = 'true';
else
tree.value = 'false';
end
return
end
%% 计算当前属性的熵
p1 = lastColumnSum / numberExamples;
if (p1 == 0);
p1_eq = 0;
else
p1_eq = -1*p1*log2(p1);
end
p0 = (numberExamples - lastColumnSum) / numberExamples;
if (p0 == 0);
p0_eq = 0;
else
p0_eq = -1*p0*log2(p0);
end
currentEntropy = p1_eq + p0_eq;
%% 寻找最大增益
gains = -1*ones(1,numberAttributes); % 初始化增益
for i=1:numberAttributes;
if (activeAttributes(i)) % 该属性仍处于活跃状态,对其更新
s0 = 0; s0_and_true = 0;
s1 = 0; s1_and_true = 0;
for j=1:numberExamples;
if (examples(j,i));
s1 = s1 + 1;
if (examples(j, numberAttributes + 1));
s1_and_true = s1_and_true + 1;
end
else
s0 = s0 + 1;
if (examples(j, numberAttributes + 1));
s0_and_true = s0_and_true + 1;
end
end
end
% 熵 S(v=1)
if (~s1);
p1 = 0;
else
p1 = (s1_and_true / s1);
end
if (p1 == 0);
p1_eq = 0;
else
p1_eq = -1*(p1)*log2(p1);
end
if (~s1);
p0 = 0;
else
p0 = ((s1 - s1_and_true) / s1);
end
if (p0 == 0);
p0_eq = 0;
else
p0_eq = -1*(p0)*log2(p0);
end
entropy_s1 = p1_eq + p0_eq;
% 熵 S(v=0)
if (~s0);
p1 = 0;
else
p1 = (s0_and_true / s0);
end
if (p1 == 0);
p1_eq = 0;
else
p1_eq = -1*(p1)*log2(p1);
end
if (~s0);
p0 = 0;
else
p0 = ((s0 - s0_and_true) / s0);
end
if (p0 == 0);
p0_eq = 0;
else
p0_eq = -1*(p0)*log2(p0);
end
entropy_s0 = p1_eq + p0_eq;
gains(i) = currentEntropy - ((s1/numberExamples)*entropy_s1) - ((s0/numberExamples)*entropy_s0);
end
end
% 选出最大增益
[~, bestAttribute] = max(gains);
% 设置相应值
tree.value = attributes{bestAttribute};
% 去活跃状态
activeAttributes(bestAttribute) = 0;
% 根据bestAttribute把数据进行分组
examples_0= examples(examples(:,bestAttribute)==0,:);
examples_1= examples(examples(:,bestAttribute)==1,:);
% 当 value = false or 0, 左分支
if (isempty(examples_0));
leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
if (lastColumnSum >= numberExamples / 2); % for matrix examples
leaf.value = 'true';
else
leaf.value = 'false';
end
tree.left = leaf;
else
% 递归
tree.left = id3(examples_0, attributes, activeAttributes);
end
% 当 value = true or 1, 右分支
if (isempty(examples_1));
leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
if (lastColumnSum >= numberExamples / 2);
leaf.value = 'true';
else
leaf.value = 'false';
end
tree.right = leaf;
else
% 递归
tree.right = id3(examples_1, attributes, activeAttributes);
end
% 返回
return
End
Print_tree.m
function [nodeids_,nodevalue_] = print_tree(tree)
%% 打印树,返回树的关系向量
global nodeid nodeids nodevalue;
nodeids(1)=0; % 根节点的值为0
nodeid=0;
nodevalue={};
if isempty(tree)
disp('空树!');
return ;
end
queue = queue_push([],tree);
while ~isempty(queue) % 队列不为空
[node,queue] = queue_pop(queue); % 出队列
visit(node,queue_curr_size(queue));
if ~strcmp(node.left,'null') % 左子树不为空
queue = queue_push(queue,node.left); % 进队
end
if ~strcmp(node.right,'null') % 左子树不为空
queue = queue_push(queue,node.right); % 进队
end
end
%% 返回 节点关系,用于treeplot画图
nodeids_=nodeids;
nodevalue_=nodevalue;
end
function visit(node,length_)
%% 访问node 节点,并把其设置值为nodeid的节点
global nodeid nodeids nodevalue;
if isleaf(node)
nodeid=nodeid+1;
fprintf('叶子节点,node: %d\t,属性值: %s\n', ...
nodeid, node.value);
nodevalue{1,nodeid}=node.value;
else % 要么是叶子节点,要么不是
%if isleaf(node.left) && ~isleaf(node.right) % 左边为叶子节点,右边不是
nodeid=nodeid+1;
nodeids(nodeid+length_+1)=nodeid;
nodeids(nodeid+length_+2)=nodeid;
fprintf('node: %d\t属性值: %s\t,左子树为节点:node%d,右子树为节点:node%d\n', ...
nodeid, node.value,nodeid+length_+1,nodeid+length_+2);
nodevalue{1,nodeid}=node.value;
end
end
function flag = isleaf(node)
%% 是否是叶子节点
if strcmp(node.left,'null') && strcmp(node.right,'null') % 左右都为空
flag =1;
else
flag=0;
end
end
function tree_plot( p ,nodevalues)
%% 参考treeplot函数
[x,y,h]=treelayout(p);
f = find(p~=0);
pp = p(f);
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];
X = X(:);
Y = Y(:);
n = length(p);
if n < 500,
hold on ;
plot (x, y, 'ro', X, Y, 'r-');
nodesize = length(x);
for i=1:nodesize
% text(x(i)+0.01,y(i),['node' num2str(i)]);
text(x(i)+0.01,y(i),nodevalues{1,i});
end
hold off;
else
plot (X, Y, 'r-');
end;
xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);
End
Queue_push.m
function [ newqueue ] = queue_push( queue,item )
%% 进队
% cols = size(queue);
% newqueue =structs(1,cols+1);
newqueue=[queue,item];
End
Queue_pop.m
function [ item,newqueue ] = queue_pop( queue )
%% 访问队列
if isempty(queue)
disp('队列为空,不能访问!');
return;
end
item = queue(1); % 第一个元素弹出
newqueue=queue(2:end); % 往后移动一个元素位置
end
queue_curr_size.m
function [ length_ ] = queue_curr_size( queue )
%% 当前队列长度
length_= length(queue);
end
用MATLAB实现结果: