/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.functions.Logistic;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Range;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MakeIndicator;
import weka.filters.unsupervised.instance.RemoveWithValues;

public class MultiClassClassifier
extends RandomizableSingleClassifierEnhancer
implements OptionHandler {
    static final long serialVersionUID = -3879602011542849141L;
    protected Classifier[] m_Classifiers;
    protected boolean m_pairwiseCoupling = false;
    protected double[] m_SumOfWeights;
    protected Filter[] m_ClassFilters;
    private ZeroR m_ZeroR;
    protected Attribute m_ClassAttribute;
    protected Instances m_TwoClassDataset;
    private double m_RandomWidthFactor = 2.0;
    protected boolean m_logLossDecoding = false;
    protected int m_Method = 0;
    public static final int METHOD_1_AGAINST_ALL = 0;
    public static final int METHOD_ERROR_RANDOM = 1;
    public static final int METHOD_ERROR_EXHAUSTIVE = 2;
    public static final int METHOD_1_AGAINST_1 = 3;
    public static final Tag[] TAGS_METHOD = new Tag[]{new Tag(0, "1-against-all"), new Tag(1, "Random correction code"), new Tag(2, "Exhaustive correction code"), new Tag(3, "1-against-1")};

    public MultiClassClassifier() {
        this.m_Classifier = new Logistic();
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.functions.Logistic";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        boolean zeroTrainingInstances = insts.numInstances() == 0;
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        if (this.m_Classifier == null) {
            throw new Exception("No base classifier has been set!");
        }
        this.m_ZeroR = new ZeroR();
        this.m_ZeroR.buildClassifier(insts);
        this.m_TwoClassDataset = null;
        int numClassifiers = insts.numClasses();
        if (numClassifiers <= 2) {
            this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, 1);
            this.m_Classifiers[0].buildClassifier(insts);
            this.m_ClassFilters = null;
        } else if (this.m_Method == 3) {
            int i;
            ArrayList<int[]> pairs = new ArrayList<int[]>();
            for (i = 0; i < insts.numClasses(); ++i) {
                for (int j = 0; j < insts.numClasses(); ++j) {
                    if (j <= i) continue;
                    int[] pair = new int[]{i, j};
                    pairs.add(pair);
                }
            }
            numClassifiers = pairs.size();
            this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, numClassifiers);
            this.m_ClassFilters = new Filter[numClassifiers];
            this.m_SumOfWeights = new double[numClassifiers];
            for (i = 0; i < numClassifiers; ++i) {
                RemoveWithValues classFilter = new RemoveWithValues();
                classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
                classFilter.setModifyHeader(true);
                classFilter.setInvertSelection(true);
                classFilter.setNominalIndicesArr((int[])pairs.get(i));
                Instances tempInstances = new Instances(insts, 0);
                tempInstances.setClassIndex(-1);
                classFilter.setInputFormat(tempInstances);
                Instances newInsts = Filter.useFilter(insts, classFilter);
                if (newInsts.numInstances() > 0 || zeroTrainingInstances) {
                    newInsts.setClassIndex(insts.classIndex());
                    this.m_Classifiers[i].buildClassifier(newInsts);
                    this.m_ClassFilters[i] = classFilter;
                    this.m_SumOfWeights[i] = newInsts.sumOfWeights();
                    continue;
                }
                this.m_Classifiers[i] = null;
                this.m_ClassFilters[i] = null;
            }
            this.m_TwoClassDataset = new Instances(insts, 0);
            int classIndex = this.m_TwoClassDataset.classIndex();
            this.m_TwoClassDataset.setClassIndex(-1);
            ArrayList<String> classLabels = new ArrayList<String>();
            classLabels.add("class0");
            classLabels.add("class1");
            this.m_TwoClassDataset.replaceAttributeAt(new Attribute("class", classLabels), classIndex);
            this.m_TwoClassDataset.setClassIndex(classIndex);
        } else {
            Code code = null;
            switch (this.m_Method) {
                case 2: {
                    code = new ExhaustiveCode(numClassifiers);
                    break;
                }
                case 1: {
                    code = new RandomCode(numClassifiers, (int)((double)numClassifiers * this.m_RandomWidthFactor), insts);
                    break;
                }
                case 0: {
                    code = new StandardCode(numClassifiers);
                    break;
                }
                default: {
                    throw new Exception("Unrecognized correction code type");
                }
            }
            numClassifiers = code.size();
            this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, numClassifiers);
            this.m_ClassFilters = new MakeIndicator[numClassifiers];
            for (int i = 0; i < this.m_Classifiers.length; ++i) {
                this.m_ClassFilters[i] = new MakeIndicator();
                MakeIndicator classFilter = (MakeIndicator)this.m_ClassFilters[i];
                classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
                classFilter.setValueIndices(code.getIndices(i));
                classFilter.setNumeric(false);
                classFilter.setInputFormat(insts);
                Instances newInsts = Filter.useFilter(insts, this.m_ClassFilters[i]);
                this.m_Classifiers[i].buildClassifier(newInsts);
            }
        }
        this.m_ClassAttribute = insts.classAttribute();
    }

    public double[] individualPredictions(Instance inst) throws Exception {
        double[] result = null;
        if (this.m_Classifiers.length == 1) {
            result = new double[]{this.m_Classifiers[0].distributionForInstance(inst)[1]};
        } else {
            result = new double[this.m_ClassFilters.length];
            for (int i = 0; i < this.m_ClassFilters.length; ++i) {
                if (this.m_Classifiers[i] == null) continue;
                if (this.m_Method == 3) {
                    Instance tempInst = (Instance)inst.copy();
                    tempInst.setDataset(this.m_TwoClassDataset);
                    result[i] = this.m_Classifiers[i].distributionForInstance(tempInst)[1];
                    continue;
                }
                this.m_ClassFilters[i].input(inst);
                this.m_ClassFilters[i].batchFinished();
                result[i] = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output())[1];
            }
        }
        return result;
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        if (this.m_Classifiers.length == 1) {
            return this.m_Classifiers[0].distributionForInstance(inst);
        }
        double[] probs = new double[inst.numClasses()];
        if (this.m_Method == 3) {
            double[][] r = new double[inst.numClasses()][inst.numClasses()];
            double[][] n = new double[inst.numClasses()][inst.numClasses()];
            for (int i = 0; i < this.m_ClassFilters.length; ++i) {
                if (this.m_Classifiers[i] == null) continue;
                Instance tempInst = (Instance)inst.copy();
                tempInst.setDataset(this.m_TwoClassDataset);
                double[] current = this.m_Classifiers[i].distributionForInstance(tempInst);
                Range range = new Range(((RemoveWithValues)this.m_ClassFilters[i]).getNominalIndices());
                range.setUpper(this.m_ClassAttribute.numValues());
                int[] pair = range.getSelection();
                if (this.m_pairwiseCoupling && inst.numClasses() > 2) {
                    r[pair[0]][pair[1]] = current[0];
                    n[pair[0]][pair[1]] = this.m_SumOfWeights[i];
                    continue;
                }
                if (current[0] > current[1]) {
                    int n2 = pair[0];
                    probs[n2] = probs[n2] + 1.0;
                    continue;
                }
                int n3 = pair[1];
                probs[n3] = probs[n3] + 1.0;
            }
            if (this.m_pairwiseCoupling && inst.numClasses() > 2) {
                return MultiClassClassifier.pairwiseCoupling(n, r);
            }
        } else if (this.m_Method == 0) {
            for (int i = 0; i < this.m_ClassFilters.length; ++i) {
                this.m_ClassFilters[i].input(inst);
                this.m_ClassFilters[i].batchFinished();
                probs[i] = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output())[1];
            }
        } else if (this.getLogLossDecoding()) {
            Arrays.fill(probs, 1.0);
            for (int i = 0; i < this.m_ClassFilters.length; ++i) {
                this.m_ClassFilters[i].input(inst);
                this.m_ClassFilters[i].batchFinished();
                double[] current = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output());
                for (int j = 0; j < this.m_ClassAttribute.numValues(); ++j) {
                    if (((MakeIndicator)this.m_ClassFilters[i]).getValueRange().isInRange(j)) {
                        int n = j;
                        probs[n] = probs[n] + Math.log(Utils.SMALL + (1.0 - 2.0 * Utils.SMALL) * current[1]);
                        continue;
                    }
                    int n = j;
                    probs[n] = probs[n] + Math.log(Utils.SMALL + (1.0 - 2.0 * Utils.SMALL) * current[0]);
                }
            }
            probs = Utils.logs2probs(probs);
        } else {
            for (int i = 0; i < this.m_ClassFilters.length; ++i) {
                this.m_ClassFilters[i].input(inst);
                this.m_ClassFilters[i].batchFinished();
                double[] current = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output());
                for (int j = 0; j < this.m_ClassAttribute.numValues(); ++j) {
                    if (((MakeIndicator)this.m_ClassFilters[i]).getValueRange().isInRange(j)) {
                        int n = j;
                        probs[n] = probs[n] + current[1];
                        continue;
                    }
                    int n = j;
                    probs[n] = probs[n] + current[0];
                }
            }
        }
        if (Utils.gr(Utils.sum(probs), 0.0)) {
            Utils.normalize(probs);
            return probs;
        }
        return this.m_ZeroR.distributionForInstance(inst);
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "MultiClassClassifier: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("MultiClassClassifier\n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append("Classifier ").append(i + 1);
            if (this.m_Classifiers[i] != null) {
                if (this.m_ClassFilters != null && this.m_ClassFilters[i] != null) {
                    if (this.m_ClassFilters[i] instanceof RemoveWithValues) {
                        Range range = new Range(((RemoveWithValues)this.m_ClassFilters[i]).getNominalIndices());
                        range.setUpper(this.m_ClassAttribute.numValues());
                        int[] pair = range.getSelection();
                        text.append(", " + (pair[0] + 1) + " vs " + (pair[1] + 1));
                    } else if (this.m_ClassFilters[i] instanceof MakeIndicator) {
                        text.append(", using indicator values: ");
                        text.append(((MakeIndicator)this.m_ClassFilters[i]).getValueRange());
                    }
                }
                text.append('\n');
                text.append(this.m_Classifiers[i].toString() + "\n\n");
                continue;
            }
            text.append(" Skipped (no training examples)\n");
        }
        return text.toString();
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> vec = new Vector<Option>(3);
        vec.addElement(new Option("\tSets the method to use. Valid values are 0 (1-against-all),\n\t1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)\n", "M", 1, "-M <num>"));
        vec.addElement(new Option("\tSets the multiplier when using random codes. (default 2.0)", "R", 1, "-R <num>"));
        vec.addElement(new Option("\tUse pairwise coupling (only has an effect for 1-against1)", "P", 0, "-P"));
        vec.addElement(new Option("\tUse log loss decoding for random and exhaustive codes", "L", 0, "-L"));
        vec.addAll(Collections.list(super.listOptions()));
        return vec.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String errorString = Utils.getOption('M', options);
        if (errorString.length() != 0) {
            this.setMethod(new SelectedTag(Integer.parseInt(errorString), TAGS_METHOD));
        } else {
            this.setMethod(new SelectedTag(0, TAGS_METHOD));
        }
        String rfactorString = Utils.getOption('R', options);
        if (rfactorString.length() != 0) {
            this.setRandomWidthFactor(new Double(rfactorString));
        } else {
            this.setRandomWidthFactor(2.0);
        }
        this.setUsePairwiseCoupling(Utils.getFlag('P', options));
        this.setLogLossDecoding(Utils.getFlag('L', options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-M");
        options.add("" + this.m_Method);
        if (this.getUsePairwiseCoupling()) {
            options.add("-P");
        }
        if (this.getLogLossDecoding()) {
            options.add("-L");
        }
        options.add("-R");
        options.add("" + this.m_RandomWidthFactor);
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    public String globalInfo() {
        return "A metaclassifier for handling multi-class datasets with 2-class classifiers. This classifier is also capable of applying error correcting output codes for increased accuracy.";
    }

    public String logLossDecodingTipText() {
        return "Use log loss decoding for random or exhaustive codes.";
    }

    public boolean getLogLossDecoding() {
        return this.m_logLossDecoding;
    }

    public void setLogLossDecoding(boolean newlogLossDecoding) {
        this.m_logLossDecoding = newlogLossDecoding;
    }

    public String randomWidthFactorTipText() {
        return "Sets the width multiplier when using random codes. The number of codes generated will be thus number multiplied by the number of classes.";
    }

    public double getRandomWidthFactor() {
        return this.m_RandomWidthFactor;
    }

    public void setRandomWidthFactor(double newRandomWidthFactor) {
        this.m_RandomWidthFactor = newRandomWidthFactor;
    }

    public String methodTipText() {
        return "Sets the method to use for transforming the multi-class problem into several 2-class ones.";
    }

    public SelectedTag getMethod() {
        return new SelectedTag(this.m_Method, TAGS_METHOD);
    }

    public void setMethod(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_METHOD) {
            this.m_Method = newMethod.getSelectedTag().getID();
        }
    }

    public void setUsePairwiseCoupling(boolean p) {
        this.m_pairwiseCoupling = p;
    }

    public boolean getUsePairwiseCoupling() {
        return this.m_pairwiseCoupling;
    }

    public String usePairwiseCouplingTipText() {
        return "Use pairwise coupling (only has an effect for 1-against-1).";
    }

    public static double[] pairwiseCoupling(double[][] n, double[][] r) {
        boolean changed;
        double[] p = new double[r.length];
        for (int i = 0; i < p.length; ++i) {
            p[i] = 1.0 / (double)p.length;
        }
        double[][] u = new double[r.length][r.length];
        for (int i = 0; i < r.length; ++i) {
            for (int j = i + 1; j < r.length; ++j) {
                u[i][j] = 0.5;
            }
        }
        double[] firstSum = new double[p.length];
        for (int i = 0; i < p.length; ++i) {
            for (int j = i + 1; j < p.length; ++j) {
                int n2 = i;
                firstSum[n2] = firstSum[n2] + n[i][j] * r[i][j];
                int n3 = j;
                firstSum[n3] = firstSum[n3] + n[i][j] * (1.0 - r[i][j]);
            }
        }
        do {
            int i;
            changed = false;
            double[] secondSum = new double[p.length];
            for (i = 0; i < p.length; ++i) {
                for (int j = i + 1; j < p.length; ++j) {
                    int n4 = i;
                    secondSum[n4] = secondSum[n4] + n[i][j] * u[i][j];
                    int n5 = j;
                    secondSum[n5] = secondSum[n5] + n[i][j] * (1.0 - u[i][j]);
                }
            }
            for (i = 0; i < p.length; ++i) {
                if (firstSum[i] == 0.0 || secondSum[i] == 0.0) {
                    if (p[i] > 0.0) {
                        changed = true;
                    }
                    p[i] = 0.0;
                    continue;
                }
                double factor = firstSum[i] / secondSum[i];
                double pOld = p[i];
                int n6 = i;
                p[n6] = p[n6] * factor;
                if (!(Math.abs(pOld - p[i]) > 0.001)) continue;
                changed = true;
            }
            Utils.normalize(p);
            for (i = 0; i < r.length; ++i) {
                for (int j = i + 1; j < r.length; ++j) {
                    u[i][j] = p[i] / (p[i] + p[j]);
                }
            }
        } while (changed);
        return p;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11890 $");
    }

    public static void main(String[] argv) {
        MultiClassClassifier.runClassifier(new MultiClassClassifier(), argv);
    }

    private class ExhaustiveCode
    extends Code {
        static final long serialVersionUID = 8090991039670804047L;

        public ExhaustiveCode(int numClasses) {
            int width = (int)Math.pow(2.0, numClasses - 1) - 1;
            this.m_Codebits = new boolean[width][numClasses];
            for (int j = 0; j < width; ++j) {
                this.m_Codebits[j][0] = true;
            }
            for (int i = 1; i < numClasses; ++i) {
                int skip = (int)Math.pow(2.0, numClasses - (i + 1));
                for (int j = 0; j < width; ++j) {
                    this.m_Codebits[j][i] = j / skip % 2 != 0;
                }
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 11890 $");
        }
    }

    private class RandomCode
    extends Code {
        static final long serialVersionUID = 4413410540703926563L;
        Random r;

        public RandomCode(int numClasses, int numCodes, Instances data) {
            this.r = null;
            this.r = data.getRandomNumberGenerator(MultiClassClassifier.this.m_Seed);
            numCodes = Math.max(2, numCodes);
            this.m_Codebits = new boolean[numCodes][numClasses];
            int i = 0;
            do {
                this.randomize();
            } while (!this.good() && i++ < 100);
        }

        private boolean good() {
            int i;
            boolean[] ninClass = new boolean[this.m_Codebits[0].length];
            boolean[] ainClass = new boolean[this.m_Codebits[0].length];
            for (i = 0; i < ainClass.length; ++i) {
                ainClass[i] = true;
            }
            for (i = 0; i < this.m_Codebits.length; ++i) {
                boolean ninCode = false;
                boolean ainCode = true;
                for (int j = 0; j < this.m_Codebits[i].length; ++j) {
                    boolean current = this.m_Codebits[i][j];
                    ninCode = ninCode || current;
                    ainCode = ainCode && current;
                    ninClass[j] = ninClass[j] || current;
                    ainClass[j] = ainClass[j] && current;
                }
                if (ninCode && !ainCode) continue;
                return false;
            }
            for (int j = 0; j < ninClass.length; ++j) {
                if (ninClass[j] && !ainClass[j]) continue;
                return false;
            }
            return true;
        }

        private void randomize() {
            for (int i = 0; i < this.m_Codebits.length; ++i) {
                for (int j = 0; j < this.m_Codebits[i].length; ++j) {
                    double temp = this.r.nextDouble();
                    this.m_Codebits[i][j] = !(temp < 0.5);
                }
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 11890 $");
        }
    }

    private class StandardCode
    extends Code {
        static final long serialVersionUID = 3707829689461467358L;

        public StandardCode(int numClasses) {
            this.m_Codebits = new boolean[numClasses][numClasses];
            for (int i = 0; i < numClasses; ++i) {
                this.m_Codebits[i][i] = true;
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 11890 $");
        }
    }

    private abstract class Code
    implements Serializable,
    RevisionHandler {
        static final long serialVersionUID = 418095077487120846L;
        protected boolean[][] m_Codebits;

        private Code() {
        }

        public int size() {
            return this.m_Codebits.length;
        }

        public String getIndices(int which) {
            StringBuffer sb = new StringBuffer();
            for (int i = 0; i < this.m_Codebits[which].length; ++i) {
                if (!this.m_Codebits[which][i]) continue;
                if (sb.length() != 0) {
                    sb.append(',');
                }
                sb.append(i + 1);
            }
            return sb.toString();
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            for (int i = 0; i < this.m_Codebits[0].length; ++i) {
                for (int j = 0; j < this.m_Codebits.length; ++j) {
                    sb.append(this.m_Codebits[j][i] ? " 1" : " 0");
                }
                sb.append('\n');
            }
            return sb.toString();
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 11890 $");
        }
    }
}

