Logo Search packages:      
Sourcecode: weka version File versions

SimpleCart.java

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * SimpleCart.java
 * Copyright (C) 2007 Haijian Shi
 *
 */

package weka.classifiers.trees;

import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.matrix.Matrix;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
 <!-- globalinfo-start -->
 * Class implementing minimal cost-complexity pruning.<br/>
 * Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/>
 * <br/>
 * For more information, see:<br/>
 * <br/>
 * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
 * <p/>
 <!-- globalinfo-end -->      
 *
 <!-- technical-bibtex-start -->
 * BibTeX:
 * <pre>
 * &#64;book{Breiman1984,
 *    address = {Belmont, California},
 *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
 *    publisher = {Wadsworth International Group},
 *    title = {Classification and Regression Trees},
 *    year = {1984}
 * }
 * </pre>
 * <p/>
 <!-- technical-bibtex-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 * 
 * <pre> -S &lt;num&gt;
 *  Random number seed.
 *  (default 1)</pre>
 * 
 * <pre> -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console</pre>
 * 
 * <pre> -M &lt;min no&gt;
 *  The minimal number of instances at the terminal nodes.
 *  (default 2)</pre>
 * 
 * <pre> -N &lt;num folds&gt;
 *  The number of folds used in the minimal cost-complexity pruning.
 *  (default 5)</pre>
 * 
 * <pre> -U
 *  Don't use the minimal cost-complexity pruning.
 *  (default yes).</pre>
 * 
 * <pre> -H
 *  Don't use the heuristic method for binary split.
 *  (default true).</pre>
 * 
 * <pre> -A
 *  Use 1 SE rule to make pruning decision.
 *  (default no).</pre>
 * 
 * <pre> -C
 *  Percentage of training data size (0-1].
 *  (default 1).</pre>
 * 
 <!-- options-end -->
 *
 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
 * @version $Revision: 1.4 $
 */
00112 public class SimpleCart
  extends RandomizableClassifier
  implements AdditionalMeasureProducer, TechnicalInformationHandler {

  /** For serialization.       */
00117   private static final long serialVersionUID = 4154189200352566053L;

  /** Training data.  */
00120   protected Instances m_train;

  /** Successor nodes. */
00123   protected SimpleCart[] m_Successors;

  /** Attribute used to split data. */
00126   protected Attribute m_Attribute;

  /** Split point for a numeric attribute. */
00129   protected double m_SplitValue;

  /** Split subset used to split data for nominal attributes. */
00132   protected String m_SplitString;

  /** Class value if the node is leaf. */
00135   protected double m_ClassValue;

  /** Class attriubte of data. */
00138   protected Attribute m_ClassAttribute;

  /** Minimum number of instances in at the terminal nodes. */
00141   protected double m_minNumObj = 2;

  /** Number of folds for minimal cost-complexity pruning. */
00144   protected int m_numFoldsPruning = 5;

  /** Alpha-value (for pruning) at the node. */
00147   protected double m_Alpha;

  /** Number of training examples misclassified by the model (subtree rooted). */
00150   protected double m_numIncorrectModel;

  /** Number of training examples misclassified by the model (subtree not rooted). */
00153   protected double m_numIncorrectTree;

  /** Indicate if the node is a leaf node. */
00156   protected boolean m_isLeaf;

  /** If use minimal cost-compexity pruning. */
00159   protected boolean m_Prune = true;

  /** Total number of instances used to build the classifier. */
00162   protected int m_totalTrainInstances;

  /** Proportion for each branch. */
00165   protected double[] m_Props;

  /** Class probabilities. */
00168   protected double[] m_ClassProbs = null;

  /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
00171   protected double[] m_Distribution;

  /** If use huristic search for nominal attributes in multi-class problems (default true). */
00174   protected boolean m_Heuristic = true;

  /** If use the 1SE rule to make final decision tree. */
00177   protected boolean m_UseOneSE = false;

  /** Training data size. */
00180   protected double m_SizePer = 1;

  /**
   * Return a description suitable for displaying in the explorer/experimenter.
   * 
   * @return            a description suitable for displaying in the 
   *              explorer/experimenter
   */
00188   public String globalInfo() {
    return  
        "Class implementing minimal cost-complexity pruning.\n"
      + "Note when dealing with missing values, use \"fractional "
      + "instances\" method instead of surrogate split method.\n\n"
      + "For more information, see:\n\n"
      + getTechnicalInformation().toString();
  }

  /**
   * Returns an instance of a TechnicalInformation object, containing 
   * detailed information about the technical background of this class,
   * e.g., paper reference or book this class is based on.
   * 
   * @return            the technical information about this class
   */
00204   public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation      result;
    
    result = new TechnicalInformation(Type.BOOK);
    result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
    result.setValue(Field.YEAR, "1984");
    result.setValue(Field.TITLE, "Classification and Regression Trees");
    result.setValue(Field.PUBLISHER, "Wadsworth International Group");
    result.setValue(Field.ADDRESS, "Belmont, California");
    
    return result;
  }

  /**
   * Returns default capabilities of the classifier.
   * 
   * @return            the capabilities of this classifier
   */
00222   public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();

    // attributes
    result.enable(Capability.NOMINAL_ATTRIBUTES);
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enable(Capability.MISSING_VALUES);

    // class
    result.enable(Capability.NOMINAL_CLASS);

    return result;
  }

  /**
   * Build the classifier.
   * 
   * @param data  the training instances
   * @throws Exception  if something goes wrong
   */
00242   public void buildClassifier(Instances data) throws Exception {

    getCapabilities().testWithFail(data);
    data = new Instances(data);        
    data.deleteWithMissingClass();

    // unpruned CART decision tree
    if (!m_Prune) {

      // calculate sorted indices and weights, and compute initial class counts.
      int[][] sortedIndices = new int[data.numAttributes()][0];
      double[][] weights = new double[data.numAttributes()][0];
      double[] classProbs = new double[data.numClasses()];
      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);

      makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
        totalWeight,m_minNumObj, m_Heuristic);
      return;
    }

    Random random = new Random(m_Seed);
    Instances cvData = new Instances(data);
    cvData.randomize(random);
    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
    cvData.stratify(m_numFoldsPruning);

    double[][] alphas = new double[m_numFoldsPruning][];
    double[][] errors = new double[m_numFoldsPruning][];

    // calculate errors and alphas for each fold
    for (int i = 0; i < m_numFoldsPruning; i++) {

      //for every fold, grow tree on training set and fix error on test set.
      Instances train = cvData.trainCV(m_numFoldsPruning, i);
      Instances test = cvData.testCV(m_numFoldsPruning, i);

      // calculate sorted indices and weights, and compute initial class counts for each fold
      int[][] sortedIndices = new int[train.numAttributes()][0];
      double[][] weights = new double[train.numAttributes()][0];
      double[] classProbs = new double[train.numClasses()];
      double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);

      makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
        totalWeight,m_minNumObj, m_Heuristic);

      int numNodes = numInnerNodes();
      alphas[i] = new double[numNodes + 2];
      errors[i] = new double[numNodes + 2];

      // prune back and log alpha-values and errors on test set
      prune(alphas[i], errors[i], test);
    }

    // calculate sorted indices and weights, and compute initial class counts on all training instances
    int[][] sortedIndices = new int[data.numAttributes()][0];
    double[][] weights = new double[data.numAttributes()][0];
    double[] classProbs = new double[data.numClasses()];
    double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);

    //build tree using all the data
    makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
      totalWeight,m_minNumObj, m_Heuristic);

    int numNodes = numInnerNodes();

    double[] treeAlphas = new double[numNodes + 2];

    // prune back and log alpha-values
    int iterations = prune(treeAlphas, null, null);

    double[] treeErrors = new double[numNodes + 2];

    // for each pruned subtree, find the cross-validated error
    for (int i = 0; i <= iterations; i++){
      //compute midpoint alphas
      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
      double error = 0;
      for (int k = 0; k < m_numFoldsPruning; k++) {
      int l = 0;
      while (alphas[k][l] <= alpha) l++;
      error += errors[k][l - 1];
      }
      treeErrors[i] = error/m_numFoldsPruning;
    }

    // find best alpha
    int best = -1;
    double bestError = Double.MAX_VALUE;
    for (int i = iterations; i >= 0; i--) {
      if (treeErrors[i] < bestError) {
      bestError = treeErrors[i];
      best = i;
      }
    }

    // 1 SE rule to choose expansion
    if (m_UseOneSE) {
      double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
      for (int i = iterations; i >= 0; i--) {
      if (treeErrors[i] <= bestError+oneSE) {
        best = i;
        break;
      }
      }
    }

    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);

    //"unprune" final tree (faster than regrowing it)
    unprune();
    prune(bestAlpha);        
  }

  /**
   * Make binary decision tree recursively.
   * 
   * @param data        the training instances
   * @param totalInstances    total number of instances
   * @param sortedIndices     sorted indices of the instances
   * @param weights           weights of the instances
   * @param classProbs        class probabilities
   * @param totalWeight       total weight of instances
   * @param minNumObj         minimal number of instances at leaf nodes
   * @param useHeuristic      if use heuristic search for nominal attributes in multi-class problem
   * @throws Exception        if something goes wrong
   */
00368   protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
      double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
      boolean useHeuristic) throws Exception{

    // if no instances have reached this node (normally won't happen)
    if (totalWeight == 0){
      m_Attribute = null;
      m_ClassValue = Instance.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    m_totalTrainInstances = totalInstances;
    m_isLeaf = true;

    m_ClassProbs = new double[classProbs.length];
    m_Distribution = new double[classProbs.length];
    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
    if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);

    // Compute class distributions and value of splitting
    // criterion for each attribute
    double[][][] dists = new double[data.numAttributes()][0][0];
    double[][] props = new double[data.numAttributes()][0];
    double[][] totalSubsetWeights = new double[data.numAttributes()][2];
    double[] splits = new double[data.numAttributes()];
    String[] splitString = new String[data.numAttributes()];
    double[] giniGains = new double[data.numAttributes()];

    // for each attribute find split information
    for (int i = 0; i < data.numAttributes(); i++) {
      Attribute att = data.attribute(i);
      if (i==data.classIndex()) continue;
      if (att.isNumeric()) {
      // numeric attribute
      splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
          weights[i], totalSubsetWeights, giniGains, data);
      } else {
      // nominal attribute
      splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
          weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
      }
    }

    // Find best attribute (split with maximum Gini gain)
    int attIndex = Utils.maxIndex(giniGains);
    m_Attribute = data.attribute(attIndex);

    m_train = new Instances(data, sortedIndices[attIndex].length);
    for (int i=0; i<sortedIndices[attIndex].length; i++) {
      Instance inst = data.instance(sortedIndices[attIndex][i]);
      Instance instCopy = (Instance)inst.copy();
      instCopy.setWeight(weights[attIndex][i]);
      m_train.add(instCopy);
    }

    // Check if node does not contain enough instances, or if it can not be split,
    // or if it is pure. If does, make leaf.
    if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||
      props[attIndex][0]==0 || props[attIndex][1]==0) {
      makeLeaf(data);
    }

    else {            
      m_Props = props[attIndex];
      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
      double[][][] subsetWeights = new double[2][data.numAttributes()][0];

      // numeric split
      if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];

      // nominal split
      else m_SplitString = splitString[attIndex];

      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
        m_SplitString, sortedIndices, weights, data);

      // If split of the node results in a node with less than minimal number of isntances, 
      // make the node leaf node.
      if (subsetIndices[0][attIndex].length<minNumObj ||
        subsetIndices[1][attIndex].length<minNumObj) {
      makeLeaf(data);
      return;
      }

      // Otherwise, split the node.
      m_isLeaf = false;
      m_Successors = new SimpleCart[2];
      for (int i = 0; i < 2; i++) {
      m_Successors[i] = new SimpleCart();
      m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
          subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],
          minNumObj, useHeuristic);
      }
    }
  }

  /**
   * Prunes the original tree using the CART pruning scheme, given a 
   * cost-complexity parameter alpha.
   * 
   * @param alpha       the cost-complexity parameter
   * @throws Exception  if something goes wrong
   */
00473   public void prune(double alpha) throws Exception {

    Vector nodeList;

    // determine training error of pruned subtrees (both with and without replacing a subtree),
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      // select node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // want to prune if its alpha is smaller than alpha
      if (nodeToPrune.m_Alpha > alpha) {
      break;
      }

      nodeToPrune.makeLeaf(nodeToPrune.m_train);

      // normally would not happen
      if (nodeToPrune.m_Alpha==preAlpha) {
      nodeToPrune.makeLeaf(nodeToPrune.m_train);
      treeErrors();
      calculateAlphas();
      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
      continue;
      }
      preAlpha = nodeToPrune.m_Alpha;

      //update tree errors and alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }
  }

  /**
   * Method for performing one fold in the cross-validation of minimal 
   * cost-complexity pruning. Generates a sequence of alpha-values with error 
   * estimates for the corresponding (partially pruned) trees, given the test 
   * set of that fold.
   *
   * @param alphas      array to hold the generated alpha-values
   * @param errors      array to hold the corresponding error estimates
   * @param test  test set of that fold (to obtain error estimates)
   * @return            the iteration of the pruning
   * @throws Exception  if something goes wrong
   */
00532   public int prune(double[] alphas, double[] errors, Instances test) 
    throws Exception {

    Vector nodeList;

    // determine training error of subtrees (both with and without replacing a subtree), 
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);

    //alpha_0 is always zero (unpruned tree)
    alphas[0] = 0;

    Evaluation eval;

    // error of unpruned tree
    if (errors != null) {
      eval = new Evaluation(test);
      eval.evaluateModel(this, test);
      errors[0] = eval.errorRate();
    }

    int iteration = 0;
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      iteration++;

      // get node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // do not set m_sons null, want to unprune
      nodeToPrune.m_isLeaf = true;

      // normally would not happen
      if (nodeToPrune.m_Alpha==preAlpha) {
      iteration--;
      treeErrors();
      calculateAlphas();
      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
      continue;
      }

      // get alpha-value of node
      alphas[iteration] = nodeToPrune.m_Alpha;

      // log error
      if (errors != null) {
      eval = new Evaluation(test);
      eval.evaluateModel(this, test);
      errors[iteration] = eval.errorRate();
      }
      preAlpha = nodeToPrune.m_Alpha;

      //update errors/alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }

    //set last alpha 1 to indicate end
    alphas[iteration + 1] = 1.0;
    return iteration;
  }

  /**
   * Method to "unprune" the CART tree. Sets all leaf-fields to false.
   * Faster than re-growing the tree because CART do not have to be fit again.
   */
00610   protected void unprune() {
    if (m_Successors != null) {
      m_isLeaf = false;
      for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
    }
  }

  /**
   * Compute distributions, proportions and total weights of two successor 
   * nodes for a given numeric attribute.
   * 
   * @param props             proportions of each two branches for each attribute
   * @param dists             class distributions of two branches for each attribute
   * @param att         numeric att split on
   * @param sortedIndices     sorted indices of instances for the attirubte
   * @param weights           weights of instances for the attirbute
   * @param subsetWeights     total weight of two branches split based on the attribute
   * @param giniGains         Gini gains for each attribute 
   * @param data        training instances
   * @return                  Gini gain the given numeric attribute
   * @throws Exception        if something goes wrong
   */
00632   protected double numericDistribution(double[][] props, double[][][] dists,
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
      double[] giniGains, Instances data)
    throws Exception {

    double splitPoint = Double.NaN;
    double[][] dist = null;
    int numClasses = data.numClasses();
    int i; // differ instances with or without missing values

    double[][] currDist = new double[2][numClasses];
    dist = new double[2][numClasses];

    // Move all instances without missing values into second subset
    double[] parentDist = new double[numClasses];
    int missingStart = 0;
    for (int j = 0; j < sortedIndices.length; j++) {
      Instance inst = data.instance(sortedIndices[j]);
      if (!inst.isMissing(att)) {
      missingStart ++;
      currDist[1][(int)inst.classValue()] += weights[j];
      }
      parentDist[(int)inst.classValue()] += weights[j];
    }
    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

    // Try all possible split points
    double currSplit = data.instance(sortedIndices[0]).value(att);
    double currGiniGain;
    double bestGiniGain = -Double.MAX_VALUE;

    for (i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (inst.isMissing(att)) {
      break;
      }
      if (inst.value(att) > currSplit) {

      double[][] tempDist = new double[2][numClasses];
      for (int k=0; k<2; k++) {
        //tempDist[k] = currDist[k];
        System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
      }

      double[] tempProps = new double[2];
      for (int k=0; k<2; k++) {
        tempProps[k] = Utils.sum(tempDist[k]);
      }

      if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);

      // split missing values
      int index = missingStart;
      while (index < sortedIndices.length) {
        Instance insta = data.instance(sortedIndices[index]);
        for (int j = 0; j < 2; j++) {
          tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
        }
        index++;
      }

      currGiniGain = computeGiniGain(parentDist,tempDist);

      if (currGiniGain > bestGiniGain) {
        bestGiniGain = currGiniGain;

        // clean split point
        splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; 

        for (int j = 0; j < currDist.length; j++) {
          System.arraycopy(tempDist[j], 0, dist[j], 0,
            dist[j].length);
        }
      }
      }
      currSplit = inst.value(att);
      currDist[0][(int)inst.classValue()] += weights[i];
      currDist[1][(int)inst.classValue()] -= weights[i];
    }

    // Compute weights
    int attIndex = att.index();
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }
    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);

    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // clean Gini gain
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
    dists[attIndex] = dist;

    return splitPoint;
  }

  /**
   * Compute distributions, proportions and total weights of two successor 
   * nodes for a given nominal attribute.
   * 
   * @param props             proportions of each two branches for each attribute
   * @param dists             class distributions of two branches for each attribute
   * @param att         numeric att split on
   * @param sortedIndices     sorted indices of instances for the attirubte
   * @param weights           weights of instances for the attirbute
   * @param subsetWeights     total weight of two branches split based on the attribute
   * @param giniGains         Gini gains for each attribute 
   * @param data        training instances
   * @param useHeuristic      if use heuristic search
   * @return                  Gini gain for the given nominal attribute
   * @throws Exception        if something goes wrong
   */
00749   protected String nominalDistribution(double[][] props, double[][][] dists,
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
      double[] giniGains, Instances data, boolean useHeuristic)
    throws Exception {

    String[] values = new String[att.numValues()];
    int numCat = values.length; // number of values of the attribute
    int numClasses = data.numClasses();

    String bestSplitString = "";
    double bestGiniGain = -Double.MAX_VALUE;

    // class frequency for each value
    int[] classFreq = new int[numCat];
    for (int j=0; j<numCat; j++) classFreq[j] = 0;

    double[] parentDist = new double[numClasses];
    double[][] currDist = new double[2][numClasses];
    double[][] dist = new double[2][numClasses];
    int missingStart = 0;

    for (int i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (!inst.isMissing(att)) {
      missingStart++;
      classFreq[(int)inst.value(att)] ++;
      }
      parentDist[(int)inst.classValue()] += weights[i];
    }

    // count the number of values that class frequency is not 0
    int nonEmpty = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]!=0) nonEmpty ++;
    }

    // attribute values that class frequency is not 0
    String[] nonEmptyValues = new String[nonEmpty];
    int nonEmptyIndex = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]!=0) {
      nonEmptyValues[nonEmptyIndex] = att.value(j);
      nonEmptyIndex ++;
      }
    }

    // attribute values that class frequency is 0
    int empty = numCat - nonEmpty;
    String[] emptyValues = new String[empty];
    int emptyIndex = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]==0) {
      emptyValues[emptyIndex] = att.value(j);
      emptyIndex ++;
      }
    }

    if (nonEmpty<=1) {
      giniGains[att.index()] = 0;
      return "";
    }

    // for tow-class probloms
    if (data.numClasses()==2) {

      //// Firstly, for attribute values which class frequency is not zero

      // probability of class 0 for each attribute value
      double[] pClass0 = new double[nonEmpty];
      // class distribution for each attribute value
      double[][] valDist = new double[nonEmpty][2];

      for (int j=0; j<nonEmpty; j++) {
      for (int k=0; k<2; k++) {
        valDist[j][k] = 0;
      }
      }

      for (int i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (inst.isMissing(att)) {
        break;
      }

      for (int j=0; j<nonEmpty; j++) {
        if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
          valDist[j][(int)inst.classValue()] += inst.weight();
          break;
        }
      }
      }

      for (int j=0; j<nonEmpty; j++) {
      double distSum = Utils.sum(valDist[j]);
      if (distSum==0) pClass0[j]=0;
      else pClass0[j] = valDist[j][0]/distSum;
      }

      // sort category according to the probability of the first class
      String[] sortedValues = new String[nonEmpty];
      for (int j=0; j<nonEmpty; j++) {
      sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
      pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
      }

      // Find a subset of attribute values that maximize Gini decrease

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j=0; j<nonEmpty-1; j++) {
      currDist = new double[2][numClasses];
      if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
      else tempStr += "|"+ "(" + sortedValues[j] + ")";
      for (int i=0; i<sortedIndices.length;i++) {
        Instance inst = data.instance(sortedIndices[i]);
        if (inst.isMissing(att)) {
          break;
        }

        if (tempStr.indexOf
            ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
          currDist[0][(int)inst.classValue()] += weights[i];
        } else currDist[1][(int)inst.classValue()] += weights[i];
      }

      double[][] tempDist = new double[2][numClasses];
      for (int kk=0; kk<2; kk++) {
        tempDist[kk] = currDist[kk];
      }

      double[] tempProps = new double[2];
      for (int kk=0; kk<2; kk++) {
        tempProps[kk] = Utils.sum(tempDist[kk]);
      }

      if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

      // split missing values
      int mstart = missingStart;
      while (mstart < sortedIndices.length) {
        Instance insta = data.instance(sortedIndices[mstart]);
        for (int jj = 0; jj < 2; jj++) {
          tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
        }
        mstart++;
      }

      double currGiniGain = computeGiniGain(parentDist,tempDist);

      if (currGiniGain>bestGiniGain) {
        bestGiniGain = currGiniGain;
        bestSplitString = tempStr;
        for (int jj = 0; jj < 2; jj++) {
          //dist[jj] = new double[currDist[jj].length];
          System.arraycopy(tempDist[jj], 0, dist[jj], 0,
            dist[jj].length);
        }
      }
      }
    }

    // multi-class problems - exhaustive search
    else if (!useHeuristic || nonEmpty<=4) {

      // Firstly, for attribute values which class frequency is not zero
      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
      String tempStr="";
      currDist = new double[2][numClasses];
      int mod;
      int bit10 = i;
      for (int j=nonEmpty-1; j>=0; j--) {
        mod = bit10%2; // convert from 10bit to 2bit
        if (mod==1) {
          if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
          else tempStr += "|" + "("+nonEmptyValues[j]+")";
        }
        bit10 = bit10/2;
      }
      for (int j=0; j<sortedIndices.length;j++) {
        Instance inst = data.instance(sortedIndices[j]);
        if (inst.isMissing(att)) {
          break;
        }

        if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
          currDist[0][(int)inst.classValue()] += weights[j];
        } else currDist[1][(int)inst.classValue()] += weights[j];
      }

      double[][] tempDist = new double[2][numClasses];
      for (int k=0; k<2; k++) {
        tempDist[k] = currDist[k];
      }

      double[] tempProps = new double[2];
      for (int k=0; k<2; k++) {
        tempProps[k] = Utils.sum(tempDist[k]);
      }

      if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

      // split missing values
      int index = missingStart;
      while (index < sortedIndices.length) {
        Instance insta = data.instance(sortedIndices[index]);
        for (int j = 0; j < 2; j++) {
          tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
        }
        index++;
      }

      double currGiniGain = computeGiniGain(parentDist,tempDist);

      if (currGiniGain>bestGiniGain) {
        bestGiniGain = currGiniGain;
        bestSplitString = tempStr;
        for (int j = 0; j < 2; j++) {
          //dist[jj] = new double[currDist[jj].length];
          System.arraycopy(tempDist[j], 0, dist[j], 0,
            dist[j].length);
        }
      }
      }
    }

    // huristic search to solve multi-classes problems
    else {
      // Firstly, for attribute values which class frequency is not zero
      int n = nonEmpty;
      int k = data.numClasses();  // number of classes of the data
      double[][] P = new double[n][k];      // class probability matrix
      int[] numInstancesValue = new int[n]; // number of instances for an attribute value
      double[] meanClass = new double[k];   // vector of mean class probability
      int numInstances = data.numInstances(); // total number of instances

      // initialize the vector of mean class probability
      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;

      for (int j=0; j<numInstances; j++) {
      Instance inst = (Instance)data.instance(j);
      int valueIndex = 0; // attribute value index in nonEmptyValues
      for (int i=0; i<nonEmpty; i++) {
        if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
          valueIndex = i;
          break;
        }
      }
      P[valueIndex][(int)inst.classValue()]++;
      numInstancesValue[valueIndex]++;
      meanClass[(int)inst.classValue()]++;
      }

      // calculate the class probability matrix
      for (int i=0; i<P.length; i++) {
      for (int j=0; j<P[0].length; j++) {
        if (numInstancesValue[i]==0) P[i][j]=0;
        else P[i][j]/=numInstancesValue[i];
      }
      }

      //calculate the vector of mean class probability
      for (int i=0; i<meanClass.length; i++) {
      meanClass[i]/=numInstances;
      }

      // calculate the covariance matrix
      double[][] covariance = new double[k][k];
      for (int i1=0; i1<k; i1++) {
      for (int i2=0; i2<k; i2++) {
        double element = 0;
        for (int j=0; j<n; j++) {
          element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
          *numInstancesValue[j];
        }
        covariance[i1][i2] = element;
      }
      }

      Matrix matrix = new Matrix(covariance);
      weka.core.matrix.EigenvalueDecomposition eigen =
      new weka.core.matrix.EigenvalueDecomposition(matrix);
      double[] eigenValues = eigen.getRealEigenvalues();

      // find index of the largest eigenvalue
      int index=0;
      double largest = eigenValues[0];
      for (int i=1; i<eigenValues.length; i++) {
      if (eigenValues[i]>largest) {
        index=i;
        largest = eigenValues[i];
      }
      }

      // calculate the first principle component
      double[] FPC = new double[k];
      Matrix eigenVector = eigen.getV();
      double[][] vectorArray = eigenVector.getArray();
      for (int i=0; i<FPC.length; i++) {
      FPC[i] = vectorArray[i][index];
      }

      // calculate the first principle component scores
      //System.out.println("the first principle component scores: ");
      double[] Sa = new double[n];
      for (int i=0; i<Sa.length; i++) {
      Sa[i]=0;
      for (int j=0; j<k; j++) {
        Sa[i] += FPC[j]*P[i][j];
      }
      }

      // sort category according to Sa(s)
      double[] pCopy = new double[n];
      System.arraycopy(Sa,0,pCopy,0,n);
      String[] sortedValues = new String[n];
      Arrays.sort(Sa);

      for (int j=0; j<n; j++) {
      sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
      pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
      }

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j=0; j<nonEmpty-1; j++) {
      currDist = new double[2][numClasses];
      if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
      else tempStr += "|"+ "(" + sortedValues[j] + ")";
      for (int i=0; i<sortedIndices.length;i++) {
        Instance inst = data.instance(sortedIndices[i]);
        if (inst.isMissing(att)) {
          break;
        }

        if (tempStr.indexOf
            ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
          currDist[0][(int)inst.classValue()] += weights[i];
        } else currDist[1][(int)inst.classValue()] += weights[i];
      }

      double[][] tempDist = new double[2][numClasses];
      for (int kk=0; kk<2; kk++) {
        tempDist[kk] = currDist[kk];
      }

      double[] tempProps = new double[2];
      for (int kk=0; kk<2; kk++) {
        tempProps[kk] = Utils.sum(tempDist[kk]);
      }

      if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

      // split missing values
      int mstart = missingStart;
      while (mstart < sortedIndices.length) {
        Instance insta = data.instance(sortedIndices[mstart]);
        for (int jj = 0; jj < 2; jj++) {
          tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
        }
        mstart++;
      }

      double currGiniGain = computeGiniGain(parentDist,tempDist);

      if (currGiniGain>bestGiniGain) {
        bestGiniGain = currGiniGain;
        bestSplitString = tempStr;
        for (int jj = 0; jj < 2; jj++) {
          //dist[jj] = new double[currDist[jj].length];
          System.arraycopy(tempDist[jj], 0, dist[jj], 0,
            dist[jj].length);
        }
      }
      }
    }

    // Compute weights
    int attIndex = att.index();        
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }

    if (!(Utils.sum(props[attIndex]) > 0)) {
      for (int k = 0; k < props[attIndex].length; k++) {
      props[attIndex][k] = 1.0 / (double)props[attIndex].length;
      }
    } else {
      Utils.normalize(props[attIndex]);
    }


    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // Then, for the attribute values that class frequency is 0, split it into the
    // most frequent branch
    for (int j=0; j<empty; j++) {
      if (props[attIndex][0]>=props[attIndex][1]) {
      if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
      else bestSplitString += "|" + "(" + emptyValues[j] + ")";
      }
    }

    // clean Gini gain for the attribute
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;

    dists[attIndex] = dist;
    return bestSplitString;
  }


  /**
   * Split data into two subsets and store sorted indices and weights for two
   * successor nodes.
   * 
   * @param subsetIndices     sorted indecis of instances for each attribute 
   *                    for two successor node
   * @param subsetWeights     weights of instances for each attribute for 
   *                    two successor node
   * @param att         attribute the split based on
   * @param splitPoint        split point the split based on if att is numeric
   * @param splitStr          split subset the split based on if att is nominal
   * @param sortedIndices     sorted indices of the instances to be split
   * @param weights           weights of the instances to bes split
   * @param data        training data
   * @throws Exception        if something goes wrong  
   */
01182   protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
      double[][] weights, Instances data) throws Exception {

    int j;
    // For each attribute
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i==data.classIndex()) continue;
      int[] num = new int[2];
      for (int k = 0; k < 2; k++) {
      subsetIndices[k][i] = new int[sortedIndices[i].length];
      subsetWeights[k][i] = new double[weights[i].length];
      }

      for (j = 0; j < sortedIndices[i].length; j++) {
      Instance inst = data.instance(sortedIndices[i][j]);
      if (inst.isMissing(att)) {
        // Split instance up
        for (int k = 0; k < 2; k++) {
          if (m_Props[k] > 0) {
            subsetIndices[k][i][num[k]] = sortedIndices[i][j];
            subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
            num[k]++;
          }
        }
      } else {
        int subset;
        if (att.isNumeric())  {
          subset = (inst.value(att) < splitPoint) ? 0 : 1;
        } else { // nominal attribute
          if (splitStr.indexOf
            ("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
            subset = 0;
          } else subset = 1;
        }
        subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
        subsetWeights[subset][i][num[subset]] = weights[i][j];
        num[subset]++;
      }
      }

      // Trim arrays
      for (int k = 0; k < 2; k++) {
      int[] copy = new int[num[k]];
      System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
      subsetIndices[k][i] = copy;
      double[] copyWeights = new double[num[k]];
      System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
      subsetWeights[k][i] = copyWeights;
      }
    }
  }

  /**
   * Updates the numIncorrectModel field for all nodes when subtree (to be 
   * pruned) is rooted. This is needed for calculating the alpha-values.
   * 
   * @throws Exception  if something goes wrong
   */
01241   public void modelErrors() throws Exception{
    Evaluation eval = new Evaluation(m_train);

    if (!m_isLeaf) {
      m_isLeaf = true; //temporarily make leaf

      // calculate distribution for evaluation
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();

      m_isLeaf = false;

      for (int i = 0; i < m_Successors.length; i++)
      m_Successors[i].modelErrors();

    } else {
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();
    }       
  }

  /**
   * Updates the numIncorrectTree field for all nodes. This is needed for
   * calculating the alpha-values.
   * 
   * @throws Exception  if something goes wrong
   */
01268   public void treeErrors() throws Exception {
    if (m_isLeaf) {
      m_numIncorrectTree = m_numIncorrectModel;
    } else {
      m_numIncorrectTree = 0;
      for (int i = 0; i < m_Successors.length; i++) {
      m_Successors[i].treeErrors();
      m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
      }
    }
  }

  /**
   * Updates the alpha field for all nodes.
   * 
   * @throws Exception  if something goes wrong
   */
01285   public void calculateAlphas() throws Exception {

    if (!m_isLeaf) {
      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
      if (errorDiff <=0) {
      //split increases training error (should not normally happen).
      //prune it instantly.
      makeLeaf(m_train);
      m_Alpha = Double.MAX_VALUE;
      } else {
      //compute alpha
      errorDiff /= m_totalTrainInstances;
      m_Alpha = errorDiff / (double)(numLeaves() - 1);
      long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
      m_Alpha = (double)alphaLong/Math.pow(10,10);
      for (int i = 0; i < m_Successors.length; i++) {
        m_Successors[i].calculateAlphas();
      }
      }
    } else {
      //alpha = infinite for leaves (do not want to prune)
      m_Alpha = Double.MAX_VALUE;
    }
  }

  /**
   * Find the node with minimal alpha value. If two nodes have the same alpha, 
   * choose the one with more leave nodes.
   * 
   * @param nodeList    list of inner nodes
   * @return            the node to be pruned
   */
01317   protected SimpleCart nodeToPrune(Vector nodeList) {
    if (nodeList.size()==0) return null;
    if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
    SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
    double baseAlpha = returnNode.m_Alpha;
    for (int i=1; i<nodeList.size(); i++) {
      SimpleCart node = (SimpleCart)nodeList.elementAt(i);
      if (node.m_Alpha < baseAlpha) {
      baseAlpha = node.m_Alpha;
      returnNode = node;
      } else if (node.m_Alpha == baseAlpha) { // break tie
      if (node.numLeaves()>returnNode.numLeaves()) {
        returnNode = node;
      }
      }
    }
    return returnNode;
  }

  /**
   * Compute sorted indices, weights and class probabilities for a given 
   * dataset. Return total weights of the data at the node.
   * 
   * @param data        training data
   * @param sortedIndices     sorted indices of instances at the node
   * @param weights           weights of instances at the node
   * @param classProbs        class probabilities at the node
   * @return total            weights of instances at the node
   * @throws Exception        if something goes wrong
   */
01347   protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
      double[] classProbs) throws Exception {

    // Create array of sorted indices and weights
    double[] vals = new double[data.numInstances()];
    for (int j = 0; j < data.numAttributes(); j++) {
      if (j==data.classIndex()) continue;
      weights[j] = new double[data.numInstances()];

      if (data.attribute(j).isNominal()) {

      // Handling nominal attributes. Putting indices of
      // instances with missing values at the end.
      sortedIndices[j] = new int[data.numInstances()];
      int count = 0;
      for (int i = 0; i < data.numInstances(); i++) {
        Instance inst = data.instance(i);
        if (!inst.isMissing(j)) {
          sortedIndices[j][count] = i;
          weights[j][count] = inst.weight();
          count++;
        }
      }
      for (int i = 0; i < data.numInstances(); i++) {
        Instance inst = data.instance(i);
        if (inst.isMissing(j)) {
          sortedIndices[j][count] = i;
          weights[j][count] = inst.weight();
          count++;
        }
      }
      } else {

      // Sorted indices are computed for numeric attributes
      // missing values instances are put to end 
      for (int i = 0; i < data.numInstances(); i++) {
        Instance inst = data.instance(i);
        vals[i] = inst.value(j);
      }
      sortedIndices[j] = Utils.sort(vals);
      for (int i = 0; i < data.numInstances(); i++) {
        weights[j][i] = data.instance(sortedIndices[j][i]).weight();
      }
      }
    }

    // Compute initial class counts
    double totalWeight = 0;
    for (int i = 0; i < data.numInstances(); i++) {
      Instance inst = data.instance(i);
      classProbs[(int)inst.classValue()] += inst.weight();
      totalWeight += inst.weight();
    }

    return totalWeight;
  }

  /**
   * Compute and return gini gain for given distributions of a node and its 
   * successor nodes.
   * 
   * @param parentDist  class distributions of parent node
   * @param childDist   class distributions of successor nodes
   * @return            Gini gain computed
   */
01412   protected double computeGiniGain(double[] parentDist, double[][] childDist) {
    double totalWeight = Utils.sum(parentDist);
    if (totalWeight==0) return 0;

    double leftWeight = Utils.sum(childDist[0]);
    double rightWeight = Utils.sum(childDist[1]);

    double parentGini = computeGini(parentDist, totalWeight);
    double leftGini = computeGini(childDist[0],leftWeight);
    double rightGini = computeGini(childDist[1], rightWeight);

    return parentGini - leftWeight/totalWeight*leftGini -
    rightWeight/totalWeight*rightGini;
  }

  /**
   * Compute and return gini index for a given distribution of a node.
   * 
   * @param dist  class distributions
   * @param total       class distributions
   * @return            Gini index of the class distributions
   */
01434   protected double computeGini(double[] dist, double total) {
    if (total==0) return 0;
    double val = 0;
    for (int i=0; i<dist.length; i++) {
      val += (dist[i]/total)*(dist[i]/total);
    }
    return 1- val;
  }

  /**
   * Computes class probabilities for instance using the decision tree.
   * 
   * @param instance    the instance for which class probabilities is to be computed
   * @return            the class probabilities for the given instance
   * @throws Exception  if something goes wrong
   */
01450   public double[] distributionForInstance(Instance instance)
  throws Exception {
    if (!m_isLeaf) {
      // value of split attribute is missing
      if (instance.isMissing(m_Attribute)) {
      double[] returnedDist = new double[m_ClassProbs.length];

      for (int i = 0; i < m_Successors.length; i++) {
        double[] help =
          m_Successors[i].distributionForInstance(instance);
        if (help != null) {
          for (int j = 0; j < help.length; j++) {
            returnedDist[j] += m_Props[i] * help[j];
          }
        }
      }
      return returnedDist;
      }

      // split attribute is nonimal
      else if (m_Attribute.isNominal()) {
      if (m_SplitString.indexOf("(" +
          m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
        return  m_Successors[0].distributionForInstance(instance);
      else return  m_Successors[1].distributionForInstance(instance);
      }

      // split attribute is numeric
      else {
      if (instance.value(m_Attribute) < m_SplitValue)
        return m_Successors[0].distributionForInstance(instance);
      else
        return m_Successors[1].distributionForInstance(instance);
      }
    }

    // leaf node
    else return m_ClassProbs;
  }

  /**
   * Make the node leaf node.
   * 
   * @param data  trainging data
   */
01495   protected void makeLeaf(Instances data) {
    m_Attribute = null;
    m_isLeaf = true;
    m_ClassValue=Utils.maxIndex(m_ClassProbs);
    m_ClassAttribute = data.classAttribute();
  }

  /**
   * Prints the decision tree using the protected toString method from below.
   * 
   * @return            a textual description of the classifier
   */
01507   public String toString() {
    if ((m_ClassProbs == null) && (m_Successors == null)) {
      return "CART Tree: No model built yet.";
    }

    return "CART Decision Tree\n" + toString(0)+"\n\n"
    +"Number of Leaf Nodes: "+numLeaves()+"\n\n" +
    "Size of the Tree: "+numNodes();
  }

  /**
   * Outputs a tree at a certain level.
   * 
   * @param level       the level at which the tree is to be printed
   * @return            a tree at a certain level
   */
01523   protected String toString(int level) {

    StringBuffer text = new StringBuffer();
    // if leaf nodes
    if (m_Attribute == null) {
      if (Instance.isMissingValue(m_ClassValue)) {
      text.append(": null");
      } else {
      double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
      100.0;
      double wrongNum = (int)((Utils.sum(m_Distribution) -
          m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
      String str = "("  + correctNum + "/" + wrongNum + ")";
      text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
      }
    } else {
      for (int j = 0; j < 2; j++) {
      text.append("\n");
      for (int i = 0; i < level; i++) {
        text.append("|  ");
      }
      if (j==0) {
        if (m_Attribute.isNumeric())
          text.append(m_Attribute.name() + " < " + m_SplitValue);
        else
          text.append(m_Attribute.name() + "=" + m_SplitString);
      } else {
        if (m_Attribute.isNumeric())
          text.append(m_Attribute.name() + " >= " + m_SplitValue);
        else
          text.append(m_Attribute.name() + "!=" + m_SplitString);
      }
      text.append(m_Successors[j].toString(level + 1));
      }
    }
    return text.toString();
  }

  /**
   * Compute size of the tree.
   * 
   * @return            size of the tree
   */
01566   public int numNodes() {
    if (m_isLeaf) {
      return 1;
    } else {
      int size =1;
      for (int i=0;i<m_Successors.length;i++) {
      size+=m_Successors[i].numNodes();
      }
      return size;
    }
  }

  /**
   * Method to count the number of inner nodes in the tree.
   * 
   * @return            the number of inner nodes
   */
01583   public int numInnerNodes(){
    if (m_Attribute==null) return 0;
    int numNodes = 1;
    for (int i = 0; i < m_Successors.length; i++)
      numNodes += m_Successors[i].numInnerNodes();
    return numNodes;
  }

  /**
   * Return a list of all inner nodes in the tree.
   * 
   * @return            the list of all inner nodes
   */
01596   protected Vector getInnerNodes(){
    Vector nodeList = new Vector();
    fillInnerNodes(nodeList);
    return nodeList;
  }

  /**
   * Fills a list with all inner nodes in the tree.
   * 
   * @param nodeList    the list to be filled
   */
01607   protected void fillInnerNodes(Vector nodeList) {
    if (!m_isLeaf) {
      nodeList.add(this);
      for (int i = 0; i < m_Successors.length; i++)
      m_Successors[i].fillInnerNodes(nodeList);
    }
  }

  /**
   * Compute number of leaf nodes.
   * 
   * @return            number of leaf nodes
   */
01620   public int numLeaves() {
    if (m_isLeaf) return 1;
    else {
      int size=0;
      for (int i=0;i<m_Successors.length;i++) {
      size+=m_Successors[i].numLeaves();
      }
      return size;
    }
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return            an enumeration of all the available options.
   */
01636   public Enumeration listOptions() {
    Vector  result;
    Enumeration   en;
    
    result = new Vector();
    
    en = super.listOptions();
    while (en.hasMoreElements())
      result.addElement(en.nextElement());

    result.addElement(new Option(
      "\tThe minimal number of instances at the terminal nodes.\n" 
      + "\t(default 2)",
      "M", 1, "-M <min no>"));
    
    result.addElement(new Option(
      "\tThe number of folds used in the minimal cost-complexity pruning.\n"
      + "\t(default 5)",
      "N", 1, "-N <num folds>"));
    
    result.addElement(new Option(
      "\tDon't use the minimal cost-complexity pruning.\n"
      + "\t(default yes).",
      "U", 0, "-U"));
    
    result.addElement(new Option(
      "\tDon't use the heuristic method for binary split.\n"
      + "\t(default true).",
      "H", 0, "-H"));
    
    result.addElement(new Option(
      "\tUse 1 SE rule to make pruning decision.\n"
      + "\t(default no).",
      "A", 0, "-A"));
    
    result.addElement(new Option(
      "\tPercentage of training data size (0-1].\n" 
      + "\t(default 1).",
      "C", 1, "-C"));

    return result.elements();
  }

  /**
   * Parses a given list of options. <p/>
   * 
   <!-- options-start -->
   * Valid options are: <p/>
   * 
   * <pre> -S &lt;num&gt;
   *  Random number seed.
   *  (default 1)</pre>
   * 
   * <pre> -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console</pre>
   * 
   * <pre> -M &lt;min no&gt;
   *  The minimal number of instances at the terminal nodes.
   *  (default 2)</pre>
   * 
   * <pre> -N &lt;num folds&gt;
   *  The number of folds used in the minimal cost-complexity pruning.
   *  (default 5)</pre>
   * 
   * <pre> -U
   *  Don't use the minimal cost-complexity pruning.
   *  (default yes).</pre>
   * 
   * <pre> -H
   *  Don't use the heuristic method for binary split.
   *  (default true).</pre>
   * 
   * <pre> -A
   *  Use 1 SE rule to make pruning decision.
   *  (default no).</pre>
   * 
   * <pre> -C
   *  Percentage of training data size (0-1].
   *  (default 1).</pre>
   * 
   <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an options is not supported
   */
01722   public void setOptions(String[] options) throws Exception {
    String  tmpStr;
    
    super.setOptions(options);
    
    tmpStr = Utils.getOption('M', options);
    if (tmpStr.length() != 0)
      setMinNumObj(Double.parseDouble(tmpStr));
    else
      setMinNumObj(2);

    tmpStr = Utils.getOption('N', options);
    if (tmpStr.length()!=0)
      setNumFoldsPruning(Integer.parseInt(tmpStr));
    else
      setNumFoldsPruning(5);

    setUsePrune(!Utils.getFlag('U',options));
    setHeuristic(!Utils.getFlag('H',options));
    setUseOneSE(Utils.getFlag('A',options));

    tmpStr = Utils.getOption('C', options);
    if (tmpStr.length()!=0)
      setSizePer(Double.parseDouble(tmpStr));
    else
      setSizePer(1);

    Utils.checkForRemainingOptions(options);
  }

  /**
   * Gets the current settings of the classifier.
   * 
   * @return            the current setting of the classifier
   */
01757   public String[] getOptions() {
    int           i;
    Vector        result;
    String[]      options;

    result = new Vector();

    options = super.getOptions();
    for (i = 0; i < options.length; i++)
      result.add(options[i]);

    result.add("-M");
    result.add("" + getMinNumObj());
    
    result.add("-N");
    result.add("" + getNumFoldsPruning());
    
    if (!getUsePrune())
      result.add("-U");
    
    if (!getHeuristic())
      result.add("-H");
    
    if (getUseOneSE())
      result.add("-A");
    
    result.add("-C");
    result.add("" + getSizePer());

    return (String[]) result.toArray(new String[result.size()]);    
  }

  /**
   * Return an enumeration of the measure names.
   * 
   * @return            an enumeration of the measure names
   */
01794   public Enumeration enumerateMeasures() {
    Vector result = new Vector();
    
    result.addElement("measureTreeSize");
    
    return result.elements();
  }

  /**
   * Return number of tree size.
   * 
   * @return            number of tree size
   */
01807   public double measureTreeSize() {
    return numNodes();
  }

  /**
   * Returns the value of the named measure.
   * 
   * @param additionalMeasureName   the name of the measure to query for its value
   * @return                        the value of the named measure
   * @throws IllegalArgumentException     if the named measure is not supported
   */
01818   public double getMeasure(String additionalMeasureName) {
    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
      return measureTreeSize();
    } else {
      throw new IllegalArgumentException(additionalMeasureName
        + " not supported (Cart pruning)");
    }
  }

  /**
   * Returns the tip text for this property
   * 
   * @return            tip text for this property suitable for
   *              displaying in the explorer/experimenter gui
   */
01833   public String minNumObjTipText() {
    return "The minimal number of observations at the terminal nodes (default 2).";
  }

  /**
   * Set minimal number of instances at the terminal nodes.
   * 
   * @param value       minimal number of instances at the terminal nodes
   */
01842   public void setMinNumObj(double value) {
    m_minNumObj = value;
  }

  /**
   * Get minimal number of instances at the terminal nodes.
   * 
   * @return            minimal number of instances at the terminal nodes
   */
01851   public double getMinNumObj() {
    return m_minNumObj;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return            tip text for this property suitable for
   *              displaying in the explorer/experimenter gui
   */
01861   public String numFoldsPruningTipText() {
    return "The number of folds in the internal cross-validation (default 5).";
  }

  /** 
   * Set number of folds in internal cross-validation.
   * 
   * @param value       number of folds in internal cross-validation.
   */
01870   public void setNumFoldsPruning(int value) {
    m_numFoldsPruning = value;
  }

  /**
   * Set number of folds in internal cross-validation.
   * 
   * @return            number of folds in internal cross-validation.
   */
01879   public int getNumFoldsPruning() {
    return m_numFoldsPruning;
  }

  /**
   * Return the tip text for this property
   * 
   * @return            tip text for this property suitable for displaying in 
   *              the explorer/experimenter gui.
   */
01889   public String usePruneTipText() {
    return "Use minimal cost-complexity pruning (default yes).";
  }

  /** 
   * Set if use minimal cost-complexity pruning.
   * 
   * @param value       if use minimal cost-complexity pruning
   */
01898   public void setUsePrune(boolean value) {
    m_Prune = value;
  }

  /** 
   * Get if use minimal cost-complexity pruning.
   * 
   * @return            if use minimal cost-complexity pruning
   */
01907   public boolean getUsePrune() {
    return m_Prune;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return            tip text for this property suitable for
   *              displaying in the explorer/experimenter gui.
   */
01917   public String heuristicTipText() {
    return 
        "If heuristic search is used for binary split for nominal attributes "
      + "in multi-class problems (default yes).";
  }

  /**
   * Set if use heuristic search for nominal attributes in multi-class problems.
   * 
   * @param value       if use heuristic search for nominal attributes in 
   *              multi-class problems
   */
01929   public void setHeuristic(boolean value) {
    m_Heuristic = value;
  }

  /** 
   * Get if use heuristic search for nominal attributes in multi-class problems.
   * 
   * @return            if use heuristic search for nominal attributes in 
   *              multi-class problems
   */
01939   public boolean getHeuristic() {return m_Heuristic;}

  /**
   * Returns the tip text for this property
   * 
   * @return            tip text for this property suitable for
   *              displaying in the explorer/experimenter gui.
   */
01947   public String useOneSETipText() {
    return "Use the 1SE rule to make pruning decisoin.";
  }

  /** 
   * Set if use the 1SE rule to choose final model.
   * 
   * @param value       if use the 1SE rule to choose final model
   */
01956   public void setUseOneSE(boolean value) {
    m_UseOneSE = value;
  }

  /**
   * Get if use the 1SE rule to choose final model.
   * 
   * @return            if use the 1SE rule to choose final model
   */
01965   public boolean getUseOneSE() {
    return m_UseOneSE;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return            tip text for this property suitable for
   *              displaying in the explorer/experimenter gui.
   */
01975   public String sizePerTipText() {
    return "The percentage of the training set size (0-1, 0 not included).";
  }

  /** 
   * Set training set size.
   * 
   * @param value       training set size
   */  
01984   public void setSizePer(double value) {
    if ((value <= 0) || (value > 1))
      System.err.println(
        "The percentage of the training set size must be in range 0 to 1 "
        + "(0 not included) - ignored!");
    else
      m_SizePer = value;
  }

  /**
   * Get training set size.
   * 
   * @return            training set size
   */
01998   public double getSizePer() {
    return m_SizePer;
  }
  
  /**
   * Returns the revision string.
   * 
   * @return            the revision
   */
02007   public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.4 $");
  }

  /**
   * Main method.
   * @param args the options for the classifier
   */
02015   public static void main(String[] args) {
    runClassifier(new SimpleCart(), args);
  }
}

Generated by  Doxygen 1.6.0   Back to index