package edu.cornell.cs.nlp.spf.learn.situated.perceptron;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorFactory;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.categories.ICategoryServices;
import edu.cornell.cs.nlp.spf.data.ILabeledDataItem;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.data.sentence.Sentence;
import edu.cornell.cs.nlp.spf.data.situated.ISituatedDataItem;
import edu.cornell.cs.nlp.spf.data.utils.IValidator;
import edu.cornell.cs.nlp.spf.explat.IResourceRepository;
import edu.cornell.cs.nlp.spf.explat.ParameterizedExperiment;
import edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator;
import edu.cornell.cs.nlp.spf.explat.resources.usage.ResourceUsage;
import edu.cornell.cs.nlp.spf.genlex.ccg.ILexiconGenerator;
import edu.cornell.cs.nlp.spf.learn.situated.AbstractSituatedLearner;
import edu.cornell.cs.nlp.spf.parser.IParserOutput;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IDataItemModel;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.skolem.SkolemIDRule;
import edu.cornell.cs.nlp.spf.parser.joint.IJointDerivation;
import edu.cornell.cs.nlp.spf.parser.joint.IJointOutput;
import edu.cornell.cs.nlp.spf.parser.joint.IJointOutputLogger;
import edu.cornell.cs.nlp.spf.parser.joint.IJointParser;
import edu.cornell.cs.nlp.spf.parser.joint.model.IJointDataItemModel;
import edu.cornell.cs.nlp.spf.parser.joint.model.IJointModelImmutable;
import edu.cornell.cs.nlp.spf.parser.joint.model.JointModel;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import gnu.trove.impl.PrimeFinder;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/situated/perceptron/SituatedValidationPerceptron.class */
public class SituatedValidationPerceptron<SAMPLE extends ISituatedDataItem<Sentence, ?>, MR, ESTEP, ERESULT, DI extends ILabeledDataItem<SAMPLE, ?>> extends AbstractSituatedLearner<SAMPLE, MR, ESTEP, ERESULT, DI> {
    public static final ILogger LOG = LoggerFactory.create((Class<?>) SituatedValidationPerceptron.class);
    private final boolean hardUpdates;
    private final double margin;
    private final IValidator<DI, ERESULT> validator;

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/situated/perceptron/SituatedValidationPerceptron$Builder.class */
    public static class Builder<SAMPLE extends ISituatedDataItem<Sentence, ?>, MR, ESTEP, ERESULT, DI extends ILabeledDataItem<SAMPLE, ?>> {
        private final IJointParser<SAMPLE, MR, ESTEP, ERESULT> parser;
        private final IDataCollection<DI> trainingData;
        private final IValidator<DI, ERESULT> validator;
        private ICategoryServices<MR> categoryServices = null;
        private ILexiconGenerator<DI, MR, IJointModelImmutable<SAMPLE, MR, ESTEP>> genlex = null;
        private boolean hardUpdates = false;
        private int lexiconGenerationBeamSize = 20;
        private double margin = 1.0d;
        private int maxSentenceLength = PrimeFinder.largestPrime;
        private int numIterations = 4;
        private IJointOutputLogger<MR, ESTEP, ERESULT> parserOutputLogger = new IJointOutputLogger<MR, ESTEP, ERESULT>() { // from class: edu.cornell.cs.nlp.spf.learn.situated.perceptron.SituatedValidationPerceptron.Builder.1
            private static final long serialVersionUID = 4342845964338126692L;

            @Override // edu.cornell.cs.nlp.spf.parser.joint.IJointOutputLogger
            public void log(IJointOutput<MR, ERESULT> iJointOutput, IJointDataItemModel<MR, ESTEP> iJointDataItemModel, String str) {
            }

            @Override // edu.cornell.cs.nlp.spf.parser.IOutputLogger
            public void log(IParserOutput<MR> iParserOutput, IDataItemModel<MR> iDataItemModel, String str) {
            }
        };
        private Map<DI, Pair<MR, ERESULT>> trainingDataDebug = new HashMap();

        public Builder(IDataCollection<DI> iDataCollection, IJointParser<SAMPLE, MR, ESTEP, ERESULT> iJointParser, IValidator<DI, ERESULT> iValidator) {
            this.trainingData = iDataCollection;
            this.parser = iJointParser;
            this.validator = iValidator;
        }

        public SituatedValidationPerceptron<SAMPLE, MR, ESTEP, ERESULT, DI> build() {
            return new SituatedValidationPerceptron<>(this.numIterations, this.margin, this.trainingData, this.trainingDataDebug, this.maxSentenceLength, this.lexiconGenerationBeamSize, this.parser, this.hardUpdates, this.parserOutputLogger, this.validator, this.categoryServices, this.genlex);
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setGenlex(ILexiconGenerator<DI, MR, IJointModelImmutable<SAMPLE, MR, ESTEP>> iLexiconGenerator, ICategoryServices<MR> iCategoryServices) {
            this.genlex = iLexiconGenerator;
            this.categoryServices = iCategoryServices;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setHardUpdates(boolean z) {
            this.hardUpdates = z;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setLexiconGenerationBeamSize(int i) {
            this.lexiconGenerationBeamSize = i;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setMargin(double d) {
            this.margin = d;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setMaxSentenceLength(int i) {
            this.maxSentenceLength = i;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setNumTrainingIterations(int i) {
            this.numIterations = i;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setParserOutputLogger(IJointOutputLogger<MR, ESTEP, ERESULT> iJointOutputLogger) {
            this.parserOutputLogger = iJointOutputLogger;
            return this;
        }

        public Builder<SAMPLE, MR, ESTEP, ERESULT, DI> setTrainingDataDebug(Map<DI, Pair<MR, ERESULT>> map) {
            this.trainingDataDebug = map;
            return this;
        }
    }

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/situated/perceptron/SituatedValidationPerceptron$Creator.class */
    public static class Creator<SAMPLE extends ISituatedDataItem<Sentence, ?>, MR, ESTEP, ERESULT, DI extends ILabeledDataItem<SAMPLE, ?>> implements IResourceObjectCreator<SituatedValidationPerceptron<SAMPLE, MR, ESTEP, ERESULT, DI>> {
        private final String name;

        public Creator() {
            this("learner.weakp.valid.situated");
        }

        public Creator(String str) {
            this.name = str;
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public SituatedValidationPerceptron<SAMPLE, MR, ESTEP, ERESULT, DI> create(ParameterizedExperiment.Parameters parameters, IResourceRepository iResourceRepository) {
            Builder builder = new Builder((IDataCollection) iResourceRepository.get(parameters.get("data")), (IJointParser) iResourceRepository.get(ParameterizedExperiment.PARSER_RESOURCE), (IValidator) iResourceRepository.get(parameters.get("validator")));
            if ("true".equals(parameters.get("hard"))) {
                builder.setHardUpdates(true);
            }
            if (parameters.contains("parseLogger")) {
                builder.setParserOutputLogger((IJointOutputLogger) iResourceRepository.get(parameters.get("parseLogger")));
            }
            if (parameters.contains(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN)) {
                builder.setGenlex((ILexiconGenerator) iResourceRepository.get(parameters.get(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN)), (ICategoryServices) iResourceRepository.get(ParameterizedExperiment.CATEGORY_SERVICES_RESOURCE));
            }
            if (parameters.contains("genlexbeam")) {
                builder.setLexiconGenerationBeamSize(Integer.valueOf(parameters.get("genlexbeam")).intValue());
            }
            if (parameters.contains("margin")) {
                builder.setMargin(Double.valueOf(parameters.get("margin")).doubleValue());
            }
            if (parameters.contains("maxSentenceLength")) {
                builder.setMaxSentenceLength(Integer.valueOf(parameters.get("maxSentenceLength")).intValue());
            }
            if (parameters.contains("iter")) {
                builder.setNumTrainingIterations(Integer.valueOf(parameters.get("iter")).intValue());
            }
            return builder.build();
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public String type() {
            return this.name;
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public ResourceUsage usage() {
            return new ResourceUsage.Builder(type(), SituatedValidationPerceptron.class).setDescription("Validation senstive perceptron for situated learning of models with situated inference (cite: Artzi and Zettlemoyer 2013)").addParam("data", SkolemIDRule.RULE_LABEL, "Training data").addParam("hard", "boolean", "Use hard updates (i.e., only use max scoring valid parses/evaluation as positive samples). Options: true, false. Default: false").addParam("parseLogger", SkolemIDRule.RULE_LABEL, "Parse logger for debug detailed logging of parses").addParam(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN, "ILexiconGenerator", "GENLEX procedure").addParam("genlexbeam", "int", "Beam to use for GENLEX inference (parsing).").addParam("margin", "double", "Margin to use for updates. Updates will be done when this margin is violated.").addParam("maxSentenceLength", "int", "Max sentence length to process").addParam("iter", "int", "Number of training iterations").addParam("validator", "IValidator", "Validation function").build();
        }
    }

    private SituatedValidationPerceptron(int i, double d, IDataCollection<DI> iDataCollection, Map<DI, Pair<MR, ERESULT>> map, int i2, int i3, IJointParser<SAMPLE, MR, ESTEP, ERESULT> iJointParser, boolean z, IJointOutputLogger<MR, ESTEP, ERESULT> iJointOutputLogger, IValidator<DI, ERESULT> iValidator, ICategoryServices<MR> iCategoryServices, ILexiconGenerator<DI, MR, IJointModelImmutable<SAMPLE, MR, ESTEP>> iLexiconGenerator) {
        super(i, iDataCollection, map, i2, i3, iJointParser, iJointOutputLogger, iCategoryServices, iLexiconGenerator);
        this.margin = d;
        this.hardUpdates = z;
        this.validator = iValidator;
        LOG.info("Init SituatedValidationSensitivePerceptron: numIterations=%d, margin=%f, trainingData.size()=%d, trainingDataDebug.size()=%d, maxSentenceLength=%d ...", Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(iDataCollection.size()), Integer.valueOf(map.size()), Integer.valueOf(i2));
        LOG.info("Init SituatedValidationSensitivePerceptron: ... lexiconGenerationBeamSize=%d", Integer.valueOf(i3));
    }

    private IHashVector constructUpdate(List<IJointDerivation<MR, ERESULT>> list, List<IJointDerivation<MR, ERESULT>> list2, JointModel<SAMPLE, MR, ESTEP> jointModel) {
        IHashVector create = HashVectorFactory.create();
        Iterator<IJointDerivation<MR, ERESULT>> it2 = list.iterator();
        while (it2.hasNext()) {
            it2.next().getMeanMaxFeatures().addTimesInto(1.0d / list.size(), create);
        }
        Iterator<IJointDerivation<MR, ERESULT>> it3 = list2.iterator();
        while (it3.hasNext()) {
            it3.next().getMeanMaxFeatures().addTimesInto((-1.0d) * (1.0d / list2.size()), create);
        }
        create.dropNoise();
        if (jointModel.isValidWeightVector(create)) {
            return create;
        }
        throw new IllegalStateException("invalid update: " + create);
    }

    private Pair<List<IJointDerivation<MR, ERESULT>>, List<IJointDerivation<MR, ERESULT>>> createValidInvalidSets(DI di, Collection<? extends IJointDerivation<MR, ERESULT>> collection) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        double d = -1.7976931348623157E308d;
        for (IJointDerivation<MR, ERESULT> iJointDerivation : collection) {
            if (!validate(di, iJointDerivation.getResult())) {
                linkedList2.add(iJointDerivation);
            } else if (!this.hardUpdates) {
                linkedList.add(iJointDerivation);
            } else if (iJointDerivation.getViterbiScore() > d) {
                d = iJointDerivation.getViterbiScore();
                linkedList.clear();
                linkedList.add(iJointDerivation);
            } else if (iJointDerivation.getViterbiScore() == d) {
                linkedList.add(iJointDerivation);
            }
        }
        return Pair.of(linkedList, linkedList2);
    }

    private Pair<List<IJointDerivation<MR, ERESULT>>, List<IJointDerivation<MR, ERESULT>>> marginViolatingSets(JointModel<SAMPLE, MR, ESTEP> jointModel, List<IJointDerivation<MR, ERESULT>> list, List<IJointDerivation<MR, ERESULT>> list2) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        boolean[] zArr = new boolean[list.size()];
        boolean[] zArr2 = new boolean[list2.size()];
        int i = 0;
        for (IJointDerivation<MR, ERESULT> iJointDerivation : list) {
            int i2 = 0;
            for (IJointDerivation<MR, ERESULT> iJointDerivation2 : list2) {
                if (!zArr[i] || !zArr2[i2]) {
                    IHashVector addTimes = iJointDerivation.getMeanMaxFeatures().addTimes(-1.0d, iJointDerivation2.getMeanMaxFeatures());
                    double score = jointModel.score(addTimes);
                    if (!zArr[i] && score < this.margin * addTimes.l1Norm()) {
                        linkedList.add(iJointDerivation);
                        zArr[i] = true;
                    }
                    if (!zArr2[i2] && score < this.margin * addTimes.l1Norm()) {
                        linkedList2.add(iJointDerivation2);
                        zArr2[i2] = true;
                    }
                }
                i2++;
            }
            i++;
        }
        return Pair.of(linkedList, linkedList2);
    }

    @Override // edu.cornell.cs.nlp.spf.learn.situated.AbstractSituatedLearner
    protected void parameterUpdate(DI di, IJointDataItemModel<MR, ESTEP> iJointDataItemModel, JointModel<SAMPLE, MR, ESTEP> jointModel, int i, int i2) {
        IJointOutput<MR, ERESULT> parse = this.parser.parse((ISituatedDataItem) di.getSample(), iJointDataItemModel);
        this.stats.mean("model parse", parse.getInferenceTime() / 1000.0d, "sec");
        this.parserOutputLogger.log(parse, iJointDataItemModel, String.format("%d-update", Integer.valueOf(i)));
        List<? extends IJointDerivation<MR, ERESULT>> derivations = parse.getDerivations();
        List<? extends IJointDerivation<MR, ERESULT>> maxDerivations = parse.getMaxDerivations();
        if (derivations.isEmpty()) {
            LOG.info("No parses for: %s", di);
            LOG.info("Skipping parameter update");
            return;
        }
        LOG.info("Created %d model parses for training sample", Integer.valueOf(derivations.size()));
        LOG.info("Model parsing time: %.4fsec", Double.valueOf(parse.getInferenceTime() / 1000.0d));
        LOG.info("Output is %s", parse.isExact() ? "exact" : "approximate");
        Pair<List<IJointDerivation<MR, ERESULT>>, List<IJointDerivation<MR, ERESULT>>> createValidInvalidSets = createValidInvalidSets(di, derivations);
        List<IJointDerivation<MR, ERESULT>> first = createValidInvalidSets.first();
        List<IJointDerivation<MR, ERESULT>> second = createValidInvalidSets.second();
        LOG.info("%d valid parses, %d invalid parses", Integer.valueOf(first.size()), Integer.valueOf(second.size()));
        LOG.info("Valid parses:");
        Iterator<IJointDerivation<MR, ERESULT>> it2 = first.iterator();
        while (it2.hasNext()) {
            logParse(di, it2.next(), true, true, iJointDataItemModel);
        }
        if (maxDerivations.size() == 1 && isGoldDebugCorrect(di, maxDerivations.get(0).getResult())) {
            this.stats.appendSampleStat(i, i2, "G");
        } else if (!first.isEmpty()) {
            this.stats.appendSampleStat(i, i2, "V");
        }
        if (!first.isEmpty()) {
            this.stats.count("valid", i2);
        }
        if (first.isEmpty() || second.isEmpty()) {
            LOG.info("No valid/invalid parses -- skipping");
            return;
        }
        Pair<List<IJointDerivation<MR, ERESULT>>, List<IJointDerivation<MR, ERESULT>>> marginViolatingSets = marginViolatingSets(jointModel, first, second);
        List<IJointDerivation<MR, ERESULT>> first2 = marginViolatingSets.first();
        List<IJointDerivation<MR, ERESULT>> second2 = marginViolatingSets.second();
        LOG.info("%d violating valid parses, %d violating invalid parses", Integer.valueOf(first2.size()), Integer.valueOf(second2.size()));
        if (first2.isEmpty()) {
            LOG.info("There are no violating valid/invalid parses -- skipping");
            return;
        }
        LOG.info("Violating valid parses: ");
        Iterator<IJointDerivation<MR, ERESULT>> it3 = first2.iterator();
        while (it3.hasNext()) {
            logParse(di, it3.next(), true, true, iJointDataItemModel);
        }
        LOG.info("Violating invalid parses: ");
        Iterator<IJointDerivation<MR, ERESULT>> it4 = second2.iterator();
        while (it4.hasNext()) {
            logParse(di, it4.next(), false, true, iJointDataItemModel);
        }
        IHashVector constructUpdate = constructUpdate(first2, second2, jointModel);
        LOG.info("Update: %s", constructUpdate);
        constructUpdate.addTimesInto(1.0d, jointModel.getTheta());
        this.stats.appendSampleStat(i, i2, "U");
        this.stats.count("update", i2);
    }

    @Override // edu.cornell.cs.nlp.spf.learn.situated.AbstractSituatedLearner
    protected boolean validate(DI di, ERESULT eresult) {
        return this.validator.isValid(di, eresult);
    }
}
