package edu.cornell.cs.nlp.spf.learn.simple;

import edu.cornell.cs.nlp.spf.base.hashvector.HashVectorFactory;
import edu.cornell.cs.nlp.spf.base.hashvector.IHashVector;
import edu.cornell.cs.nlp.spf.ccg.categories.Category;
import edu.cornell.cs.nlp.spf.data.collection.IDataCollection;
import edu.cornell.cs.nlp.spf.data.sentence.Sentence;
import edu.cornell.cs.nlp.spf.data.singlesentence.SingleSentence;
import edu.cornell.cs.nlp.spf.learn.ILearner;
import edu.cornell.cs.nlp.spf.mr.lambda.LogicalExpression;
import edu.cornell.cs.nlp.spf.parser.IDerivation;
import edu.cornell.cs.nlp.spf.parser.IParser;
import edu.cornell.cs.nlp.spf.parser.IParserOutput;
import edu.cornell.cs.nlp.spf.parser.ccg.model.Model;
import edu.cornell.cs.nlp.utils.filter.IFilter;
import edu.cornell.cs.nlp.utils.log.ILogger;
import edu.cornell.cs.nlp.utils.log.LoggerFactory;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:edu/cornell/cs/nlp/spf/learn/simple/SimplePerceptron.class */
public class SimplePerceptron implements ILearner<Sentence, SingleSentence, Model<Sentence, LogicalExpression>> {
    public static final ILogger LOG = LoggerFactory.create((Class<?>) SimplePerceptron.class);
    private final int numIterations;
    private final IParser<Sentence, LogicalExpression> parser;
    private final IDataCollection<SingleSentence> trainingData;

    public SimplePerceptron(int i, IDataCollection<SingleSentence> iDataCollection, IParser<Sentence, LogicalExpression> iParser) {
        this.numIterations = i;
        this.trainingData = iDataCollection;
        this.parser = iParser;
    }

    @Override // edu.cornell.cs.nlp.spf.learn.ILearner
    public void train(Model<Sentence, LogicalExpression> model) {
        for (int i = 0; i < this.numIterations; i++) {
            LOG.info("=========================");
            LOG.info("Training iteration %d", Integer.valueOf(i));
            LOG.info("=========================");
            int i2 = -1;
            for (final SingleSentence singleSentence : this.trainingData) {
                long currentTimeMillis = System.currentTimeMillis();
                i2++;
                LOG.info("%d : ================== [%d]", Integer.valueOf(i2), Integer.valueOf(i));
                LOG.info("Sample type: %s", singleSentence.getClass().getSimpleName());
                LOG.info("%s", singleSentence);
                IParserOutput<LogicalExpression> parse = this.parser.parse(singleSentence.getSample(), model.createDataItemModel(singleSentence.getSample()));
                List<? extends IDerivation<LogicalExpression>> bestDerivations = parse.getBestDerivations();
                List<? extends IDerivation<LogicalExpression>> maxDerivations = parse.getMaxDerivations(new IFilter<Category<LogicalExpression>>() { // from class: edu.cornell.cs.nlp.spf.learn.simple.SimplePerceptron.1
                    @Override // edu.cornell.cs.nlp.utils.filter.IFilter
                    public boolean test(Category<LogicalExpression> category) {
                        return singleSentence.getLabel().equals(category.getSemantics());
                    }
                });
                LinkedList linkedList = new LinkedList();
                for (IDerivation<LogicalExpression> iDerivation : bestDerivations) {
                    if (!singleSentence.isCorrect(iDerivation.getSemantics())) {
                        linkedList.add(iDerivation);
                        LOG.info("Bad parse: %s", iDerivation.getSemantics());
                    }
                }
                if (!linkedList.isEmpty() && !maxDerivations.isEmpty()) {
                    IHashVector create = HashVectorFactory.create();
                    Iterator<? extends IDerivation<LogicalExpression>> it2 = maxDerivations.iterator();
                    while (it2.hasNext()) {
                        it2.next().getAverageMaxFeatureVector().addTimesInto(1.0d / maxDerivations.size(), create);
                    }
                    Iterator it3 = linkedList.iterator();
                    while (it3.hasNext()) {
                        ((IDerivation) it3.next()).getAverageMaxFeatureVector().addTimesInto((-1.0d) * (1.0d / linkedList.size()), create);
                    }
                    create.dropNoise();
                    if (!model.isValidWeightVector(create)) {
                        throw new IllegalStateException("invalid update: " + create);
                    }
                    LOG.info("Update: %s", create);
                    create.addTimesInto(1.0d, model.getTheta());
                } else if (maxDerivations.isEmpty()) {
                    LOG.info("No correct parses. No update.");
                } else {
                    LOG.info("Correct. No update.");
                }
                LOG.info("Sample processing time %.4f", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
            }
        }
    }
}
