package org.tribuo.sequence;

import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.Tribuo;
import org.tribuo.impl.DatasetDataCarrier;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.SequenceDatasetProto;
import org.tribuo.protos.core.SequenceExampleProto;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DatasetProvenance;

/* loaded from: input_file:org/tribuo/sequence/SequenceDataset.class */
public abstract class SequenceDataset<T extends Output<T>> implements Iterable<SequenceExample<T>>, ProtoSerializable<SequenceDatasetProto>, Provenancable<DatasetProvenance>, Serializable {
    private static final Logger logger = Logger.getLogger(SequenceDataset.class.getName());
    private static final long serialVersionUID = 2;
    protected final OutputFactory<T> outputFactory;
    protected final List<SequenceExample<T>> data;
    protected final String tribuoVersion;
    protected final DataProvenance sourceProvenance;

    /* loaded from: input_file:org/tribuo/sequence/SequenceDataset$FlatDataset.class */
    private static class FlatDataset<T extends Output<T>> extends ImmutableDataset<T> {
        private static final long serialVersionUID = 1;

        FlatDataset(SequenceDataset<T> sequenceDataset) {
            super(sequenceDataset.sourceProvenance, sequenceDataset.outputFactory, sequenceDataset.getFeatureIDMap(), sequenceDataset.getOutputIDInfo());
            Iterator<SequenceExample<T>> it = sequenceDataset.iterator();
            while (it.hasNext()) {
                Iterator<Example<T>> it2 = it.next().iterator();
                while (it2.hasNext()) {
                    this.data.add(it2.next());
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SequenceDataset(DataProvenance dataProvenance, OutputFactory<T> outputFactory) {
        this(dataProvenance, outputFactory, Tribuo.VERSION);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SequenceDataset(DataProvenance dataProvenance, OutputFactory<T> outputFactory, String str) {
        this.data = new ArrayList();
        this.sourceProvenance = dataProvenance;
        this.outputFactory = outputFactory;
        this.tribuoVersion = str;
    }

    public String getSourceDescription() {
        return "SequenceDataset(source=" + this.sourceProvenance.toString() + ")";
    }

    public List<SequenceExample<T>> getData() {
        return Collections.unmodifiableList(this.data);
    }

    public DataProvenance getSourceProvenance() {
        return this.sourceProvenance;
    }

    public abstract Set<T> getOutputs();

    public SequenceExample<T> getExample(int i) {
        if (i < 0 || i >= size()) {
            throw new IllegalArgumentException("Example index " + i + " is out of bounds.");
        }
        return this.data.get(i);
    }

    public Dataset<T> getFlatDataset() {
        return new FlatDataset(this);
    }

    public int size() {
        return this.data.size();
    }

    public abstract ImmutableOutputInfo<T> getOutputIDInfo();

    public abstract OutputInfo<T> getOutputInfo();

    public abstract ImmutableFeatureMap getFeatureIDMap();

    public abstract FeatureMap getFeatureMap();

    public OutputFactory<T> getOutputFactory() {
        return this.outputFactory;
    }

    @Override // java.lang.Iterable
    public Iterator<SequenceExample<T>> iterator() {
        return this.data.iterator();
    }

    public String toString() {
        return "SequenceDataset(source=" + this.sourceProvenance.toString() + ")";
    }

    public boolean validate(Class<? extends Output<?>> cls) {
        boolean z = true;
        Iterator<T> it = getOutputInfo().getDomain().iterator();
        while (it.hasNext()) {
            z &= cls.isInstance(it.next());
        }
        return z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T extends Output<T>> SequenceDataset<T> castDataset(SequenceDataset<?> sequenceDataset, Class<T> cls) {
        if (sequenceDataset.validate(cls)) {
            return sequenceDataset;
        }
        throw new ClassCastException("Attempted to cast dataset to " + cls.getName() + " which is not valid for dataset " + sequenceDataset.toString());
    }

    public static SequenceDataset<?> deserialize(SequenceDatasetProto sequenceDatasetProto) {
        return (SequenceDataset) ProtoUtil.deserialize(sequenceDatasetProto);
    }

    public static SequenceDataset<?> deserializeFromFile(Path path) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(Files.newInputStream(path, new OpenOption[0]));
        try {
            SequenceDataset<?> deserializeFromStream = deserializeFromStream(bufferedInputStream);
            bufferedInputStream.close();
            return deserializeFromStream;
        } catch (Throwable th) {
            try {
                bufferedInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static SequenceDataset<?> deserializeFromStream(InputStream inputStream) throws IOException {
        return deserialize(SequenceDatasetProto.parseFrom(inputStream));
    }

    public void serializeToFile(Path path) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(Files.newOutputStream(path, new OpenOption[0]));
        try {
            serializeToStream(bufferedOutputStream);
            bufferedOutputStream.close();
        } catch (Throwable th) {
            try {
                bufferedOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public void serializeToStream(OutputStream outputStream) throws IOException {
        mo14serialize().writeTo(outputStream);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DatasetDataCarrier<T> createDataCarrier(FeatureMap featureMap, OutputInfo<T> outputInfo) {
        return new DatasetDataCarrier<>(this.sourceProvenance, featureMap, outputInfo, this.outputFactory, Collections.emptyList(), this.tribuoVersion == null ? Tribuo.VERSION : this.tribuoVersion);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r0v21, types: [org.tribuo.Output] */
    /* JADX WARN: Type inference failed for: r3v5, types: [org.tribuo.Output] */
    public static List<SequenceExample<?>> deserializeExamples(List<SequenceExampleProto> list, Class<?> cls, FeatureMap featureMap) {
        ArrayList arrayList = new ArrayList();
        Iterator<SequenceExampleProto> it = list.iterator();
        while (it.hasNext()) {
            SequenceExample<?> deserialize = SequenceExample.deserialize(it.next());
            Iterator<Example<?>> it2 = deserialize.iterator();
            while (it2.hasNext()) {
                Example<?> next = it2.next();
                if (!next.getOutput().getClass().equals(cls)) {
                    throw new IllegalStateException("Invalid protobuf, expected all examples to have output class " + cls + ", but found " + next.getOutput().getClass());
                }
                Iterator<Feature> it3 = next.iterator();
                while (it3.hasNext()) {
                    Feature next2 = it3.next();
                    if (featureMap.get(next2.getName()) == null) {
                        throw new IllegalStateException("Invalid protobuf, feature domain does not contain feature " + next2.getName() + " present in an example");
                    }
                }
            }
            arrayList.add(deserialize);
        }
        return arrayList;
    }
}
