Logo Search packages:      
Sourcecode: weka version File versions

EM.java

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    EM.java
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.clusterers;

import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
 <!-- globalinfo-start -->
 * Simple EM (expectation maximisation) class.<br/>
 * <br/>
 * EM assigns a probability distribution to each instance which indicates the probability of it belonging to each of the clusters. EM can decide how many clusters to create by cross validation, or you may specify apriori how many clusters to generate.<br/>
 * <br/>
 * The cross validation performed to determine the number of clusters is done in the following steps:<br/>
 * 1. the number of clusters is set to 1<br/>
 * 2. the training set is split randomly into 10 folds.<br/>
 * 3. EM is performed 10 times using the 10 folds the usual CV way.<br/>
 * 4. the loglikelihood is averaged over all 10 results.<br/>
 * 5. if loglikelihood has increased the number of clusters is increased by 1 and the program continues at step 2. <br/>
 * <br/>
 * The number of folds is fixed to 10, as long as the number of instances in the training set is not smaller 10. If this is the case the number of folds is set equal to the number of instances.
 * <p/>
 <!-- globalinfo-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 * 
 * <pre> -N &lt;num&gt;
 *  number of clusters. If omitted or -1 specified, then 
 *  cross validation is used to select the number of clusters.</pre>
 * 
 * <pre> -I &lt;num&gt;
 *  max iterations.
 * (default 100)</pre>
 * 
 * <pre> -V
 *  verbose.</pre>
 * 
 * <pre> -M &lt;num&gt;
 *  minimum allowable standard deviation for normal density
 *  computation
 *  (default 1e-6)</pre>
 * 
 * <pre> -O
 *  Display model in old format (good when there are many clusters)
 * </pre>
 * 
 * <pre> -S &lt;num&gt;
 *  Random number seed.
 *  (default 100)</pre>
 * 
 <!-- options-end -->
 *
 * @author Mark Hall (mhall@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision: 1.44 $
 */
00091 public class EM
  extends RandomizableDensityBasedClusterer
  implements NumberOfClustersRequestable, WeightedInstancesHandler {

  /** for serialization */
00096   static final long serialVersionUID = 8348181483812829475L;
  
  /** hold the discrete estimators for each cluster */
00099   private Estimator m_model[][];

  /** hold the normal estimators for each cluster */
00102   private double m_modelNormal[][][];

  /** default minimum standard deviation */
00105   private double m_minStdDev = 1e-6;

  private double [] m_minStdDevPerAtt;

  /** hold the weights of each instance for each cluster */
00110   private double m_weights[][];

  /** the prior probabilities for clusters */
00113   private double m_priors[];

  /** the loglikelihood of the data */
00116   private double m_loglikely;

  /** training instances */
00119   private Instances m_theInstances = null;

  /** number of clusters selected by the user or cross validation */
00122   private int m_num_clusters;

  /** the initial number of clusters requested by the user--- -1 if
      xval is to be used to find the number of clusters */
00126   private int m_initialNumClusters;

  /** number of attributes */
00129   private int m_num_attribs;

  /** number of training instances */
00132   private int m_num_instances;

  /** maximum iterations to perform */
00135   private int m_max_iterations;

  /** attribute min values */
00138   private double [] m_minValues;

  /** attribute max values */
00141   private double [] m_maxValues;

  /** random number generator */
00144   private Random m_rr;

  /** Verbose? */
00147   private boolean m_verbose;

 /** globally replace missing values */
00150   private ReplaceMissingValues m_replaceMissing;

  /** display model output in old-style format */
00153   private boolean m_displayModelInOldFormat;

  /**
   * Returns a string describing this clusterer
   * @return a description of the evaluator suitable for
   * displaying in the explorer/experimenter gui
   */
00160   public String globalInfo() {
    return
        "Simple EM (expectation maximisation) class.\n\n"
      + "EM assigns a probability distribution to each instance which "
      + "indicates the probability of it belonging to each of the clusters. "
      + "EM can decide how many clusters to create by cross validation, or you "
      + "may specify apriori how many clusters to generate.\n\n"
      + "The cross validation performed to determine the number of clusters "
      + "is done in the following steps:\n"
      + "1. the number of clusters is set to 1\n"
      + "2. the training set is split randomly into 10 folds.\n"
      + "3. EM is performed 10 times using the 10 folds the usual CV way.\n"
      + "4. the loglikelihood is averaged over all 10 results.\n"
      + "5. if loglikelihood has increased the number of clusters is increased "
      + "by 1 and the program continues at step 2. \n\n"
      + "The number of folds is fixed to 10, as long as the number of "
      + "instances in the training set is not smaller 10. If this is the case "
      + "the number of folds is set equal to the number of instances.";
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
00185   public Enumeration listOptions () {
    Vector result = new Vector();
    
    result.addElement(new Option(
      "\tnumber of clusters. If omitted or -1 specified, then \n"
      + "\tcross validation is used to select the number of clusters.", 
      "N", 1, "-N <num>"));

    result.addElement(new Option(
      "\tmax iterations."
      + "\n(default 100)", 
      "I", 1, "-I <num>"));
    
    result.addElement(new Option(
      "\tverbose.",
      "V", 0, "-V"));
    
    result.addElement(new Option(
      "\tminimum allowable standard deviation for normal density\n"
      + "\tcomputation\n"
      + "\t(default 1e-6)",
      "M",1,"-M <num>"));

    result.addElement(
              new Option("\tDisplay model in old format (good when there are "
                         + "many clusters)\n",
                         "O", 0, "-O"));

    Enumeration en = super.listOptions();
    while (en.hasMoreElements())
      result.addElement(en.nextElement());
    
    return  result.elements();
  }


  /**
   * Parses a given list of options. <p/>
   * 
   <!-- options-start -->
   * Valid options are: <p/>
   * 
   * <pre> -N &lt;num&gt;
   *  number of clusters. If omitted or -1 specified, then 
   *  cross validation is used to select the number of clusters.</pre>
   * 
   * <pre> -I &lt;num&gt;
   *  max iterations.
   * (default 100)</pre>
   * 
   * <pre> -V
   *  verbose.</pre>
   * 
   * <pre> -M &lt;num&gt;
   *  minimum allowable standard deviation for normal density
   *  computation
   *  (default 1e-6)</pre>
   * 
   * <pre> -O
   *  Display model in old format (good when there are many clusters)
   * </pre>
   * 
   * <pre> -S &lt;num&gt;
   *  Random number seed.
   *  (default 100)</pre>
   * 
   <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
00256   public void setOptions (String[] options)
    throws Exception {
    resetOptions();
    setDebug(Utils.getFlag('V', options));
    String optionString = Utils.getOption('I', options);

    if (optionString.length() != 0) {
      setMaxIterations(Integer.parseInt(optionString));
    }

    optionString = Utils.getOption('N', options);
    if (optionString.length() != 0) {
      setNumClusters(Integer.parseInt(optionString));
    }

    optionString = Utils.getOption('M', options);
    if (optionString.length() != 0) {
      setMinStdDev((new Double(optionString)).doubleValue());
    }

    setDisplayModelInOldFormat(Utils.getFlag('O', options));
    
    super.setOptions(options);
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
00286   public String displayModelInOldFormatTipText() {
    return "Use old format for model output. The old format is "
      + "better when there are many clusters. The new format "
      + "is better when there are fewer clusters and many attributes.";
  }

  /**
   * Set whether to display model output in the old, original
   * format.
   *
   * @param d true if model ouput is to be shown in the old format
   */
00298   public void setDisplayModelInOldFormat(boolean d) {
    m_displayModelInOldFormat = d;
  }

  /**
   * Get whether to display model output in the old, original
   * format.
   *
   * @return true if model ouput is to be shown in the old format
   */
00308   public boolean getDisplayModelInOldFormat() {
    return m_displayModelInOldFormat;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
00317   public String minStdDevTipText() {
    return "set minimum allowable standard deviation";
  }

  /**
   * Set the minimum value for standard deviation when calculating
   * normal density. Reducing this value can help prevent arithmetic
   * overflow resulting from multiplying large densities (arising from small
   * standard deviations) when there are many singleton or near singleton
   * values.
   * @param m minimum value for standard deviation
   */
00329   public void setMinStdDev(double m) {
    m_minStdDev = m;
  }

  public void setMinStdDevPerAtt(double [] m) {
    m_minStdDevPerAtt = m;
  }

  /**
   * Get the minimum allowable standard deviation.
   * @return the minumum allowable standard deviation
   */
00341   public double getMinStdDev() {
    return m_minStdDev;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
00350   public String numClustersTipText() {
    return "set number of clusters. -1 to select number of clusters "
      +"automatically by cross validation.";
  }

  /**
   * Set the number of clusters (-1 to select by CV).
   *
   * @param n the number of clusters
   * @throws Exception if n is 0
   */
00361   public void setNumClusters (int n)
    throws Exception {
    
    if (n == 0) {
      throw  new Exception("Number of clusters must be > 0. (or -1 to " 
                     + "select by cross validation).");
    }

    if (n < 0) {
      m_num_clusters = -1;
      m_initialNumClusters = -1;
    }
    else {
      m_num_clusters = n;
      m_initialNumClusters = n;
    }
  }


  /**
   * Get the number of clusters
   *
   * @return the number of clusters.
   */
00385   public int getNumClusters () {
    return  m_initialNumClusters;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
00394   public String maxIterationsTipText() {
    return "maximum number of iterations";
  }

  /**
   * Set the maximum number of iterations to perform
   *
   * @param i the number of iterations
   * @throws Exception if i is less than 1
   */
00404   public void setMaxIterations (int i)
    throws Exception {
    if (i < 1) {
      throw  new Exception("Maximum number of iterations must be > 0!");
    }

    m_max_iterations = i;
  }


  /**
   * Get the maximum number of iterations
   *
   * @return the number of iterations
   */
00419   public int getMaxIterations () {
    return  m_max_iterations;
  }

  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
00429   public String debugTipText() {
    return "If set to true, clusterer may output additional info to " +
      "the console.";
  }


  /**
   * Set debug mode - verbose output
   *
   * @param v true for verbose output
   */
00440   public void setDebug (boolean v) {
    m_verbose = v;
  }


  /**
   * Get debug mode
   *
   * @return true if debug mode is set
   */
00450   public boolean getDebug () {
    return  m_verbose;
  }


  /**
   * Gets the current settings of EM.
   *
   * @return an array of strings suitable for passing to setOptions()
   */
00460   public String[] getOptions () {
    int           i;
    Vector        result;
    String[]      options;

    result = new Vector();

    result.add("-I");
    result.add("" + m_max_iterations);
    result.add("-N");
    result.add("" + getNumClusters());
    result.add("-M");
    result.add("" + getMinStdDev());
    if (m_displayModelInOldFormat) {
      result.add("-O");
    }

    options = super.getOptions();
    for (i = 0; i < options.length; i++)
      result.add(options[i]);

    return (String[]) result.toArray(new String[result.size()]);    
  }

  /**
   * Initialise estimators and storage.
   *
   * @param inst the instances
   * @throws Exception if initialization fails
   **/
00490   private void EM_Init (Instances inst)
    throws Exception {
    int i, j, k;

    // run k means 10 times and choose best solution
    SimpleKMeans bestK = null;
    double bestSqE = Double.MAX_VALUE;
    for (i = 0; i < 10; i++) {
      SimpleKMeans sk = new SimpleKMeans();
      sk.setSeed(m_rr.nextInt());
      sk.setNumClusters(m_num_clusters);
      sk.setDisplayStdDevs(true);
      sk.buildClusterer(inst);
      if (sk.getSquaredError() < bestSqE) {
      bestSqE = sk.getSquaredError();
      bestK = sk;
      }
    }
    
    // initialize with best k-means solution
    m_num_clusters = bestK.numberOfClusters();
    m_weights = new double[inst.numInstances()][m_num_clusters];
    m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];
    m_modelNormal = new double[m_num_clusters][m_num_attribs][3];
    m_priors = new double[m_num_clusters];
    Instances centers = bestK.getClusterCentroids();
    Instances stdD = bestK.getClusterStandardDevs();
    int [][][] nominalCounts = bestK.getClusterNominalCounts();
    int [] clusterSizes = bestK.getClusterSizes();

    for (i = 0; i < m_num_clusters; i++) {
      Instance center = centers.instance(i);
      for (j = 0; j < m_num_attribs; j++) {
      if (inst.attribute(j).isNominal()) {
        m_model[i][j] = new DiscreteEstimator(m_theInstances.
                                    attribute(j).numValues()
                                    , true);
        for (k = 0; k < inst.attribute(j).numValues(); k++) {
          m_model[i][j].addValue(k, nominalCounts[i][j][k]);
        }
      } else {
        double minStdD = (m_minStdDevPerAtt != null)
          ? m_minStdDevPerAtt[j]
          : m_minStdDev;
        double mean = (center.isMissing(j))
          ? inst.meanOrMode(j)
          : center.value(j);
        m_modelNormal[i][j][0] = mean;
        double stdv = (stdD.instance(i).isMissing(j))
          ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))
          : stdD.instance(i).value(j);
        if (stdv < minStdD) {
          stdv = inst.attributeStats(j).numericStats.stdDev;
            if (Double.isInfinite(stdv)) {
              stdv = minStdD;
            }
          if (stdv < minStdD) {
            stdv = minStdD;
          }
        }
        if (stdv <= 0) {
          stdv = m_minStdDev;
        }

        m_modelNormal[i][j][1] = stdv;
        m_modelNormal[i][j][2] = 1.0;
      }
      } 
    }    
    
    
    for (j = 0; j < m_num_clusters; j++) {
      //      m_priors[j] += 1.0;
      m_priors[j] = clusterSizes[j];
    }
    Utils.normalize(m_priors);
  }


  /**
   * calculate prior probabilites for the clusters
   *
   * @param inst the instances
   * @throws Exception if priors can't be calculated
   **/
00575   private void estimate_priors (Instances inst)
    throws Exception {

    for (int i = 0; i < m_num_clusters; i++) {
      m_priors[i] = 0.0;
    }

    for (int i = 0; i < inst.numInstances(); i++) {
      for (int j = 0; j < m_num_clusters; j++) {
        m_priors[j] += inst.instance(i).weight() * m_weights[i][j];
      }
    }

    Utils.normalize(m_priors);
  }


  /** Constant for normal distribution. */
00593   private static double m_normConst = Math.log(Math.sqrt(2*Math.PI));

  /**
   * Density function of normal distribution.
   * @param x input value
   * @param mean mean of distribution
   * @param stdDev standard deviation of distribution
   * @return the density
   */
00602   private double logNormalDens (double x, double mean, double stdDev) {

    double diff = x - mean;
    //    System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev);
    //    System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev)));
    
    return - (diff * diff / (2 * stdDev * stdDev))  - m_normConst - Math.log(stdDev);
  }

  /**
   * New probability estimators for an iteration
   */
00614   private void new_estimators () {
    for (int i = 0; i < m_num_clusters; i++) {
      for (int j = 0; j < m_num_attribs; j++) {
        if (m_theInstances.attribute(j).isNominal()) {
          m_model[i][j] = new DiscreteEstimator(m_theInstances.
                                    attribute(j).numValues()
                                    , true);
        }
        else {
          m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 
          m_modelNormal[i][j][2] = 0.0;
        }
      }
    }
  }


  /**
   * The M step of the EM algorithm.
   * @param inst the training instances
   * @throws Exception if something goes wrong
   */
00636   private void M (Instances inst)
    throws Exception {

    int i, j, l;

    new_estimators();

    for (i = 0; i < m_num_clusters; i++) {
      for (j = 0; j < m_num_attribs; j++) {
        for (l = 0; l < inst.numInstances(); l++) {
        Instance in = inst.instance(l);
          if (!in.isMissing(j)) {
            if (inst.attribute(j).isNominal()) {
              m_model[i][j].addValue(in.value(j), 
                             in.weight() * m_weights[l][i]);
            }
            else {
              m_modelNormal[i][j][0] += (in.value(j) * in.weight() *
                               m_weights[l][i]);
              m_modelNormal[i][j][2] += in.weight() * m_weights[l][i];
              m_modelNormal[i][j][1] += (in.value(j) * 
                               in.value(j) * in.weight() * m_weights[l][i]);
            }
          }
        }
      }
    }
    
    // calcualte mean and std deviation for numeric attributes
    for (j = 0; j < m_num_attribs; j++) {
      if (!inst.attribute(j).isNominal()) {
        for (i = 0; i < m_num_clusters; i++) {
          if (m_modelNormal[i][j][2] <= 0) {
            m_modelNormal[i][j][1] = Double.MAX_VALUE;
          //          m_modelNormal[i][j][0] = 0;
          m_modelNormal[i][j][0] = m_minStdDev;
          } else {
            
          // variance
          m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 
                              (m_modelNormal[i][j][0] * 
                               m_modelNormal[i][j][0] / 
                               m_modelNormal[i][j][2])) / 
            (m_modelNormal[i][j][2]);
          
          if (m_modelNormal[i][j][1] < 0) {
            m_modelNormal[i][j][1] = 0;
          }
          
          // std dev      
          double minStdD = (m_minStdDevPerAtt != null)
          ? m_minStdDevPerAtt[j]
          : m_minStdDev;

          m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]);              

          if ((m_modelNormal[i][j][1] <= minStdD)) {
            m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev;
            if ((m_modelNormal[i][j][1] <= minStdD)) {
            m_modelNormal[i][j][1] = minStdD;
            }
          }
          if ((m_modelNormal[i][j][1] <= 0)) {
            m_modelNormal[i][j][1] = m_minStdDev;
          }
            if (Double.isInfinite(m_modelNormal[i][j][1])) {
              m_modelNormal[i][j][1] = m_minStdDev;
            }
          
          // mean
          m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
        }
        }
      }
    }
  }

  /**
   * The E step of the EM algorithm. Estimate cluster membership 
   * probabilities.
   *
   * @param inst the training instances
   * @param change_weights whether to change the weights
   * @return the average log likelihood
   * @throws Exception if computation fails
   */
00722   private double E (Instances inst, boolean change_weights)
    throws Exception {

    double loglk = 0.0, sOW = 0.0;

    for (int l = 0; l < inst.numInstances(); l++) {

      Instance in = inst.instance(l);

      loglk += in.weight() * logDensityForInstance(in);
      sOW += in.weight();

      if (change_weights) {
      m_weights[l] = distributionForInstance(in);
      }
    }
    
    // reestimate priors
    if (change_weights) {
      estimate_priors(inst);
    }
    return  loglk / sOW;
  }
  
  
  /**
   * Constructor.
   *
   **/
00751   public EM () {
    super();
    
    m_SeedDefault = 100;
    resetOptions();
  }


  /**
   * Reset to default options
   */
00762   protected void resetOptions () {
    m_minStdDev = 1e-6;
    m_max_iterations = 100;
    m_Seed = m_SeedDefault;
    m_num_clusters = -1;
    m_initialNumClusters = -1;
    m_verbose = false;
  }

  /**
   * Return the normal distributions for the cluster models
   *
   * @return a <code>double[][][]</code> value
   */
00776   public double [][][] getClusterModelsNumericAtts() {
    return m_modelNormal;
  }

  /**
   * Return the priors for the clusters
   *
   * @return a <code>double[]</code> value
   */
00785   public double [] getClusterPriors() {
    return m_priors;
  }

  /**
   * Outputs the generated clusters into a string.
   * 
   * @return the clusterer in string representation
   */
00794   public String toString() {
    if (m_displayModelInOldFormat) {
      return toStringOriginal();
    }

    if (m_priors == null) {
      return "No clusterer built yet!";
    }
    StringBuffer temp = new StringBuffer();
    temp.append("\nEM\n==\n");
    if (m_initialNumClusters == -1) {
      temp.append("\nNumber of clusters selected by cross validation: "
              +m_num_clusters+"\n");
    } else {
      temp.append("\nNumber of clusters: " + m_num_clusters + "\n");
    }

    int maxWidth = 0;
    int maxAttWidth = 0;
    boolean containsKernel = false;
    
    // set up max widths
    // attributes
    for (int i = 0; i < m_num_attribs; i++) {
      Attribute a = m_theInstances.attribute(i);
      if (a.name().length() > maxAttWidth) {
        maxAttWidth = m_theInstances.attribute(i).name().length();
      }
      if (a.isNominal()) {
        // check values
        for (int j = 0; j < a.numValues(); j++) {
          String val = a.value(j) + "  ";
          if (val.length() > maxAttWidth) {
            maxAttWidth = val.length();
          }
        }
      }
    }

    for (int i = 0; i < m_num_clusters; i++) {
      for (int j = 0; j < m_num_attribs; j++) {
        if (m_theInstances.attribute(j).isNumeric()) {
          // check mean and std. dev. against maxWidth
          double mean = Math.log(Math.abs(m_modelNormal[i][j][0])) / Math.log(10.0);
          double stdD = Math.log(Math.abs(m_modelNormal[i][j][1])) / Math.log(10.0);
          double width = (mean > stdD)
            ? mean
            : stdD;
          if (width < 0) {
            width = 1;
          }
          // decimal + # decimal places + 1
          width += 6.0;
          if ((int)width > maxWidth) {
            maxWidth = (int)width;
          }
        } else {
          // nominal distributions
          DiscreteEstimator d = (DiscreteEstimator)m_model[i][j];
          for (int k = 0; k < d.getNumSymbols(); k++) {
            String size = Utils.doubleToString(d.getCount(k), maxWidth, 4).trim();
            if (size.length() > maxWidth) {
              maxWidth = size.length();
            }
          }
          int sum = 
            Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim().length();
          if (sum > maxWidth) {
            maxWidth = sum;
          }
        }
      }
    }

    if (maxAttWidth < "Attribute".length()) {
      maxAttWidth = "Attribute".length();
    }    
    
    maxAttWidth += 2;

    temp.append("\n\n");
    temp.append(pad("Cluster", " ", 
                    (maxAttWidth + maxWidth + 1) - "Cluster".length(), 
                    true));
    
    temp.append("\n");
    temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));

    // cluster #'s
    for (int i = 0; i < m_num_clusters; i++) {
      String classL = "" + i;
      temp.append(pad(classL, " ", maxWidth + 1 - classL.length(), true));
    }
    temp.append("\n");

    // cluster priors
    temp.append(pad("", " ", maxAttWidth, true));
    for (int i = 0; i < m_num_clusters; i++) {
      String priorP = Utils.doubleToString(m_priors[i], maxWidth, 2).trim();
      priorP = "(" + priorP + ")";
      temp.append(pad(priorP, " ", maxWidth + 1 - priorP.length(), true));
    }

    temp.append("\n");
    temp.append(pad("", "=", maxAttWidth + 
                    (maxWidth * m_num_clusters)
                    + m_num_clusters + 1, true));
    temp.append("\n");

    for (int i = 0; i < m_num_attribs; i++) {
      String attName = m_theInstances.attribute(i).name();
      temp.append(attName + "\n");

      if (m_theInstances.attribute(i).isNumeric()) {
        String meanL = "  mean";
        temp.append(pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false));
        for (int j = 0; j < m_num_clusters; j++) {
          // means
          String mean = 
            Utils.doubleToString(m_modelNormal[j][i][0], maxWidth, 4).trim();
          temp.append(pad(mean, " ", maxWidth + 1 - mean.length(), true));
        }
        temp.append("\n");            
        // now do std deviations
        String stdDevL = "  std. dev.";
        temp.append(pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false));
        for (int j = 0; j < m_num_clusters; j++) {
          String stdDev = 
            Utils.doubleToString(m_modelNormal[j][i][1], maxWidth, 4).trim();
          temp.append(pad(stdDev, " ", maxWidth + 1 - stdDev.length(), true));
        }
        temp.append("\n\n");
      } else {
        Attribute a = m_theInstances.attribute(i);
        for (int j = 0; j < a.numValues(); j++) {
          String val = "  " + a.value(j);
          temp.append(pad(val, " ", maxAttWidth + 1 - val.length(), false));
          for (int k = 0; k < m_num_clusters; k++) {
            DiscreteEstimator d = (DiscreteEstimator)m_model[k][i];
            String count = Utils.doubleToString(d.getCount(j), maxWidth, 4).trim();
            temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
          }
          temp.append("\n");
        }
        // do the totals
        String total = "  [total]";
        temp.append(pad(total, " ", maxAttWidth + 1 - total.length(), false));
        for (int k = 0; k < m_num_clusters; k++) {
          DiscreteEstimator d = (DiscreteEstimator)m_model[k][i];
          String count = 
            Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim();
            temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
        }
        temp.append("\n");        
      }
    }

    return temp.toString();
  }

  private String pad(String source, String padChar, 
                     int length, boolean leftPad) {
    StringBuffer temp = new StringBuffer();

    if (leftPad) {
      for (int i = 0; i< length; i++) {
        temp.append(padChar);
      }
      temp.append(source);
    } else {
      temp.append(source);
      for (int i = 0; i< length; i++) {
        temp.append(padChar);
      }
    }
    return temp.toString();
  }

  /**
   * Outputs the generated clusters into a string.
   * 
   * @return the clusterer in string representation
   */
00977   protected String toStringOriginal () {
    if (m_priors == null) {
      return "No clusterer built yet!";
    }
    StringBuffer temp = new StringBuffer();
    temp.append("\nEM\n==\n");
    if (m_initialNumClusters == -1) {
      temp.append("\nNumber of clusters selected by cross validation: "
              +m_num_clusters+"\n");
    } else {
      temp.append("\nNumber of clusters: " + m_num_clusters + "\n");
    }

    for (int j = 0; j < m_num_clusters; j++) {
      temp.append("\nCluster: " + j + " Prior probability: " 
              + Utils.doubleToString(m_priors[j], 4) + "\n\n");

      for (int i = 0; i < m_num_attribs; i++) {
        temp.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");

        if (m_theInstances.attribute(i).isNominal()) {
          if (m_model[j][i] != null) {
            temp.append(m_model[j][i].toString());
          }
        }
        else {
          temp.append("Normal Distribution. Mean = " 
                  + Utils.doubleToString(m_modelNormal[j][i][0], 4) 
                  + " StdDev = " 
                  + Utils.doubleToString(m_modelNormal[j][i][1], 4) 
                  + "\n");
        }
      }
    }

    return  temp.toString();
  }


  /**
   * verbose output for debugging
   * @param inst the training instances
   */
01020   private void EM_Report (Instances inst) {
    int i, j, l, m;
    System.out.println("======================================");

    for (j = 0; j < m_num_clusters; j++) {
      for (i = 0; i < m_num_attribs; i++) {
      System.out.println("Clust: " + j + " att: " + i + "\n");

      if (m_theInstances.attribute(i).isNominal()) {
        if (m_model[j][i] != null) {
          System.out.println(m_model[j][i].toString());
        }
      }
      else {
        System.out.println("Normal Distribution. Mean = " 
                       + Utils.doubleToString(m_modelNormal[j][i][0]
                                        , 8, 4) 
                       + " StandardDev = " 
                       + Utils.doubleToString(m_modelNormal[j][i][1]
                                        , 8, 4) 
                       + " WeightSum = " 
                       + Utils.doubleToString(m_modelNormal[j][i][2]
                                        , 8, 4));
      }
      }
    }
    
    for (l = 0; l < inst.numInstances(); l++) {
      m = Utils.maxIndex(m_weights[l]);
      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 
                   + " Class " + m + "\t");
      for (j = 0; j < m_num_clusters; j++) {
      System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
      }
      System.out.println();
    }
  }


  /**
   * estimate the number of clusters by cross validation on the training
   * data.
   *
   * @throws Exception if something goes wrong
   */
01065   private void CVClusters ()
    throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;
    int num_clusters = m_num_clusters;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) 
      ? m_theInstances.numInstances() 
      : 10;

    boolean ok = true;
    int seed = getSeed();
    int restartCount = 0;
    CLUSTER_SEARCH: while (CVincreased) {
      // theInstances.stratify(10);
        
      CVincreased = false;
      cvr = new Random(getSeed());
      trainCopy = new Instances(m_theInstances);
      trainCopy.randomize(cvr);
      templl = 0.0;
      for (i = 0; i < numFolds; i++) {
      Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
      if (num_clusters > cvTrain.numInstances()) {
        break CLUSTER_SEARCH;
      }
      Instances cvTest = trainCopy.testCV(numFolds, i);
      m_rr = new Random(seed);
        for (int z=0; z<10; z++) m_rr.nextDouble();
      m_num_clusters = num_clusters;
      EM_Init(cvTrain);
      try {
        iterate(cvTrain, false);
      } catch (Exception ex) {
        // catch any problems - i.e. empty clusters occuring
        ex.printStackTrace();
          //          System.err.println("Restarting after CV training failure ("+num_clusters+" clusters");
          seed++;
          restartCount++;
          ok = false;
          if (restartCount > 5) {
            break CLUSTER_SEARCH;
          }
        break;
      }
        try {
          tll = E(cvTest, false);
        } catch (Exception ex) {
          // catch any problems - i.e. empty clusters occuring
          //          ex.printStackTrace();
          ex.printStackTrace();
          //          System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters");
          //          throw new Exception(ex); 
          seed++;
          restartCount++;
          ok = false;
          if (restartCount > 5) {
            break CLUSTER_SEARCH;
          }
          break;
        }

      if (m_verbose) {
        System.out.println("# clust: " + num_clusters + " Fold: " + i 
                       + " Loglikely: " + tll);
      }
      templl += tll;
      }

      if (ok) {
        restartCount = 0;
        seed = getSeed();
        templl /= (double)numFolds;
        
        if (m_verbose) {
          System.out.println("===================================" 
                             + "==============\n# clust: " 
                             + num_clusters 
                             + " Mean Loglikely: " 
                             + templl 
                             + "\n================================" 
                             + "=================");
        }
        
        if (templl > CVLogLikely) {
          CVLogLikely = templl;
          CVincreased = true;
          num_clusters++;
        }
      }
    }

    if (m_verbose) {
      System.out.println("Number of clusters: " + (num_clusters - 1));
    }

    m_num_clusters = num_clusters - 1;
  }


  /**
   * Returns the number of clusters.
   *
   * @return the number of clusters generated for a training dataset.
   * @throws Exception if number of clusters could not be returned
   * successfully
   */
01176   public int numberOfClusters ()
    throws Exception {
    if (m_num_clusters == -1) {
      throw  new Exception("Haven't generated any clusters!");
    }

    return  m_num_clusters;
  }

 /**
  * Updates the minimum and maximum values for all the attributes
  * based on a new instance.
  *
  * @param instance the new instance
  */
01191   private void updateMinMax(Instance instance) {
    
    for (int j = 0; j < m_theInstances.numAttributes(); j++) {
      if (!instance.isMissing(j)) {
      if (Double.isNaN(m_minValues[j])) {
        m_minValues[j] = instance.value(j);
        m_maxValues[j] = instance.value(j);
      } else {
        if (instance.value(j) < m_minValues[j]) {
          m_minValues[j] = instance.value(j);
        } else {
          if (instance.value(j) > m_maxValues[j]) {
            m_maxValues[j] = instance.value(j);
          }
        }
      }
      }
    }
  }

  /**
   * Returns default capabilities of the clusterer (i.e., the ones of 
   * SimpleKMeans).
   *
   * @return      the capabilities of this clusterer
   */
01217   public Capabilities getCapabilities() {
    Capabilities result = new SimpleKMeans().getCapabilities();
    result.setOwner(this);
    return result;
  }
  
  /**
   * Generates a clusterer. Has to initialize all fields of the clusterer
   * that are not being set via options.
   *
   * @param data set of instances serving as training data 
   * @throws Exception if the clusterer has not been 
   * generated successfully
   */
01231   public void buildClusterer (Instances data)
    throws Exception {
    
    // can clusterer handle the data?
    getCapabilities().testWithFail(data);

    m_replaceMissing = new ReplaceMissingValues();
    Instances instances = new Instances(data);
    instances.setClassIndex(-1);
    m_replaceMissing.setInputFormat(instances);
    data = weka.filters.Filter.useFilter(instances, m_replaceMissing);
    instances = null;
    
    m_theInstances = data;

    // calculate min and max values for attributes
    m_minValues = new double [m_theInstances.numAttributes()];
    m_maxValues = new double [m_theInstances.numAttributes()];
    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
      m_minValues[i] = m_maxValues[i] = Double.NaN;
    }
    for (int i = 0; i < m_theInstances.numInstances(); i++) {
      updateMinMax(m_theInstances.instance(i));
    }

    doEM();
    
    // save memory
    m_theInstances = new Instances(m_theInstances,0);
  }

  /**
   * Returns the cluster priors.
   * 
   * @return the cluster priors
   */
01267   public double[] clusterPriors() {

    double[] n = new double[m_priors.length];
  
    System.arraycopy(m_priors, 0, n, 0, n.length);
    return n;
  }

  /**
   * Computes the log of the conditional density (per cluster) for a given instance.
   * 
   * @param inst the instance to compute the density for
   * @return an array containing the estimated densities
   * @throws Exception if the density could not be computed
   * successfully
   */
01283   public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {

    int i, j;
    double logprob;
    double[] wghts = new double[m_num_clusters];
    
    m_replaceMissing.input(inst);
    inst = m_replaceMissing.output();

    for (i = 0; i < m_num_clusters; i++) {
      //      System.err.println("Cluster : "+i);
      logprob = 0.0;

      for (j = 0; j < m_num_attribs; j++) {
      if (!inst.isMissing(j)) {
        if (inst.attribute(j).isNominal()) {
          logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));
        }
        else { // numeric attribute
          logprob += logNormalDens(inst.value(j), 
                             m_modelNormal[i][j][0], 
                             m_modelNormal[i][j][1]);
          /*          System.err.println(logNormalDens(inst.value(j), 
                             m_modelNormal[i][j][0], 
                             m_modelNormal[i][j][1]) + " "); */
        }
      }
      }
      //      System.err.println("");

      wghts[i] = logprob;
    }
    return  wghts;
  }


  /**
   * Perform the EM algorithm
   * 
   * @throws Exception if something goes wrong
   */
01324   private void doEM ()
    throws Exception {
    
    if (m_verbose) {
      System.out.println("Seed: " + getSeed());
    }

    m_rr = new Random(getSeed());

    // throw away numbers to avoid problem of similar initial numbers
    // from a similar seed
    for (int i=0; i<10; i++) m_rr.nextDouble();

    m_num_instances = m_theInstances.numInstances();
    m_num_attribs = m_theInstances.numAttributes();

    if (m_verbose) {
      System.out.println("Number of instances: " 
                   + m_num_instances 
                   + "\nNumber of atts: " 
                   + m_num_attribs 
                   + "\n");
    }

    // setDefaultStdDevs(theInstances);
    // cross validate to determine number of clusters?
    if (m_initialNumClusters == -1) {
      if (m_theInstances.numInstances() > 9) {
      CVClusters();
      m_rr = new Random(getSeed());
      for (int i=0; i<10; i++) m_rr.nextDouble();
      } else {
      m_num_clusters = 1;
      }
    }

    // fit full training set
    EM_Init(m_theInstances);
    m_loglikely = iterate(m_theInstances, m_verbose);
  }


  /**
   * iterates the E and M steps until the log likelihood of the data
   * converges.
   *
   * @param inst the training instances.
   * @param report be verbose.
   * @return the log likelihood of the data
   * @throws Exception if something goes wrong
   */
01375   private double iterate (Instances inst, boolean report)
    throws Exception {
    
    int i;
    double llkold = 0.0;
    double llk = 0.0;

    if (report) {
      EM_Report(inst);
    }

    boolean ok = false;
    int seed = getSeed();
    int restartCount = 0;
    while (!ok) {
      try {
        for (i = 0; i < m_max_iterations; i++) {
          llkold = llk;
          llk = E(inst, true);
          
          if (report) {
            System.out.println("Loglikely: " + llk);
          }
          
          if (i > 0) {
            if ((llk - llkold) < 1e-6) {
              break;
            }
          }
          M(inst);
        }
        ok = true;
      } catch (Exception ex) {
        //        System.err.println("Restarting after training failure");
        ex.printStackTrace();
        seed++;
        restartCount++;
        m_rr = new Random(seed);
        for (int z = 0; z < 10; z++) {
          m_rr.nextDouble(); m_rr.nextInt();
        }
        if (restartCount > 5) {
          //          System.err.println("Reducing the number of clusters");
          m_num_clusters--;
          restartCount = 0;
        }
        EM_Init(m_theInstances);
      }
    }
      
    if (report) {
      EM_Report(inst);
    }

    return  llk;
  }
  
  /**
   * Returns the revision string.
   * 
   * @return            the revision
   */
01437   public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.44 $");
  }

  // ============
  // Test method.
  // ============
  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments: <p>
   * -t training file [-T test file] [-N number of clusters] [-S random seed]
   */
01450   public static void main (String[] argv) {
    runClusterer(new EM(), argv);
  }
}


Generated by  Doxygen 1.6.0   Back to index