package org.tribuo.evaluation;

import java.util.Arrays;
import java.util.Iterator;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.dataset.DatasetView;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/evaluation/KFoldSplitter.class */
public class KFoldSplitter<T extends Output<T>> {
    private static final Logger logger = Logger.getLogger(KFoldSplitter.class.getName());
    protected final int nsplits;
    protected final long seed;
    protected final SplittableRandom rng;

    /* loaded from: input_file:org/tribuo/evaluation/KFoldSplitter$TrainTestFold.class */
    public static class TrainTestFold<T extends Output<T>> {
        public final DatasetView<T> train;
        public final DatasetView<T> test;

        TrainTestFold(DatasetView<T> datasetView, DatasetView<T> datasetView2) {
            this.train = datasetView;
            this.test = datasetView2;
        }
    }

    public KFoldSplitter(int i, long j) {
        if (i < 2) {
            throw new IllegalArgumentException("nsplits must be at least 2");
        }
        this.nsplits = i;
        this.seed = j;
        this.rng = new SplittableRandom(j);
    }

    public KFoldSplitter(int i) {
        this(i, Trainer.DEFAULT_SEED);
    }

    public Iterator<TrainTestFold<T>> split(final Dataset<T> dataset, boolean z) {
        final int size = dataset.size();
        if (size == 0) {
            throw new IllegalArgumentException("empty input data");
        }
        if (this.nsplits > size) {
            throw new IllegalArgumentException("cannot have nsplits > nsamples");
        }
        int[] randperm = z ? Util.randperm(size, this.rng) : IntStream.range(0, size).toArray();
        final int[] iArr = new int[this.nsplits];
        Arrays.fill(iArr, size / this.nsplits);
        for (int i = 0; i < size % this.nsplits; i++) {
            int i2 = i;
            iArr[i2] = iArr[i2] + 1;
        }
        final int[] iArr2 = randperm;
        return (Iterator<TrainTestFold<T>>) new Iterator<TrainTestFold<T>>(this) { // from class: org.tribuo.evaluation.KFoldSplitter.1
            int foldPtr = 0;
            int dataPtr = 0;
            final /* synthetic */ KFoldSplitter this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.foldPtr < iArr.length;
            }

            @Override // java.util.Iterator
            public TrainTestFold<T> next() {
                int i3 = iArr[this.foldPtr];
                this.foldPtr++;
                int i4 = this.dataPtr;
                int i5 = this.dataPtr + i3;
                this.dataPtr = i5;
                int[] copyOfRange = Arrays.copyOfRange(iArr2, i4, i5);
                int[] iArr3 = new int[iArr2.length - copyOfRange.length];
                System.arraycopy(iArr2, 0, iArr3, 0, i4);
                System.arraycopy(iArr2, i5, iArr3, i4, size - i5);
                return new TrainTestFold<>(new DatasetView(dataset, iArr3, "TrainFold(seed=" + this.this$0.seed + "," + this.foldPtr + " of " + this.this$0.nsplits + ")"), new DatasetView(dataset, copyOfRange, "TestFold(seed=" + this.this$0.seed + "," + this.foldPtr + " of " + this.this$0.nsplits + ")"));
            }
        };
    }
}
