package edu.tufts.cs.hrilab.pinc.weka;

import edu.tufts.cs.hrilab.pinc.options.Option;
import edu.tufts.cs.hrilab.pinc.options.OptionSet;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils;

/* loaded from: input_file:edu/tufts/cs/hrilab/pinc/weka/WekaLearner.class */
public class WekaLearner {
    String name;
    private OptionSet options;
    private Classifier tree;
    private Instances instances;
    private Algorithm algorithm = new J48();
    private boolean serialized = false;
    ArrayList<String> classValues = new ArrayList<>();

    public OptionSet initializeOptions() {
        this.options.add(new Option("classifier", this.name));
        return this.options;
    }

    public OptionSet initializeOptions(OptionSet optionSet) {
        this.options = optionSet;
        this.options.add(new Option("classifier", this.name));
        return this.options;
    }

    public WekaLearner(String str, OptionSet optionSet, String[] strArr) {
        this.name = "";
        this.name = str;
        this.options = optionSet;
        initializeOptions();
        handleArgs(strArr);
        initialize(strArr);
    }

    public WekaLearner(String str, OptionSet optionSet) {
        this.name = "";
        this.name = str;
        this.options = optionSet;
        initializeOptions();
    }

    public WekaLearner(String str) {
        this.name = "";
        this.name = str;
    }

    public int getNumberOfFeatures() {
        if (this.instances != null) {
            return this.instances.numAttributes();
        }
        return 0;
    }

    public void setClassIndex(int i) {
        this.instances.setClassIndex(i);
    }

    public int getClassIndex() {
        return this.instances.classIndex();
    }

    public WekaLearner deleteAttributeAt(String str, int i) throws Exception {
        WekaLearner wekaLearner = null;
        if (i < this.instances.numAttributes()) {
            Instances instances = this.instances;
            instances.setClassIndex(i - 1);
            instances.deleteAttributeAt(i);
            Classifier classifier = this.algorithm.getClassifier();
            classifier.buildClassifier(instances);
            wekaLearner = new WekaLearner(str, this.options);
            wekaLearner.setInstances(instances);
            wekaLearner.setClassifier(classifier);
            wekaLearner.setClassValues();
        } else {
            System.out.println("Attribute not valid.");
        }
        return wekaLearner;
    }

    private void setInstances(Instances instances) {
        this.instances = instances;
    }

    private void setClassifier(Classifier classifier) {
        this.tree = classifier;
    }

    private void setClassValues() {
        this.classValues = setClassValues(this.instances);
    }

    public void handleArgs(String[] strArr) {
        if (strArr != null) {
            int i = 0;
            while (i < strArr.length) {
                if (strArr[i] != null) {
                    if (strArr[i].equalsIgnoreCase("-" + this.name + "algorithm")) {
                        i++;
                        this.algorithm = Algorithm.fromString(strArr[i], strArr);
                    } else if (strArr[i].equalsIgnoreCase("-" + this.name + "classifier")) {
                        strArr[i] = "";
                        i++;
                        this.options.set("classifier", this.name, strArr[i]);
                        strArr[i] = "";
                    } else if (strArr[i].equals("-debug")) {
                        i++;
                    } else if (strArr[i].equalsIgnoreCase("-serialized")) {
                        this.serialized = true;
                    }
                }
                i++;
            }
        }
    }

    public static String printHelp() {
        return "*** WekaLearner help ***\n  -classifier\t<filename of classifier>\n  -serialized\t<use if the classifier is already serialized>\n";
    }

    public void initialize(String[] strArr) {
        try {
            if (this.serialized) {
                loadSerializedModel();
            } else {
                loadTrainingData(strArr);
            }
        } catch (Exception e) {
            System.err.println("Unable to load classifier " + this.options.get("classifier", this.name) + ".");
            e.printStackTrace();
        }
    }

    public void loadTrainingData(String[] strArr) throws Exception {
        if (this.options.get("classifier", this.name) == null) {
            System.err.println("No classifier file given.");
            return;
        }
        this.instances = null;
        try {
            this.instances = new ConverterUtils.DataSource((String) this.options.get("classifier", this.name)).getDataSet();
        } catch (Exception e) {
            System.err.println("Training instances not read.");
        }
        if (this.instances != null) {
            this.instances.setClassIndex(this.instances.numAttributes() - 1);
        }
        this.classValues = setClassValues(this.instances);
        this.tree = this.algorithm.getClassifier();
        this.tree.buildClassifier(this.instances);
    }

    public ArrayList<String> setClassValues(Instances instances) {
        ArrayList<String> arrayList = new ArrayList<>();
        Enumeration enumerateValues = instances.classAttribute().enumerateValues();
        while (enumerateValues.hasMoreElements()) {
            arrayList.add(enumerateValues.nextElement().toString());
        }
        return arrayList;
    }

    public void loadSerializedModel() throws Exception {
        if (((String) this.options.get("classifier", this.name)) == "") {
            System.err.println("No model file given.");
            return;
        }
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream((String) this.options.get("classifier", this.name)));
        this.tree = (Classifier) objectInputStream.readObject();
        objectInputStream.close();
    }

    public Result classifyInstance(ArrayList<String> arrayList) {
        Result result = null;
        try {
            Instance instance = new Instance(arrayList.size() + 1);
            instance.setDataset(this.instances);
            for (int i = 0; i < instance.numAttributes() && i < arrayList.size(); i++) {
                try {
                    instance.setValue(i, arrayList.get(i));
                } catch (IllegalArgumentException e) {
                    instance.setMissing(i);
                } catch (Exception e2) {
                    e2.printStackTrace();
                }
            }
            if (instance == null || this.tree == null) {
                System.err.println("I can't do that, Dave.");
            } else {
                try {
                    double classifyInstance = this.tree.classifyInstance(instance);
                    result = new Result(this.classValues.get(new Double(classifyInstance).intValue()), this.tree.distributionForInstance(instance)[new Double(classifyInstance).intValue()]);
                } catch (Exception e3) {
                    e3.printStackTrace();
                }
            }
        } catch (Exception e4) {
            e4.printStackTrace();
        }
        return result;
    }

    public ArrayList<Result> getLabelsInDescendingOrder(ArrayList<String> arrayList) {
        ArrayList<Result> arrayList2 = new ArrayList<>();
        try {
            Instance instance = new Instance(arrayList.size() + 1);
            instance.setDataset(this.instances);
            for (int i = 0; i < arrayList.size(); i++) {
                try {
                    instance.setValue(i, arrayList.get(i));
                } catch (IllegalArgumentException e) {
                }
            }
            if (instance == null || this.tree == null) {
                System.err.println("I can't do that, Dave.");
            } else {
                try {
                    double[] distributionForInstance = this.tree.distributionForInstance(instance);
                    Result findAnyTies = findAnyTies(getMaximum(distributionForInstance), distributionForInstance);
                    arrayList2.add(findAnyTies);
                    Result maximumLessThan = getMaximumLessThan(distributionForInstance, findAnyTies.probability);
                    while (maximumLessThan != null) {
                        arrayList2.add(maximumLessThan);
                        maximumLessThan = getMaximumLessThan(distributionForInstance, maximumLessThan.probability);
                    }
                } catch (Exception e2) {
                    System.err.println("Problem classifying instance.");
                }
            }
        } catch (Exception e3) {
            e3.printStackTrace();
        }
        return arrayList2;
    }

    public ArrayList<Result> getTiedLabels(ArrayList<String> arrayList) {
        ArrayList<Result> arrayList2 = new ArrayList<>();
        try {
            Instance instance = new Instance(arrayList.size() + 1);
            instance.setDataset(this.instances);
            for (int i = 0; i < arrayList.size(); i++) {
                try {
                    instance.setValue(i, arrayList.get(i));
                } catch (IllegalArgumentException e) {
                }
            }
            if (instance == null || this.tree == null) {
                System.err.println("I can't do that, Dave.");
            } else {
                try {
                    double[] distributionForInstance = this.tree.distributionForInstance(instance);
                    for (double d : distributionForInstance) {
                        System.out.println(d);
                    }
                    arrayList2.add(findAnyTies(getMaximum(distributionForInstance), distributionForInstance));
                } catch (Exception e2) {
                    System.err.println("Problem classifying instance.");
                }
            }
        } catch (Exception e3) {
            e3.printStackTrace();
        }
        return arrayList2;
    }

    public Result findAnyTies(Result result, double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (result.probability == dArr[i]) {
                String str = this.classValues.get(new Double(i).intValue());
                if (!result.label.equals(str)) {
                    result.label += "-" + str;
                }
            }
        }
        return result;
    }

    public Result getMaximum(double[] dArr) {
        double d = dArr[0];
        String str = this.classValues.get(new Double(0.0d).intValue());
        for (int i = 1; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
                str = this.classValues.get(new Double(i).intValue());
            }
        }
        return new Result(str, d);
    }

    public Result getMaximumLessThan(double[] dArr, double d) {
        double d2 = -1.0d;
        String str = "NULL";
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d2 && dArr[i] < d) {
                d2 = dArr[i];
                str = this.classValues.get(new Double(i).intValue());
            }
        }
        return d2 != -1.0d ? new Result(str, d2) : (Result) null;
    }
}
