package org.tribuo.datasource;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.MutableDataset;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.test.MockOutputFactory;

/* loaded from: input_file:org/tribuo/datasource/IDXDataSourceTest.class */
public class IDXDataSourceTest {
    private static final MockOutputFactory factory = new MockOutputFactory();

    public static void generateOutputData() throws IOException {
        IDXDataSource.IDXData.createIDXData(IDXDataSource.IDXType.BYTE, new int[]{4}, new double[]{0.0d, 1.0d, 1.0d, 0.0d}).save(Paths.get("./outputs.idx", new String[0]), false);
        IDXDataSource.IDXData.createIDXData(IDXDataSource.IDXType.BYTE, new int[]{8}, new double[]{0.0d, 1.0d, 1.0d, 0.0d, 2.0d, 1.0d, 1.0d, 2.0d}).save(Paths.get("./outputs-long.idx", new String[0]), false);
    }

    public static void generateByteData() throws IOException {
        IDXDataSource.IDXData.createIDXData(IDXDataSource.IDXType.BYTE, new int[]{4, 4}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 4.0d, 3.0d, 2.0d, 1.0d, 5.0d, 6.0d, 7.0d, 8.0d, 8.0d, 7.0d, 6.0d, 5.0d}).save(Paths.get("./byte.idx", new String[0]), false);
        IDXDataSource.IDXData.createIDXData(IDXDataSource.IDXType.BYTE, new int[]{4, 2, 2}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 4.0d, 3.0d, 2.0d, 1.0d, 5.0d, 6.0d, 7.0d, 8.0d, 8.0d, 7.0d, 6.0d, 5.0d}).save(Paths.get("./byte-mat.idx", new String[0]), false);
    }

    public static void generateIntData() throws IOException {
        IDXDataSource.IDXData.createIDXData(IDXDataSource.IDXType.INT, new int[]{4, 4}, new double[]{1234.0d, 2.0d, 3.0d, 4.0d, 4321.0d, 3.0d, 2.0d, 1.0d, 5.0d, 6789.0d, 7.0d, 8.0d, 8765.0d, 7.0d, 6.0d, 5.0d}).save(Paths.get("./int.idx", new String[0]), false);
    }

    public static void generateInvalidIDX() throws IOException {
        new IDXDataSource.IDXData(IDXDataSource.IDXType.INT, new int[]{5, 2}, new double[]{1.0d, 2.0d, 3.0d, 4.0d}).save(Paths.get("./too-little-data.idx", new String[0]), false);
        new IDXDataSource.IDXData(IDXDataSource.IDXType.INT, new int[]{5}, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d}).save(Paths.get("./too-much-data.idx", new String[0]), false);
    }

    public static void generateNonsenseIDX() throws IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(new FileOutputStream(Paths.get("invalid-magic-byte.idx", new String[0]).toFile()));
        try {
            dataOutputStream.writeShort(5);
            dataOutputStream.writeByte(IDXDataSource.IDXType.INT.value);
            dataOutputStream.writeByte(2);
            dataOutputStream.writeInt(2);
            dataOutputStream.writeInt(2);
            dataOutputStream.writeInt(1);
            dataOutputStream.writeInt(2);
            dataOutputStream.writeInt(3);
            dataOutputStream.writeInt(4);
            dataOutputStream.close();
            DataOutputStream dataOutputStream2 = new DataOutputStream(new FileOutputStream(Paths.get("invalid-type-byte.idx", new String[0]).toFile()));
            try {
                dataOutputStream2.writeShort(0);
                dataOutputStream2.writeByte(128);
                dataOutputStream2.writeByte(2);
                dataOutputStream2.writeInt(2);
                dataOutputStream2.writeInt(2);
                dataOutputStream2.writeInt(1);
                dataOutputStream2.writeInt(2);
                dataOutputStream2.writeInt(3);
                dataOutputStream2.writeInt(4);
                dataOutputStream2.close();
                DataOutputStream dataOutputStream3 = new DataOutputStream(new FileOutputStream(Paths.get("invalid-dim-byte.idx", new String[0]).toFile()));
                try {
                    dataOutputStream3.writeShort(0);
                    dataOutputStream3.writeByte(IDXDataSource.IDXType.INT.value);
                    dataOutputStream3.writeByte(-2);
                    dataOutputStream3.writeInt(2);
                    dataOutputStream3.writeInt(2);
                    dataOutputStream3.writeInt(1);
                    dataOutputStream3.writeInt(2);
                    dataOutputStream3.writeInt(3);
                    dataOutputStream3.writeInt(4);
                    dataOutputStream3.close();
                    dataOutputStream2 = new DataOutputStream(new FileOutputStream(Paths.get("incorrect-num-dims.idx", new String[0]).toFile()));
                    try {
                        dataOutputStream2.writeShort(0);
                        dataOutputStream2.writeByte(IDXDataSource.IDXType.INT.value);
                        dataOutputStream2.writeByte(3);
                        dataOutputStream2.writeInt(2);
                        dataOutputStream2.writeInt(2);
                        dataOutputStream2.writeInt(1);
                        dataOutputStream2.writeInt(2);
                        dataOutputStream2.writeInt(3);
                        dataOutputStream2.writeInt(4);
                        dataOutputStream2.close();
                        dataOutputStream = new DataOutputStream(new FileOutputStream(Paths.get("no-data.idx", new String[0]).toFile()));
                        try {
                            dataOutputStream.writeShort(0);
                            dataOutputStream.writeByte(IDXDataSource.IDXType.INT.value);
                            dataOutputStream.writeByte(2);
                            dataOutputStream.writeInt(2);
                            dataOutputStream.writeInt(2);
                            dataOutputStream.close();
                        } finally {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th) {
                                th.addSuppressed(th);
                            }
                        }
                    } finally {
                        try {
                            dataOutputStream2.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                } finally {
                    try {
                        dataOutputStream3.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                }
            } finally {
            }
        } finally {
        }
    }

    public static void main(String[] strArr) throws IOException {
        generateByteData();
        generateIntData();
        generateOutputData();
        generateInvalidIDX();
        generateNonsenseIDX();
    }

    @Test
    public void testByteLoading() throws IOException, URISyntaxException {
        Path path = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/byte.idx").toURI());
        Path path2 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/outputs.idx").toURI());
        testIDXLoading(path, path2);
        testIDXLoading(Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/byte-mat.idx").toURI()), path2);
    }

    @Test
    public void testIntLoading() throws IOException, URISyntaxException {
        testIDXLoading(Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/int.idx").toURI()), Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/outputs.idx").toURI()));
    }

    private void testIDXLoading(Path path, Path path2) throws IOException {
        MutableDataset mutableDataset = new MutableDataset(new IDXDataSource(path, path2, factory));
        Assertions.assertEquals(4, mutableDataset.getFeatureMap().size());
        Assertions.assertEquals(2, mutableDataset.getOutputInfo().size());
    }

    @Test
    public void testInvalidCombination() throws IOException, URISyntaxException {
        Path path = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/int.idx").toURI());
        Path path2 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/outputs-long.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            new IDXDataSource(path, path2, factory);
        });
    }

    @Test
    public void testInvalidIDX() throws URISyntaxException {
        Path path = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/too-much-data.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path);
        });
        Path path2 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/too-little-data.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path2);
        });
    }

    @Test
    public void testNonsenseIDX() throws URISyntaxException {
        Path path = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/invalid-magic-byte.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path);
        });
        Path path2 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/invalid-type-byte.idx").toURI());
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            IDXDataSource.readData(path2);
        });
        Path path3 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/invalid-dim-byte.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path3);
        });
        Path path4 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/incorrect-num-dims.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path4);
        });
        Path path5 = Paths.get(IDXDataSourceTest.class.getResource("/org/tribuo/datasource/no-data.idx").toURI());
        Assertions.assertThrows(IllegalStateException.class, () -> {
            IDXDataSource.readData(path5);
        });
    }

    @Test
    public void testFileNotFound() throws IOException {
        try {
            new ConfigurationManager("/org/tribuo/datasource/config.xml").lookup("train");
            Assertions.fail("Should have thrown PropertyException");
        } catch (RuntimeException e) {
            Assertions.fail("Incorrect exception thrown", e);
        } catch (PropertyException e2) {
            if (!e2.getMessage().contains("Failed to load from path - ")) {
                Assertions.fail("Incorrect exception message", e2);
            }
        }
        try {
            new IDXDataSource(Paths.get("these-features-dont-exist", new String[0]), Paths.get("these-outputs-dont-exist", new String[0]), new MockOutputFactory());
            Assertions.fail("Should have thrown FileNotFoundException");
        } catch (FileNotFoundException e3) {
            if (e3.getMessage().contains("Failed to load from path - ")) {
                return;
            }
            Assertions.fail("Incorrect exception message", e3);
        }
    }
}
