package edu.cornell.cs.nlp.spf.parser.joint.graph;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorUtils;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.categories.Category;
import edu.cornell.cs.nlp.spf.parser.graph.IGraphDerivation;
import edu.cornell.cs.nlp.spf.parser.graph.IGraphParserOutput;
import edu.cornell.cs.nlp.spf.parser.joint.AbstractJointOutput;
import edu.cornell.cs.nlp.spf.parser.joint.IEvaluation;
import edu.cornell.cs.nlp.spf.parser.joint.graph.JointGraphDerivation;
import edu.cornell.cs.nlp.utils.collections.IScorer;
import edu.cornell.cs.nlp.utils.collections.ListUtils;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.filter.FilterUtils;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.math.LogSumExp;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/joint/graph/JointGraphOutput.class */
public class JointGraphOutput<MR, ERESULT> extends AbstractJointOutput<MR, ERESULT, JointGraphDerivation<MR, ERESULT>> implements IJointGraphOutput<MR, ERESULT> {
    private final IGraphParserOutput<MR> baseOutput;

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/joint/graph/JointGraphOutput$Builder.class */
    public static class Builder<MR, ERESULT> {
        private final IGraphParserOutput<MR> baseOutput;
        private boolean exactEvaluation = false;
        private final List<Pair<IGraphDerivation<MR>, IEvaluation<ERESULT>>> inferencePairs = new LinkedList();
        private final long inferenceTime;

        public Builder(IGraphParserOutput<MR> iGraphParserOutput, long j) {
            this.baseOutput = iGraphParserOutput;
            this.inferenceTime = j;
        }

        public Builder<MR, ERESULT> addInferencePair(Pair<IGraphDerivation<MR>, IEvaluation<ERESULT>> pair) {
            this.inferencePairs.add(pair);
            return this;
        }

        public Builder<MR, ERESULT> addInferencePairs(List<Pair<IGraphDerivation<MR>, IEvaluation<ERESULT>>> list) {
            this.inferencePairs.addAll(list);
            return this;
        }

        public JointGraphOutput<MR, ERESULT> build() {
            HashMap hashMap = new HashMap();
            for (Pair<IGraphDerivation<MR>, IEvaluation<ERESULT>> pair : this.inferencePairs) {
                ERESULT result = pair.second().getResult();
                if (!hashMap.containsKey(result)) {
                    hashMap.put(result, new JointGraphDerivation.Builder(result));
                }
                ((JointGraphDerivation.Builder) hashMap.get(result)).addInferencePair(pair);
            }
            return new JointGraphOutput<>(this.baseOutput, this.inferenceTime, Collections.unmodifiableList(ListUtils.map(hashMap.values(), new ListUtils.Mapper<JointGraphDerivation.Builder<MR, ERESULT>, JointGraphDerivation<MR, ERESULT>>() { // from class: edu.cornell.cs.nlp.spf.parser.joint.graph.JointGraphOutput.Builder.1
                @Override // edu.cornell.cs.nlp.utils.collections.ListUtils.Mapper
                public JointGraphDerivation<MR, ERESULT> process(JointGraphDerivation.Builder<MR, ERESULT> builder) {
                    return builder.build();
                }
            })), this.exactEvaluation);
        }

        public Builder<MR, ERESULT> setExactEvaluation(boolean z) {
            this.exactEvaluation = z;
            return this;
        }
    }

    public JointGraphOutput(IGraphParserOutput<MR> iGraphParserOutput, long j, List<JointGraphDerivation<MR, ERESULT>> list, boolean z) {
        super(iGraphParserOutput, j, list, z && iGraphParserOutput.isExact());
        this.baseOutput = iGraphParserOutput;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.IJointOutput
    public IGraphParserOutput<MR> getBaseParserOutput() {
        return this.baseOutput;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public IHashVector logExpectedFeatures() {
        return logExpectedFeatures(FilterUtils.stubTrue());
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public IHashVector logExpectedFeatures(IFilter<ERESULT> iFilter) {
        LinkedList linkedList = new LinkedList();
        for (DERIV deriv : this.derivations) {
            if (iFilter.test(deriv.getResult())) {
                linkedList.add(deriv);
            }
        }
        final HashMap hashMap = new HashMap();
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            for (Pair<IGraphDerivation<MR>, ? extends IEvaluation<ERESULT>> pair : ((JointGraphDerivation) it2.next()).getInferencePairs()) {
                Category<MR> category = pair.first().getCategory();
                double score = pair.second().getScore();
                if (hashMap.containsKey(category)) {
                    hashMap.put(category, Double.valueOf(LogSumExp.of(((Double) hashMap.get(category)).doubleValue(), score)));
                } else {
                    hashMap.put(category, Double.valueOf(score));
                }
            }
        }
        IHashVector logExpectedFeatures = this.baseOutput.logExpectedFeatures(new IScorer<Category<MR>>() { // from class: edu.cornell.cs.nlp.spf.parser.joint.graph.JointGraphOutput.1
            @Override // edu.cornell.cs.nlp.utils.collections.IScorer
            public double score(Category<MR> category2) {
                if (hashMap.containsKey(category2)) {
                    return ((Double) hashMap.get(category2)).doubleValue();
                }
                return Double.NEGATIVE_INFINITY;
            }
        });
        Iterator it3 = linkedList.iterator();
        while (it3.hasNext()) {
            for (Pair<IGraphDerivation<MR>, ? extends IEvaluation<ERESULT>> pair2 : ((JointGraphDerivation) it3.next()).getInferencePairs()) {
                HashVectorUtils.logSumExpAdd(pair2.second().getScore() + pair2.first().getLogInsideScore() + IHashVector.ZERO_VALUE, pair2.second().getFeatures(), logExpectedFeatures);
            }
        }
        return logExpectedFeatures;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public double logNorm() {
        return logNorm(FilterUtils.stubTrue());
    }

    @Override // edu.cornell.cs.nlp.spf.parser.joint.graph.IJointGraphOutput
    public double logNorm(IFilter<ERESULT> iFilter) {
        ArrayList arrayList = new ArrayList(this.derivations.size());
        for (DERIV deriv : this.derivations) {
            if (iFilter.test(deriv.getResult())) {
                arrayList.add(Double.valueOf(deriv.getLogInsideScore()));
            }
        }
        return LogSumExp.of(arrayList);
    }
}
