package edu.cornell.cs.nlp.spf.parser.ccg.cky.chart;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorFactory;
import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorUtils;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVectorImmutable;
import edu.cornell.cs.nlp.spf.ccg.categories.Category;
import edu.cornell.cs.nlp.spf.ccg.lexicon.LexicalEntry;
import edu.cornell.cs.nlp.spf.mr.language.type.RecursiveComplexType;
import edu.cornell.cs.nlp.spf.parser.RuleUsageTriplet;
import edu.cornell.cs.nlp.spf.parser.ccg.ILexicalParseStep;
import edu.cornell.cs.nlp.spf.parser.ccg.cky.steps.IWeightedCKYStep;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.IArrayRuleNameSet;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.RuleName;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.Span;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import edu.cornell.cs.nlp.utils.math.LogSumExp;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.apache.commons.cli.HelpFormatter;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Cell.class */
public class Cell<MR> implements IArrayRuleNameSet {
    public static final ILogger LOG;
    private final Category<MR> category;
    private final int end;
    private final int hashCodeCache;
    private final boolean isCompleteSpan;
    private final boolean isFullParse;
    private boolean isMax;
    private final int start;
    static final /* synthetic */ boolean $assertionsDisabled;
    private RuleName[] generatingRules = null;
    private double logInsideScore = Double.NEGATIVE_INFINITY;
    private double logOutsideScore = Double.NEGATIVE_INFINITY;
    private int numViterbiSteps = 0;
    private final Set<IWeightedCKYStep<MR>> steps = new HashSet();
    private double viterbiScore = -1.7976931348623157E308d;
    protected long numParses = 0;
    protected long numViterbiParses = 0;
    protected List<IWeightedCKYStep<MR>> viterbiSteps = null;

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Cell$ScoreComparator.class */
    public static class ScoreComparator<MR> implements Comparator<Cell<MR>>, Serializable {
        private static final long serialVersionUID = 5348011347391634770L;

        @Override // java.util.Comparator
        public int compare(Cell<MR> cell, Cell<MR> cell2) {
            int compare = Double.compare(cell.getPruneScore(), cell2.getPruneScore());
            return compare == 0 ? Double.compare(cell.getSecondPruneScore(), cell2.getSecondPruneScore()) : compare;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Cell(IWeightedCKYStep<MR> iWeightedCKYStep, boolean z) {
        this.isCompleteSpan = z;
        this.isFullParse = iWeightedCKYStep.isFullParse();
        this.category = iWeightedCKYStep.getRoot();
        this.start = iWeightedCKYStep.getStart();
        this.end = iWeightedCKYStep.getEnd();
        this.steps.add(iWeightedCKYStep);
        updateScores(iWeightedCKYStep);
        this.hashCodeCache = calcHashCode();
    }

    public boolean addCell(Cell<MR> cell) {
        boolean z = false;
        for (IWeightedCKYStep<MR> iWeightedCKYStep : cell.steps) {
            if (!$assertionsDisabled && (iWeightedCKYStep.getStart() != this.start || iWeightedCKYStep.getEnd() != this.end)) {
                throw new AssertionError();
            }
            if (this.steps.add(iWeightedCKYStep)) {
                this.viterbiSteps = null;
                this.generatingRules = null;
                z = updateScores(iWeightedCKYStep) || z;
            }
        }
        return z;
    }

    public IHashVector computeMaxAvgFeaturesRecursively() {
        return computeMaxAvgFeaturesRecursively(new HashMap());
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Cell cell = (Cell) obj;
        return this.start == cell.start && this.category.equals(cell.category) && this.end == cell.end;
    }

    public LinkedHashSet<LexicalEntry<MR>> getAllLexicalEntriesRecursively() {
        LinkedHashSet<LexicalEntry<MR>> linkedHashSet = new LinkedHashSet<>();
        recursiveGetLexicalEntries(linkedHashSet, new HashSet(), false);
        return linkedHashSet;
    }

    public Set<IWeightedCKYStep<MR>> getAllSteps() {
        LinkedHashSet<IWeightedCKYStep<MR>> linkedHashSet = new LinkedHashSet<>();
        recursiveGetParseSteps(linkedHashSet, new HashSet<>(), false);
        return linkedHashSet;
    }

    public Category<MR> getCategory() {
        return this.category;
    }

    public int getEnd() {
        return this.end;
    }

    public double getLogInsideScore() {
        return this.logInsideScore;
    }

    public LinkedHashSet<LexicalEntry<MR>> getMaxLexicalEntriesRecursively() {
        LinkedHashSet<LexicalEntry<MR>> linkedHashSet = new LinkedHashSet<>();
        recursiveGetLexicalEntries(linkedHashSet, new HashSet(), true);
        return linkedHashSet;
    }

    public LinkedHashSet<RuleUsageTriplet> getMaxRulesUsedRecursively() {
        LinkedHashSet<RuleUsageTriplet> linkedHashSet = new LinkedHashSet<>();
        recursiveGetMaxRulesUsed(linkedHashSet, new HashSet<>());
        return linkedHashSet;
    }

    public LinkedHashSet<IWeightedCKYStep<MR>> getMaxSteps() {
        LinkedHashSet<IWeightedCKYStep<MR>> linkedHashSet = new LinkedHashSet<>();
        recursiveGetParseSteps(linkedHashSet, new HashSet<>(), true);
        return linkedHashSet;
    }

    public long getNumParses() {
        return this.numParses;
    }

    public long getNumViterbiParses() {
        return this.numViterbiParses;
    }

    public double getPruneScore() {
        return this.viterbiScore;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.ccg.rules.IArrayRuleNameSet
    public RuleName getRuleName(int i) {
        if (this.generatingRules == null) {
            createGeneratingRules();
        }
        return this.generatingRules[i];
    }

    public double getSecondPruneScore() {
        return getPruneScore();
    }

    public int getStart() {
        return this.start;
    }

    public Set<IWeightedCKYStep<MR>> getSteps() {
        return Collections.unmodifiableSet(this.steps);
    }

    public Set<LexicalEntry<MR>> getViterbiLexicalEntries() {
        HashSet hashSet = new HashSet();
        for (IWeightedCKYStep<MR> iWeightedCKYStep : getViterbiSteps()) {
            if (iWeightedCKYStep instanceof ILexicalParseStep) {
                hashSet.add(((ILexicalParseStep) iWeightedCKYStep).getLexicalEntry());
            }
        }
        return hashSet;
    }

    public double getViterbiScore() {
        return this.viterbiScore;
    }

    public List<IWeightedCKYStep<MR>> getViterbiSteps() {
        if (this.viterbiSteps == null) {
            computeViterbiSteps();
        }
        return Collections.unmodifiableList(this.viterbiSteps);
    }

    public int hashCode() {
        return this.hashCodeCache;
    }

    public boolean hasLexicalMaxStep() {
        Iterator<IWeightedCKYStep<MR>> it2 = getViterbiSteps().iterator();
        while (it2.hasNext()) {
            if (it2.next() instanceof ILexicalParseStep) {
                return true;
            }
        }
        return false;
    }

    public boolean hasLexicalStep() {
        Iterator<IWeightedCKYStep<MR>> it2 = this.steps.iterator();
        while (it2.hasNext()) {
            if (it2.next() instanceof ILexicalParseStep) {
                return true;
            }
        }
        return false;
    }

    public boolean isCompleteSpan() {
        return this.isCompleteSpan;
    }

    public boolean isFullParse() {
        return this.isFullParse;
    }

    public boolean isMax() {
        return this.isMax;
    }

    @Override // edu.cornell.cs.nlp.spf.parser.ccg.rules.IArrayRuleNameSet
    public int numRuleNames() {
        if (this.generatingRules == null) {
            createGeneratingRules();
        }
        return this.generatingRules.length;
    }

    public int numSteps() {
        return this.steps.size();
    }

    public String toString() {
        return toString(false, null, true, null);
    }

    public String toString(boolean z, String str, boolean z2, IHashVectorImmutable iHashVectorImmutable) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("[");
        stringBuffer.append(this.start).append(HelpFormatter.DEFAULT_OPT_PREFIX).append(this.end).append(" : ").append(str == null ? "" : str).append(str == null ? "" : " :- ").append(this.category).append(" : ").append("prune=").append(getPruneScore() == getSecondPruneScore() ? Double.valueOf(getPruneScore()) : String.format("(%f,%f)", Double.valueOf(getPruneScore()), Double.valueOf(getSecondPruneScore()))).append(" : ").append("numParses=").append(this.numParses).append(" : ").append("numViterbiParses=").append(this.numViterbiParses).append(" : ").append("hash=").append(hashCode()).append(" : ").append(this.steps.size()).append(" : ").append(this.viterbiScore).append(" : ");
        stringBuffer.append("[");
        if (z2) {
            Iterator<IWeightedCKYStep<MR>> it2 = getViterbiSteps().iterator();
            while (it2.hasNext()) {
                stringBuffer.append(it2.next().toString(true, z, iHashVectorImmutable));
                if (it2.hasNext()) {
                    stringBuffer.append(", ");
                }
            }
        } else {
            Iterator<IWeightedCKYStep<MR>> it3 = this.steps.iterator();
            while (it3.hasNext()) {
                IWeightedCKYStep<MR> next = it3.next();
                if (getViterbiSteps().contains(next)) {
                    stringBuffer.append(RecursiveComplexType.Option.DOMAIN_REPEAT_ORDER_INSENSITIVE);
                }
                stringBuffer.append(next.toString(true, z, iHashVectorImmutable));
                if (it3.hasNext()) {
                    stringBuffer.append(", ");
                }
            }
        }
        stringBuffer.append("]");
        stringBuffer.append("]");
        return stringBuffer.toString();
    }

    private int calcHashCode() {
        return (31 * ((31 * ((31 * 1) + this.start)) + (this.category == null ? 0 : this.category.hashCode()))) + this.end;
    }

    private IHashVector computeMaxAvgFeaturesRecursively(Map<Cell<MR>, IHashVector> map) {
        IHashVector create = HashVectorFactory.create();
        int i = 0;
        IHashVector iHashVector = map.get(this);
        if (iHashVector != null) {
            return iHashVector;
        }
        for (IWeightedCKYStep<MR> iWeightedCKYStep : getViterbiSteps()) {
            Iterator it2 = iWeightedCKYStep.iterator();
            while (it2.hasNext()) {
                ((Cell) it2.next()).computeMaxAvgFeaturesRecursively(map).addTimesInto(1.0d, create);
            }
            iWeightedCKYStep.getStepFeatures().addTimesInto(1.0d, create);
            i++;
        }
        if (i > 1) {
            create.divideBy(i);
        }
        map.put(this, create);
        return create;
    }

    private void computeViterbiSteps() {
        IWeightedCKYStep[] iWeightedCKYStepArr = (IWeightedCKYStep[]) Array.newInstance((Class<?>) IWeightedCKYStep.class, this.numViterbiSteps);
        int i = 0;
        for (IWeightedCKYStep<MR> iWeightedCKYStep : this.steps) {
            double stepScore = iWeightedCKYStep.getStepScore();
            Iterator it2 = iWeightedCKYStep.iterator();
            while (it2.hasNext()) {
                stepScore += ((Cell) it2.next()).getViterbiScore();
            }
            if (!$assertionsDisabled && stepScore > this.viterbiScore) {
                throw new AssertionError();
            }
            if (stepScore == this.viterbiScore) {
                int i2 = i;
                i++;
                iWeightedCKYStepArr[i2] = iWeightedCKYStep;
            }
        }
        if (!$assertionsDisabled && i != this.numViterbiSteps) {
            throw new AssertionError();
        }
        this.viterbiSteps = Arrays.asList(iWeightedCKYStepArr);
    }

    private void createGeneratingRules() {
        HashSet hashSet = new HashSet();
        Iterator<IWeightedCKYStep<MR>> it2 = this.steps.iterator();
        while (it2.hasNext()) {
            hashSet.add(it2.next().getRuleName());
        }
        this.generatingRules = (RuleName[]) hashSet.toArray(new RuleName[hashSet.size()]);
    }

    private void recursiveGetLexicalEntries(LinkedHashSet<LexicalEntry<MR>> linkedHashSet, Set<Cell<MR>> set, boolean z) {
        if (set.contains(this)) {
            return;
        }
        for (IWeightedCKYStep iWeightedCKYStep : z ? getViterbiSteps() : this.steps) {
            if (iWeightedCKYStep instanceof ILexicalParseStep) {
                linkedHashSet.add(((ILexicalParseStep) iWeightedCKYStep).getLexicalEntry());
            }
            Iterator it2 = iWeightedCKYStep.iterator();
            while (it2.hasNext()) {
                ((Cell) it2.next()).recursiveGetLexicalEntries(linkedHashSet, set, z);
            }
        }
        set.add(this);
    }

    private void recursiveGetMaxRulesUsed(LinkedHashSet<RuleUsageTriplet> linkedHashSet, HashSet<Cell<MR>> hashSet) {
        if (hashSet.contains(this)) {
            return;
        }
        for (IWeightedCKYStep<MR> iWeightedCKYStep : getViterbiSteps()) {
            ArrayList arrayList = new ArrayList();
            for (MR mr : iWeightedCKYStep) {
                mr.recursiveGetMaxRulesUsed(linkedHashSet, hashSet);
                arrayList.add(Pair.of(Integer.valueOf(mr.getStart()), Integer.valueOf(mr.getEnd())));
            }
            if (arrayList.isEmpty()) {
                arrayList.add(Pair.of(Integer.valueOf(this.start), Integer.valueOf(this.end)));
            }
            linkedHashSet.add(new RuleUsageTriplet(iWeightedCKYStep.getRuleName(), arrayList));
        }
        hashSet.add(this);
    }

    private void recursiveGetParseSteps(LinkedHashSet<IWeightedCKYStep<MR>> linkedHashSet, HashSet<Cell<MR>> hashSet, boolean z) {
        if (hashSet.contains(this)) {
            return;
        }
        Iterator it2 = (z ? getViterbiSteps() : this.steps).iterator();
        while (it2.hasNext()) {
            Iterator it3 = ((IWeightedCKYStep) it2.next()).iterator();
            while (it3.hasNext()) {
                ((Cell) it3.next()).recursiveGetParseSteps(linkedHashSet, hashSet, z);
            }
        }
        linkedHashSet.addAll(z ? getViterbiSteps() : this.steps);
        hashSet.add(this);
    }

    private boolean updateScores(IWeightedCKYStep<MR> iWeightedCKYStep) {
        LOG.debug("Updating score from: %s", iWeightedCKYStep);
        double stepScore = iWeightedCKYStep.getStepScore();
        LOG.debug("Step local score: %f", Double.valueOf(stepScore));
        long j = 1;
        long j2 = 1;
        double stepScore2 = iWeightedCKYStep.getStepScore();
        for (MR mr : iWeightedCKYStep) {
            stepScore2 += mr.getLogInsideScore();
            stepScore += mr.getViterbiScore();
            j *= mr.numParses;
            j2 *= mr.numViterbiParses;
        }
        LOG.debug("Step viterbi score: %f", Double.valueOf(stepScore));
        LOG.debug("Step contribution to inside score: %f", Double.valueOf(stepScore2));
        LOG.debug("# parses in step: %d", Long.valueOf(j));
        LOG.debug("# viterbi parses in step: %d", Long.valueOf(j2));
        this.logInsideScore = LogSumExp.of(this.logInsideScore, stepScore2);
        this.numParses += j;
        if (stepScore == this.viterbiScore) {
            LOG.debug("Step is a viterbi step");
            this.numViterbiParses += j2;
            this.numViterbiSteps++;
            return true;
        }
        if (stepScore <= this.viterbiScore) {
            return false;
        }
        LOG.debug("Step re-set viterbi score, step is a viterbi step");
        this.viterbiScore = stepScore;
        this.numViterbiParses = j2;
        this.numViterbiSteps = 1;
        return true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void collectLogExpectedFeatures(IHashVector iHashVector) {
        if (this.logOutsideScore != Double.NEGATIVE_INFINITY) {
            for (IWeightedCKYStep<MR> iWeightedCKYStep : this.steps) {
                double stepScore = this.logOutsideScore + iWeightedCKYStep.getStepScore();
                Iterator it2 = iWeightedCKYStep.iterator();
                while (it2.hasNext()) {
                    stepScore += ((Cell) it2.next()).logInsideScore;
                }
                HashVectorUtils.logSumExpAdd(stepScore, iWeightedCKYStep.getStepFeatures(), iHashVector);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initializeLogOutsideProbabilities(Function<Category<MR>, Double> function, Span span) {
        if (span.getStart() == this.start && span.getEnd() == this.end) {
            this.logOutsideScore = function.apply(this.category).doubleValue();
        } else {
            this.logOutsideScore = Double.NEGATIVE_INFINITY;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void recomputeScores() {
        this.numParses = 0L;
        this.numViterbiParses = 0L;
        this.viterbiScore = -1.7976931348623157E308d;
        this.logInsideScore = Double.NEGATIVE_INFINITY;
        this.viterbiSteps = null;
        this.generatingRules = null;
        Iterator<IWeightedCKYStep<MR>> it2 = this.steps.iterator();
        while (it2.hasNext()) {
            updateScores(it2.next());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setIsMax(boolean z) {
        this.isMax = z;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateBinaryChildrenLogOutsideScore() {
        if (this.logOutsideScore != Double.NEGATIVE_INFINITY) {
            for (IWeightedCKYStep<MR> iWeightedCKYStep : this.steps) {
                if (iWeightedCKYStep.numChildren() == 2) {
                    double stepScore = iWeightedCKYStep.getStepScore();
                    Cell<MR> childCell = iWeightedCKYStep.getChildCell(0);
                    Cell<MR> childCell2 = iWeightedCKYStep.getChildCell(1);
                    childCell.logOutsideScore = LogSumExp.of(childCell.logOutsideScore, this.logOutsideScore + childCell2.getLogInsideScore() + stepScore);
                    childCell2.logOutsideScore = LogSumExp.of(childCell2.logOutsideScore, this.logOutsideScore + childCell.getLogInsideScore() + stepScore);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateUnaryChildrenLogOutsideScore() {
        for (IWeightedCKYStep<MR> iWeightedCKYStep : this.steps) {
            if (iWeightedCKYStep.numChildren() == 1) {
                iWeightedCKYStep.getChildCell(0).logOutsideScore = LogSumExp.of(iWeightedCKYStep.getChildCell(0).logOutsideScore, this.logOutsideScore + iWeightedCKYStep.getStepScore());
            }
        }
    }

    static {
        $assertionsDisabled = !Cell.class.desiredAssertionStatus();
        LOG = LoggerFactory.create((Class<?>) Cell.class);
    }
}
