package edu.cornell.cs.nlp.spf.parser.joint.injective.graph;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorUtils;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.categories.Category;
import edu.cornell.cs.nlp.spf.parser.graph.IGraphParserOutput;
import edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput;
import edu.cornell.cs.nlp.spf.parser.joint.injective.AbstractInjectiveJointOutput;
import edu.cornell.cs.nlp.utils.collections.IScorer;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.math.LogSumExp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/joint/injective/graph/InjectiveJointGraphOutput.class */
public class InjectiveJointGraphOutput<MR, ERESULT> extends AbstractInjectiveJointOutput<MR, ERESULT, InjectiveJointGraphDerivation<MR, ERESULT>> implements IJointGraphOutput<MR, ERESULT> {
    private final IGraphParserOutput<MR> baseParserOutput;
    private final Map<ERESULT, InjectiveJointGraphOutput<MR, ERESULT>.ResultCell> resultCells;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/joint/injective/graph/InjectiveJointGraphOutput$ResultCell.class */
    public class ResultCell {
        private double logInsideScore = Double.NEGATIVE_INFINITY;
        private double logOutsideScore = Double.NEGATIVE_INFINITY;
        private final List<InjectiveJointGraphDerivation<MR, ERESULT>> parses = new LinkedList();
        private final ERESULT result;

        public ResultCell(ERESULT eresult, InjectiveJointGraphDerivation<MR, ERESULT> injectiveJointGraphDerivation) {
            this.result = eresult;
            addParse(injectiveJointGraphDerivation);
        }

        public void addParse(InjectiveJointGraphDerivation<MR, ERESULT> injectiveJointGraphDerivation) {
            this.parses.add(injectiveJointGraphDerivation);
            this.logInsideScore = LogSumExp.of(this.logInsideScore, injectiveJointGraphDerivation.getExecResult().getScore() + injectiveJointGraphDerivation.getBaseParse().getLogInsideScore());
        }

        public void initLogOutsideScore(IFilter<ERESULT> iFilter) {
            if (iFilter.test(this.result)) {
                this.logOutsideScore = IHashVector.ZERO_VALUE;
            } else {
                this.logOutsideScore = Double.NEGATIVE_INFINITY;
            }
        }
    }

    public InjectiveJointGraphOutput(IGraphParserOutput<MR> iGraphParserOutput, List<InjectiveJointGraphDerivation<MR, ERESULT>> list, long j, boolean z) {
        super(list, j, z && iGraphParserOutput.isExact());
        this.resultCells = new HashMap();
        this.baseParserOutput = iGraphParserOutput;
        for (InjectiveJointGraphDerivation<MR, ERESULT> injectiveJointGraphDerivation : list) {
            if (this.resultCells.containsKey(injectiveJointGraphDerivation.getResult())) {
                this.resultCells.get(injectiveJointGraphDerivation.getResult()).addParse(injectiveJointGraphDerivation);
            } else {
                this.resultCells.put(injectiveJointGraphDerivation.getResult(), new ResultCell(injectiveJointGraphDerivation.getResult(), injectiveJointGraphDerivation));
            }
        }
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.IJointOutput
    public IGraphParserOutput<MR> getBaseParserOutput() {
        return this.baseParserOutput;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public IHashVector logExpectedFeatures() {
        return logExpectedFeatures(new IFilter<ERESULT>() { // from class: edu.cornell.cs.nlp.spf.parser.joint.injective.graph.InjectiveJointGraphOutput.1
            @Override // edu.cornell.cs.nlp.utils.filter.IFilter
            public boolean test(ERESULT eresult) {
                return true;
            }
        });
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public IHashVector logExpectedFeatures(IFilter<ERESULT> iFilter) {
        final HashMap hashMap = new HashMap();
        for (InjectiveJointGraphOutput<MR, ERESULT>.ResultCell resultCell : this.resultCells.values()) {
            resultCell.initLogOutsideScore(iFilter);
            for (InjectiveJointGraphDerivation injectiveJointGraphDerivation : ((ResultCell) resultCell).parses) {
                Category<MR> category = injectiveJointGraphDerivation.getBaseParse().getCategory();
                double score = ((ResultCell) resultCell).logOutsideScore + injectiveJointGraphDerivation.getExecResult().getScore();
                if (hashMap.containsKey(category)) {
                    hashMap.put(category, Double.valueOf(LogSumExp.of(((Double) hashMap.get(category)).doubleValue(), score)));
                } else {
                    hashMap.put(category, Double.valueOf(score));
                }
            }
        }
        IHashVector logExpectedFeatures = this.baseParserOutput.logExpectedFeatures(new IScorer<Category<MR>>() { // from class: edu.cornell.cs.nlp.spf.parser.joint.injective.graph.InjectiveJointGraphOutput.2
            @Override // edu.cornell.cs.nlp.utils.collections.IScorer
            public double score(Category<MR> category2) {
                if (hashMap.containsKey(category2)) {
                    return ((Double) hashMap.get(category2)).doubleValue();
                }
                return Double.NEGATIVE_INFINITY;
            }
        });
        for (InjectiveJointGraphOutput<MR, ERESULT>.ResultCell resultCell2 : this.resultCells.values()) {
            for (InjectiveJointGraphDerivation injectiveJointGraphDerivation2 : ((ResultCell) resultCell2).parses) {
                HashVectorUtils.logSumExpAdd(injectiveJointGraphDerivation2.getExecResult().getScore() + (injectiveJointGraphDerivation2.getBaseParse().getLogInsideScore() * ((ResultCell) resultCell2).logOutsideScore), injectiveJointGraphDerivation2.getExecResult().getFeatures(), logExpectedFeatures);
            }
        }
        return logExpectedFeatures;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public double logNorm() {
        return logNorm(new IFilter<ERESULT>() { // from class: edu.cornell.cs.nlp.spf.parser.joint.injective.graph.InjectiveJointGraphOutput.3
            @Override // edu.cornell.cs.nlp.utils.filter.IFilter
            public boolean test(ERESULT eresult) {
                return true;
            }
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public double logNorm(IFilter<ERESULT> iFilter) {
        ArrayList arrayList = new ArrayList(this.resultCells.values().size());
        for (InjectiveJointGraphOutput<MR, ERESULT>.ResultCell resultCell : this.resultCells.values()) {
            if (iFilter.test(((ResultCell) resultCell).result)) {
                arrayList.add(Double.valueOf(((ResultCell) resultCell).logInsideScore));
            }
        }
        return LogSumExp.of(arrayList);
    }
}
