package org.grouplens.lenskit.eval.metrics.predict;

import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongListIterator;
import java.util.List;
import javax.annotation.Nonnull;
import org.grouplens.lenskit.eval.Attributed;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.AbstractTestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.eval.traintest.TestUser;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/metrics/predict/HLUtilityPredictMetric.class */
public class HLUtilityPredictMetric extends AbstractTestUserMetric {
    private static final Logger logger = LoggerFactory.getLogger(HLUtilityPredictMetric.class);
    private static final List<String> COLUMNS = ImmutableList.of("HLUtility");
    private double alpha;

    /* loaded from: input_file:org/grouplens/lenskit/eval/metrics/predict/HLUtilityPredictMetric$Accum.class */
    public class Accum implements TestUserMetricAccumulator {
        double total = 0.0d;
        int nusers = 0;

        public Accum() {
        }

        @Override // org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator
        @Nonnull
        public List<Object> evaluate(TestUser testUser) {
            SparseVector predictions = testUser.getPredictions();
            return predictions == null ? HLUtilityPredictMetric.this.userRow(new Object[0]) : evaluatePredictions(testUser.getTestRatings(), predictions);
        }

        @Nonnull
        List<Object> evaluatePredictions(SparseVector sparseVector, SparseVector sparseVector2) {
            double computeHLU = HLUtilityPredictMetric.this.computeHLU(sparseVector2.keysByValue(true), sparseVector) / HLUtilityPredictMetric.this.computeHLU(sparseVector.keysByValue(true), sparseVector);
            this.total += computeHLU;
            this.nusers++;
            return HLUtilityPredictMetric.this.userRow(Double.valueOf(computeHLU));
        }

        @Override // org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator
        @Nonnull
        public List<Object> finalResults() {
            if (this.nusers <= 0) {
                return HLUtilityPredictMetric.this.finalRow(new Object[0]);
            }
            double d = this.total / this.nusers;
            HLUtilityPredictMetric.logger.info("HLU: {}", Double.valueOf(d));
            return HLUtilityPredictMetric.this.finalRow(Double.valueOf(d));
        }
    }

    public HLUtilityPredictMetric(double d) {
        this.alpha = d;
    }

    public HLUtilityPredictMetric() {
        this.alpha = 5.0d;
    }

    @Override // org.grouplens.lenskit.eval.metrics.TestUserMetric
    public Accum makeAccumulator(Attributed attributed, TTDataSet tTDataSet) {
        return new Accum();
    }

    @Override // org.grouplens.lenskit.eval.metrics.TestUserMetric
    public List<String> getColumnLabels() {
        return COLUMNS;
    }

    @Override // org.grouplens.lenskit.eval.metrics.TestUserMetric
    public List<String> getUserColumnLabels() {
        return COLUMNS;
    }

    double computeHLU(LongList longList, SparseVector sparseVector) {
        double d = 0.0d;
        int i = 0;
        LongListIterator it = longList.iterator();
        while (it.hasNext()) {
            i++;
            d += sparseVector.get(it.nextLong()) / Math.pow(2.0d, (i - 1) / (this.alpha - 1.0d));
        }
        return d;
    }
}
