package edu.cornell.cs.nlp.spf.test;

import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.lexicon.LexicalEntry;
import edu.cornell.cs.nlp.spf.data.IDataItem;
import edu.cornell.cs.nlp.spf.data.ILabeledDataItem;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.explat.IResourceRepository;
import edu.cornell.cs.nlp.spf.explat.ParameterizedExperiment;
import edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator;
import edu.cornell.cs.nlp.spf.explat.resources.usage.ResourceUsage;
import edu.cornell.cs.nlp.spf.parser.IDerivation;
import edu.cornell.cs.nlp.spf.parser.IOutputLogger;
import edu.cornell.cs.nlp.spf.parser.IParser;
import edu.cornell.cs.nlp.spf.parser.IParserOutput;
import edu.cornell.cs.nlp.spf.parser.ccg.IWeightedParseStep;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IDataItemModel;
import edu.cornell.cs.nlp.spf.parser.ccg.model.IModelImmutable;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.skolem.SkolemIDRule;
import edu.cornell.cs.nlp.spf.test.stats.ITestingStatistics;
import edu.cornell.cs.nlp.utils.collections.ListUtils;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/test/Tester.class */
public class Tester<SAMPLE extends IDataItem<?>, MR, DI extends ILabeledDataItem<SAMPLE, MR>> implements ITester<SAMPLE, MR, DI> {
    public static final ILogger LOG = LoggerFactory.create(Tester.class.getName());
    private final IOutputLogger<MR> outputLogger;
    private final IParser<SAMPLE, MR> parser;
    private final IFilter<SAMPLE> skipParsingFilter;
    private final IDataCollection<? extends DI> testData;

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/test/Tester$Builder.class */
    public static class Builder<SAMPLE extends IDataItem<?>, MR, DI extends ILabeledDataItem<SAMPLE, MR>> {
        private final IParser<SAMPLE, MR> parser;
        private final IDataCollection<? extends DI> testData;
        private IOutputLogger<MR> outputLogger = new IOutputLogger<MR>() { // from class: edu.cornell.cs.nlp.spf.test.Tester.Builder.1
            private static final long serialVersionUID = -2828347737693835555L;

            @Override // edu.cornell.cs.nlp.spf.parser.IOutputLogger
            public void log(IParserOutput<MR> iParserOutput, IDataItemModel<MR> iDataItemModel, String str) {
            }
        };
        private IFilter<SAMPLE> skipParsingFilter = iDataItem -> {
            return true;
        };

        public Builder(IDataCollection<? extends DI> iDataCollection, IParser<SAMPLE, MR> iParser) {
            this.testData = iDataCollection;
            this.parser = iParser;
        }

        public Tester<SAMPLE, MR, DI> build() {
            return new Tester<>(this.testData, this.skipParsingFilter, this.parser, this.outputLogger);
        }

        public Builder<SAMPLE, MR, DI> setOutputLogger(IOutputLogger<MR> iOutputLogger) {
            this.outputLogger = iOutputLogger;
            return this;
        }

        public Builder<SAMPLE, MR, DI> setSkipParsingFilter(IFilter<SAMPLE> iFilter) {
            this.skipParsingFilter = iFilter;
            return this;
        }
    }

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/test/Tester$Creator.class */
    public static class Creator<SAMPLE extends IDataItem<?>, LABEL, DI extends ILabeledDataItem<SAMPLE, LABEL>> implements IResourceObjectCreator<Tester<SAMPLE, LABEL, DI>> {
        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public Tester<SAMPLE, LABEL, DI> create(ParameterizedExperiment.Parameters parameters, IResourceRepository iResourceRepository) {
            Object obj = iResourceRepository.get(parameters.get("data"));
            if (obj == null || !(obj instanceof IDataCollection)) {
                throw new RuntimeException("Unknown or non labeled dataset: " + parameters.get("data"));
            }
            IDataCollection iDataCollection = (IDataCollection) obj;
            if (!parameters.contains(ParameterizedExperiment.PARSER_RESOURCE)) {
                throw new IllegalStateException("tester now requires you to provide a parser");
            }
            Builder builder = new Builder(iDataCollection, (IParser) iResourceRepository.get(parameters.get(ParameterizedExperiment.PARSER_RESOURCE)));
            if (parameters.get("skippingFilter") != null) {
                builder.setSkipParsingFilter((IFilter) iResourceRepository.get(parameters.get("skippingFilter")));
            }
            return builder.build();
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public String type() {
            return "tester";
        }

        @Override // edu.cornell.cs.nlp.spf.explat.resources.IResourceObjectCreator
        public ResourceUsage usage() {
            return new ResourceUsage.Builder(type(), Tester.class).setDescription("Model tester. Tests inference using the model on some testing data").addParam("data", SkolemIDRule.RULE_LABEL, "IDataCollection that holds ILabaledDataItem entries").addParam(ParameterizedExperiment.PARSER_RESOURCE, SkolemIDRule.RULE_LABEL, "Parser object").addParam("skippingFilter", SkolemIDRule.RULE_LABEL, "IFilter used to decide which data items to skip").build();
        }
    }

    private Tester(IDataCollection<? extends DI> iDataCollection, IFilter<SAMPLE> iFilter, IParser<SAMPLE, MR> iParser, IOutputLogger<MR> iOutputLogger) {
        this.testData = iDataCollection;
        this.skipParsingFilter = iFilter;
        this.parser = iParser;
        this.outputLogger = iOutputLogger;
        LOG.info("Init Tester:  testData.size()=%d", Integer.valueOf(iDataCollection.size()));
    }

    @Override // edu.cornell.cs.nlp.spf.test.ITester
    public void test(IModelImmutable<SAMPLE, MR> iModelImmutable, ITestingStatistics<SAMPLE, MR, DI> iTestingStatistics) {
        test(this.testData, iModelImmutable, iTestingStatistics);
    }

    private void logDerivation(IDerivation<MR> iDerivation, IDataItemModel<MR> iDataItemModel) {
        LOG.info("[%.2f] %s", Double.valueOf(iDerivation.getScore()), iDerivation);
        Iterator<? extends IWeightedParseStep<MR>> it2 = iDerivation.getMaxSteps().iterator();
        while (it2.hasNext()) {
            LOG.info("\t%s", it2.next().toString(false, false, iDataItemModel.getTheta()));
        }
    }

    private void logParse(ILabeledDataItem<SAMPLE, MR> iLabeledDataItem, IDerivation<MR> iDerivation, boolean z, String str, IModelImmutable<SAMPLE, MR> iModelImmutable) {
        LOG.info("%s%s[S%.2f] %s", iLabeledDataItem.getLabel().equals(iDerivation.getCategory()) ? "* " : "  ", str == null ? "" : str + " ", Double.valueOf(iDerivation.getScore()), iDerivation);
        LOG.info("Calculated score: %f", Double.valueOf(iModelImmutable.score(iDerivation.getAverageMaxFeatureVector())));
        LOG.info("Features: %s", iModelImmutable.getTheta().printValues(iDerivation.getAverageMaxFeatureVector()));
        if (z) {
            Iterator<LexicalEntry<MR>> it2 = iDerivation.getMaxLexicalEntries().iterator();
            while (it2.hasNext()) {
                LexicalEntry<MR> next = it2.next();
                LOG.info("\t[%f] %s", Double.valueOf(iModelImmutable.score(next)), next);
            }
        }
    }

    private void processSingleBestParse(DI di, IDataItemModel<MR> iDataItemModel, IParserOutput<MR> iParserOutput, IDerivation<MR> iDerivation, boolean z, ITestingStatistics<SAMPLE, MR, DI> iTestingStatistics) {
        MR semantics = iDerivation.getSemantics();
        if (z) {
            iTestingStatistics.recordParseWithSkipping(di, iDerivation.getSemantics());
        } else {
            iTestingStatistics.recordParse(di, iDerivation.getSemantics());
        }
        if (di.isCorrect(semantics)) {
            LOG.info("CORRECT");
            logDerivation(iDerivation, iDataItemModel);
            return;
        }
        LOG.info("WRONG", semantics);
        logDerivation(iDerivation, iDataItemModel);
        List<? extends IDerivation<MR>> maxDerivations = iParserOutput.getMaxDerivations(category -> {
            return di.getLabel().equals(category.getSemantics());
        });
        LOG.info("Had correct parses: %s", Boolean.valueOf(!maxDerivations.isEmpty()));
        if (!maxDerivations.isEmpty()) {
            for (IDerivation<MR> iDerivation2 : maxDerivations) {
                LOG.info("Correct derivation:");
                logDerivation(iDerivation2, iDataItemModel);
                IHashVector addTimes = iDerivation2.getAverageMaxFeatureVector().addTimes(-1.0d, iDerivation.getAverageMaxFeatureVector());
                addTimes.dropNoise();
                LOG.info("Diff: %s", iDataItemModel.getTheta().printValues(addTimes));
            }
        }
        LOG.info("Feats: %s", iDataItemModel.getTheta().printValues(iDerivation.getAverageMaxFeatureVector()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void test(IDataCollection<? extends DI> iDataCollection, IModelImmutable<SAMPLE, MR> iModelImmutable, ITestingStatistics<SAMPLE, MR, DI> iTestingStatistics) {
        int i = 0;
        Iterator<DI> it2 = iDataCollection.iterator();
        while (it2.hasNext()) {
            i++;
            test(i, (ILabeledDataItem) it2.next(), iModelImmutable, iTestingStatistics);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void test(int i, DI di, IModelImmutable<SAMPLE, MR> iModelImmutable, ITestingStatistics<SAMPLE, MR, DI> iTestingStatistics) {
        LOG.info("%d : ==================", Integer.valueOf(i));
        LOG.info("%s", di);
        IDataItemModel<MR> createDataItemModel = iModelImmutable.createDataItemModel((IDataItem) di.getSample());
        IParserOutput<MR> parse = this.parser.parse((IDataItem) di.getSample(), createDataItemModel);
        LOG.info("Test parsing time %.2fsec", Double.valueOf(parse.getParsingTime() / 1000.0d));
        this.outputLogger.log(parse, createDataItemModel, String.format("test-%d", Integer.valueOf(i)));
        List<? extends IDerivation<MR>> bestDerivations = parse.getBestDerivations();
        if (bestDerivations.size() == 1) {
            processSingleBestParse(di, createDataItemModel, parse, bestDerivations.get(0), false, iTestingStatistics);
            return;
        }
        if (bestDerivations.size() > 1) {
            iTestingStatistics.recordParses(di, ListUtils.map(bestDerivations, iDerivation -> {
                return iDerivation.getSemantics();
            }));
            LOG.info("too many parses");
            LOG.info("%d parses:", Integer.valueOf(bestDerivations.size()));
            Iterator<? extends IDerivation<MR>> it2 = bestDerivations.iterator();
            while (it2.hasNext()) {
                logParse(di, it2.next(), false, null, iModelImmutable);
            }
            List<? extends IDerivation<MR>> maxDerivations = parse.getMaxDerivations(category -> {
                return di.getLabel().equals(category.getSemantics());
            });
            LOG.info("Had correct parses: %s", Boolean.valueOf(!maxDerivations.isEmpty()));
            if (maxDerivations.isEmpty()) {
                return;
            }
            Iterator<? extends IDerivation<MR>> it3 = maxDerivations.iterator();
            while (it3.hasNext()) {
                logDerivation(it3.next(), createDataItemModel);
            }
            return;
        }
        LOG.info("no parses");
        iTestingStatistics.recordNoParse(di);
        if (!this.skipParsingFilter.test(di.getSample())) {
            LOG.info("Skipping word-skip parsing due to length");
            iTestingStatistics.recordNoParseWithSkipping(di);
            return;
        }
        IParserOutput<MR> parse2 = this.parser.parse((IParser<SAMPLE, MR>) di.getSample(), (IDataItemModel) createDataItemModel, true);
        LOG.info("EMPTY Parsing time %fsec", Double.valueOf(parse2.getParsingTime() / 1000.0d));
        this.outputLogger.log(parse2, createDataItemModel, String.format("test-%d-sloppy", Integer.valueOf(i)));
        List<? extends IDerivation<MR>> bestDerivations2 = parse2.getBestDerivations();
        if (bestDerivations2.size() == 1) {
            processSingleBestParse(di, createDataItemModel, parse2, bestDerivations2.get(0), true, iTestingStatistics);
            return;
        }
        if (bestDerivations2.isEmpty()) {
            LOG.info("no parses");
            iTestingStatistics.recordNoParseWithSkipping(di);
            return;
        }
        iTestingStatistics.recordParsesWithSkipping(di, ListUtils.map(bestDerivations2, iDerivation2 -> {
            return iDerivation2.getSemantics();
        }));
        LOG.info("WRONG: %d parses", Integer.valueOf(bestDerivations2.size()));
        Iterator<? extends IDerivation<MR>> it4 = bestDerivations2.iterator();
        while (it4.hasNext()) {
            logParse(di, it4.next(), false, null, iModelImmutable);
        }
        List<? extends IDerivation<MR>> maxDerivations2 = parse2.getMaxDerivations(category2 -> {
            return di.getLabel().equals(category2.getSemantics());
        });
        LOG.info("Had correct parses: %s", Boolean.valueOf(!maxDerivations2.isEmpty()));
        if (maxDerivations2.isEmpty()) {
            return;
        }
        Iterator<? extends IDerivation<MR>> it5 = maxDerivations2.iterator();
        while (it5.hasNext()) {
            logDerivation(it5.next(), createDataItemModel);
        }
    }
}
