package io.trino.sql.query;

import io.trino.sql.query.QueryAssertions;
import io.trino.testing.MaterializedResult;
import io.trino.testing.MaterializedRow;
import java.util.ArrayList;
import java.util.Random;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.CONCURRENT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/sql/query/TestTDigestFunctions.class */
public class TestTDigestFunctions {
    private final QueryAssertions assertions = new QueryAssertions();

    @AfterAll
    public void teardown() {
        this.assertions.close();
    }

    @Test
    public void testValueAtQuantile() {
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT value_at_quantile(tdigest_agg(d), 0.75e0) FROM (VALUES 0.1e0, 0.2e0, 0.3e0, 0.4e0) T(d)"))).matches("VALUES 0.4e0");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT value_at_quantile(tdigest_agg(d), 0.75e0) FROM (VALUES -0.1e0, -0.2e0, -0.3e0, -0.4e0) T(d)"))).matches("VALUES -0.1e0");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT value_at_quantile(tdigest_agg(d), 0.9e0) FROM (VALUES 0.1e0, 0.1e0, 0.1e0, 0.1e0, 10e0) T(d)"))).matches("VALUES 10e0");
    }

    @Test
    public void testValuesAtQuantiles() {
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(tdigest_agg(d), ARRAY[0.0001e0, 0.75e0, 0.85e0]) FROM (VALUES 0.1e0, 0.2e0, 0.3e0, 0.4e0) T(d)"))).matches("VALUES ARRAY[0.1e0, 0.4e0, 0.4e0]");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(tdigest_agg(d), ARRAY[0.0001e0, 0.75e0, 0.85e0]) FROM (VALUES -0.1e0, -0.2e0, -0.3e0, -0.4e0) T(d)"))).matches("VALUES ARRAY[-0.4e0, -0.1e0, -0.10]");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(tdigest_agg(d), ARRAY[0.0001e0, 0.75e0, 0.85e0]) FROM (VALUES 0.1e0, 0.1e0, 0.1e0, 0.1e0, 10e0) T(d)"))).matches("VALUES ARRAY[0.1e0, 0.1e0, 10.0e0]");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(tdigest_agg(d), ARRAY[1e0, 0e0]) FROM (VALUES 0.1e0) T(d)"))).failure().hasMessage("percentiles must be sorted in increasing order");
    }

    @Test
    public void testEmptyArrayOfQuantiles() {
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(tdigest_agg(d), ARRAY[]) FROM (VALUES 0.1e0, 0.2e0, 0.3e0, 0.4e0) T(d)"))).matches("VALUES CAST(ARRAY[] AS array(double))");
    }

    @Test
    public void testEmptyTDigestInput() {
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT tdigest_agg(d) FROM (SELECT 1e0 WHERE false) T(d)"))).matches("VALUES CAST(null AS tdigest)");
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(qdigest_agg(d), ARRAY[0.5e0]) FROM (SELECT 1e0 WHERE false) T(d)"))).matches("VALUES CAST(null AS array(double))");
    }

    @Test
    public void testAccuracyAtHighAndLowPercentiles() {
        Random random = new Random(1L);
        long[] array = random.longs(2000 - 1, 0L, 1000L).toArray();
        long[] array2 = random.longs(2000 - 1, 1L, 10L).toArray();
        StringBuilder sb = new StringBuilder("VALUES (BIGINT '1', BIGINT '1')");
        for (int i = 0; i < 2000 - 1; i++) {
            sb.append(", (");
            sb.append(array[i]);
            sb.append(", ");
            sb.append(array2[i]);
            sb.append(")");
        }
        MaterializedResult execute = this.assertions.getQueryRunner().execute(this.assertions.getDefaultSession(), "SELECT values_at_quantiles(tdigest_agg(n, w), " + "ARRAY[0.00001, 0.0001, 0.001, 0.01, 0.99, 0.999, 0.9999, 0.99999]" + ") FROM (" + sb.toString() + ") T(n, w)");
        MaterializedResult execute2 = this.assertions.getQueryRunner().execute(this.assertions.getDefaultSession(), "SELECT values_at_quantiles(qdigest_agg(n, w, 0.00001), " + "ARRAY[0.00001, 0.0001, 0.001, 0.01, 0.99, 0.999, 0.9999, 0.99999]" + ") FROM (" + sb.toString() + ") T(n, w)");
        ArrayList arrayList = (ArrayList) ((MaterializedRow) execute.getMaterializedRows().get(0)).getField(0);
        ArrayList arrayList2 = (ArrayList) ((MaterializedRow) execute2.getMaterializedRows().get(0)).getField(0);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            Assertions.assertThat(((Long) arrayList2.get(i2)).equals(Long.valueOf(Math.round(((Double) arrayList.get(i2)).doubleValue())))).isTrue();
        }
    }

    @Test
    public void testCastOperators() {
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT values_at_quantiles(CAST(CAST(tdigest_agg(d) AS varbinary) AS tdigest), ARRAY[0, 1]) FROM (VALUES 1, 2, 3) T(d)"))).matches("VALUES CAST(ARRAY[1, 3] AS array(double))");
    }
}
