package org.tribuo.evaluation;

import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.Dataset;
import org.tribuo.MutableDataset;
import org.tribuo.evaluation.KFoldSplitter;
import org.tribuo.test.MockDataSource;
import org.tribuo.test.MockOutput;

/* loaded from: input_file:org/tribuo/evaluation/KFoldSplitterTest.class */
public class KFoldSplitterTest {
    private static final Logger logger = Logger.getLogger(KFoldSplitterTest.class.getName());

    @Test
    public void testKFolder() {
        int i = 50 / 10;
        int i2 = 50 - i;
        Iterator split = new KFoldSplitter(10, 3L).split(getData(50), true);
        int i3 = 0;
        while (split.hasNext()) {
            KFoldSplitter.TrainTestFold trainTestFold = (KFoldSplitter.TrainTestFold) split.next();
            Assertions.assertEquals(i2, trainTestFold.train.size());
            Assertions.assertEquals(i, trainTestFold.test.size());
            i3++;
        }
        Assertions.assertEquals(10, i3);
    }

    @Test
    public void testKFolderKDoesNotDivideN() {
        int i = 52 / 10;
        int i2 = 52 - i;
        Iterator split = new KFoldSplitter(10, 3L).split(getData(52), true);
        int i3 = 0;
        while (i3 < 2 && split.hasNext()) {
            KFoldSplitter.TrainTestFold trainTestFold = (KFoldSplitter.TrainTestFold) split.next();
            Assertions.assertEquals(i2 - 1, trainTestFold.train.size());
            Assertions.assertEquals(i + 1, trainTestFold.test.size());
            i3++;
        }
        while (split.hasNext()) {
            KFoldSplitter.TrainTestFold trainTestFold2 = (KFoldSplitter.TrainTestFold) split.next();
            Assertions.assertEquals(i2, trainTestFold2.train.size());
            Assertions.assertEquals(i, trainTestFold2.test.size());
            i3++;
        }
        Assertions.assertEquals(10, i3);
    }

    @Test
    public void testKFolderNsplitsGTN() {
        try {
            new KFoldSplitter(11, 3L).split(getData(10), false);
            Assertions.fail("should fail for nsplits > ndata");
        } catch (IllegalArgumentException e) {
        }
    }

    @Test
    public void testKFolderTwoSplits() {
        Iterator split = new KFoldSplitter(2, 1L).split(getData(50), false);
        while (split.hasNext()) {
            KFoldSplitter.TrainTestFold trainTestFold = (KFoldSplitter.TrainTestFold) split.next();
            Assertions.assertFalse(Arrays.equals(trainTestFold.train.getExampleIndices(), trainTestFold.test.getExampleIndices()));
        }
    }

    private Dataset<MockOutput> getData(int i) {
        return new MutableDataset(new MockDataSource(i));
    }
}
