package io.trino.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.class */
public abstract class AbstractTestApproximateCountDistinct {
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();

    protected abstract Type getValueType();

    protected abstract Object randomValue();

    protected int getUniqueValuesCount() {
        return 20000;
    }

    @Test
    public void testNoPositions() {
        assertCount(ImmutableList.of(), 0.023d, 0L);
        assertCount(ImmutableList.of(), 0.0115d, 0L);
    }

    @Test
    public void testSinglePosition() {
        assertCount(ImmutableList.of(randomValue()), 0.023d, 1L);
        assertCount(ImmutableList.of(randomValue()), 0.0115d, 1L);
    }

    @Test
    public void testAllPositionsNull() {
        assertCount(Collections.nCopies(100, null), 0.023d, 0L);
        assertCount(Collections.nCopies(100, null), 0.0115d, 0L);
    }

    @Test
    public void testMixedNullsAndNonNulls() {
        testMixedNullsAndNonNulls(0.023d);
        testMixedNullsAndNonNulls(0.0115d);
    }

    private void testMixedNullsAndNonNulls(double d) {
        int uniqueValuesCount = getUniqueValuesCount();
        List<Object> createRandomSample = createRandomSample(uniqueValuesCount, (int) (uniqueValuesCount * 1.5d));
        Iterator<Object> it = createRandomSample.iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            arrayList.add(ThreadLocalRandom.current().nextBoolean() ? null : it.next());
        }
        assertCount(arrayList, d, estimateGroupByCount(createRandomSample, d));
    }

    @Test
    public void testMultiplePositions() {
        testMultiplePositions(0.023d);
        testMultiplePositions(0.0115d);
    }

    private void testMultiplePositions(double d) {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        for (int i = 0; i < 500; i++) {
            descriptiveStatistics.addValue(((estimateGroupByCount(createRandomSample(r0, (int) (r0 * 1.5d)), d) - r0) * 1.0d) / (ThreadLocalRandom.current().nextInt(getUniqueValuesCount()) + 1));
        }
        Assertions.assertThat(descriptiveStatistics.getMean()).isLessThan(0.01d);
        Assertions.assertThat(descriptiveStatistics.getStandardDeviation()).isLessThan(0.01d + d);
    }

    @Test
    public void testMultiplePositionsPartial() {
        testMultiplePositionsPartial(0.023d);
        testMultiplePositionsPartial(0.0115d);
    }

    private void testMultiplePositionsPartial(double d) {
        for (int i = 0; i < 100; i++) {
            int nextInt = ThreadLocalRandom.current().nextInt(getUniqueValuesCount()) + 1;
            List<Object> createRandomSample = createRandomSample(nextInt, (int) (nextInt * 1.5d));
            Assertions.assertThat(estimateCountPartial(createRandomSample, d)).isEqualTo(estimateGroupByCount(createRandomSample, d));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void assertCount(List<?> list, double d, long j) {
        if (!list.isEmpty()) {
            Assertions.assertThat(estimateGroupByCount(list, d)).isEqualTo(j);
        }
        Assertions.assertThat(estimateCount(list, d)).isEqualTo(j);
        Assertions.assertThat(estimateCountPartial(list, d)).isEqualTo(j);
    }

    private long estimateGroupByCount(List<?> list, double d) {
        return ((Long) AggregationTestUtils.groupedAggregation(getAggregationFunction(), createPage(list, d))).longValue();
    }

    private long estimateCount(List<?> list, double d) {
        return ((Long) AggregationTestUtils.aggregation(getAggregationFunction(), createPage(list, d))).longValue();
    }

    private long estimateCountPartial(List<?> list, double d) {
        return ((Long) AggregationTestUtils.partialAggregation(getAggregationFunction(), createPage(list, d))).longValue();
    }

    private TestingAggregationFunction getAggregationFunction() {
        return FUNCTION_RESOLUTION.getAggregateFunction("approx_distinct", TypeSignatureProvider.fromTypes(new Type[]{getValueType(), DoubleType.DOUBLE}));
    }

    private Page createPage(List<?> list, double d) {
        return list.isEmpty() ? new Page(0) : new Page(list.size(), new Block[]{createBlock(getValueType(), list), createBlock(DoubleType.DOUBLE, ImmutableList.copyOf(Collections.nCopies(list.size(), Double.valueOf(d))))});
    }

    private static Block createBlock(Type type, List<?> list) {
        BlockBuilder createBlockBuilder = type.createBlockBuilder((BlockBuilderStatus) null, list.size());
        for (Object obj : list) {
            Class javaType = type.getJavaType();
            if (obj == null) {
                createBlockBuilder.appendNull();
            } else if (javaType == Boolean.TYPE) {
                type.writeBoolean(createBlockBuilder, ((Boolean) obj).booleanValue());
            } else if (javaType == Long.TYPE) {
                type.writeLong(createBlockBuilder, ((Long) obj).longValue());
            } else if (javaType == Double.TYPE) {
                type.writeDouble(createBlockBuilder, ((Double) obj).doubleValue());
            } else if (javaType == Slice.class) {
                Slice slice = (Slice) obj;
                type.writeSlice(createBlockBuilder, slice, 0, slice.length());
            } else {
                type.writeObject(createBlockBuilder, obj);
            }
        }
        return createBlockBuilder.build();
    }

    private List<Object> createRandomSample(int i, int i2) {
        Preconditions.checkArgument(i <= i2, "uniques (%s) must be <= total (%s)", i, i2);
        ArrayList arrayList = new ArrayList(i2);
        arrayList.addAll(makeRandomSet(i));
        ThreadLocalRandom current = ThreadLocalRandom.current();
        while (arrayList.size() < i2) {
            arrayList.add(arrayList.get(current.nextInt(arrayList.size())));
        }
        return arrayList;
    }

    private Set<Object> makeRandomSet(int i) {
        HashSet hashSet = new HashSet();
        while (hashSet.size() < i) {
            hashSet.add(randomValue());
        }
        return hashSet;
    }
}
