package org.tribuo.datasource;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.IOUtil;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.DataSourceProvenance;

/* loaded from: input_file:org/tribuo/datasource/IDXDataSource.class */
public final class IDXDataSource<T extends Output<T>> implements ConfigurableDataSource<T> {
    private static final Logger logger = Logger.getLogger(IDXDataSource.class.getName());

    @Config(mandatory = true, description = "Path to load the features from.")
    private Path featuresPath;

    @Config(mandatory = true, description = "Path to load the features from.")
    private Path outputPath;

    @Config(mandatory = true, description = "The output factory to use.")
    private OutputFactory<T> outputFactory;
    private final ArrayList<Example<T>> data = new ArrayList<>();
    private IDXType dataType;
    private IDXDataSourceProvenance provenance;

    /* loaded from: input_file:org/tribuo/datasource/IDXDataSource$IDXData.class */
    public static class IDXData {
        final IDXType dataType;
        final int[] shape;
        final double[] data;

        IDXData(IDXType iDXType, int[] iArr, double[] dArr) {
            this.dataType = iDXType;
            this.shape = iArr;
            this.data = dArr;
        }

        public static IDXData createIDXData(IDXType iDXType, int[] iArr, double[] dArr) {
            int[] copyOf = Arrays.copyOf(iArr, iArr.length);
            double[] copyOf2 = Arrays.copyOf(dArr, dArr.length);
            if (iArr.length > 128) {
                throw new IllegalArgumentException("Must have fewer than 128 dimensions");
            }
            int i = 1;
            for (int i2 = 0; i2 < copyOf.length; i2++) {
                i *= copyOf[i2];
                if (copyOf[i2] < 1) {
                    throw new IllegalArgumentException("Invalid shape, all elements must be positive, found " + Arrays.toString(copyOf));
                }
            }
            if (i != copyOf2.length) {
                throw new IllegalArgumentException("Incorrect number of elements, expected " + i + ", found " + copyOf2.length);
            }
            if (iDXType != IDXType.DOUBLE) {
                for (int i3 = 0; i3 < copyOf2.length; i3++) {
                    switch (iDXType) {
                        case UBYTE:
                            if (copyOf2[i3] != (255 & ((int) copyOf2[i3]))) {
                                throw new IllegalArgumentException("Invalid value at idx " + i3 + ", could not be converted to unsigned byte");
                            }
                            break;
                        case BYTE:
                            if (copyOf2[i3] != ((byte) copyOf2[i3])) {
                                throw new IllegalArgumentException("Invalid value at idx " + i3 + ", could not be converted to byte");
                            }
                            break;
                        case SHORT:
                            if (copyOf2[i3] != ((short) copyOf2[i3])) {
                                throw new IllegalArgumentException("Invalid value at idx " + i3 + ", could not be converted to short");
                            }
                            break;
                        case INT:
                            if (copyOf2[i3] != ((int) copyOf2[i3])) {
                                throw new IllegalArgumentException("Invalid value at idx " + i3 + ", could not be converted to int");
                            }
                            break;
                        case FLOAT:
                            if (copyOf2[i3] != ((float) copyOf2[i3])) {
                                throw new IllegalArgumentException("Invalid value at idx " + i3 + ", could not be converted to float");
                            }
                            break;
                    }
                }
            }
            return new IDXData(iDXType, iArr, dArr);
        }

        public void save(Path path, boolean z) throws IOException {
            DataOutputStream makeStream = makeStream(path, z);
            try {
                makeStream.writeShort(0);
                makeStream.writeByte(this.dataType.value);
                makeStream.writeByte(this.shape.length);
                for (int i = 0; i < this.shape.length; i++) {
                    makeStream.writeInt(this.shape[i]);
                }
                for (int i2 = 0; i2 < this.data.length; i2++) {
                    switch (this.dataType) {
                        case UBYTE:
                            makeStream.writeByte(255 & ((int) this.data[i2]));
                            break;
                        case BYTE:
                            makeStream.writeByte((byte) this.data[i2]);
                            break;
                        case SHORT:
                            makeStream.writeShort((short) this.data[i2]);
                            break;
                        case INT:
                            makeStream.writeInt((int) this.data[i2]);
                            break;
                        case FLOAT:
                            makeStream.writeFloat((float) this.data[i2]);
                            break;
                        case DOUBLE:
                            makeStream.writeDouble(this.data[i2]);
                            break;
                    }
                }
                if (makeStream != null) {
                    makeStream.close();
                }
            } catch (Throwable th) {
                if (makeStream != null) {
                    try {
                        makeStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }

        private static DataOutputStream makeStream(Path path, boolean z) throws IOException {
            return new DataOutputStream(new BufferedOutputStream(z ? new GZIPOutputStream(new FileOutputStream(path.toFile())) : new FileOutputStream(path.toFile())));
        }
    }

    /* loaded from: input_file:org/tribuo/datasource/IDXDataSource$IDXDataSourceProvenance.class */
    public static final class IDXDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
        private static final long serialVersionUID = 1;
        public static final String OUTPUT_FILE_MODIFIED_TIME = "output-file-modified-time";
        public static final String FEATURES_FILE_MODIFIED_TIME = "features-file-modified-time";
        public static final String FEATURES_RESOURCE_HASH = "features-resource-hash";
        public static final String OUTPUT_RESOURCE_HASH = "output-resource-hash";
        public static final String FEATURE_TYPE = "idx-feature-type";
        private final DateTimeProvenance featuresFileModifiedTime;
        private final DateTimeProvenance outputFileModifiedTime;
        private final DateTimeProvenance dataSourceCreationTime;
        private final HashProvenance featuresSHA256Hash;
        private final HashProvenance outputSHA256Hash;
        private final EnumProvenance<IDXType> featureType;

        <T extends Output<T>> IDXDataSourceProvenance(IDXDataSource<T> iDXDataSource) {
            super(iDXDataSource, "DataSource");
            this.outputFileModifiedTime = new DateTimeProvenance(OUTPUT_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(((IDXDataSource) iDXDataSource).outputPath.toFile().lastModified()), ZoneId.systemDefault()));
            this.featuresFileModifiedTime = new DateTimeProvenance(FEATURES_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(((IDXDataSource) iDXDataSource).featuresPath.toFile().lastModified()), ZoneId.systemDefault()));
            this.dataSourceCreationTime = new DateTimeProvenance(DataSourceProvenance.DATASOURCE_CREATION_TIME, OffsetDateTime.now());
            this.featuresSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, FEATURES_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, ((IDXDataSource) iDXDataSource).featuresPath));
            this.outputSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, OUTPUT_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, ((IDXDataSource) iDXDataSource).outputPath));
            this.featureType = new EnumProvenance<>(FEATURE_TYPE, ((IDXDataSource) iDXDataSource).dataType);
        }

        public IDXDataSourceProvenance(Map<String, Provenance> map) {
            this(extractProvenanceInfo(map));
        }

        private IDXDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
            this.featuresFileModifiedTime = (DateTimeProvenance) extractedInfo.instanceValues.get(FEATURES_FILE_MODIFIED_TIME);
            this.outputFileModifiedTime = (DateTimeProvenance) extractedInfo.instanceValues.get(OUTPUT_FILE_MODIFIED_TIME);
            this.dataSourceCreationTime = (DateTimeProvenance) extractedInfo.instanceValues.get(DataSourceProvenance.DATASOURCE_CREATION_TIME);
            this.featuresSHA256Hash = (HashProvenance) extractedInfo.instanceValues.get(FEATURES_RESOURCE_HASH);
            this.outputSHA256Hash = (HashProvenance) extractedInfo.instanceValues.get(OUTPUT_RESOURCE_HASH);
            this.featureType = (EnumProvenance) extractedInfo.instanceValues.get(FEATURE_TYPE);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap hashMap = new HashMap(map);
            String value = ObjectProvenance.checkAndExtractProvenance(hashMap, "class-name", StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue();
            String value2 = ObjectProvenance.checkAndExtractProvenance(hashMap, "host-short-name", StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue();
            HashMap hashMap2 = new HashMap();
            hashMap2.put(FEATURES_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(hashMap, FEATURES_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            hashMap2.put(OUTPUT_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(hashMap, OUTPUT_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            hashMap2.put(DataSourceProvenance.DATASOURCE_CREATION_TIME, ObjectProvenance.checkAndExtractProvenance(hashMap, DataSourceProvenance.DATASOURCE_CREATION_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            hashMap2.put(FEATURES_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(hashMap, FEATURES_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            hashMap2.put(OUTPUT_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(hashMap, OUTPUT_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            hashMap2.put(FEATURE_TYPE, ObjectProvenance.checkAndExtractProvenance(hashMap, FEATURE_TYPE, EnumProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(value, value2, hashMap, hashMap2);
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map<String, PrimitiveProvenance<?>> instanceValues = super.getInstanceValues();
            instanceValues.put(this.featuresFileModifiedTime.getKey(), this.featuresFileModifiedTime);
            instanceValues.put(this.outputFileModifiedTime.getKey(), this.outputFileModifiedTime);
            instanceValues.put(this.dataSourceCreationTime.getKey(), this.dataSourceCreationTime);
            instanceValues.put(this.featuresSHA256Hash.getKey(), this.featuresSHA256Hash);
            instanceValues.put(this.outputSHA256Hash.getKey(), this.outputSHA256Hash);
            instanceValues.put(this.featureType.getKey(), this.featureType);
            return instanceValues;
        }
    }

    /* loaded from: input_file:org/tribuo/datasource/IDXDataSource$IDXType.class */
    public enum IDXType {
        UBYTE((byte) 8),
        BYTE((byte) 9),
        SHORT((byte) 11),
        INT((byte) 12),
        FLOAT((byte) 13),
        DOUBLE((byte) 14);

        public final byte value;

        IDXType(byte b) {
            this.value = b;
        }

        public static IDXType convert(byte b) {
            for (IDXType iDXType : values()) {
                if (iDXType.value == b) {
                    return iDXType;
                }
            }
            throw new IllegalArgumentException("Invalid byte found - " + ((int) b));
        }
    }

    private IDXDataSource() {
    }

    public IDXDataSource(Path path, Path path2, OutputFactory<T> outputFactory) throws IOException {
        this.outputFactory = outputFactory;
        this.featuresPath = path;
        this.outputPath = path2;
        read();
    }

    public void postConfig() throws IOException {
        read();
    }

    public String toString() {
        return "IDXDataSource(featuresPath=" + this.featuresPath.toString() + ",outputPath=" + this.outputPath.toString() + ",featureType=" + this.dataType + ")";
    }

    @Override // org.tribuo.DataSource
    public OutputFactory<T> getOutputFactory() {
        return this.outputFactory;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public synchronized DataSourceProvenance m22getProvenance() {
        if (this.provenance == null) {
            this.provenance = cacheProvenance();
        }
        return this.provenance;
    }

    private IDXDataSourceProvenance cacheProvenance() {
        return new IDXDataSourceProvenance(this);
    }

    private void read() throws IOException {
        IDXData readData = readData(this.featuresPath);
        IDXData readData2 = readData(this.outputPath);
        this.dataType = readData.dataType;
        if (readData.shape[0] != readData2.shape[0]) {
            throw new IllegalStateException("Features and outputs have different numbers of examples, feature shape = " + Arrays.toString(readData.shape) + ", output shape = " + Arrays.toString(readData2.shape));
        }
        int i = 1;
        for (int i2 = 1; i2 < readData.shape.length; i2++) {
            i *= readData.shape[i2];
        }
        int i3 = 1;
        for (int i4 = 1; i4 < readData2.shape.length; i4++) {
            i3 *= readData2.shape[i4];
        }
        String[] strArr = new String[i];
        String str = "%0" + ("" + i).length() + "d";
        for (int i5 = 0; i5 < i; i5++) {
            strArr[i5] = String.format(str, Integer.valueOf(i5));
        }
        ArrayList arrayList = new ArrayList();
        int i6 = 0;
        int i7 = 0;
        StringBuilder sb = new StringBuilder();
        for (int i8 = 0; i8 < readData.data.length; i8++) {
            double d = readData.data[i8];
            if (d != 0.0d) {
                arrayList.add(new Feature(strArr[i6], d));
            }
            i6++;
            if (i6 == i) {
                sb.setLength(0);
                for (int i9 = 0; i9 < i3; i9++) {
                    if (i9 != 0) {
                        sb.append(',');
                    }
                    switch (readData2.dataType) {
                        case UBYTE:
                        case BYTE:
                        case SHORT:
                        case INT:
                            sb.append((int) readData2.data[i9 + i7]);
                            break;
                        case FLOAT:
                        case DOUBLE:
                            sb.append(readData2.data[i9 + i7]);
                            break;
                    }
                }
                i7 += i3;
                ArrayExample arrayExample = new ArrayExample(this.outputFactory.generateOutput(sb.toString()));
                arrayExample.addAll(arrayList);
                this.data.add(arrayExample);
                arrayList.clear();
                i6 = 0;
            }
        }
        if (i6 != 0) {
            throw new IllegalStateException("Failed to process all the features, missing " + (i - i6) + " values");
        }
    }

    static IDXData readData(Path path) throws IOException {
        InputStream inputStreamForLocation = IOUtil.getInputStreamForLocation(path.toString());
        if (inputStreamForLocation == null) {
            throw new FileNotFoundException("Failed to load from path - " + path);
        }
        DataInputStream dataInputStream = new DataInputStream(inputStreamForLocation);
        try {
            short readShort = dataInputStream.readShort();
            if (readShort != 0) {
                throw new IllegalStateException("Invalid IDX file, magic number was not zero. Found " + ((int) readShort));
            }
            IDXType convert = IDXType.convert(dataInputStream.readByte());
            int readByte = dataInputStream.readByte();
            if (readByte < 1) {
                throw new IllegalStateException("Invalid number of dimensions, found " + readByte);
            }
            int[] iArr = new int[readByte];
            int i = 1;
            for (int i2 = 0; i2 < readByte; i2++) {
                iArr[i2] = dataInputStream.readInt();
                if (iArr[i2] < 1) {
                    throw new IllegalStateException("Invalid shape, found " + Arrays.toString(iArr));
                }
                i *= iArr[i2];
            }
            double[] dArr = new double[i];
            for (int i3 = 0; i3 < i; i3++) {
                try {
                    switch (convert) {
                        case UBYTE:
                            dArr[i3] = dataInputStream.readUnsignedByte();
                            break;
                        case BYTE:
                            dArr[i3] = dataInputStream.readByte();
                            break;
                        case SHORT:
                            dArr[i3] = dataInputStream.readShort();
                            break;
                        case INT:
                            dArr[i3] = dataInputStream.readInt();
                            break;
                        case FLOAT:
                            dArr[i3] = dataInputStream.readFloat();
                            break;
                        case DOUBLE:
                            dArr[i3] = dataInputStream.readDouble();
                            break;
                    }
                } catch (EOFException e) {
                    throw new IllegalStateException("Too little data in the file, expected to find " + i + " elements");
                }
            }
            try {
                dataInputStream.readByte();
                throw new IllegalStateException("Too much data in the file");
            } catch (EOFException e2) {
                IDXData iDXData = new IDXData(convert, iArr, dArr);
                dataInputStream.close();
                return iDXData;
            }
        } catch (Throwable th) {
            try {
                dataInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public int size() {
        return this.data.size();
    }

    public IDXType getDataType() {
        return this.dataType;
    }

    @Override // java.lang.Iterable
    public Iterator<Example<T>> iterator() {
        return this.data.iterator();
    }
}
