package org.tribuo.transform.transformations;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.BinningTransformerProto;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
import org.tribuo.transform.TransformationProvenance;
import org.tribuo.transform.Transformer;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation.class */
public final class BinningTransformation implements Transformation {
    private static final String NUM_BINS = "numBins";
    private static final String TYPE = "type";

    @Config(description = "Number of bins.")
    private int numBins;

    @Config(description = "Binning algorithm to use.")
    private BinningType type;

    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$BinningTransformationProvenance.class */
    public static final class BinningTransformationProvenance implements TransformationProvenance {
        private static final long serialVersionUID = 1;
        private final IntProvenance numBins;
        private final EnumProvenance<BinningType> type;

        BinningTransformationProvenance(BinningTransformation binningTransformation) {
            this.numBins = new IntProvenance(BinningTransformation.NUM_BINS, binningTransformation.numBins);
            this.type = new EnumProvenance<>(BinningTransformation.TYPE, binningTransformation.type);
        }

        public BinningTransformationProvenance(Map<String, Provenance> map) {
            this.numBins = ObjectProvenance.checkAndExtractProvenance(map, BinningTransformation.NUM_BINS, IntProvenance.class, BinningTransformationProvenance.class.getSimpleName());
            this.type = ObjectProvenance.checkAndExtractProvenance(map, BinningTransformation.TYPE, EnumProvenance.class, BinningTransformationProvenance.class.getSimpleName());
        }

        public String getClassName() {
            return BinningTransformation.class.getName();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof BinningTransformationProvenance)) {
                return false;
            }
            BinningTransformationProvenance binningTransformationProvenance = (BinningTransformationProvenance) obj;
            return this.numBins.equals(binningTransformationProvenance.numBins) && this.type.equals(binningTransformationProvenance.type);
        }

        public int hashCode() {
            return Objects.hash(this.numBins, this.type);
        }

        public Map<String, Provenance> getConfiguredParameters() {
            HashMap hashMap = new HashMap();
            hashMap.put(BinningTransformation.NUM_BINS, this.numBins);
            hashMap.put(BinningTransformation.TYPE, this.type);
            return Collections.unmodifiableMap(hashMap);
        }
    }

    @ProtoSerializableClass(version = 0, serializedDataClass = BinningTransformerProto.class)
    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$BinningTransformer.class */
    public static final class BinningTransformer implements Transformer {
        private static final long serialVersionUID = 1;
        public static final int CURRENT_VERSION = 0;

        @ProtoSerializableField(name = "binningType")
        private final BinningType type;

        @ProtoSerializableField
        private final double[] bins;

        @ProtoSerializableField
        private final double[] values;

        public BinningTransformer(BinningType binningType, double[] dArr, double[] dArr2) {
            if (dArr == null || dArr.length == 0) {
                throw new IllegalArgumentException("Invalid bin array");
            }
            if (dArr2 == null || dArr2.length == 0) {
                throw new IllegalArgumentException("Invalid value array");
            }
            double d = dArr[0];
            for (int i = 1; i < dArr.length; i++) {
                double d2 = dArr[i];
                if (d > d2) {
                    throw new IllegalArgumentException("Invalid bin array, values are not increasing.");
                }
                d = d2;
            }
            this.type = binningType;
            this.bins = dArr;
            this.values = dArr2;
        }

        static BinningTransformer deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
            BinningTransformerProto unpack = any.unpack(BinningTransformerProto.class);
            if (i != 0) {
                throw new IllegalArgumentException("Unknown version " + i + " expected {0}");
            }
            if (unpack.getBinsCount() != unpack.getValuesCount()) {
                throw new IllegalArgumentException("Invalid protobuf, found a differing number of bins and values.");
            }
            return new BinningTransformer(BinningType.valueOf(unpack.getBinningType()), Util.toPrimitiveDouble(unpack.getBinsList()), Util.toPrimitiveDouble(unpack.getValuesList()));
        }

        @Override // org.tribuo.transform.Transformer
        public double transform(double d) {
            if (d > this.bins[this.bins.length - 1]) {
                return this.values[this.bins.length - 1];
            }
            int binarySearch = Arrays.binarySearch(this.bins, d);
            return binarySearch < 0 ? this.values[(-1) - binarySearch] : this.values[binarySearch];
        }

        public String toString() {
            return "BinningTransformer(type=" + this.type + ",bins=" + Arrays.toString(this.bins) + ",values=" + Arrays.toString(this.values) + ")";
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            BinningTransformer binningTransformer = (BinningTransformer) obj;
            return this.type == binningTransformer.type && Arrays.equals(this.bins, binningTransformer.bins) && Arrays.equals(this.values, binningTransformer.values);
        }

        public int hashCode() {
            return (31 * ((31 * Objects.hash(this.type)) + Arrays.hashCode(this.bins))) + Arrays.hashCode(this.values);
        }

        @Override // org.tribuo.protos.ProtoSerializable
        /* renamed from: serialize, reason: avoid collision after fix types in other method and merged with bridge method [inline-methods] */
        public TransformerProto mo14serialize() {
            return ProtoUtil.serialize(this);
        }
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$BinningType.class */
    public enum BinningType {
        EQUAL_WIDTH,
        EQUAL_FREQUENCY,
        STD_DEVS
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$EqualFreqStats.class */
    private static class EqualFreqStats implements TransformStatistics {
        private static final int DEFAULT_SIZE = 50;
        private final int numBins;
        private double[] observedValues = new double[50];
        private int count = 0;

        public EqualFreqStats(int i) {
            this.numBins = i;
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeValue(double d) {
            if (this.observedValues.length == this.count + 1) {
                growArray();
            }
            this.observedValues[this.count] = d;
            this.count++;
        }

        protected void growArray(int i) {
            this.observedValues = Arrays.copyOf(this.observedValues, newCapacity(i));
        }

        private int newCapacity(int i) {
            int length = this.observedValues.length;
            int i2 = length + (length >> 1);
            if (i2 - i > 0) {
                return i2;
            }
            if (i < 0) {
                throw new OutOfMemoryError();
            }
            return i;
        }

        protected void growArray() {
            growArray(this.count + 1);
        }

        @Override // org.tribuo.transform.TransformStatistics
        @Deprecated
        public void observeSparse() {
            observeValue(0.0d);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeSparse(int i) {
            if (this.observedValues.length < this.count + i) {
                growArray(this.count + i);
            }
            this.count += i;
        }

        @Override // org.tribuo.transform.TransformStatistics
        public Transformer generateTransformer() {
            if (this.numBins > this.observedValues.length) {
                throw new IllegalStateException("Needs more values than bins, requested " + this.numBins + " bins, but only found " + this.observedValues.length + " values.");
            }
            Arrays.sort(this.observedValues, 0, this.count);
            double[] dArr = new double[this.numBins];
            double[] dArr2 = new double[this.numBins];
            int i = this.count / this.numBins;
            for (int i2 = 0; i2 < this.numBins - 1; i2++) {
                dArr[i2] = this.observedValues[(i2 + 1) * i];
                dArr2[i2] = i2 + 1;
            }
            dArr[this.numBins - 1] = this.observedValues[this.count - 1];
            dArr2[this.numBins - 1] = this.numBins;
            return new BinningTransformer(BinningType.EQUAL_FREQUENCY, dArr, dArr2);
        }

        public String toString() {
            return "EqualFreqStatistics(count=" + this.count + ",numBins=" + this.numBins + ")";
        }
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$EqualWidthStats.class */
    private static class EqualWidthStats implements TransformStatistics {
        private final int numBins;
        private double min = Double.POSITIVE_INFINITY;
        private double max = Double.NEGATIVE_INFINITY;

        public EqualWidthStats(int i) {
            this.numBins = i;
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeValue(double d) {
            if (d < this.min) {
                this.min = d;
            }
            if (d > this.max) {
                this.max = d;
            }
        }

        @Override // org.tribuo.transform.TransformStatistics
        @Deprecated
        public void observeSparse() {
            observeValue(0.0d);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeSparse(int i) {
            observeValue(0.0d);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public Transformer generateTransformer() {
            double abs = Math.abs(this.max - this.min) / this.numBins;
            double[] dArr = new double[this.numBins];
            double[] dArr2 = new double[this.numBins];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = this.min + ((i + 1) * abs);
                dArr2[i] = i + 1;
            }
            return new BinningTransformer(BinningType.EQUAL_WIDTH, dArr, dArr2);
        }

        public String toString() {
            return "EqualWidthStatistics(min=" + this.min + ",max=" + this.max + ",numBins=" + this.numBins + ")";
        }
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/BinningTransformation$StdDevStats.class */
    private static class StdDevStats implements TransformStatistics {
        private static final Logger logger = Logger.getLogger(StdDevStats.class.getName());
        private final int numBins;
        private double mean = 0.0d;
        private double sumSquares = 0.0d;
        private long count = 0;

        public StdDevStats(int i) {
            this.numBins = i;
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeValue(double d) {
            this.count++;
            double d2 = d - this.mean;
            this.mean += d2 / this.count;
            this.sumSquares += d2 * (d - this.mean);
        }

        @Override // org.tribuo.transform.TransformStatistics
        @Deprecated
        public void observeSparse() {
            observeValue(0.0d);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeSparse(int i) {
            this.count += i;
            double d = -this.mean;
            this.mean += d;
            this.sumSquares += i * d * (-this.mean);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public Transformer generateTransformer() {
            if (this.sumSquares == 0.0d) {
                logger.info("Only observed a single value (" + this.mean + ") when building a BinningTransformer using standard deviation bins.");
            }
            double[] dArr = new double[this.numBins];
            double[] dArr2 = new double[this.numBins];
            double sqrt = Math.sqrt(this.sumSquares / (this.count - 1));
            int i = -(this.numBins / 2);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i2] = i2 + 1;
                i++;
                dArr[i2] = this.mean + (i * sqrt);
            }
            return new BinningTransformer(BinningType.STD_DEVS, dArr, dArr2);
        }

        public String toString() {
            return "StdDevStatistics(mean=" + this.mean + ",sumSquares=" + this.sumSquares + ",count=" + this.count + ",numBins=" + this.numBins + ")";
        }
    }

    private BinningTransformation() {
    }

    private BinningTransformation(BinningType binningType, int i) {
        this.type = binningType;
        this.numBins = i;
        postConfig();
    }

    public void postConfig() {
        if (this.numBins < 2) {
            throw new IllegalArgumentException("Number of bins must be 2 or greater, found " + this.numBins);
        }
        if (this.type == BinningType.STD_DEVS && (this.numBins & 1) == 1) {
            throw new IllegalArgumentException("Std dev must have an even number of bins, found " + this.numBins);
        }
    }

    @Override // org.tribuo.transform.Transformation
    public TransformStatistics createStats() {
        switch (this.type) {
            case EQUAL_WIDTH:
                return new EqualWidthStats(this.numBins);
            case EQUAL_FREQUENCY:
                return new EqualFreqStats(this.numBins);
            case STD_DEVS:
                return new StdDevStats(this.numBins);
            default:
                throw new IllegalStateException("Unknown binning type " + this.type);
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TransformationProvenance m2528getProvenance() {
        return new BinningTransformationProvenance(this);
    }

    public String toString() {
        return "BinningTransformation(type=" + this.type + ",numBins=" + this.numBins + ")";
    }

    public static BinningTransformation equalWidth(int i) {
        return new BinningTransformation(BinningType.EQUAL_WIDTH, i);
    }

    public static BinningTransformation equalFrequency(int i) {
        return new BinningTransformation(BinningType.EQUAL_FREQUENCY, i);
    }

    public static BinningTransformation stdDevs(int i) {
        return new BinningTransformation(BinningType.STD_DEVS, i * 2);
    }
}
