package edu.cornell.cs.nlp.spf.learn.validation;

import edu.cornell.cs.nlp.spf.ccg.categories.ICategoryServices;
import edu.cornell.cs.nlp.spf.ccg.lexicon.ILexiconImmutable;
import edu.cornell.cs.nlp.spf.ccg.lexicon.LexicalEntry;
import edu.cornell.cs.nlp.spf.data.IDataItem;
import edu.cornell.cs.nlp.spf.data.ILabeledDataItem;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.genlex.ccg.ILexiconGenerator;
import edu.cornell.cs.nlp.spf.genlex.ccg.LexiconGenerationServices;
import edu.cornell.cs.nlp.spf.learn.ILearner;
import edu.cornell.cs.nlp.spf.learn.LearningStats;
import edu.cornell.cs.nlp.spf.parser.IDerivation;
import edu.cornell.cs.nlp.spf.parser.IOutputLogger;
import edu.cornell.cs.nlp.spf.parser.IParserOutput;
import edu.cornell.cs.nlp.spf.parser.ParsingOp;
import edu.cornell.cs.nlp.spf.parser.ccg.IWeightedParseStep;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IDataItemModel;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IModelImmutable;
import edu.cornell.cs.nlp.spf.parser.ccg.model.Model;
import edu.cornell.cs.nlp.spf.parser.filter.IParsingFilterFactory;
import edu.cornell.cs.nlp.utils.collections.CollectionUtils;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import edu.cornell.cs.nlp.utils.system.MemoryReport;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/validation/AbstractLearner.class */
public abstract class AbstractLearner<SAMPLE extends IDataItem<?>, DI extends ILabeledDataItem<SAMPLE, ?>, PO extends IParserOutput<MR>, MR> implements ILearner<SAMPLE, DI, Model<SAMPLE, MR>> {
    public static final ILogger LOG = LoggerFactory.create((Class<?>) AbstractLearner.class);
    protected static final String GOLD_LF_IS_MAX = "G";
    protected static final String HAS_VALID_LF = "V";
    protected static final String TRIGGERED_UPDATE = "U";
    private final ICategoryServices<MR> categoryServices;
    private final boolean conflateGenlexAndPrunedParses;
    private final int epochs;
    private final boolean errorDriven;
    private final ILexiconGenerator<DI, MR, IModelImmutable<SAMPLE, MR>> genlex;
    private final Integer lexiconGenerationBeamSize;
    private final IParsingFilterFactory<DI, MR> parsingFilterFactory;
    private final IFilter<DI> processingFilter;
    private final IDataCollection<DI> trainingData;
    private final Map<DI, MR> trainingDataDebug;
    protected final IOutputLogger<MR> parserOutputLogger;
    protected final LearningStats stats;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractLearner(int i, IDataCollection<DI> iDataCollection, Map<DI, MR> map, int i2, IOutputLogger<MR> iOutputLogger, boolean z, boolean z2, ICategoryServices<MR> iCategoryServices, ILexiconGenerator<DI, MR, IModelImmutable<SAMPLE, MR>> iLexiconGenerator, IFilter<DI> iFilter, IParsingFilterFactory<DI, MR> iParsingFilterFactory) {
        this.epochs = i;
        this.trainingData = iDataCollection;
        this.trainingDataDebug = map;
        this.lexiconGenerationBeamSize = Integer.valueOf(i2);
        this.parserOutputLogger = iOutputLogger;
        this.conflateGenlexAndPrunedParses = z;
        this.errorDriven = z2;
        this.categoryServices = iCategoryServices;
        this.genlex = iLexiconGenerator;
        this.processingFilter = iFilter;
        this.parsingFilterFactory = iParsingFilterFactory;
        this.stats = new LearningStats.Builder(iDataCollection.size()).addStat(HAS_VALID_LF, "Has a valid parse").addStat(TRIGGERED_UPDATE, "Sample triggered update").addStat(GOLD_LF_IS_MAX, "The best-scoring LF equals the provided GOLD debug LF").setNumberStat("Number of new lexical entries added").build();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cornell.cs.nlp.spf.learn.ILearner
    public void train(Model<SAMPLE, MR> model) {
        LOG.info("Initializing GENLEX ...");
        this.genlex.init(model);
        for (int i = 0; i < this.epochs; i++) {
            LOG.info("=========================");
            LOG.info("Training epoch %d", Integer.valueOf(i));
            LOG.info("=========================");
            int i2 = -1;
            for (DI di : this.trainingData) {
                long currentTimeMillis = System.currentTimeMillis();
                i2++;
                LOG.info("%d : ================== [%d]", Integer.valueOf(i2), Integer.valueOf(i));
                LOG.info("Sample type: %s", di.getClass().getSimpleName());
                LOG.info("%s", di);
                if (this.processingFilter.test(di)) {
                    this.stats.count("Processed", i);
                    try {
                        IDataItemModel<MR> createDataItemModel = model.createDataItemModel((IDataItem) di.getSample());
                        IParserOutput<MR> parse = parse(di, createDataItemModel);
                        this.stats.mean("Model parse", parse.getParsingTime() / 1000.0d, "sec");
                        this.parserOutputLogger.log(parse, createDataItemModel, String.format("train-%d-%d", Integer.valueOf(i), Integer.valueOf(i2)));
                        List<? extends IDerivation<MR>> allDerivations = parse.getAllDerivations();
                        LOG.info("Model parsing time: %.4fsec", Double.valueOf(parse.getParsingTime() / 1000.0d));
                        LOG.info("Output is %s", parse.isExact() ? "exact" : "approximate");
                        LOG.info("Created %d model parses for training sample:", Integer.valueOf(allDerivations.size()));
                        for (IDerivation<MR> iDerivation : allDerivations) {
                            logParse(di, iDerivation, Boolean.valueOf(validate(di, iDerivation.getSemantics())), true, createDataItemModel);
                        }
                        if (!getValidParses(parse, di).isEmpty() && this.errorDriven) {
                            parameterUpdate(di, parse, parse, model, i2, i);
                            this.stats.mean("Sample processing", (System.currentTimeMillis() - currentTimeMillis) / 1000.0d, "sec");
                            LOG.info("Total sample handling time: %.4fsec", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                        } else if (this.genlex == null) {
                            this.stats.mean("Sample processing", (System.currentTimeMillis() - currentTimeMillis) / 1000.0d, "sec");
                            LOG.info("Total sample handling time: %.4fsec", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                        } else {
                            IParserOutput lexicalInduction = lexicalInduction(di, i2, createDataItemModel, model, i);
                            if (!this.conflateGenlexAndPrunedParses || lexicalInduction == null) {
                                IParserOutput<MR> parse2 = parse(di, this.parsingFilterFactory.create(di), createDataItemModel);
                                LOG.info("Conditioned parsing time: %.4fsec", Double.valueOf(parse2.getParsingTime() / 1000.0d));
                                this.parserOutputLogger.log(parse2, createDataItemModel, String.format("train-%d-%d-conditioned", Integer.valueOf(i), Integer.valueOf(i2)));
                                parameterUpdate(di, parse, parse2, model, i2, i);
                            } else {
                                parameterUpdate(di, parse, lexicalInduction, model, i2, i);
                            }
                        }
                    } finally {
                        this.stats.mean("Sample processing", (System.currentTimeMillis() - currentTimeMillis) / 1000.0d, "sec");
                        LOG.info("Total sample handling time: %.4fsec", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                    }
                } else {
                    LOG.info("Skipped training sample, due to processing filter");
                }
            }
            LOG.info("System memory: %s", MemoryReport.generate());
            LOG.info("Epoch stats:");
            LOG.info(this.stats);
        }
    }

    private List<? extends IDerivation<MR>> getValidParses(PO po, DI di) {
        LinkedList linkedList = new LinkedList(po.getAllDerivations());
        CollectionUtils.filterInPlace(linkedList, iDerivation -> {
            return validate(di, iDerivation.getSemantics());
        });
        return linkedList;
    }

    private PO lexicalInduction(DI di, int i, IDataItemModel<MR> iDataItemModel, Model<SAMPLE, MR> model, int i2) {
        ILexiconImmutable<MR> generate = this.genlex.generate(di, model, this.categoryServices);
        LOG.info("Generated lexicon size = %d", Integer.valueOf(generate.size()));
        if (generate.size() <= 0) {
            LOG.info("Skipped GENLEX step. No generated lexical items.");
            return null;
        }
        PO parse = parse(di, this.parsingFilterFactory.create(di), iDataItemModel, generate, this.lexiconGenerationBeamSize);
        this.stats.mean("genlex parse", parse.getParsingTime() / 1000.0d, "sec");
        LOG.info("Lexicon induction parsing time: %.4fsec", Double.valueOf(parse.getParsingTime() / 1000.0d));
        LOG.info("Output is %s", parse.isExact() ? "exact" : "approximate");
        this.parserOutputLogger.log(parse, iDataItemModel, String.format("train-%d-%d-genlex", Integer.valueOf(i2), Integer.valueOf(i)));
        LOG.info("Created %d lexicon generation parses for training sample", Integer.valueOf(parse.getAllDerivations().size()));
        List<? extends IDerivation<MR>> validParses = getValidParses(parse, di);
        LOG.info("Removed %d invalid parses", Integer.valueOf(parse.getAllDerivations().size() - validParses.size()));
        LinkedList linkedList = new LinkedList();
        double d = -1.7976931348623157E308d;
        for (IDerivation<MR> iDerivation : validParses) {
            if (iDerivation.getScore() > d) {
                d = iDerivation.getScore();
                linkedList.clear();
                linkedList.add(iDerivation);
            } else if (iDerivation.getScore() == d) {
                linkedList.add(iDerivation);
            }
        }
        LOG.info("%d valid best parses for lexical generation:", Integer.valueOf(linkedList.size()));
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            logParse(di, (IDerivation) it2.next(), true, true, iDataItemModel);
        }
        int i3 = 0;
        Iterator it3 = linkedList.iterator();
        while (it3.hasNext()) {
            Iterator<LexicalEntry<MR>> it4 = ((IDerivation) it3.next()).getMaxLexicalEntries().iterator();
            while (it4.hasNext()) {
                LexicalEntry<MR> next = it4.next();
                if (this.genlex.isGenerated(next)) {
                    if (model.addLexEntry(LexiconGenerationServices.unmark(next))) {
                        i3++;
                        LOG.info("Added LexicalEntry to model: %s [%s]", next, model.getTheta().printValues(model.computeFeatures(next)));
                    }
                    for (LexicalEntry<MR> lexicalEntry : next.getLinkedEntries()) {
                        if (model.addLexEntry(LexiconGenerationServices.unmark(lexicalEntry))) {
                            i3++;
                            LOG.info("Added (linked) LexicalEntry to model: %s [%s]", lexicalEntry, model.getTheta().printValues(model.computeFeatures(lexicalEntry)));
                        }
                    }
                }
            }
        }
        if (i3 > 0) {
            this.stats.appendSampleStat(i, i2, i3);
        }
        return parse;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isGoldDebugCorrect(DI di, MR mr) {
        if (this.trainingDataDebug.containsKey(di)) {
            return this.trainingDataDebug.get(di).equals(mr);
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void logParse(DI di, IDerivation<MR> iDerivation, Boolean bool, boolean z, IDataItemModel<MR> iDataItemModel) {
        logParse(di, iDerivation, bool, z, null, iDataItemModel);
    }

    protected void logParse(DI di, IDerivation<MR> iDerivation, Boolean bool, boolean z, String str, IDataItemModel<MR> iDataItemModel) {
        LOG.info("%s%s[%.2f%s] %s", isGoldDebugCorrect(di, iDerivation.getSemantics()) ? "* " : "  ", str == null ? "" : str + " ", Double.valueOf(iDerivation.getScore()), bool == null ? "" : bool.booleanValue() ? ", V" : ", X", iDerivation);
        if (z) {
            Iterator<? extends IWeightedParseStep<MR>> it2 = iDerivation.getMaxSteps().iterator();
            while (it2.hasNext()) {
                LOG.info("\t%s", it2.next().toString(false, false, iDataItemModel.getTheta()));
            }
        }
    }

    protected abstract void parameterUpdate(DI di, PO po, PO po2, Model<SAMPLE, MR> model, int i, int i2);

    protected abstract PO parse(DI di, IDataItemModel<MR> iDataItemModel);

    protected abstract PO parse(DI di, Predicate<ParsingOp<MR>> predicate, IDataItemModel<MR> iDataItemModel);

    protected abstract PO parse(DI di, Predicate<ParsingOp<MR>> predicate, IDataItemModel<MR> iDataItemModel, ILexiconImmutable<MR> iLexiconImmutable, Integer num);

    protected abstract boolean validate(DI di, MR mr);
}
