首页 > Java开发 > 决策树算法Java实现示例

决策树算法Java实现示例

  1. package xx;
  2. import java.util.HashMap;
  3. import java.util.LinkedList;
  4. import java.util.List;
  5. import java.util.Map;
  6. import java.util.Map.Entry;
  7. import java.util.Set;
  8. public class DicisionTree {
  9.     public static void main(String[] args) throws Exception {
  10.         String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT",
  11.                 "CREDIT_RATING" };
  12.         // 读取样本集
  13.         Map<Object, List<Sample>> samples = readSamples(attrNames);
  14.         // 生成决策树
  15.         Object decisionTree = generateDecisionTree(samples, attrNames);
  16.         // 输出决策树
  17.         outputDecisionTree(decisionTree, 0, null);
  18.     }
  19.     /**
  20.      * 读取已分类的样本集,返回Map:分类 -> 属于该分类的样本的列表
  21.      */
  22.     static Map<Object, List<Sample>> readSamples(String[] attrNames) {
  23.         // 样本属性及其所属分类(数组中的最后一个元素为样本所属分类)
  24.         Object[][] rawData = new Object[][] {
  25.                 { "<30  ", "High  ", "No ", "Fair     ", "0" },
  26.                 { "<30  ", "High  ", "No ", "Excellent", "0" },
  27.                 { "30-40", "High  ", "No ", "Fair     ", "1" },
  28.                 { ">40  ", "Medium", "No ", "Fair     ", "1" },
  29.                 { ">40  ", "Low   ", "Yes", "Fair     ", "1" },
  30.                 { ">40  ", "Low   ", "Yes", "Excellent", "0" },
  31.                 { "30-40", "Low   ", "Yes", "Excellent", "1" },
  32.                 { "<30  ", "Medium", "No ", "Fair     ", "0" },
  33.                 { "<30  ", "Low   ", "Yes", "Fair     ", "1" },
  34.                 { ">40  ", "Medium", "Yes", "Fair     ", "1" },
  35.                 { "<30  ", "Medium", "Yes", "Excellent", "1" },
  36.                 { "30-40", "Medium", "No ", "Excellent", "1" },
  37.                 { "30-40", "High  ", "Yes", "Fair     ", "1" },
  38.                 { ">40  ", "Medium", "No ", "Excellent", "0" } };
  39.         // 读取样本属性及其所属分类,构造表示样本的Sample对象,并按分类划分样本集
  40.         Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>();
  41.         for (Object[] row : rawData) {
  42.             Sample sample = new Sample();
  43.             int i = 0;
  44.             for (int n = row.length - 1; i < n; i++)
  45.                 sample.setAttribute(attrNames[i], row[i]);
  46.             sample.setCategory(row[i]);
  47.             List<Sample> samples = ret.get(row[i]);
  48.             if (samples == null) {
  49.                 samples = new LinkedList<Sample>();
  50.                 ret.put(row[i], samples);
  51.             }
  52.             samples.add(sample);
  53.         }
  54.         return ret;
  55.     }
  56.     /**
  57.      * 构造决策树
  58.      */
  59.     static Object generateDecisionTree(
  60.             Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {
  61.         // 如果只有一个样本,将该样本所属分类作为新样本的分类
  62.         if (categoryToSamples.size() == 1)
  63.             return categoryToSamples.keySet().iterator().next();
  64.         // 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出分类
  65.         if (attrNames.length == 0) {
  66.             int max = 0;
  67.             Object maxCategory = null;
  68.             for (Entry<Object, List<Sample>> entry : categoryToSamples
  69.                     .entrySet()) {
  70.                 int cur = entry.getValue().size();
  71.                 if (cur > max) {
  72.                     max = cur;
  73.                     maxCategory = entry.getKey();
  74.                 }
  75.             }
  76.             return maxCategory;
  77.         }
  78.         // 选取测试属性
  79.         Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);
  80.         // 决策树根结点,分支属性为选取的测试属性
  81.         Tree tree = new Tree(attrNames[(Integer) rst[0]]);
  82.         // 已用过的测试属性不应再次被选为测试属性
  83.         String[] subA = new String[attrNames.length - 1];
  84.         for (int i = 0, j = 0; i < attrNames.length; i++)
  85.             if (i != (Integer) rst[0])
  86.                 subA[j++] = attrNames[i];
  87.         // 根据分支属性生成分支
  88.         @SuppressWarnings("unchecked")
  89.         Map<Object, Map<Object, List<Sample>>> splits =
  90.         /* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2];
  91.         for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) {
  92.             Object attrValue = entry.getKey();
  93.             Map<Object, List<Sample>> split = entry.getValue();
  94.             Object child = generateDecisionTree(split, subA);
  95.             tree.setChild(attrValue, child);
  96.         }
  97.         return tree;
  98.     }
  99.     /**
  100.      * 选取最优测试属性。最优是指如果根据选取的测试属性分支,则从各分支确定新样本
  101.      * 的分类需要的信息量之和最小,这等价于确定新样本的测试属性获得的信息增益最大
  102.      * 返回数组:选取的属性下标、信息量之和、Map(属性值->(分类->样本列表))
  103.      */
  104.     static Object[] chooseBestTestAttribute(
  105.             Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {
  106.         int minIndex = -1; // 最优属性下标
  107.         double minValue = Double.MAX_VALUE; // 最小信息量
  108.         Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最优分支方案
  109.         // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优
  110.         for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {
  111.             int allCount = 0; // 统计样本总数的计数器
  112.             // 按当前属性构建Map:属性值->(分类->样本列表)
  113.             Map<Object, Map<Object, List<Sample>>> curSplits =
  114.             /* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>();
  115.             for (Entry<Object, List<Sample>> entry : categoryToSamples
  116.                     .entrySet()) {
  117.                 Object category = entry.getKey();
  118.                 List<Sample> samples = entry.getValue();
  119.                 for (Sample sample : samples) {
  120.                     Object attrValue = sample
  121.                             .getAttribute(attrNames[attrIndex]);
  122.                     Map<Object, List<Sample>> split = curSplits.get(attrValue);
  123.                     if (split == null) {
  124.                         split = new HashMap<Object, List<Sample>>();
  125.                         curSplits.put(attrValue, split);
  126.                     }
  127.                     List<Sample> splitSamples = split.get(category);
  128.                     if (splitSamples == null) {
  129.                         splitSamples = new LinkedList<Sample>();
  130.                         split.put(category, splitSamples);
  131.                     }
  132.                     splitSamples.add(sample);
  133.                 }
  134.                 allCount += samples.size();
  135.             }
  136.             // 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和
  137.             double curValue = 0.0; // 计数器:累加各分支
  138.             for (Map<Object, List<Sample>> splits : curSplits.values()) {
  139.                 double perSplitCount = 0;
  140.                 for (List<Sample> list : splits.values())
  141.                     perSplitCount += list.size(); // 累计当前分支样本数
  142.                 double perSplitValue = 0.0; // 计数器:当前分支
  143.                 for (List<Sample> list : splits.values()) {
  144.                     double p = list.size() / perSplitCount;
  145.                     perSplitValue -= p * (Math.log(p) / Math.log(2));
  146.                 }
  147.                 curValue += (perSplitCount / allCount) * perSplitValue;
  148.             }
  149.             // 选取最小为最优
  150.             if (minValue > curValue) {
  151.                 minIndex = attrIndex;
  152.                 minValue = curValue;
  153.                 minSplits = curSplits;
  154.             }
  155.         }
  156.         return new Object[] { minIndex, minValue, minSplits };
  157.     }
  158.     /**
  159.      * 将决策树输出到标准输出
  160.      */
  161.     static void outputDecisionTree(Object obj, int level, Object from) {
  162.         for (int i = 0; i < level; i++)
  163.             System.out.print("|-----");
  164.         if (from != null)
  165.             System.out.printf("(%s):", from);
  166.         if (obj instanceof Tree) {
  167.             Tree tree = (Tree) obj;
  168.             String attrName = tree.getAttribute();
  169.             System.out.printf("[%s = ?]\n", attrName);
  170.             for (Object attrValue : tree.getAttributeValues()) {
  171.                 Object child = tree.getChild(attrValue);
  172.                 outputDecisionTree(child, level + 1, attrName + " = "
  173.                         + attrValue);
  174.             }
  175.         } else {
  176.             System.out.printf("[CATEGORY = %s]\n", obj);
  177.         }
  178.     }
  179.     /**
  180.      * 样本,包含多个属性和一个指明样本所属分类的分类值
  181.      */
  182.     static class Sample {
  183.         private Map<String, Object> attributes = new HashMap<String, Object>();
  184.         private Object category;
  185.         public Object getAttribute(String name) {
  186.             return attributes.get(name);
  187.         }
  188.         public void setAttribute(String name, Object value) {
  189.             attributes.put(name, value);
  190.         }
  191.         public Object getCategory() {
  192.             return category;
  193.         }
  194.         public void setCategory(Object category) {
  195.             this.category = category;
  196.         }
  197.         public String toString() {
  198.             return attributes.toString();
  199.         }
  200.     }
  201.     /**
  202.      * 决策树(非叶结点),决策树中的每个非叶结点都引导了一棵决策树
  203.      * 每个非叶结点包含一个分支属性和多个分支,分支属性的每个值对应一个分支,该分支引导了一棵子决策树
  204.      */
  205.     static class Tree {
  206.         private String attribute;
  207.         private Map<Object, Object> children = new HashMap<Object, Object>();
  208.         public Tree(String attribute) {
  209.             this.attribute = attribute;
  210.         }
  211.         public String getAttribute() {
  212.             return attribute;
  213.         }
  214.         public Object getChild(Object attrValue) {
  215.             return children.get(attrValue);
  216.         }
  217.         public void setChild(Object attrValue, Object child) {
  218.             children.put(attrValue, child);
  219.         }
  220.         public Set<Object> getAttributeValues() {
  221.             return children.keySet();
  222.         }
  223.     }
  224. }

本文固定链接: http://www.devba.com/index.php/archives/3654.html | 开发吧

报歉!评论已关闭.