2008-04-01

决策树ID3算法

package graph;      
     
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;      
     
/**      
 * 决策树的ID3算法      
 * 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml    
 * @author Leon.Chen      
 */     
public class DTree {      
          
    /**    
     * root    
     */     
    TreeNode root;      
          
    /**    
     * 可见性数组    
     */     
    private static boolean[] visable;      
          
    private Object[] array;  
    
    private int index;
     
    /**    
     * @param args    
     */     
    @SuppressWarnings("boxing")      
    public static void main(String[] args) {      
        //初始数据      
        Object[] array = new Object[] {       
                        new String[]{ "Sunny"    ,"Hot"   ,"High"    ,"Weak"    ,"No" },      
                        new String[]{ "Sunny"    ,"Hot"   ,"High"    ,"Strong"  ,"No" },      
                        new String[]{ "Overcast" ,"Hot"   ,"High"    ,"Weak"    ,"Yes"},      
                        new String[]{ "Rain"     ,"Mild"  ,"High"    ,"Weak"    ,"Yes"},      
                        new String[]{ "Rain"     ,"Cool"  ,"Normal"  ,"Weak"    ,"Yes"},      
                        new String[]{ "Rain"     ,"Cool"  ,"Normal"  ,"Strong"  ,"No" },      
                        new String[]{ "Overcast" ,"Cool"  ,"Normal"  ,"Strong"  ,"Yes"},      
                        new String[]{ "Sunny"    ,"Mild"  ,"High"    ,"Weak"    ,"No" },      
                        new String[]{ "Sunny"    ,"Cool"  ,"Normal"  ,"Weak"    ,"Yes"},      
                        new String[]{ "Rain"     ,"Mild"  ,"Normal"  ,"Weak"    ,"Yes"},      
                        new String[]{ "Sunny"    ,"Mild"  ,"Normal"  ,"Strong"  ,"Yes"},      
                        new String[]{ "Overcast" ,"Mild"  ,"High"    ,"Strong"  ,"Yes"},      
                        new String[]{ "Overcast" ,"Hot"   ,"Normal"  ,"Weak"    ,"Yes"},      
                        new String[]{ "Rain"     ,"Mild"  ,"High"    ,"Strong"  ,"No" },      
                        };      
              
        DTree tree = new DTree();       
        tree.create(array,4);
    } 
    
    public void create(Object[] array,int index){
    	this.array = array;
    	init(array,index);
    	createDTree(array);
    	printDTree(root);
    }
       
    public Object[] getMaxGain(Object[] array){   
        Object[] result = new Object[2];   
        double gain = 0;   
        int index = 0;
        
        for(int i=0;i<visable.length;i++){   
            if(!visable[i]){   
                double value = gain(array,i);   
                if(gain < value){   
                    gain = value;   
                    index = i;   
                }   
            }   
        }   
        result[0] = gain;   
        result[1] = index;   
        visable[index] = true;   
        return result;   
    }   
       
    public void createDTree(Object[] array) {   
        Object[] maxgain = getMaxGain(array);   
        if (root == null) {   
            root = new TreeNode();   
            root.parent = null;   
            root.parentArrtibute = null;   
            root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());   
            root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());  
            root.childNodes = new TreeNode[root.arrtibutes.length];
            insertTree(array,root);
        }
    }   
       
    public void insertTree(Object[] array,TreeNode parentNode){
    	String[] arrtibutes = parentNode.arrtibutes;
    	for(int i=0;i<arrtibutes.length;i++){
    		Object[] pickArray = pickUpAndCreateArray(array,arrtibutes[i],getNodeIndex(parentNode.nodeName));
    		Object[] info = getMaxGain(pickArray);
    		double gain = ((Double)info[0]).doubleValue();
    		if(gain != 0){
    			int index = ((Integer) info[1]).intValue();
    			System.out.println("gain = "+gain+" ,node name = "+getNodeName(index));
        		TreeNode currentNode = new TreeNode();
        		currentNode.parent = parentNode;
        		currentNode.parentArrtibute = arrtibutes[i];
        		currentNode.arrtibutes = getArrtibutes(index);
        		currentNode.nodeName = getNodeName(index);
        		currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
        		parentNode.childNodes[i] = currentNode;
        		insertTree(pickArray,currentNode);
    		}else {
    			TreeNode leafNode = new TreeNode();
    			leafNode.parent = parentNode;
    			leafNode.parentArrtibute = arrtibutes[i];
    			leafNode.arrtibutes = new String[0];
    			leafNode.nodeName = getLeafNodeName(pickArray);
    			leafNode.childNodes = new TreeNode[0];
    			parentNode.childNodes[i] = leafNode;
    		}
    	}

    }
    
    public void printDTree(TreeNode node){
    	System.out.println(node.nodeName);

    	TreeNode[] childs = node.childNodes;
    	for(int i=0;i<childs.length;i++){
    		if(childs[i]!=null){
    			System.out.println(childs[i].parentArrtibute);
    			printDTree(childs[i]);
    		}
    	}
    }
  
    /**      
     * @param dataArray 原始数组 D     
     * @param criterion 标准值      
     * @return double      
     */     
    public void init(Object[] dataArray,int index) {
    	this.index = index;
        //数据初始化   
        visable = new boolean[((String[])dataArray[0]).length];      
        for(int i=0;i<visable.length;i++) {    
            if(i == index){   
                visable[i] = true;    
            }else {   
                visable[i] = false;    
            }   
        }   
    }   
    
    public Object[] pickUpAndCreateArray(Object[] array,String arrtibute,int index){
    	List<String[]> list = new ArrayList<String[]>();
    	for(int i=0;i<array.length;i++){
    		String[] strs = (String[])array[i];
    		if(strs[index].equals(arrtibute)){
    			list.add(strs);
    		}
    	}
    	return list.toArray();
    }
  
    /**    
     * Entropy(S)    
     * @param array    
     * @return double     
     */     
    public double gain(Object[] array,int index) {      
        String[] playBalls = getArrtibutes(this.index);   
        int[] counts = new int[playBalls.length];      
        for(int i=0;i<counts.length;i++) {   
            counts[i] = 0;      
        }   
        for(int i=0;i<array.length;i++) {   
            String[] strs = (String[])array[i];   
            for(int j=0;j<playBalls.length;j++) {   
                if(strs[this.index].equals(playBalls[j])) {   
                    counts[j]++;   
                }   
            }   
        }   
        /**  
         * Entropy(S) = S -p(I) log2 p(I)  
         */  
        double entropyS = 0;   
        for(int i=0;i<counts.length;i++) {      
            entropyS += DTreeUtil.sigma(counts[i],array.length);      
        }   
        String[] arrtibutes = getArrtibutes(index);   
        /**  
         * total ((|Sv| / |S|) * Entropy(Sv))   
         */  
        double sv_total = 0;   
        for(int i=0;i<arrtibutes.length;i++){   
            sv_total += entropySv(array, index,arrtibutes[i],array.length);   
        }   
        return entropyS-sv_total;   
    }   
       
    /**  
     * ((|Sv| / |S|) * Entropy(Sv))  
     * @param array  
     * @param index  
     * @param arrtibute  
     * @param allTotal  
     * @return  
     */  
    public double entropySv(Object[] array,int index,String arrtibute,int allTotal) {   
        String[] playBalls = getArrtibutes(this.index);   
        int[] counts = new int[playBalls.length];   
        for(int i=0;i<counts.length;i++) {   
            counts[i] = 0;      
        }   
  
        for (int i = 0; i < array.length; i++) {   
            String[] strs = (String[]) array[i];   
            if (strs[index].equals(arrtibute)) {   
                for (int k = 0; k < playBalls.length; k++) {   
                    if (strs[this.index].equals(playBalls[k])) {   
                        counts[k]++;   
                    }   
                }   
            }   
        }   
  
        int total = 0;   
        double entropySv = 0;    
        for(int i=0;i<counts.length;i++){   
            total += counts[i];   
        }   
        for(int i=0;i<counts.length;i++){   
            entropySv += DTreeUtil.sigma(counts[i],total);    
        }    
        return DTreeUtil.getPi(total, allTotal)*entropySv;   
    }   
          
    @SuppressWarnings("unchecked")   
    public String[] getArrtibutes(int index) {      
        TreeSet<String> set = new TreeSet<String>(new SequenceComparator());      
        for (int i = 0; i < array.length; i++) {      
            String[] strs = (String[]) array[i];      
            set.add(strs[index]);      
        }      
        String[] result = new String[set.size()];      
        return set.toArray(result);      
    }   
          
    public String getNodeName(int index) {    
    	String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
    	for(int i=0;i<strs.length;i++){
    		if(i == index){
    			return strs[i];
    		}
    	}
    	return null;    
    }
    
    public String getLeafNodeName(Object[] array){
    	if(array!=null && array.length>0){
    		String[] strs = (String[])array[0];
    		return strs[index];
    	}
    	return null;		
    }
    
    public int getNodeIndex(String name) {  
    	String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
    	for(int i=0;i<strs.length;i++){
    		if(name.equals(strs[i])){
    			return i;
    		}
    	}
        return -1;    
    } 
}      

package graph;      
     
/**    
 * @author B.Chen    
 */     
public class TreeNode {      
     
    /**    
     * 父    
     */     
    TreeNode parent;    
     
    /**    
     * 指向父的哪个属性    
     */     
    String parentArrtibute;      
     
    /**    
     * 节点名    
     */     
    String nodeName;      
     
    /**    
     * 属性数组    
     */     
    String[] arrtibutes;    
    
    /**
     * 节点数组
     */
    TreeNode[] childNodes;

}      

package graph;   
  
public class DTreeUtil {   
  
    /**  
     * 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi  
     *   
     * @param x  
     * @param total  
     * @return double  
     */  
    public static double sigma(int x, int total) {
    	if(x == 0){
    		return 0;
    	}
        double x_pi = getPi(x, total);   
        return -(x_pi * logYBase2(x_pi));   
    }   
  
    /**  
     * log2y  
     *   
     * @param y  
     * @return double  
     */  
    public static double logYBase2(double y) {   
        return Math.log(y) / Math.log(2);   
    }   
  
    /**  
     * pi是当前这个属性出现的概率(=出现次数/总数)  
     *   
     * @param x  
     * @param total  
     * @return double  
     */  
    public static double getPi(int x, int total) {   
        return x * Double.parseDouble("1.0") / total;   
    }  

}   

package graph;

import java.util.Comparator;

public class SequenceComparator implements Comparator {

    public int compare(Object o1, Object o2) throws ClassCastException {
        String str1 = (String) o1;
        String str2 = (String) o2;
        return str1.compareTo(str2);
    }

}

评论
发表评论

您还没有登录,请登录后发表评论

leon_a
  • 浏览: 2062 次
  • 性别: Icon_minigender_1
  • 来自: 那美克星
  • 详细资料
搜索本博客
博客分类
存档
最新评论