package edu.cornell.cs.nlp.spf.learn.validation.perceptron;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorFactory;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.categories.ICategoryServices;
import edu.cornell.cs.nlp.spf.ccg.lexicon.ILexiconImmutable;
import edu.cornell.cs.nlp.spf.data.IDataItem;
import edu.cornell.cs.nlp.spf.data.ILabeledDataItem;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.data.utils.IValidator;
import edu.cornell.cs.nlp.spf.explat.IResourceRepository;
import edu.cornell.cs.nlp.spf.explat.ParameterizedExperiment;
import edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator;
import edu.cornell.cs.nlp.spf.explat.resources.usage.ResourceUsage;
import edu.cornell.cs.nlp.spf.genlex.ccg.ILexiconGenerator;
import edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner;
import edu.cornell.cs.nlp.spf.parser.IDerivation;
import edu.cornell.cs.nlp.spf.parser.IOutputLogger;
import edu.cornell.cs.nlp.spf.parser.IParser;
import edu.cornell.cs.nlp.spf.parser.IParserOutput;
import edu.cornell.cs.nlp.spf.parser.ParsingOp;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IDataItemModel;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IModelImmutable;
import edu.cornell.cs.nlp.spf.parser.ccg.model.Model;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.skolem.SkolemIDRule;
import edu.cornell.cs.nlp.spf.parser.filter.IParsingFilterFactory;
import edu.cornell.cs.nlp.spf.parser.filter.StubFilterFactory;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/validation/perceptron/ValidationPerceptron.class */
public class ValidationPerceptron<SAMPLE extends IDataItem<?>, DI extends ILabeledDataItem<SAMPLE, ?>, MR> extends AbstractLearner<SAMPLE, DI, IParserOutput<MR>, MR> {
    public static final ILogger LOG = LoggerFactory.create((Class<?>) ValidationPerceptron.class);
    private final boolean hardUpdates;
    private final double margin;
    private final IParser<SAMPLE, MR> parser;
    private final IValidator<DI, MR> validator;

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/validation/perceptron/ValidationPerceptron$Builder.class */
    public static class Builder<SAMPLE extends IDataItem<?>, DI extends ILabeledDataItem<SAMPLE, ?>, MR> {
        private final IParser<SAMPLE, MR> parser;
        private final IDataCollection<DI> trainingData;
        private final IValidator<DI, MR> validator;
        private ICategoryServices<MR> categoryServices = null;
        private boolean conflateGenlexAndPrunedParses = false;
        private boolean errorDriven = false;
        private ILexiconGenerator<DI, MR, IModelImmutable<SAMPLE, MR>> genlex = null;
        private boolean hardUpdates = false;
        private int lexiconGenerationBeamSize = 20;
        private double margin = 1.0d;
        private int numIterations = 4;
        private IOutputLogger<MR> parserOutputLogger = new IOutputLogger<MR>() { // from class: edu.cornell.cs.nlp.spf.learn.validation.perceptron.ValidationPerceptron.Builder.1
            private static final long serialVersionUID = 6377014574873324695L;

            @Override // edu.cornell.cs.nlp.spf.parser.IOutputLogger
            public void log(IParserOutput<MR> iParserOutput, IDataItemModel<MR> iDataItemModel, String str) {
            }
        };
        private IParsingFilterFactory<DI, MR> parsingFilterFactory = new StubFilterFactory();
        private IFilter<DI> processingFilter = iLabeledDataItem -> {
            return true;
        };
        private Map<DI, MR> trainingDataDebug = new HashMap();

        public Builder(IDataCollection<DI> iDataCollection, IParser<SAMPLE, MR> iParser, IValidator<DI, MR> iValidator) {
            this.trainingData = iDataCollection;
            this.parser = iParser;
            this.validator = iValidator;
        }

        public ValidationPerceptron<SAMPLE, DI, MR> build() {
            return new ValidationPerceptron<>(this.numIterations, this.trainingData, this.trainingDataDebug, this.lexiconGenerationBeamSize, this.parser, this.parserOutputLogger, this.conflateGenlexAndPrunedParses, this.errorDriven, this.categoryServices, this.genlex, this.margin, this.hardUpdates, this.validator, this.processingFilter, this.parsingFilterFactory);
        }

        public Builder<SAMPLE, DI, MR> setConflateGenlexAndPrunedParses(boolean z) {
            this.conflateGenlexAndPrunedParses = z;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setErrorDriven(boolean z) {
            this.errorDriven = z;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setGenlex(ILexiconGenerator<DI, MR, IModelImmutable<SAMPLE, MR>> iLexiconGenerator, ICategoryServices<MR> iCategoryServices) {
            this.genlex = iLexiconGenerator;
            this.categoryServices = iCategoryServices;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setHardUpdates(boolean z) {
            this.hardUpdates = z;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setLexiconGenerationBeamSize(int i) {
            this.lexiconGenerationBeamSize = i;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setMargin(double d) {
            this.margin = d;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setNumTrainingIterations(int i) {
            this.numIterations = i;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setParserOutputLogger(IOutputLogger<MR> iOutputLogger) {
            this.parserOutputLogger = iOutputLogger;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setParsingFilterFactory(IParsingFilterFactory<DI, MR> iParsingFilterFactory) {
            this.parsingFilterFactory = iParsingFilterFactory;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setProcessingFilter(IFilter<DI> iFilter) {
            this.processingFilter = iFilter;
            return this;
        }

        public Builder<SAMPLE, DI, MR> setTrainingDataDebug(Map<DI, MR> map) {
            this.trainingDataDebug = map;
            return this;
        }
    }

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/validation/perceptron/ValidationPerceptron$Creator.class */
    public static class Creator<SAMPLE extends IDataItem<?>, DI extends ILabeledDataItem<SAMPLE, ?>, MR> implements IResourceObjectCreator<ValidationPerceptron<SAMPLE, DI, MR>> {
        private final String type;

        public Creator() {
            this("learner.validation.perceptron");
        }

        public Creator(String str) {
            this.type = str;
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public ValidationPerceptron<SAMPLE, DI, MR> create(ParameterizedExperiment.Parameters parameters, IResourceRepository iResourceRepository) {
            Builder builder = new Builder((IDataCollection) iResourceRepository.get(parameters.get("data")), (IParser) iResourceRepository.get(ParameterizedExperiment.PARSER_RESOURCE), (IValidator) iResourceRepository.get(parameters.get("validator")));
            if ("true".equals(parameters.get("hard"))) {
                builder.setHardUpdates(true);
            }
            if (parameters.contains("parseLogger")) {
                builder.setParserOutputLogger((IOutputLogger) iResourceRepository.get(parameters.get("parseLogger")));
            }
            if (parameters.contains(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN)) {
                builder.setGenlex((ILexiconGenerator) iResourceRepository.get(parameters.get(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN)), (ICategoryServices) iResourceRepository.get(ParameterizedExperiment.CATEGORY_SERVICES_RESOURCE));
            }
            if (parameters.contains("genlexbeam")) {
                builder.setLexiconGenerationBeamSize(Integer.valueOf(parameters.get("genlexbeam")).intValue());
            }
            if (parameters.contains("conflateParses")) {
                builder.setConflateGenlexAndPrunedParses("true".equals(parameters.get("conflateParses")));
            }
            if (parameters.contains("errorDriven")) {
                builder.setErrorDriven("true".equals(parameters.get("errorDriven")));
            }
            if (parameters.contains("margin")) {
                builder.setMargin(Double.valueOf(parameters.get("margin")).doubleValue());
            }
            if (parameters.contains("filterFactory")) {
                builder.setParsingFilterFactory((IParsingFilterFactory) iResourceRepository.get(parameters.get("filterFactory")));
            }
            if (parameters.contains("filter")) {
                builder.setProcessingFilter((IFilter) iResourceRepository.get(parameters.get("filter")));
            }
            if (parameters.contains("iter")) {
                builder.setNumTrainingIterations(Integer.valueOf(parameters.get("iter")).intValue());
            }
            return builder.build();
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public String type() {
            return this.type;
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public ResourceUsage usage() {
            return new ResourceUsage.Builder(type(), ValidationPerceptron.class).setDescription("Validation-based perceptron").addParam("data", SkolemIDRule.RULE_LABEL, "Training data").addParam(ILexiconGenerator.GENLEX_LEXICAL_ORIGIN, "ILexiconGenerator", "GENLEX procedure").addParam("filterFactory", IParsingFilterFactory.class, "Factory to create parsing filters (optional).").addParam("hard", "boolean", "Use hard updates (i.e., only use max scoring valid parses/evaluation as positive samples). Options: true, false. Default: false").addParam("parseLogger", SkolemIDRule.RULE_LABEL, "Parse logger for debug detailed logging of parses").addParam("tester", "ITester", "Intermediate tester to use between epochs").addParam("genlexbeam", "int", "Beam to use for GENLEX inference (parsing).").addParam("margin", "double", "Margin to use for updates. Updates will be done when this margin is violated.").addParam("filter", "IFilter", "Processing filter").addParam("iter", "int", "Number of training iterations").addParam("validator", "IValidator", "Validation function").addParam("conflateParses", "boolean", "Recyle lexical induction parsing output as pruned parsing output").addParam("errorDriven", "boolean", "Error driven lexical generation, if the can generate a valid parse, skip lexical induction").build();
        }
    }

    private ValidationPerceptron(int i, IDataCollection<DI> iDataCollection, Map<DI, MR> map, int i2, IParser<SAMPLE, MR> iParser, IOutputLogger<MR> iOutputLogger, boolean z, boolean z2, ICategoryServices<MR> iCategoryServices, ILexiconGenerator<DI, MR, IModelImmutable<SAMPLE, MR>> iLexiconGenerator, double d, boolean z3, IValidator<DI, MR> iValidator, IFilter<DI> iFilter, IParsingFilterFactory<DI, MR> iParsingFilterFactory) {
        super(i, iDataCollection, map, i2, iOutputLogger, z, z2, iCategoryServices, iLexiconGenerator, iFilter, iParsingFilterFactory);
        this.margin = d;
        this.parser = iParser;
        this.hardUpdates = z3;
        this.validator = iValidator;
        LOG.info("Init ValidationStocGrad: numIterations=%d, margin=%f, trainingData.size()=%d, trainingDataDebug.size()=%d  ...", Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(iDataCollection.size()), Integer.valueOf(map.size()));
        LOG.info("Init ValidationStocGrad: ... lexiconGenerationBeamSize=%d", Integer.valueOf(i2));
        LOG.info("Init ValidationStocGrad: ... conflateParses=%s, errorDriven=%s", z ? "true" : "false", z2 ? "true" : "false");
        LOG.info("Init ValidationStocGrad: ... parsingFilterFactory=%s", iParsingFilterFactory);
    }

    private static <MR, P extends IDerivation<MR>, MODEL extends IModelImmutable<?, MR>> IHashVector constructUpdate(List<P> list, List<P> list2, MODEL model) {
        IHashVector create = HashVectorFactory.create();
        Iterator<P> it2 = list.iterator();
        while (it2.hasNext()) {
            it2.next().getAverageMaxFeatureVector().addTimesInto(1.0d / list.size(), create);
        }
        Iterator<P> it3 = list2.iterator();
        while (it3.hasNext()) {
            it3.next().getAverageMaxFeatureVector().addTimesInto((-1.0d) * (1.0d / list2.size()), create);
        }
        create.dropNoise();
        if (model.isValidWeightVector(create)) {
            return create;
        }
        throw new IllegalStateException("invalid update: " + create);
    }

    private static <LF, P extends IDerivation<LF>, MODEL extends IModelImmutable<?, LF>> Pair<List<P>, List<P>> marginViolatingSets(MODEL model, double d, List<P> list, List<P> list2) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        boolean[] zArr = new boolean[list.size()];
        boolean[] zArr2 = new boolean[list2.size()];
        int i = 0;
        for (P p : list) {
            int i2 = 0;
            for (P p2 : list2) {
                if (!zArr[i] || !zArr2[i2]) {
                    IHashVector addTimes = p.getAverageMaxFeatureVector().addTimes(-1.0d, p2.getAverageMaxFeatureVector());
                    double score = model.score(addTimes);
                    if (!zArr[i] && score < d * addTimes.l1Norm()) {
                        linkedList.add(p);
                        zArr[i] = true;
                    }
                    if (!zArr2[i2] && score < d * addTimes.l1Norm()) {
                        linkedList2.add(p2);
                        zArr2[i2] = true;
                    }
                }
                i2++;
            }
            i++;
        }
        return Pair.of(linkedList, linkedList2);
    }

    private Pair<List<IDerivation<MR>>, List<IDerivation<MR>>> createValidInvalidSets(DI di, IParserOutput<MR> iParserOutput, IParserOutput<MR> iParserOutput2) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        HashSet hashSet = new HashSet();
        for (IDerivation<MR> iDerivation : iParserOutput.getAllDerivations()) {
            if (!validate(di, iDerivation.getSemantics())) {
                linkedList2.add(iDerivation);
                hashSet.add(iDerivation);
            }
        }
        double d = -1.7976931348623157E308d;
        for (IDerivation<MR> iDerivation2 : iParserOutput2.getAllDerivations()) {
            if (validate(di, iDerivation2.getSemantics())) {
                if (!this.hardUpdates) {
                    linkedList.add(iDerivation2);
                } else if (iDerivation2.getScore() > d) {
                    d = iDerivation2.getScore();
                    linkedList.clear();
                    linkedList.add(iDerivation2);
                } else if (iDerivation2.getScore() == d) {
                    linkedList.add(iDerivation2);
                }
            } else if (!hashSet.contains(iDerivation2)) {
                linkedList2.add(iDerivation2);
            }
        }
        return Pair.of(linkedList, linkedList2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner
    protected void parameterUpdate(DI di, IParserOutput<MR> iParserOutput, IParserOutput<MR> iParserOutput2, Model<SAMPLE, MR> model, int i, int i2) {
        IDataItemModel<MR> createDataItemModel = model.createDataItemModel((IDataItem) di.getSample());
        Pair<List<IDerivation<MR>>, List<IDerivation<MR>>> createValidInvalidSets = createValidInvalidSets(di, iParserOutput, iParserOutput2);
        List<IDerivation<MR>> first = createValidInvalidSets.first();
        List<IDerivation<MR>> second = createValidInvalidSets.second();
        LOG.info("%d valid parses, %d invalid parses", Integer.valueOf(first.size()), Integer.valueOf(second.size()));
        LOG.info("Valid parses:");
        Iterator<IDerivation<MR>> it2 = first.iterator();
        while (it2.hasNext()) {
            logParse(di, it2.next(), true, true, createDataItemModel);
        }
        if (iParserOutput.getBestDerivations().size() == 1 && isGoldDebugCorrect(di, iParserOutput.getBestDerivations().get(0).getSemantics())) {
            this.stats.appendSampleStat(i, i2, "G");
        } else if (!first.isEmpty()) {
            this.stats.appendSampleStat(i, i2, "V");
        }
        if (!first.isEmpty()) {
            this.stats.count("Valid", i2);
        }
        if (first.isEmpty() || second.isEmpty()) {
            LOG.info("No valid/invalid parses -- skipping");
            return;
        }
        Pair marginViolatingSets = marginViolatingSets(model, this.margin, first, second);
        List list = (List) marginViolatingSets.first();
        List list2 = (List) marginViolatingSets.second();
        LOG.info("%d violating valid parses, %d violating invalid parses", Integer.valueOf(list.size()), Integer.valueOf(list2.size()));
        if (list.isEmpty()) {
            LOG.info("There are no violating valid/invalid parses -- skipping");
            return;
        }
        LOG.info("Violating valid parses: ");
        Iterator it3 = list.iterator();
        while (it3.hasNext()) {
            logParse(di, (IDerivation) it3.next(), true, true, createDataItemModel);
        }
        LOG.info("Violating invalid parses: ");
        Iterator it4 = list2.iterator();
        while (it4.hasNext()) {
            logParse(di, (IDerivation) it4.next(), false, true, createDataItemModel);
        }
        IHashVector constructUpdate = constructUpdate(list, list2, model);
        LOG.info("Update: %s", constructUpdate);
        constructUpdate.addTimesInto(1.0d, model.getTheta());
        this.stats.appendSampleStat(i, i2, "U");
    }

    @Override // edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner
    protected IParserOutput<MR> parse(DI di, IDataItemModel<MR> iDataItemModel) {
        return this.parser.parse((IDataItem) di.getSample(), iDataItemModel);
    }

    @Override // edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner
    protected IParserOutput<MR> parse(DI di, Predicate<ParsingOp<MR>> predicate, IDataItemModel<MR> iDataItemModel) {
        return this.parser.parse((IParser<SAMPLE, MR>) di.getSample(), predicate, iDataItemModel);
    }

    @Override // edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner
    protected IParserOutput<MR> parse(DI di, Predicate<ParsingOp<MR>> predicate, IDataItemModel<MR> iDataItemModel, ILexiconImmutable<MR> iLexiconImmutable, Integer num) {
        return this.parser.parse((IDataItem) di.getSample(), predicate, iDataItemModel, false, iLexiconImmutable, num);
    }

    @Override // edu.cornell.cs.nlp.spf.learn.validation.AbstractLearner
    protected boolean validate(DI di, MR mr) {
        return this.validator.isValid(di, mr);
    }
}
