决策树生成算法与ID3算法java实现
算法过程
数据
最终需要分类的属性为“电脑”,它有2个不同值0和1,1有4个样本,0有2个样本。
为计算每个属性的信息增益,我们首先给定样本电脑分类所需的期望信息:
I(4,2)=-4/6log2(4/6)-2/6log2(2/6)=0.918 |
---|
从“性别”属性开始。 “性别”=1,有3个“电脑”=1,2个“电脑”=0; “性别”=0,有1个“电脑”=1,没有“电脑”=0。
i= -3/5log2(3/5)-1/5log2(1/5)-1log2(1)=0.971 |
---|
按“性别”划分,则的熵为
e=5/6(-3/5log2(3/5)-1/5log2(1/5))+1/6(-1log2(1))=0.809 |
---|
信息增益是
Gain(性别)=i-e=0.109 |
---|
同理
Gain(学生)=0.459;
Gain(民族)=0.316;
决策树生成过程
在集合中找到信息增益最大的
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=0, student=0, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}
Gain(性别)=0.109
Gain(学生)=0.459;
Gain(民族)=0.316;
选择学生分类
学生(“1”)
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
各个样本均相同则熵为0 分类结束
学生(“0”)
{computer=1, gender=0, student=0, nation=0}
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}
再次计算信息增益
Gain(性别)=0.9182958340544896
Gain(民族)=0.2516291673878229
选择性别分类
性别(“0”)
{computer=1, gender=0, student=0, nation=0}
熵为0 分类结束
性别(“1”)
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}
熵为0 分类结束
决策树生成完毕
输出结果
[root<-(student:0)<-(gender:0)]:[{computer=1, gender=0, student=0, nation=0}]
[root<-(student:0)<-(gender:1)]:[{computer=0, gender=1, student=0, nation=0},{computer=0, gender=1, student=0, nation=1}]
[root<-(student:1)]:[{computer=1, gender=1, student=1, nation=0}, {computer=1, gender=1, student=1, nation=0}, {computer=1, gender=1, student=1, nation=0}
图示
代码
package decisiontree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
public class ID3 {
public String nameProperty;
public Set<String> nameSet;
public Set<String> setProperty;
ArrayList<SampleInter> data=new ArrayList<SampleInter> () ;
public static void main(String[] args) {
ID3 id3=new ID3();
}
private void createData() {
data.add(new Sample("1","1","0","1"));
data.add(new Sample("0","0","0","1"));
data.add(new Sample("1","1","0","1"));
data.add(new Sample("1","1","0","1"));
data.add(new Sample("1","0","0","0"));
data.add(new Sample("1","0","1","0"));
}
public ID3() {
createData();
data.forEach(System.out::println);
createProperty("computer");
System.out.println(nameProperty+" "+setProperty);
nameSet=data.get(0).getKeys();
System.out.println("nameSet"+nameSet);
System.out.println(getGain(data,"gender"));
System.out.println(getGain(data,"student"));
System.out.println(getGain(data,"nation"));
Set<String> names =new HashSet<String>(nameSet);
names.remove(nameProperty);
tree(new ArrayList<SampleInter>(data),names,"root");
}
public void tree(ArrayList<SampleInter> data,Set<String> names,String root) {
if(data.size()==1) {
System.out.println("["+root+"]:"+data);
return;
}
int count=-1;
String str=data.get(0).toString();
for(SampleInter sample : data) {
if(str.equals(sample.toString())) {
count++;
}
}
if(count==data.size()) {
System.out.println("["+root+"]:"+data);
return;
}
String maxName="";
double maxGain=0;
for( String name:names) {
double tmp = getGain(data,name);
if(maxGain<tmp) {
maxGain=tmp;
maxName=name;
}
}
if(maxGain<0.001||maxGain>0.999) {
System.out.println("["+root+"]:"+data);
return;
}
names.remove(maxName);
for( String att:getSet(data,maxName)){
ArrayList<SampleInter> newdata =new ArrayList<SampleInter>();
for(SampleInter sample:data) {
if(sample.getValue(maxName).equals(att)) {
newdata.add(sample);
}
}
tree(newdata,new HashSet<String>(names),root+"<-("+maxName+":"+att+")");
}
}
public double getGain(ArrayList<SampleInter> data,String name) {
double IS=getI(data,"computer");
int size=data.size();
double ix=0;
double ex=0;
for( String str:getSet(data,name)) {
long count = getPropertyCount(data,name,str);
if(count==0)
continue;
double px=0;
for(String value :setProperty) {
long yes = getPropertyCount(data,name,str,nameProperty,value);
if(yes==0)
continue;
px=1.0*yes/count;
ix-=px*log2(px);
ex-=px*log2(px)*count/size;
}
}
return (IS-ex);
}
public double getI(ArrayList<SampleInter> data ,String name) {
int size=data.size();
double ix=0;
for( String str:getSet(data,name)) {
long count = getPropertyCount(data,name,str);
double px=1.0*count/size;
ix-=px*log2(px);
}
return ix;
}
public long getPropertyCount(ArrayList<SampleInter> data,String name,String attribute) {
return data.stream().filter(p->p.getValue(name).equals(attribute)).count();
}
public long getPropertyCount(ArrayList<SampleInter> data,String name1,String attribute1,String name2,String attribute2) {
return data.stream().filter(p->p.getValue(name1).equals(attribute1)
&&p.getValue(name2).equals(attribute2)).count();
}
public static double log2(double x) {
return Math.log(x)/Math.log(2);
}
private void createProperty(String string) {
nameProperty=string;
setProperty=getSet(data,string);
}
public Set<String> getSet(ArrayList<SampleInter> data,String string) {
return data.stream().map( m->m.getValue(string)).collect(Collectors.toSet());
}
}
interface SampleInter {
public String getValue(String string);
public boolean containsValue(String value);
public Set<String> getKeys();
}
class Sample implements SampleInter{
Map<String,String> map =new HashMap<String,String>();
public Sample(String gender,String student,String nation,String computer ) {
map.put("gender",gender);
map.put("student",student);
map.put("nation",nation);
map.put("computer",computer);
}
public Set<String> getKeys() {
return map.keySet();
}
public String getValue(String string) {
return map.get(string);
}
public boolean containsValue(String value) {
return map.containsValue(value);
}
public String toString() {
return map.toString();
}
}