package edu.cornell.cs.nlp.spf.learn.situated;

import edu.cornell.cs.nlp.spf.ccg.categories.ICategoryServices;
import edu.cornell.cs.nlp.spf.ccg.lexicon.ILexiconImmutable;
import edu.cornell.cs.nlp.spf.ccg.lexicon.LexicalEntry;
import edu.cornell.cs.nlp.spf.data.ILabeledDataItem;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.data.sentence.Sentence;
import edu.cornell.cs.nlp.spf.data.situated.ISituatedDataItem;
import edu.cornell.cs.nlp.spf.genlex.ccg.ILexiconGenerator;
import edu.cornell.cs.nlp.spf.genlex.ccg.LexiconGenerationServices;
import edu.cornell.cs.nlp.spf.learn.ILearner;
import edu.cornell.cs.nlp.spf.learn.LearningStats;
import edu.cornell.cs.nlp.spf.parser.ccg.IWeightedParseStep;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IDataItemModel;
import edu.cornell.cs.nlp.spf.parser.joint.IJointDerivation;
import edu.cornell.cs.nlp.spf.parser.joint.IJointOutput;
import edu.cornell.cs.nlp.spf.parser.joint.IJointOutputLogger;
import edu.cornell.cs.nlp.spf.parser.joint.IJointParser;
import edu.cornell.cs.nlp.spf.parser.joint.model.IJointDataItemModel;
import edu.cornell.cs.nlp.spf.parser.joint.model.IJointModelImmutable;
import edu.cornell.cs.nlp.spf.parser.joint.model.JointModel;
import edu.cornell.cs.nlp.utils.collections.CollectionUtils;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import edu.cornell.cs.nlp.utils.system.MemoryReport;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/situated/AbstractSituatedLearner.class */
public abstract class AbstractSituatedLearner<SAMPLE extends ISituatedDataItem<Sentence, ?>, MR, ESTEP, ERESULT, DI extends ILabeledDataItem<SAMPLE, ?>> implements ILearner<SAMPLE, DI, JointModel<SAMPLE, MR, ESTEP>> {
    public static final ILogger LOG = LoggerFactory.create((Class<?>) AbstractSituatedLearner.class);
    protected static final String GOLD_LF_IS_MAX = "G";
    protected static final String HAS_VALID_LF = "V";
    protected static final String TRIGGERED_UPDATE = "U";
    private final ICategoryServices<MR> categoryServices;
    private final int epochs;
    private final ILexiconGenerator<DI, MR, IJointModelImmutable<SAMPLE, MR, ESTEP>> genlex;
    private final int lexiconGenerationBeamSize;
    private final int maxSentenceLength;
    private final IDataCollection<DI> trainingData;
    private final Map<DI, Pair<MR, ERESULT>> trainingDataDebug;
    protected final IJointParser<SAMPLE, MR, ESTEP, ERESULT> parser;
    protected final IJointOutputLogger<MR, ESTEP, ERESULT> parserOutputLogger;
    protected final LearningStats stats;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSituatedLearner(int i, IDataCollection<DI> iDataCollection, Map<DI, Pair<MR, ERESULT>> map, int i2, int i3, IJointParser<SAMPLE, MR, ESTEP, ERESULT> iJointParser, IJointOutputLogger<MR, ESTEP, ERESULT> iJointOutputLogger, ICategoryServices<MR> iCategoryServices, ILexiconGenerator<DI, MR, IJointModelImmutable<SAMPLE, MR, ESTEP>> iLexiconGenerator) {
        this.epochs = i;
        this.trainingData = iDataCollection;
        this.trainingDataDebug = map;
        this.maxSentenceLength = i2;
        this.lexiconGenerationBeamSize = i3;
        this.parser = iJointParser;
        this.parserOutputLogger = iJointOutputLogger;
        this.categoryServices = iCategoryServices;
        this.genlex = iLexiconGenerator;
        this.stats = new LearningStats.Builder(iDataCollection.size()).addStat(HAS_VALID_LF, "Has a valid parse").addStat(TRIGGERED_UPDATE, "Sample triggered update").addStat(GOLD_LF_IS_MAX, "The best-scoring LF equals the provided GOLD debug LF").setNumberStat("Number of new lexical entries added for sample").build();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cornell.cs.nlp.spf.learn.ILearner
    public void train(JointModel<SAMPLE, MR, ESTEP> jointModel) {
        LOG.info("Initializing GENLEX ...");
        this.genlex.init(jointModel);
        for (int i = 0; i < this.epochs; i++) {
            LOG.info("=========================");
            LOG.info("Training epoch %d", Integer.valueOf(i));
            LOG.info("=========================");
            int i2 = -1;
            for (DI di : this.trainingData) {
                long currentTimeMillis = System.currentTimeMillis();
                i2++;
                LOG.info("%d : ================== [%d]", Integer.valueOf(i2), Integer.valueOf(i));
                LOG.info("Sample type: %s", di.getClass().getSimpleName());
                LOG.info("%s", di);
                if (((Sentence) ((ISituatedDataItem) di.getSample()).getSample()).getTokens().size() > this.maxSentenceLength) {
                    LOG.warn("Training sample too long, skipping");
                } else {
                    IJointDataItemModel createJointDataItemModel = jointModel.createJointDataItemModel((ISituatedDataItem) di.getSample());
                    if (this.genlex != null) {
                        lexicalInduction(di, createJointDataItemModel, jointModel, i2, i);
                    }
                    parameterUpdate(di, createJointDataItemModel, jointModel, i2, i);
                    this.stats.mean("sample processing", (System.currentTimeMillis() - currentTimeMillis) / 1000.0d, "sec");
                    this.stats.count("processed", i);
                    LOG.info("Total sample handling time: %.4fsec", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                }
            }
            LOG.info("System memory: %s", MemoryReport.generate());
            LOG.info("Epoch stats:");
            LOG.info(this.stats);
        }
    }

    private void lexicalInduction(DI di, IJointDataItemModel<MR, ESTEP> iJointDataItemModel, JointModel<SAMPLE, MR, ESTEP> jointModel, int i, int i2) {
        ILexiconImmutable<MR> generate = this.genlex.generate(di, jointModel, this.categoryServices);
        LOG.info("Generated lexicon size = %d", Integer.valueOf(generate.size()));
        if (generate.size() <= 0) {
            LOG.info("Skipped GENLEX step. No generated lexical items.");
            return;
        }
        IJointOutput<MR, ERESULT> parse = this.parser.parse((IJointParser<SAMPLE, MR, ESTEP, ERESULT>) di.getSample(), (IJointDataItemModel) iJointDataItemModel, false, (ILexiconImmutable) generate, Integer.valueOf(this.lexiconGenerationBeamSize));
        this.stats.mean("genlex parse", parse.getInferenceTime() / 1000.0d, "sec");
        LOG.info("Lexicon induction parsing time: %.4fsec", Double.valueOf(parse.getInferenceTime() / 1000.0d));
        LOG.info("Output is %s", parse.isExact() ? "exact" : "approximate");
        this.parserOutputLogger.log(parse, iJointDataItemModel, String.format("%d-genlex", Integer.valueOf(i)));
        LinkedList<IJointDerivation> linkedList = new LinkedList(parse.getDerivations());
        LOG.info("Created %d lexicon generation parses for training sample", Integer.valueOf(linkedList.size()));
        CollectionUtils.filterInPlace(linkedList, iJointDerivation -> {
            return validate(di, iJointDerivation.getResult());
        });
        LOG.info("Removed %d invalid parses", Integer.valueOf(parse.getDerivations().size() - linkedList.size()));
        LinkedList linkedList2 = new LinkedList();
        double d = -1.7976931348623157E308d;
        for (IJointDerivation iJointDerivation2 : linkedList) {
            if (iJointDerivation2.getViterbiScore() > d) {
                d = iJointDerivation2.getViterbiScore();
                linkedList2.clear();
                linkedList2.add(iJointDerivation2);
            } else if (iJointDerivation2.getViterbiScore() == d) {
                linkedList2.add(iJointDerivation2);
            }
        }
        LOG.info("%d valid best parses for lexical generation:", Integer.valueOf(linkedList2.size()));
        Iterator it2 = linkedList2.iterator();
        while (it2.hasNext()) {
            logParse(di, (IJointDerivation) it2.next(), true, true, iJointDataItemModel);
        }
        int i3 = 0;
        Iterator it3 = linkedList2.iterator();
        while (it3.hasNext()) {
            Iterator<LexicalEntry<MR>> it4 = ((IJointDerivation) it3.next()).getMaxLexicalEntries().iterator();
            while (it4.hasNext()) {
                LexicalEntry<MR> next = it4.next();
                if (this.genlex.isGenerated(next)) {
                    if (jointModel.addLexEntry(LexiconGenerationServices.unmark(next))) {
                        i3++;
                        LOG.info("Added LexicalEntry to model: %s [%s]", next, jointModel.getTheta().printValues(jointModel.computeFeatures(next)));
                    }
                    for (LexicalEntry<MR> lexicalEntry : next.getLinkedEntries()) {
                        if (jointModel.addLexEntry(LexiconGenerationServices.unmark(lexicalEntry))) {
                            i3++;
                            LOG.info("Added (linked) LexicalEntry to model: %s [%s]", lexicalEntry, jointModel.getTheta().printValues(jointModel.computeFeatures(lexicalEntry)));
                        }
                    }
                }
            }
        }
        if (i3 > 0) {
            this.stats.appendSampleStat(i, i2, i3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isGoldDebugCorrect(DI di, ERESULT eresult) {
        if (this.trainingDataDebug.containsKey(di)) {
            return this.trainingDataDebug.get(di).equals(eresult);
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void logParse(DI di, IJointDerivation<MR, ERESULT> iJointDerivation, Boolean bool, boolean z, IDataItemModel<MR> iDataItemModel) {
        logParse(di, iJointDerivation, bool, z, null, iDataItemModel);
    }

    protected void logParse(DI di, IJointDerivation<MR, ERESULT> iJointDerivation, Boolean bool, boolean z, String str, IDataItemModel<MR> iDataItemModel) {
        LOG.info("%s%s[%.2f%s] %s", isGoldDebugCorrect(di, iJointDerivation.getResult()) ? "* " : "  ", str == null ? "" : str + " ", Double.valueOf(iJointDerivation.getViterbiScore()), bool == null ? "" : bool.booleanValue() ? ", V" : ", X", iJointDerivation);
        if (z) {
            Iterator<? extends IWeightedParseStep<MR>> it2 = iJointDerivation.getMaxSteps().iterator();
            while (it2.hasNext()) {
                LOG.info("\t%s", it2.next().toString(false, false, iDataItemModel.getTheta()));
            }
        }
    }

    protected abstract void parameterUpdate(DI di, IJointDataItemModel<MR, ESTEP> iJointDataItemModel, JointModel<SAMPLE, MR, ESTEP> jointModel, int i, int i2);

    protected abstract boolean validate(DI di, ERESULT eresult);
}
