package io.trino.operator.output;

import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.trino.SequencePageBuilder;
import io.trino.operator.PartitionFunction;
import io.trino.spi.Page;
import io.trino.spi.type.BigintType;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/output/TestSkewedPartitionRebalancer.class */
class TestSkewedPartitionRebalancer {
    private static final long MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(1, DataSize.Unit.MEGABYTE).toBytes();
    private static final long MIN_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(50, DataSize.Unit.MEGABYTE).toBytes();

    /* loaded from: input_file:io/trino/operator/output/TestSkewedPartitionRebalancer$TestPartitionFunction.class */
    private static final class TestPartitionFunction extends Record implements PartitionFunction {
        private final int partitionCount;

        private TestPartitionFunction(int i) {
            this.partitionCount = i;
        }

        public int getPartition(Page page, int i) {
            return i % this.partitionCount;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, TestPartitionFunction.class), TestPartitionFunction.class, "partitionCount", "FIELD:Lio/trino/operator/output/TestSkewedPartitionRebalancer$TestPartitionFunction;->partitionCount:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, TestPartitionFunction.class), TestPartitionFunction.class, "partitionCount", "FIELD:Lio/trino/operator/output/TestSkewedPartitionRebalancer$TestPartitionFunction;->partitionCount:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, TestPartitionFunction.class, Object.class), TestPartitionFunction.class, "partitionCount", "FIELD:Lio/trino/operator/output/TestSkewedPartitionRebalancer$TestPartitionFunction;->partitionCount:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public int partitionCount() {
            return this.partitionCount;
        }
    }

    TestSkewedPartitionRebalancer() {
    }

    @Test
    void testRebalanceWithSkewness() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(3, 3, 3, MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(3), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 1000L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(40L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 3, 6, 9, 12, 15)), new IntArrayList(ImmutableList.of(1, 4, 7, 10, 13, 16)), new IntArrayList(ImmutableList.of(2, 5, 8, 11, 14))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0), ImmutableList.of(1), ImmutableList.of(2)});
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 1000L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(20L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 2, 4, 6, 8, 10, 12, 14, 16)), new IntArrayList(ImmutableList.of(1, 3, 7, 9, 13, 15)), new IntArrayList(ImmutableList.of(5, 11))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 1), ImmutableList.of(1, 0), ImmutableList.of(2, 0)});
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 1000L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(200L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 2, 4, 9, 11, 13)), new IntArrayList(ImmutableList.of(1, 3, 5, 10, 12, 14)), new IntArrayList(ImmutableList.of(6, 7, 8, 15, 16))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 1, 2), ImmutableList.of(1, 0, 2), ImmutableList.of(2, 0, 1)});
    }

    @Test
    void testRebalanceWithoutSkewness() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(6, 3, 2, MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(6), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 700L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 600L);
        skewedPartitionRebalancer.addPartitionRowCount(3, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(4, 700L);
        skewedPartitionRebalancer.addPartitionRowCount(5, 600L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(500L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 6)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 3)), new IntArrayList(ImmutableList.of(1, 4)), new IntArrayList(ImmutableList.of(2, 5))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0), ImmutableList.of(1), ImmutableList.of(2), ImmutableList.of(0), ImmutableList.of(1), ImmutableList.of(2)});
    }

    @Test
    void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(3, 3, 3, MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(3), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 0L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 0L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(40L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 6)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 3)), new IntArrayList(ImmutableList.of(1, 4)), new IntArrayList(ImmutableList.of(2, 5))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0), ImmutableList.of(1), ImmutableList.of(2)});
    }

    @Test
    void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingMinDataProcessed() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(3, 3, 3, DataSize.of(50L, DataSize.Unit.MEGABYTE).toBytes(), MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(3), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 600L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 0L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(60L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 6)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 3)), new IntArrayList(ImmutableList.of(1, 4)), new IntArrayList(ImmutableList.of(2, 5))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0), ImmutableList.of(1), ImmutableList.of(2)});
    }

    @Test
    void testRebalancePartitionToSingleTaskInARebalancingLoop() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(3, 3, 3, MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(3), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 0L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 0L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(60L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 6, 12)), new IntArrayList(ImmutableList.of(1, 3, 4, 7, 9, 10, 13, 15, 16)), new IntArrayList(ImmutableList.of(2, 5, 8, 11, 14))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 1), ImmutableList.of(1), ImmutableList.of(2)});
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 0L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 0L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(60L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 9)), new IntArrayList(ImmutableList.of(1, 3, 4, 7, 10, 12, 13, 16)), new IntArrayList(ImmutableList.of(2, 5, 6, 8, 11, 14, 15))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 1, 2), ImmutableList.of(1), ImmutableList.of(2)});
    }

    @Test
    public void testConsiderSkewedPartitionOnlyWithinACycle() {
        SkewedPartitionRebalancer skewedPartitionRebalancer = new SkewedPartitionRebalancer(3, 3, 1, MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        SkewedPartitionFunction skewedPartitionFunction = new SkewedPartitionFunction(new TestPartitionFunction(3), skewedPartitionRebalancer);
        skewedPartitionRebalancer.addPartitionRowCount(0, 1000L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 800L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 0L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(60L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 6, 12)), new IntArrayList(ImmutableList.of(1, 4, 7, 10, 13, 16)), new IntArrayList(ImmutableList.of(2, 3, 5, 8, 9, 11, 14, 15))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 2), ImmutableList.of(1), ImmutableList.of(2)});
        skewedPartitionRebalancer.addPartitionRowCount(0, 0L);
        skewedPartitionRebalancer.addPartitionRowCount(1, 800L);
        skewedPartitionRebalancer.addPartitionRowCount(2, 1000L);
        skewedPartitionRebalancer.addDataProcessed(DataSize.of(60L, DataSize.Unit.MEGABYTE).toBytes());
        skewedPartitionRebalancer.rebalance();
        Assertions.assertThat(getPartitionPositions(skewedPartitionFunction, 17)).containsExactly(new List[]{new IntArrayList(ImmutableList.of(0, 2, 6, 8, 12, 14)), new IntArrayList(ImmutableList.of(1, 4, 7, 10, 13, 16)), new IntArrayList(ImmutableList.of(3, 5, 9, 11, 15))});
        Assertions.assertThat(skewedPartitionRebalancer.getPartitionAssignments()).containsExactly(new List[]{ImmutableList.of(0, 2), ImmutableList.of(1), ImmutableList.of(2, 0)});
    }

    private static List<List<Integer>> getPartitionPositions(PartitionFunction partitionFunction, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < partitionFunction.partitionCount(); i2++) {
            arrayList.add(new ArrayList());
        }
        for (int i3 = 0; i3 < i; i3++) {
            ((List) arrayList.get(partitionFunction.getPartition(dummyPage(), i3))).add(Integer.valueOf(i3));
        }
        return arrayList;
    }

    private static Page dummyPage() {
        return SequencePageBuilder.createSequencePage(ImmutableList.of(BigintType.BIGINT), 100, 0);
    }
}
