package org.tribuo.dataset;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.ProtocolStringList;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.MutableFeatureMap;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.impl.ArrayExample;
import org.tribuo.impl.DatasetDataCarrier;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.DatasetProto;
import org.tribuo.protos.core.ExampleProto;
import org.tribuo.protos.core.SelectedFeatureDatasetProto;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.FeatureSetProvenance;

/* loaded from: input_file:org/tribuo/dataset/SelectedFeatureDataset.class */
public final class SelectedFeatureDataset<T extends Output<T>> extends ImmutableDataset<T> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(SelectedFeatureDataset.class.getName());
    public static final int CURRENT_VERSION = 0;
    private final int k;
    private final SelectedFeatureSet featureSet;
    private final Set<String> selectedFeatures;
    private final int numExamplesRemoved;

    /* loaded from: input_file:org/tribuo/dataset/SelectedFeatureDataset$SelectedFeatureDatasetProvenance.class */
    public static final class SelectedFeatureDatasetProvenance extends DatasetProvenance {
        private static final long serialVersionUID = 1;
        private static final String K = "k";
        private static final String FEATURE_SET_PROVENANCE = "feature-set-provenance";
        private static final String DATASET_PROVENANCE = "original-data-provenance";
        private final IntProvenance k;
        private final FeatureSetProvenance featureSetProvenance;
        private final DataProvenance datasetProvenance;

        <T extends Output<T>> SelectedFeatureDatasetProvenance(SelectedFeatureDataset<T> selectedFeatureDataset) {
            super(((SelectedFeatureDataset) selectedFeatureDataset).sourceProvenance, (ListProvenance<ObjectProvenance>) new ListProvenance(), selectedFeatureDataset);
            this.k = new IntProvenance(K, ((SelectedFeatureDataset) selectedFeatureDataset).k);
            this.featureSetProvenance = ((SelectedFeatureDataset) selectedFeatureDataset).featureSet.m11getProvenance();
            this.datasetProvenance = ((SelectedFeatureDataset) selectedFeatureDataset).sourceProvenance;
        }

        public SelectedFeatureDatasetProvenance(Map<String, Provenance> map) {
            super(map);
            this.k = ObjectProvenance.checkAndExtractProvenance(map, K, IntProvenance.class, SelectedFeatureDatasetProvenance.class.getSimpleName());
            this.featureSetProvenance = ObjectProvenance.checkAndExtractProvenance(map, FEATURE_SET_PROVENANCE, FeatureSetProvenance.class, SelectedFeatureDatasetProvenance.class.getSimpleName());
            this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map, DATASET_PROVENANCE, DataProvenance.class, SelectedFeatureDatasetProvenance.class.getSimpleName());
        }

        @Override // org.tribuo.provenance.DatasetProvenance
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass() || !super.equals(obj)) {
                return false;
            }
            SelectedFeatureDatasetProvenance selectedFeatureDatasetProvenance = (SelectedFeatureDatasetProvenance) obj;
            return this.k.equals(selectedFeatureDatasetProvenance.k) && this.featureSetProvenance.equals(selectedFeatureDatasetProvenance.featureSetProvenance) && this.datasetProvenance.equals(selectedFeatureDatasetProvenance.datasetProvenance);
        }

        @Override // org.tribuo.provenance.DatasetProvenance
        public int hashCode() {
            return Objects.hash(Integer.valueOf(super.hashCode()), this.k, this.featureSetProvenance, this.datasetProvenance);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.tribuo.provenance.DatasetProvenance
        public List<Pair<String, Provenance>> allProvenances() {
            List<Pair<String, Provenance>> allProvenances = super.allProvenances();
            allProvenances.add(new Pair<>(K, this.k));
            allProvenances.add(new Pair<>(FEATURE_SET_PROVENANCE, this.featureSetProvenance));
            allProvenances.add(new Pair<>(DATASET_PROVENANCE, this.datasetProvenance));
            return allProvenances;
        }
    }

    public SelectedFeatureDataset(Dataset<T> dataset, SelectedFeatureSet selectedFeatureSet) {
        this(dataset, selectedFeatureSet, -1);
    }

    public SelectedFeatureDataset(Dataset<T> dataset, SelectedFeatureSet selectedFeatureSet, int i) {
        super(dataset.getProvenance(), dataset.getOutputFactory());
        this.featureSet = selectedFeatureSet;
        this.k = i;
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        if (i == 0 || selectedFeatureSet.featureNames().size() == 0) {
            throw new IllegalArgumentException("Tried to select zero features.");
        }
        if (i != -1 && !selectedFeatureSet.isOrdered()) {
            throw new IllegalArgumentException("Tried to select the top " + i + " features from an unordered feature set.");
        }
        if (i > selectedFeatureSet.featureNames().size()) {
            throw new IllegalArgumentException("Tried to select more features than are available in feature set, requested " + i + ", found " + selectedFeatureSet.featureNames().size());
        }
        if (i > 0) {
            List<String> featureNames = selectedFeatureSet.featureNames();
            for (int i2 = 0; i2 < i; i2++) {
                linkedHashSet.add(featureNames.get(i2));
            }
        } else {
            if (i < -1) {
                throw new IllegalArgumentException("Supplied k " + i + " but only k == -1 or 1 < k < N} is allowed.");
            }
            linkedHashSet.addAll(selectedFeatureSet.featureNames());
        }
        this.selectedFeatures = Collections.unmodifiableSet(linkedHashSet);
        HashSet hashSet = new HashSet(dataset.getFeatureMap().keySet());
        hashSet.retainAll(this.selectedFeatures);
        if (hashSet.size() == 0) {
            throw new IllegalArgumentException("The selected feature set had no overlap with the supplied dataset.");
        }
        int i3 = 0;
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap();
        MutableOutputInfo<T> generateInfo = dataset.getOutputFactory().generateInfo();
        ArrayList arrayList = new ArrayList();
        Iterator<Example<T>> it = dataset.iterator();
        while (it.hasNext()) {
            Example<T> next = it.next();
            arrayList.clear();
            ArrayExample arrayExample = new ArrayExample(next);
            Iterator<Feature> it2 = next.iterator();
            while (it2.hasNext()) {
                Feature next2 = it2.next();
                if (this.selectedFeatures.contains(next2.getName())) {
                    mutableFeatureMap.add(next2.getName(), next2.getValue());
                } else {
                    arrayList.add(next2);
                }
            }
            if (arrayList.size() > 0) {
                arrayExample.removeFeatures(arrayList);
            }
            if (arrayExample.size() > 0) {
                this.data.add(arrayExample);
                generateInfo.observe(next.getOutput());
            } else {
                i3++;
            }
        }
        this.numExamplesRemoved = i3;
        this.featureIDMap = new ImmutableFeatureMap(mutableFeatureMap);
        this.outputIDInfo = generateInfo.generateImmutableOutputInfo();
        if (this.numExamplesRemoved > 0) {
            logger.info(String.format("filtered out %d examples because they had zero features after the selected feature set was applied.", Integer.valueOf(this.numExamplesRemoved)));
        }
    }

    private SelectedFeatureDataset(DataProvenance dataProvenance, OutputFactory<T> outputFactory, String str, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Example<T>> list, int i, SelectedFeatureSet selectedFeatureSet, Set<String> set, int i2) {
        super(dataProvenance, outputFactory, str, immutableFeatureMap, immutableOutputInfo, list, false);
        this.k = i;
        this.selectedFeatures = Collections.unmodifiableSet(set);
        this.featureSet = selectedFeatureSet;
        this.numExamplesRemoved = i2;
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [org.tribuo.Output] */
    /* JADX WARN: Type inference failed for: r0v65, types: [org.tribuo.Output] */
    /* JADX WARN: Type inference failed for: r3v21, types: [org.tribuo.Output] */
    public static SelectedFeatureDataset<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        SelectedFeatureDatasetProto unpack = any.unpack(SelectedFeatureDatasetProto.class);
        DatasetDataCarrier<?> deserialize = DatasetDataCarrier.deserialize(unpack.getMetadata());
        Class<?> cls = deserialize.outputFactory().getUnknownOutput().getClass();
        FeatureMap featureDomain = deserialize.featureDomain();
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        Iterator<ExampleProto> it = unpack.getExamplesList().iterator();
        while (it.hasNext()) {
            Example<?> deserialize2 = Example.deserialize(it.next());
            if (!deserialize2.getOutput().getClass().equals(cls)) {
                throw new IllegalStateException("Invalid protobuf, expected all examples to have output class " + cls + ", but found " + deserialize2.getOutput().getClass() + " in example idx " + i2);
            }
            Iterator<Feature> it2 = deserialize2.iterator();
            while (it2.hasNext()) {
                Feature next = it2.next();
                if (featureDomain.get(next.getName()) == null) {
                    throw new IllegalStateException("Invalid protobuf, feature domain does not contain feature " + next.getName() + " present in example at idx " + i2);
                }
            }
            arrayList.add(deserialize2);
            i2++;
        }
        if (!(featureDomain instanceof ImmutableFeatureMap)) {
            throw new IllegalStateException("Invalid protobuf, feature map was not immutable");
        }
        if (!(deserialize.outputDomain() instanceof ImmutableOutputInfo)) {
            throw new IllegalStateException("Invalid protobuf, output info was not immutable");
        }
        int k = unpack.getK();
        if (k < 1 && k != -1) {
            throw new IllegalStateException("Invalid protobuf, k must be positive or -1, found " + k);
        }
        int numExamplesRemoved = unpack.getNumExamplesRemoved();
        if (numExamplesRemoved < 0) {
            throw new IllegalStateException("Invalid protobuf, number of examples removed must be non-negative, found " + numExamplesRemoved);
        }
        SelectedFeatureSet selectedFeatureSet = (SelectedFeatureSet) ProtoUtil.deserialize(unpack.getFeatureSet());
        ProtocolStringList mo1901getSelectedFeaturesList = unpack.mo1901getSelectedFeaturesList();
        LinkedHashSet linkedHashSet = new LinkedHashSet((Collection) mo1901getSelectedFeaturesList);
        if (linkedHashSet.size() != mo1901getSelectedFeaturesList.size()) {
            throw new IllegalStateException("Invalid protobuf, selected features contained duplicates, features = " + mo1901getSelectedFeaturesList);
        }
        Iterator it3 = linkedHashSet.iterator();
        while (it3.hasNext()) {
            if (featureDomain.get((String) it3.next()) == null) {
                throw new IllegalStateException("Invalid protobuf, some selected features were not found in the feature domain.");
            }
        }
        return new SelectedFeatureDataset<>(deserialize.provenance(), deserialize.outputFactory(), deserialize.tribuoVersion(), (ImmutableFeatureMap) featureDomain, (ImmutableOutputInfo) deserialize.outputDomain(), arrayList, k, selectedFeatureSet, linkedHashSet, numExamplesRemoved);
    }

    public int getNumExamplesRemoved() {
        return this.numExamplesRemoved;
    }

    public int getK() {
        return this.k;
    }

    public SelectedFeatureSet getFeatureSet() {
        return this.featureSet;
    }

    public Set<String> getSelectedFeatures() {
        return this.selectedFeatures;
    }

    @Override // org.tribuo.ImmutableDataset
    /* renamed from: getProvenance */
    public DatasetProvenance mo5getProvenance() {
        return new SelectedFeatureDatasetProvenance(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.tribuo.ImmutableDataset, org.tribuo.protos.ProtoSerializable
    /* renamed from: serialize */
    public DatasetProto mo14serialize() {
        SelectedFeatureDatasetProto.Builder newBuilder = SelectedFeatureDatasetProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier(this.featureIDMap, this.outputIDInfo).serialize());
        Iterator<Example<T>> it = this.data.iterator();
        while (it.hasNext()) {
            newBuilder.addExamples(it.next().mo14serialize());
        }
        newBuilder.setNumExamplesRemoved(this.numExamplesRemoved);
        newBuilder.setK(this.k);
        newBuilder.setFeatureSet(this.featureSet.mo14serialize());
        newBuilder.addAllSelectedFeatures(this.selectedFeatures);
        DatasetProto.Builder newBuilder2 = DatasetProto.newBuilder();
        newBuilder2.setVersion(0);
        newBuilder2.setClassName(SelectedFeatureDataset.class.getName());
        newBuilder2.setSerializedData(Any.pack(newBuilder.m1934build()));
        return newBuilder2.m327build();
    }
}
