package org.tribuo.transform;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.TransformedModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/transform/TransformedModel.class */
public class TransformedModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final Model<T> innerModel;
    private final TransformerMap transformerMap;
    private final boolean densify;
    private final ArrayList<String> featureNames;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TransformedModel(ModelProvenance modelProvenance, Model<T> model, TransformerMap transformerMap, boolean z) {
        this(model.getName(), modelProvenance, model, transformerMap, z);
    }

    private TransformedModel(String str, ModelProvenance modelProvenance, Model<T> model, TransformerMap transformerMap, boolean z) {
        super(str, modelProvenance, model.getFeatureIDMap(), model.getOutputIDInfo(), model.generatesProbabilities());
        this.innerModel = model;
        this.transformerMap = transformerMap;
        this.densify = z;
        this.featureNames = new ArrayList<>(this.featureIDMap.keySet());
        Collections.sort(this.featureNames);
    }

    public static TransformedModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        TransformedModelProto unpack = any.unpack(TransformedModelProto.class);
        ModelDataCarrier<?> deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        return new TransformedModel<>(deserialize.name(), deserialize.provenance(), Model.deserialize(unpack.getModel()), TransformerMap.deserialize(unpack.getTransformerMap()), unpack.getDensify());
    }

    public TransformerMap getTransformerMap() {
        return this.transformerMap;
    }

    public Model<T> getInnerModel() {
        return this.innerModel;
    }

    public boolean getDensify() {
        return this.densify;
    }

    @Override // org.tribuo.Model
    public Prediction<T> predict(Example<T> example) {
        return this.innerModel.predict((Example) (this.densify ? this.transformerMap.transformExample(example, this.featureNames) : this.transformerMap.transformExample(example)));
    }

    @Override // org.tribuo.Model
    public List<Prediction<T>> predict(Dataset<T> dataset) {
        MutableDataset<T> transformDataset = this.transformerMap.transformDataset(dataset, this.densify);
        ArrayList arrayList = new ArrayList();
        Iterator<Example<T>> it = transformDataset.iterator();
        while (it.hasNext()) {
            arrayList.add(this.innerModel.predict((Example) it.next()));
        }
        return arrayList;
    }

    @Override // org.tribuo.Model
    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return this.innerModel.getTopFeatures(i);
    }

    @Override // org.tribuo.Model
    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return this.innerModel.getExcuse(this.transformerMap.transformExample(example));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tribuo.Model
    public TransformedModel<T> copy(String str, ModelProvenance modelProvenance) {
        return new TransformedModel<>(modelProvenance, this.innerModel, this.transformerMap, this.densify);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.tribuo.Model, org.tribuo.protos.ProtoSerializable
    /* renamed from: serialize */
    public ModelProto mo14serialize() {
        ModelDataCarrier<T> createDataCarrier = createDataCarrier();
        TransformedModelProto.Builder newBuilder = TransformedModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setModel(this.innerModel.mo14serialize());
        newBuilder.setTransformerMap(this.transformerMap.mo14serialize());
        newBuilder.setDensify(this.densify);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m2264build()));
        newBuilder2.setClassName(TransformedModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.m1414build();
    }
}
