package edu.cornell.cs.nlp.spf.parser.ccg.cky.chart;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorFactory;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVectorImmutable;
import edu.cornell.cs.nlp.spf.base.token.TokenSeq;
import edu.cornell.cs.nlp.spf.ccg.categories.Category;
import edu.cornell.cs.nlp.spf.parser.ccg.cky.CKYDerivation;
import edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Cell;
import edu.cornell.cs.nlp.spf.parser.ccg.cky.steps.IWeightedCKYStep;
import edu.cornell.cs.nlp.spf.parser.ccg.rules.Span;
import edu.cornell.cs.nlp.utils.collections.CollectionUtils;
import edu.cornell.cs.nlp.utils.collections.IScorer;
import edu.cornell.cs.nlp.utils.collections.ListUtils;
import edu.cornell.cs.nlp.utils.collections.iterators.CompositeIterator;
import edu.cornell.cs.nlp.utils.collections.queue.DirectAccessBoundedPriorityQueue;
import edu.cornell.cs.nlp.utils.collections.queue.IDirectAccessBoundedPriorityQueue;
import edu.cornell.cs.nlp.utils.collections.queue.OrderInvariantDirectAccessBoundedQueue;
import edu.cornell.cs.nlp.utils.composites.Pair;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import edu.cornell.cs.nlp.utils.math.LogSumExp;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Chart.class */
public class Chart<MR> implements Iterable<Cell<MR>> {
    public static final ILogger LOG;
    private final int beamSize;
    private final AbstractCellFactory<MR> cellFactory;
    private final AbstractSpan<MR>[][] chart;
    private final int sentenceLength;
    private final TokenSeq tokens;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Chart$AbstractSpan.class */
    public static abstract class AbstractSpan<MR> implements Iterable<Cell<MR>> {
        protected boolean externallyPruned;

        private AbstractSpan() {
            this.externallyPruned = false;
        }

        public abstract void addToExisting(Cell<MR> cell, Cell<MR> cell2);

        public abstract Cell<MR> get(Cell<MR> cell);

        public abstract boolean isPruned();

        public abstract Pair<Double, Double> minQeueuScore();

        public abstract boolean offer(Cell<MR> cell);

        public abstract int size();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Chart$CellIterator.class */
    public class CellIterator implements Iterator<Cell<MR>> {
        private final Comparator<Cell<MR>> comparator;
        private final int end;
        private int i;
        private int j;
        private Iterator<Cell<MR>> spanIterator;

        public CellIterator(int i, int i2, Comparator<Cell<MR>> comparator) {
            this.end = i2;
            this.i = i;
            this.comparator = comparator;
            this.j = this.i;
            this.spanIterator = comparator == null ? Chart.this.getSpanIterator(this.i, this.j) : Chart.this.getSpanIterator(this.i, this.j, comparator);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            if (this.spanIterator.hasNext()) {
                return true;
            }
            loadNextIteratorIfAvailable();
            return this.spanIterator.hasNext();
        }

        @Override // java.util.Iterator
        public Cell<MR> next() {
            if (hasNext()) {
                return this.spanIterator.next();
            }
            throw new NoSuchElementException();
        }

        @Override // java.util.Iterator
        public void remove() {
            this.spanIterator.remove();
        }

        private void loadNextIteratorIfAvailable() {
            while (!this.spanIterator.hasNext()) {
                this.j++;
                if (this.j < this.end) {
                    this.spanIterator = this.comparator == null ? Chart.this.getSpanIterator(this.i, this.j) : Chart.this.getSpanIterator(this.i, this.j, this.comparator);
                } else {
                    this.i++;
                    if (this.i >= this.end) {
                        return;
                    }
                    this.j = this.i;
                    this.spanIterator = this.comparator == null ? Chart.this.getSpanIterator(this.i, this.j) : Chart.this.getSpanIterator(this.i, this.j, this.comparator);
                }
            }
        }
    }

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Chart$SingleQueueSpan.class */
    private static class SingleQueueSpan<MR> extends AbstractSpan<MR> {
        private final IDirectAccessBoundedPriorityQueue<Cell<MR>> queue;

        public SingleQueueSpan(int i, boolean z) {
            super();
            this.queue = z ? new OrderInvariantDirectAccessBoundedQueue<>(i, new Cell.ScoreComparator()) : new DirectAccessBoundedPriorityQueue<>(i, new Cell.ScoreComparator());
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public void addToExisting(Cell<MR> cell, Cell<MR> cell2) {
            if (cell.addCell(cell2)) {
                if (!this.queue.remove(cell)) {
                    throw new IllegalStateException("Failed to remove existing cell -- this is a bug");
                }
                this.queue.add(cell);
            }
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public Cell<MR> get(Cell<MR> cell) {
            return this.queue.get(cell);
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public boolean isPruned() {
            return this.externallyPruned || this.queue.isPruned();
        }

        @Override // java.lang.Iterable
        public Iterator<Cell<MR>> iterator() {
            return this.queue.iterator();
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public Pair<Double, Double> minQeueuScore() {
            if (this.queue.isEmpty()) {
                return null;
            }
            Cell<MR> peek = this.queue.peek();
            return Pair.of(Double.valueOf(peek.getPruneScore()), Double.valueOf(peek.getSecondPruneScore()));
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public boolean offer(Cell<MR> cell) {
            return this.queue.offer(cell);
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public int size() {
            return this.queue.size();
        }
    }

    /* loaded from: input_file:edu/cornell/cs/nlp/spf/parser/ccg/cky/chart/Chart$TwoQueueSpan.class */
    private static class TwoQueueSpan<MR> extends AbstractSpan<MR> {
        private final Map<Cell<MR>, Cell<MR>> lexicals;
        private final IDirectAccessBoundedPriorityQueue<Cell<MR>> nonLexicalQueue;

        public TwoQueueSpan(int i, boolean z) {
            super();
            this.lexicals = new HashMap();
            this.nonLexicalQueue = z ? new OrderInvariantDirectAccessBoundedQueue<>(i, new Cell.ScoreComparator()) : new DirectAccessBoundedPriorityQueue<>(i, new Cell.ScoreComparator());
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public void addToExisting(Cell<MR> cell, Cell<MR> cell2) {
            if (cell.hasLexicalStep()) {
                cell.addCell(cell2);
            } else if (cell.addCell(cell2)) {
                this.nonLexicalQueue.remove(cell);
                this.nonLexicalQueue.add(cell);
            }
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public Cell<MR> get(Cell<MR> cell) {
            return this.lexicals.containsKey(cell) ? this.lexicals.get(cell) : this.nonLexicalQueue.get(cell);
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public boolean isPruned() {
            return this.externallyPruned || this.nonLexicalQueue.isPruned();
        }

        @Override // java.lang.Iterable
        public Iterator<Cell<MR>> iterator() {
            ArrayList arrayList = new ArrayList(2);
            arrayList.add(this.lexicals.values().iterator());
            arrayList.add(this.nonLexicalQueue.iterator());
            return new CompositeIterator(arrayList);
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public Pair<Double, Double> minQeueuScore() {
            if (this.nonLexicalQueue.isEmpty()) {
                return null;
            }
            Cell<MR> peek = this.nonLexicalQueue.peek();
            return Pair.of(Double.valueOf(peek.getPruneScore()), Double.valueOf(peek.getSecondPruneScore()));
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public boolean offer(Cell<MR> cell) {
            if (!cell.hasLexicalStep()) {
                return this.nonLexicalQueue.offer(cell);
            }
            this.lexicals.put(cell, cell);
            return true;
        }

        @Override // edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.AbstractSpan
        public int size() {
            return this.lexicals.size() + this.nonLexicalQueue.size();
        }
    }

    public Chart(TokenSeq tokenSeq, int i, AbstractCellFactory<MR> abstractCellFactory, boolean z, boolean z2) {
        this.beamSize = i;
        this.tokens = tokenSeq;
        this.cellFactory = abstractCellFactory;
        this.sentenceLength = tokenSeq.size();
        this.chart = (AbstractSpan[][]) Array.newInstance((Class<?>) AbstractSpan.class, this.sentenceLength, this.sentenceLength);
        for (int i2 = 0; i2 < this.sentenceLength; i2++) {
            for (int i3 = i2; i3 < this.sentenceLength; i3++) {
                this.chart[i2][i3] = z ? new TwoQueueSpan<>(i, !z2) : new SingleQueueSpan<>(i, !z2);
            }
        }
    }

    public void add(Cell<MR> cell) {
        AbstractSpan<MR> abstractSpan = this.chart[cell.getStart()][cell.getEnd()];
        Cell<MR> cell2 = abstractSpan.get(cell);
        if (cell2 == null) {
            addNew(cell);
            return;
        }
        LOG.debug("Adding to existing cell: %s --> %s", cell, cell2);
        abstractSpan.addToExisting(cell2, cell);
        LOG.debug("Added to cell: %s", cell2);
    }

    public boolean contains(Cell<MR> cell) {
        return this.chart[cell.getStart()][cell.getEnd()].get(cell) != null;
    }

    public void externalPruning(int i, int i2) {
        this.chart[i][i2].externallyPruned = true;
    }

    public int getBeamSize() {
        return this.beamSize;
    }

    public Cell<MR> getCell(Cell<MR> cell) {
        return this.chart[cell.getStart()][cell.getEnd()].get(cell);
    }

    public AbstractCellFactory<MR> getCellFactory() {
        return this.cellFactory;
    }

    public Map<Span, Set<Cell<MR>>> getMaxNonOverlappingSpans(Cell<MR> cell, boolean z) {
        return getMaxNonOverlappingSpans(cell, new HashMap(), new HashSet(), z);
    }

    public List<CKYDerivation<MR>> getParseResults() {
        return ListUtils.map(fullparses(), cell -> {
            return new CKYDerivation(cell);
        });
    }

    public List<Pair<Integer, Integer>> getPrunedSpans() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < this.sentenceLength; i++) {
            for (int i2 = i; i2 < this.sentenceLength; i2++) {
                if (this.chart[i][i2].isPruned()) {
                    linkedList.add(Pair.of(Integer.valueOf(i), Integer.valueOf(i2)));
                }
            }
        }
        return linkedList;
    }

    public int getSentenceLength() {
        return this.sentenceLength;
    }

    public Iterable<Cell<MR>> getSpanIterable(int i, int i2) {
        return () -> {
            return getSpanIterator(i, i2);
        };
    }

    public Iterable<Cell<MR>> getSpanIterable(int i, int i2, Comparator<Cell<MR>> comparator) {
        return () -> {
            return getSpanIterator(i, i2, comparator);
        };
    }

    public Iterator<Cell<MR>> getSpanIterator(int i, int i2) {
        return this.chart[i][i2].iterator();
    }

    public Iterator<Cell<MR>> getSpanIterator(int i, int i2, Comparator<Cell<MR>> comparator) {
        if (!$assertionsDisabled && comparator == null) {
            throw new AssertionError("Method requires a comparator");
        }
        LinkedList linkedList = new LinkedList();
        Iterator<Cell<MR>> it2 = this.chart[i][i2].iterator();
        while (it2.hasNext()) {
            linkedList.add(it2.next());
        }
        return CollectionUtils.sorted(linkedList, comparator).iterator();
    }

    public TokenSeq getTokens() {
        return this.tokens;
    }

    @Override // java.lang.Iterable
    public Iterator<Cell<MR>> iterator() {
        return iterator(null);
    }

    public Iterator<Cell<MR>> iterator(Comparator<Cell<MR>> comparator) {
        return iterator(0, this.sentenceLength, comparator);
    }

    public Iterator<Cell<MR>> iterator(int i, int i2) {
        return iterator(i, i2, null);
    }

    public Iterator<Cell<MR>> iterator(int i, int i2, Comparator<Cell<MR>> comparator) {
        return new CellIterator(i, i2, comparator);
    }

    public IHashVector logExpectedFeatures(Function<Category<MR>, Double> function, Span span) {
        initializeLogOutsideProbabilities(function, span);
        propagateLogOutsideProbabilities();
        return collectLogExpectedFeatures();
    }

    public IHashVector logExpectedFeatures(IFilter<Category<MR>> iFilter) {
        return logExpectedFeatures(category -> {
            if (iFilter.test(category)) {
                return IHashVector.ZERO_VALUE;
            }
            return Double.NEGATIVE_INFINITY;
        });
    }

    public IHashVector logExpectedFeatures(IFilter<Category<MR>> iFilter, Span span) {
        return logExpectedFeatures(category -> {
            return iFilter.test(category) ? Double.valueOf(IHashVector.ZERO_VALUE) : Double.valueOf(Double.NEGATIVE_INFINITY);
        }, span);
    }

    public IHashVector logExpectedFeatures(IScorer<Category<MR>> iScorer) {
        return logExpectedFeatures(category -> {
            return Double.valueOf(iScorer.score(category));
        }, Span.of(0, this.tokens.size() - 1));
    }

    public double logNorm(IFilter<Category<MR>> iFilter) {
        return logNorm(iFilter, Span.of(0, this.tokens.size() - 1));
    }

    public double logNorm(IFilter<Category<MR>> iFilter, Span span) {
        ArrayList arrayList = new ArrayList(this.chart[span.getStart()][span.getEnd()].size());
        Iterator<Cell<MR>> it2 = this.chart[span.getStart()][span.getEnd()].iterator();
        while (it2.hasNext()) {
            Cell<MR> next = it2.next();
            if (iFilter.test(next.getCategory())) {
                arrayList.add(Double.valueOf(next.getLogInsideScore()));
            }
        }
        return LogSumExp.of(arrayList);
    }

    public void recomputeScores() {
        for (int i = 0; i < this.sentenceLength; i++) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> it2 = this.chart[i2][i2 + i].iterator();
                while (it2.hasNext()) {
                    it2.next().recomputeScores();
                }
            }
        }
    }

    public void setMaxes(MR mr) {
        resetMaxes();
        LinkedList linkedList = new LinkedList();
        double d = -1.7976931348623157E308d;
        for (Cell<MR> cell : fullparses()) {
            if (mr.equals(cell.getCategory().getSemantics())) {
                if (cell.getViterbiScore() > d) {
                    d = cell.getViterbiScore();
                    linkedList.clear();
                    linkedList.add(cell);
                } else if (cell.getViterbiScore() == d) {
                    linkedList.add(cell);
                }
            }
        }
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            ((Cell) it2.next()).setIsMax(true);
        }
        propogateMaxes();
    }

    public int spanSize(int i, int i2) {
        return this.chart[i][i2].size();
    }

    public String toString() {
        return toString(true, true, null);
    }

    public String toString(boolean z, boolean z2, IHashVectorImmutable iHashVectorImmutable) {
        StringBuilder sb = new StringBuilder();
        Iterator<Cell<MR>> it2 = z ? iterator(new Comparator<Cell<MR>>() { // from class: edu.cornell.cs.nlp.spf.parser.ccg.cky.chart.Chart.1
            private final Comparator<Cell<MR>> cellComparator = new Cell.ScoreComparator();

            @Override // java.util.Comparator
            public int compare(Cell<MR> cell, Cell<MR> cell2) {
                int compare = this.cellComparator.compare(cell, cell2);
                return compare == 0 ? Double.compare(cell.hashCode(), cell2.hashCode()) : -compare;
            }
        }) : iterator();
        while (it2.hasNext()) {
            Cell<MR> next = it2.next();
            sb.append(next.toString(false, ListUtils.join(this.tokens.subList(next.getStart(), next.getEnd() + 1), " "), z2, iHashVectorImmutable)).append("\n");
        }
        sb.append("Spans pruned: ").append(getPrunedSpans());
        return sb.toString();
    }

    private void addNew(Cell<MR> cell) {
        int start = cell.getStart();
        AbstractSpan<MR> abstractSpan = this.chart[start][cell.getEnd()];
        LOG.debug("Offering a new cell: %s", cell);
        LOG.debug("Pre-offer size of span: %d", Integer.valueOf(abstractSpan.size()));
        LOG.debug("Pre-offer span minimum score: %s", abstractSpan.minQeueuScore());
        if (abstractSpan.offer(cell)) {
            LOG.debug("Cell added");
        } else {
            LOG.debug("Cell rejected");
        }
        LOG.debug("Size of span: %d", Integer.valueOf(abstractSpan.size()));
        LOG.debug("Span minimum score: %s", abstractSpan.minQeueuScore());
    }

    private IHashVector collectLogExpectedFeatures() {
        IHashVector create = HashVectorFactory.create();
        for (int i = this.sentenceLength - 1; i >= 0; i--) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> spanIterator = getSpanIterator(i2, i2 + i);
                while (spanIterator.hasNext()) {
                    spanIterator.next().collectLogExpectedFeatures(create);
                }
            }
        }
        return create;
    }

    private List<Cell<MR>> fullparses() {
        LinkedList linkedList = new LinkedList();
        Iterator<Cell<MR>> spanIterator = getSpanIterator(0, this.sentenceLength - 1);
        while (spanIterator.hasNext()) {
            Cell<MR> next = spanIterator.next();
            if (next.isFullParse()) {
                linkedList.add(next);
            }
        }
        return linkedList;
    }

    private Map<Span, Set<Cell<MR>>> getMaxNonOverlappingSpans(Cell<MR> cell, Map<Span, Set<Cell<MR>>> map, Set<Cell<MR>> set, boolean z) {
        if (set.contains(cell)) {
            return map;
        }
        set.add(cell);
        if (!contains(cell)) {
            Span of = Span.of(cell.getStart(), cell.getEnd());
            for (IWeightedCKYStep iWeightedCKYStep : z ? cell.getViterbiSteps() : cell.getSteps()) {
                boolean z2 = true;
                int numChildren = iWeightedCKYStep.numChildren();
                int i = 0;
                while (true) {
                    if (i >= numChildren) {
                        break;
                    }
                    if (!contains(iWeightedCKYStep.getChildCell(i))) {
                        z2 = false;
                        break;
                    }
                    i++;
                }
                if (z2) {
                    if (!map.containsKey(of)) {
                        map.put(of, new HashSet());
                    }
                    map.get(of).add(cell);
                    return map;
                }
            }
            for (IWeightedCKYStep iWeightedCKYStep2 : z ? cell.getViterbiSteps() : cell.getSteps()) {
                int numChildren2 = iWeightedCKYStep2.numChildren();
                for (int i2 = 0; i2 < numChildren2; i2++) {
                    getMaxNonOverlappingSpans(iWeightedCKYStep2.getChildCell(i2), map, set, z);
                }
            }
        }
        return map;
    }

    private void initializeLogOutsideProbabilities(Function<Category<MR>, Double> function, Span span) {
        for (int i = this.sentenceLength - 1; i >= 0; i--) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> spanIterator = getSpanIterator(i2, i2 + i);
                while (spanIterator.hasNext()) {
                    spanIterator.next().initializeLogOutsideProbabilities(function, span);
                }
            }
        }
    }

    private void propagateLogOutsideProbabilities() {
        for (int i = this.sentenceLength - 1; i >= 0; i--) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> spanIterator = getSpanIterator(i2, i2 + i);
                while (spanIterator.hasNext()) {
                    spanIterator.next().updateUnaryChildrenLogOutsideScore();
                }
                Iterator<Cell<MR>> spanIterator2 = getSpanIterator(i2, i2 + i);
                while (spanIterator2.hasNext()) {
                    spanIterator2.next().updateBinaryChildrenLogOutsideScore();
                }
            }
        }
    }

    private void propogateMaxes() {
        for (int i = this.sentenceLength - 1; i >= 0; i--) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> spanIterator = getSpanIterator(i2, i2 + i);
                while (spanIterator.hasNext()) {
                    Cell<MR> next = spanIterator.next();
                    if (next.isMax()) {
                        Iterator<IWeightedCKYStep<MR>> it2 = next.getViterbiSteps().iterator();
                        while (it2.hasNext()) {
                            Iterator it3 = it2.next().iterator();
                            while (it3.hasNext()) {
                                ((Cell) it3.next()).setIsMax(true);
                            }
                        }
                    }
                }
            }
        }
    }

    private void resetMaxes() {
        for (int i = this.sentenceLength - 1; i >= 0; i--) {
            for (int i2 = 0; i2 < this.sentenceLength - i; i2++) {
                Iterator<Cell<MR>> spanIterator = getSpanIterator(i2, i2 + i);
                while (spanIterator.hasNext()) {
                    spanIterator.next().setIsMax(false);
                }
            }
        }
    }

    static {
        $assertionsDisabled = !Chart.class.desiredAssertionStatus();
        LOG = LoggerFactory.create(Chart.class.getName());
    }
}
