2008-04-01
决策树ID3算法
package graph;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
/**
* 决策树的ID3算法
* 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml
* @author Leon.Chen
*/
public class DTree {
/**
* root
*/
TreeNode root;
/**
* 可见性数组
*/
private static boolean[] visable;
private Object[] array;
private int index;
/**
* @param args
*/
@SuppressWarnings("boxing")
public static void main(String[] args) {
//初始数据
Object[] array = new Object[] {
new String[]{ "Sunny" ,"Hot" ,"High" ,"Weak" ,"No" },
new String[]{ "Sunny" ,"Hot" ,"High" ,"Strong" ,"No" },
new String[]{ "Overcast" ,"Hot" ,"High" ,"Weak" ,"Yes"},
new String[]{ "Rain" ,"Mild" ,"High" ,"Weak" ,"Yes"},
new String[]{ "Rain" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
new String[]{ "Rain" ,"Cool" ,"Normal" ,"Strong" ,"No" },
new String[]{ "Overcast" ,"Cool" ,"Normal" ,"Strong" ,"Yes"},
new String[]{ "Sunny" ,"Mild" ,"High" ,"Weak" ,"No" },
new String[]{ "Sunny" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
new String[]{ "Rain" ,"Mild" ,"Normal" ,"Weak" ,"Yes"},
new String[]{ "Sunny" ,"Mild" ,"Normal" ,"Strong" ,"Yes"},
new String[]{ "Overcast" ,"Mild" ,"High" ,"Strong" ,"Yes"},
new String[]{ "Overcast" ,"Hot" ,"Normal" ,"Weak" ,"Yes"},
new String[]{ "Rain" ,"Mild" ,"High" ,"Strong" ,"No" },
};
DTree tree = new DTree();
tree.create(array,4);
}
public void create(Object[] array,int index){
this.array = array;
init(array,index);
createDTree(array);
printDTree(root);
}
public Object[] getMaxGain(Object[] array){
Object[] result = new Object[2];
double gain = 0;
int index = 0;
for(int i=0;i<visable.length;i++){
if(!visable[i]){
double value = gain(array,i);
if(gain < value){
gain = value;
index = i;
}
}
}
result[0] = gain;
result[1] = index;
visable[index] = true;
return result;
}
public void createDTree(Object[] array) {
Object[] maxgain = getMaxGain(array);
if (root == null) {
root = new TreeNode();
root.parent = null;
root.parentArrtibute = null;
root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());
root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());
root.childNodes = new TreeNode[root.arrtibutes.length];
insertTree(array,root);
}
}
public void insertTree(Object[] array,TreeNode parentNode){
String[] arrtibutes = parentNode.arrtibutes;
for(int i=0;i<arrtibutes.length;i++){
Object[] pickArray = pickUpAndCreateArray(array,arrtibutes[i],getNodeIndex(parentNode.nodeName));
Object[] info = getMaxGain(pickArray);
double gain = ((Double)info[0]).doubleValue();
if(gain != 0){
int index = ((Integer) info[1]).intValue();
System.out.println("gain = "+gain+" ,node name = "+getNodeName(index));
TreeNode currentNode = new TreeNode();
currentNode.parent = parentNode;
currentNode.parentArrtibute = arrtibutes[i];
currentNode.arrtibutes = getArrtibutes(index);
currentNode.nodeName = getNodeName(index);
currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
parentNode.childNodes[i] = currentNode;
insertTree(pickArray,currentNode);
}else {
TreeNode leafNode = new TreeNode();
leafNode.parent = parentNode;
leafNode.parentArrtibute = arrtibutes[i];
leafNode.arrtibutes = new String[0];
leafNode.nodeName = getLeafNodeName(pickArray);
leafNode.childNodes = new TreeNode[0];
parentNode.childNodes[i] = leafNode;
}
}
}
public void printDTree(TreeNode node){
System.out.println(node.nodeName);
TreeNode[] childs = node.childNodes;
for(int i=0;i<childs.length;i++){
if(childs[i]!=null){
System.out.println(childs[i].parentArrtibute);
printDTree(childs[i]);
}
}
}
/**
* @param dataArray 原始数组 D
* @param criterion 标准值
* @return double
*/
public void init(Object[] dataArray,int index) {
this.index = index;
//数据初始化
visable = new boolean[((String[])dataArray[0]).length];
for(int i=0;i<visable.length;i++) {
if(i == index){
visable[i] = true;
}else {
visable[i] = false;
}
}
}
public Object[] pickUpAndCreateArray(Object[] array,String arrtibute,int index){
List<String[]> list = new ArrayList<String[]>();
for(int i=0;i<array.length;i++){
String[] strs = (String[])array[i];
if(strs[index].equals(arrtibute)){
list.add(strs);
}
}
return list.toArray();
}
/**
* Entropy(S)
* @param array
* @return double
*/
public double gain(Object[] array,int index) {
String[] playBalls = getArrtibutes(this.index);
int[] counts = new int[playBalls.length];
for(int i=0;i<counts.length;i++) {
counts[i] = 0;
}
for(int i=0;i<array.length;i++) {
String[] strs = (String[])array[i];
for(int j=0;j<playBalls.length;j++) {
if(strs[this.index].equals(playBalls[j])) {
counts[j]++;
}
}
}
/**
* Entropy(S) = S -p(I) log2 p(I)
*/
double entropyS = 0;
for(int i=0;i<counts.length;i++) {
entropyS += DTreeUtil.sigma(counts[i],array.length);
}
String[] arrtibutes = getArrtibutes(index);
/**
* total ((|Sv| / |S|) * Entropy(Sv))
*/
double sv_total = 0;
for(int i=0;i<arrtibutes.length;i++){
sv_total += entropySv(array, index,arrtibutes[i],array.length);
}
return entropyS-sv_total;
}
/**
* ((|Sv| / |S|) * Entropy(Sv))
* @param array
* @param index
* @param arrtibute
* @param allTotal
* @return
*/
public double entropySv(Object[] array,int index,String arrtibute,int allTotal) {
String[] playBalls = getArrtibutes(this.index);
int[] counts = new int[playBalls.length];
for(int i=0;i<counts.length;i++) {
counts[i] = 0;
}
for (int i = 0; i < array.length; i++) {
String[] strs = (String[]) array[i];
if (strs[index].equals(arrtibute)) {
for (int k = 0; k < playBalls.length; k++) {
if (strs[this.index].equals(playBalls[k])) {
counts[k]++;
}
}
}
}
int total = 0;
double entropySv = 0;
for(int i=0;i<counts.length;i++){
total += counts[i];
}
for(int i=0;i<counts.length;i++){
entropySv += DTreeUtil.sigma(counts[i],total);
}
return DTreeUtil.getPi(total, allTotal)*entropySv;
}
@SuppressWarnings("unchecked")
public String[] getArrtibutes(int index) {
TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
for (int i = 0; i < array.length; i++) {
String[] strs = (String[]) array[i];
set.add(strs[index]);
}
String[] result = new String[set.size()];
return set.toArray(result);
}
public String getNodeName(int index) {
String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
for(int i=0;i<strs.length;i++){
if(i == index){
return strs[i];
}
}
return null;
}
public String getLeafNodeName(Object[] array){
if(array!=null && array.length>0){
String[] strs = (String[])array[0];
return strs[index];
}
return null;
}
public int getNodeIndex(String name) {
String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
for(int i=0;i<strs.length;i++){
if(name.equals(strs[i])){
return i;
}
}
return -1;
}
}
package graph;
/**
* @author B.Chen
*/
public class TreeNode {
/**
* 父
*/
TreeNode parent;
/**
* 指向父的哪个属性
*/
String parentArrtibute;
/**
* 节点名
*/
String nodeName;
/**
* 属性数组
*/
String[] arrtibutes;
/**
* 节点数组
*/
TreeNode[] childNodes;
}
package graph;
public class DTreeUtil {
/**
* 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi
*
* @param x
* @param total
* @return double
*/
public static double sigma(int x, int total) {
if(x == 0){
return 0;
}
double x_pi = getPi(x, total);
return -(x_pi * logYBase2(x_pi));
}
/**
* log2y
*
* @param y
* @return double
*/
public static double logYBase2(double y) {
return Math.log(y) / Math.log(2);
}
/**
* pi是当前这个属性出现的概率(=出现次数/总数)
*
* @param x
* @param total
* @return double
*/
public static double getPi(int x, int total) {
return x * Double.parseDouble("1.0") / total;
}
}
package graph;
import java.util.Comparator;
public class SequenceComparator implements Comparator {
public int compare(Object o1, Object o2) throws ClassCastException {
String str1 = (String) o1;
String str2 = (String) o2;
return str1.compareTo(str2);
}
}
发表评论
- 浏览: 2062 次
- 性别:

- 来自: 那美克星

- 详细资料
搜索本博客
最近加入圈子
最新评论
-
红黑树初版
引用原来是树???。小看了一下。。懂了。是啊,是树~我家门前有好几棵大的!
-- by pf_miles -
红黑树初版
原来是树???。 小看了一下。。懂了。
-- by Wallian_hua -
红黑树初版
没懂。。 这是什么
-- by Wallian_hua -
四则运算的中缀转后缀,逆 ...
龙书一本.
-- by dogstar -
四则运算的中缀转后缀,逆 ...
引用我想说,如果你看了 编译原理,会有更大的收获。 努力! 最近也想往编译原理 ...
-- by leon_a






评论排行榜