package io.trino.sql.ir.optimizer;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.rule.RemoveRedundantSwitchClauses;
import io.trino.testing.TestingSession;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.OptionalAssert;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/ir/optimizer/TestRemoveRedundantSwitchClauses.class */
public class TestRemoveRedundantSwitchClauses {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", ImmutableList.of());

    @Test
    void test() {
        ((OptionalAssert) Assertions.assertThat(optimize(new Switch(new Reference(BigintType.BIGINT, "x"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(VarcharType.VARCHAR, "r2")), new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r3"))), new Reference(VarcharType.VARCHAR, "d")))).describedAs("redundant terms", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "x"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(VarcharType.VARCHAR, "r2"))), new Reference(VarcharType.VARCHAR, "d"))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Switch(new Constant(BigintType.BIGINT, 1L), ImmutableList.of(new WhenClause(new Constant(BigintType.BIGINT, 2L), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(BigintType.BIGINT, "x"), new Reference(VarcharType.VARCHAR, "r2"))), new Reference(VarcharType.VARCHAR, "d")))).describedAs("redundant constants", new Object[0])).isEqualTo(Optional.of(new Switch(new Constant(BigintType.BIGINT, 1L), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "x"), new Reference(VarcharType.VARCHAR, "r2"))), new Reference(VarcharType.VARCHAR, "d"))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Switch(new Reference(BigintType.BIGINT, "x"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(BigintType.BIGINT, "x"), new Reference(VarcharType.VARCHAR, "r2"))), new Reference(VarcharType.VARCHAR, "d")))).describedAs("short-circuit", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "x"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r1"))), new Reference(VarcharType.VARCHAR, "r2"))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Switch(new Reference(BigintType.BIGINT, "x"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "x"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(VarcharType.VARCHAR, "r2"))), new Reference(VarcharType.VARCHAR, "d")))).describedAs("short-circuit on first term", new Object[0])).isEqualTo(Optional.of(new Reference(VarcharType.VARCHAR, "r1")));
        ((OptionalAssert) Assertions.assertThat(optimize(new Switch(new Reference(DoubleType.DOUBLE, "x"), ImmutableList.of(new WhenClause(new Reference(DoubleType.DOUBLE, "a"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(DoubleType.DOUBLE, "b"), new Reference(VarcharType.VARCHAR, "r2")), new WhenClause(new Call(RANDOM, ImmutableList.of()), new Reference(VarcharType.VARCHAR, "r3")), new WhenClause(new Call(RANDOM, ImmutableList.of()), new Reference(VarcharType.VARCHAR, "r4")), new WhenClause(new Reference(DoubleType.DOUBLE, "a"), new Reference(VarcharType.VARCHAR, "r5"))), new Reference(VarcharType.VARCHAR, "d")))).describedAs("non-deterministic terms", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(DoubleType.DOUBLE, "x"), ImmutableList.of(new WhenClause(new Reference(DoubleType.DOUBLE, "a"), new Reference(VarcharType.VARCHAR, "r1")), new WhenClause(new Reference(DoubleType.DOUBLE, "b"), new Reference(VarcharType.VARCHAR, "r2")), new WhenClause(new Call(RANDOM, ImmutableList.of()), new Reference(VarcharType.VARCHAR, "r3")), new WhenClause(new Call(RANDOM, ImmutableList.of()), new Reference(VarcharType.VARCHAR, "r4"))), new Reference(VarcharType.VARCHAR, "d"))));
    }

    private Optional<Expression> optimize(Expression expression) {
        return new RemoveRedundantSwitchClauses(FUNCTIONS.getPlannerContext()).apply(expression, TestingSession.testSession(), ImmutableMap.of());
    }
}
