数据挖掘决策树分类算法ID3的java实现

ID3的具体算法描述见上一篇博文!

本人对ID3的算法实现做了如下假设与处理:

1. 假设所有的属性值域都是分类型或名词离散型的

2.求信息增益时,log函数本来应以2为底,但是为了方便起见,直接调用了java.util.Math类中的以e为底的log函数,无论以什么为底均不会对影响结果产生影响

3.最后的输出并没有以树结构的形式给出,但是可以根据输出结果分析出决策树的结构

java实现代码如下:

package DecisionTree; import java.util.ArrayList; /** * 决策树结点类 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class TreeNode { private String name; //节点名(分裂属性的名称) private ArrayList<String> rule; //结点的分裂规则 ArrayList<TreeNode> child; //子结点集合 private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组 private ArrayList<String> candAttr; //划分到该结点的候选属性 public TreeNode() { this.name = ""; this.rule = new ArrayList<String>(); this.child = new ArrayList<TreeNode>(); this.datas = null; this.candAttr = null; } public ArrayList<TreeNode> getChild() { return child; } public void setChild(ArrayList<TreeNode> child) { this.child = child; } public ArrayList<String> getRule() { return rule; } public void setRule(ArrayList<String> rule) { this.rule = rule; } public String getName() { return name; } public void setName(String name) { this.name = name; } public ArrayList<ArrayList<String>> getDatas() { return datas; } public void setDatas(ArrayList<ArrayList<String>> datas) { this.datas = datas; } public ArrayList<String> getCandAttr() { return candAttr; } public void setCandAttr(ArrayList<String> candAttr) { this.candAttr = candAttr; } }

package DecisionTree; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import javax.smartcardio.ATR; /** * 决策树构造类 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class DecisionTree { private Integer attrSelMode; //最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2 public DecisionTree(){ this.attrSelMode = 1; } public DecisionTree(int attrSelMode) { this.attrSelMode = attrSelMode; } public void setAttrSelMode(Integer attrSelMode) { this.attrSelMode = attrSelMode; } /** * 获取指定数据集中的类别及其计数 * @param datas 指定的数据集 * @return 类别及其计数的map */ public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas){ Map<String, Integer> classes = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(tuple.size() - 1); if (classes.containsKey(c)) { classes.put(c, classes.get(c) + 1); } else { classes.put(c, 1); } } return classes; } /** * 获取具有最大计数的类名,即求多数类 * @param classes 类的键值集合 * @return 多数类的类名 */ public String maxClass(Map<String, Integer> classes){ String maxC = ""; int max = -1; Iterator iter = classes.entrySet().iterator(); for(int i = 0; iter.hasNext(); i++) { Map.Entry entry = (Map.Entry) iter.next(); String key = (String)entry.getKey(); Integer val = (Integer) entry.getValue(); if(val > max){ max = val; maxC = key; } } return maxC; } /** * 构造决策树 * @param datas 训练元组集合 * @param attrList 候选属性集合 * @return 决策树根结点 */ public TreeNode buildTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList){ // System.out.print("候选属性列表: "); // for (int i = 0; i < attrList.size(); i++) { // System.out.print(" " + attrList.get(i) + " "); // } System.out.println(); TreeNode node = new TreeNode(); node.setDatas(datas); node.setCandAttr(attrList); Map<String, Integer> classes = classOfDatas(datas); String maxC = maxClass(classes); if (classes.size() == 1 || attrList.size() == 0) { node.setName(maxC); return node; } Gain gain = new Gain(datas, attrList); int bestAttrIndex = gain.bestGainAttrIndex(); ArrayList<String> rules = gain.getValues(datas, bestAttrIndex); node.setRule(rules); node.setName(attrList.get(bestAttrIndex)); if(rules.size() > 2){ //?此处有待商榷 attrList.remove(bestAttrIndex); } for (int i = 0; i < rules.size(); i++) { String rule = rules.get(i); ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex, rule); for (int j = 0; j < di.size(); j++) { di.get(j).remove(bestAttrIndex); } if (di.size() == 0) { TreeNode leafNode = new TreeNode(); leafNode.setName(maxC); leafNode.setDatas(di); leafNode.setCandAttr(attrList); node.getChild().add(leafNode); } else { TreeNode newNode = buildTree(di, attrList); node.getChild().add(newNode); } } return node; } }

package DecisionTree; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; /** * 选择最佳分裂属性 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class Gain { private ArrayList<ArrayList<String>> D = null; //训练元组 private ArrayList<String> attrList = null; //候选属性集 public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) { this.D = datas; this.attrList = attrList; } /** * 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的) * @param attrIndex 指定的属性列的索引 * @return 值域集合 */ public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas, int attrIndex){ ArrayList<String> values = new ArrayList<String>(); String r = ""; for (int i = 0; i < datas.size(); i++) { r = datas.get(i).get(attrIndex); if (!values.contains(r)) { values.add(r); } } return values; } /** * 获取指定数据集中指定属性列索引的域值及其计数 * @param d 指定的数据集 * @param attrIndex 指定的属性列索引 * @return 类别及其计数的map */ public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas, int attrIndex){ Map<String, Integer> valueCount = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(attrIndex); if (valueCount.containsKey(c)) { valueCount.put(c, valueCount.get(c) + 1); } else { valueCount.put(c, 1); } } return valueCount; } /** * 求对datas中元组分类所需的期望信息,即datas的熵 * @param datas 训练元组 * @return datas的熵值 */ public double infoD(ArrayList<ArrayList<String>> datas){ double info = 0.000; int total = datas.size(); Map<String, Integer> classes = valueCounts(datas, attrList.size()); Iterator iter = classes.entrySet().iterator(); Integer[] counts = new Integer[classes.size()]; for(int i = 0; iter.hasNext(); i++) { Map.Entry entry = (Map.Entry) iter.next(); Integer val = (Integer) entry.getValue(); counts[i] = val; } for (int i = 0; i < counts.length; i++) { double base = DecimalCalculate.div(counts[i], total, 3); info += (-1) * base * Math.log(base); } return info; } /** * 获取指定属性列上指定值域的所有元组 * @param attrIndex 指定属性列索引 * @param value 指定属性列的值域 * @return 指定属性列上指定值域的所有元组 */ public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value){ ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>(); ArrayList<String> t = null; for (int i = 0; i < D.size(); i++) { t = D.get(i); if(t.get(attrIndex).equals(value)){ Di.add(t); } } return Di; } /** * 基于按指定属性划分对D的元组分类所需要的期望信息 * @param attrIndex 指定属性的索引 * @return 按指定属性划分的期望信息值 */ public double infoAttr(int attrIndex){ double info = 0.000; ArrayList<String> values = getValues(D, attrIndex); for (int i = 0; i < values.size(); i++) { ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values.get(i)); info += DecimalCalculate.mul(DecimalCalculate.div(dv.size(), D.size(), 3), infoD(dv)); } return info; } /** * 获取最佳分裂属性的索引 * @return 最佳分裂属性的索引 */ public int bestGainAttrIndex(){ int index = -1; double gain = 0.000; double tempGain = 0.000; for (int i = 0; i < attrList.size(); i++) { tempGain = infoD(D) - infoAttr(i); if (tempGain > gain) { gain = tempGain; index = i; } } return index; } }

package DecisionTree; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.StringTokenizer; /** * 决策树算法测试类 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @date 2011.03.15 */ public class TestDecisionTree { /** * 读取候选属性 * @return 候选属性集合 * @throws IOException */ public ArrayList<String> readCandAttr() throws IOException{ ArrayList<String> candAttr = new ArrayList<String>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); while (tokenizer.hasMoreTokens()) { candAttr.add(tokenizer.nextToken()); } } return candAttr; } /** * 读取训练元组 * @return 训练元组集合 * @throws IOException */ public ArrayList<ArrayList<String>> readData() throws IOException { ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); ArrayList<String> s = new ArrayList<String>(); while (tokenizer.hasMoreTokens()) { s.add(tokenizer.nextToken()); } datas.add(s); } return datas; } /** * 递归打印树结构 * @param root 当前待输出信息的结点 */ public void printTree(TreeNode root){ System.out.println("name:" + root.getName()); ArrayList<String> rules = root.getRule(); System.out.print("node rules: {"); for (int i = 0; i < rules.size(); i++) { System.out.print(rules.get(i) + " "); } System.out.print("}"); System.out.println(""); ArrayList<TreeNode> children = root.getChild(); int size =children.size(); if (size == 0) { System.out.println("-->leaf node!<--"); } else { System.out.println("size of children:" + children.size()); for (int i = 0; i < children.size(); i++) { System.out.print("child " + (i + 1) + " of node " + root.getName() + ": "); printTree(children.get(i)); } } } /** * 主函数,程序入口 * @param args */ public static void main(String[] args) { TestDecisionTree tdt = new TestDecisionTree(); ArrayList<String> candAttr = null; ArrayList<ArrayList<String>> datas = null; try { System.out.println("请输入候选属性"); candAttr = tdt.readCandAttr(); System.out.println("请输入训练数据"); datas = tdt.readData(); } catch (IOException e) { e.printStackTrace(); } DecisionTree tree = new DecisionTree(); TreeNode root = tree.buildTree(datas, candAttr); tdt.printTree(root); } }

测试数据:

//属性列表 age income student credit_rating //训练数据 youth high no fair no youth high no excellent no middle_aged high no fair yes senior medium no fair yes senior low yes fair yes senior low yes excellent no middle_aged low yes excellent yes youth medium no fair no youth low yes fair yes senior medium yes fair yes youth medium yes excellent yes middle_aged medium no excellent yes middle_aged high yes fair yes senior medium no excellent no

程序输出结果:

name:age node rules: {youth middle_aged senior } size of children:3 child 1 of node age: name:student node rules: {no yes } size of children:2 child 1 of node student: name:no node rules: {} -->leaf node!<-- child 2 of node student: name:yes node rules: {} -->leaf node!<-- child 2 of node age: name:yes node rules: {} -->leaf node!<-- child 3 of node age: name:credit_rating node rules: {fair excellent } size of children:2 child 1 of node credit_rating: name:yes node rules: {} -->leaf node!<-- child 2 of node credit_rating: name:no node rules: {} -->leaf node!<--

根据输出结果画出的决策树,如下图所示:

数据挖掘决策树分类算法ID3的java实现

转载请注明出处:http://blog.csdn.net/luowen3405/archive/2011/03/15/6250731.aspx,谢谢!