/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.IteratedSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdditiveRegression
extends IteratedSingleClassifierEnhancer
implements OptionHandler,
AdditionalMeasureProducer,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -2368937577670527151L;
    protected double m_shrinkage = 1.0;
    protected int m_NumIterationsPerformed;
    protected ZeroR m_zeroR;
    protected boolean m_SuitableData = true;

    public String globalInfo() {
        return " Meta classifier that enhances the performance of a regression base classifier. Each iteration fits a model to the residuals left by the classifier on the previous iteration. Prediction is accomplished by adding the predictions of each classifier. Reducing the shrinkage (learning rate) parameter helps prevent overfitting and has a smoothing effect but increases the learning time.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "J.H. Friedman");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1999");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Stochastic Gradient Boosting");
        technicalInformation.setValue(TechnicalInformation.Field.INSTITUTION, "Stanford University");
        technicalInformation.setValue(TechnicalInformation.Field.PS, "http://www-stat.stanford.edu/~jhf/ftp/stobst.ps");
        return technicalInformation;
    }

    public AdditiveRegression() {
        this(new DecisionStump());
    }

    public AdditiveRegression(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(4);
        vector.addElement(new Option("\tSpecify shrinkage rate. (default = 1.0, ie. no shrinkage)\n", "S", 1, "-S"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('S', stringArray);
        if (string.length() != 0) {
            Double d = Double.valueOf(string);
            this.setShrinkage(d);
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[stringArray.length + 2];
        int n = 0;
        stringArray2[n++] = "-S";
        stringArray2[n++] = "" + this.getShrinkage();
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        n += stringArray.length;
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }

    public String shrinkageTipText() {
        return "Shrinkage rate. Smaller values help prevent overfitting and have a smoothing effect (but increase learning time). Default = 1.0, ie. no shrinkage.";
    }

    public void setShrinkage(double d) {
        this.m_shrinkage = d;
    }

    public double getShrinkage() {
        return this.m_shrinkage;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        super.buildClassifier(instances);
        this.getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        double d = 0.0;
        double d2 = 0.0;
        this.m_zeroR = new ZeroR();
        this.m_zeroR.buildClassifier(instances2);
        if (instances2.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_SuitableData = false;
            return;
        }
        this.m_SuitableData = true;
        instances2 = this.residualReplace(instances2, this.m_zeroR, false);
        for (n = 0; n < instances2.numInstances(); ++n) {
            d += instances2.instance(n).weight() * instances2.instance(n).classValue() * instances2.instance(n).classValue();
        }
        if (this.m_Debug) {
            System.err.println("Sum of squared residuals (predicting the mean) : " + d);
        }
        this.m_NumIterationsPerformed = 0;
        do {
            d2 = d;
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(instances2);
            instances2 = this.residualReplace(instances2, this.m_Classifiers[this.m_NumIterationsPerformed], true);
            d = 0.0;
            for (n = 0; n < instances2.numInstances(); ++n) {
                d += instances2.instance(n).weight() * instances2.instance(n).classValue() * instances2.instance(n).classValue();
            }
            if (this.m_Debug) {
                System.err.println("Sum of squared residuals : " + d);
            }
            ++this.m_NumIterationsPerformed;
        } while (d2 - d > Utils.SMALL && this.m_NumIterationsPerformed < this.m_Classifiers.length);
    }

    public double classifyInstance(Instance instance) throws Exception {
        double d = this.m_zeroR.classifyInstance(instance);
        if (!this.m_SuitableData) {
            return d;
        }
        for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
            double d2 = this.m_Classifiers[i].classifyInstance(instance);
            d += (d2 *= this.getShrinkage());
        }
        return d;
    }

    private Instances residualReplace(Instances instances, Classifier classifier, boolean bl) throws Exception {
        Instances instances2 = new Instances(instances);
        for (int i = 0; i < instances2.numInstances(); ++i) {
            double d = classifier.classifyInstance(instances2.instance(i));
            if (bl) {
                d *= this.getShrinkage();
            }
            double d2 = instances2.instance(i).classValue() - d;
            instances2.instance(i).setClassValue(d2);
        }
        return instances2;
    }

    public Enumeration enumerateMeasures() {
        Vector<String> vector = new Vector<String>(1);
        vector.addElement("measureNumIterations");
        return vector.elements();
    }

    public double getMeasure(String string) {
        if (string.compareToIgnoreCase("measureNumIterations") == 0) {
            return this.measureNumIterations();
        }
        throw new IllegalArgumentException(string + " not supported (AdditiveRegression)");
    }

    public double measureNumIterations() {
        return this.m_NumIterationsPerformed;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (!this.m_SuitableData) {
            StringBuffer stringBuffer2 = new StringBuffer();
            stringBuffer2.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
            stringBuffer2.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            stringBuffer2.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            stringBuffer2.append(this.m_zeroR.toString());
            return stringBuffer2.toString();
        }
        if (this.m_NumIterations == 0) {
            return "Classifier hasn't been built yet!";
        }
        stringBuffer.append("Additive Regression\n\n");
        stringBuffer.append("ZeroR model\n\n" + this.m_zeroR + "\n\n");
        stringBuffer.append("Base classifier " + this.getClassifier().getClass().getName() + "\n\n");
        stringBuffer.append("" + this.m_NumIterationsPerformed + " models generated.\n");
        for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
            stringBuffer.append("\nModel number " + i + "\n\n" + this.m_Classifiers[i] + "\n");
        }
        return stringBuffer.toString();
    }

    public static void main(String[] stringArray) {
        AdditiveRegression.runClassifier(new AdditiveRegression(), stringArray);
    }
}

