package io.trino.sql.ir.optimizer;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.type.BigintType;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.rule.DistributeComparisonOverSwitch;
import io.trino.testing.TestingSession;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.OptionalAssert;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/ir/optimizer/TestDistributeComparisonOverSwitch.class */
public class TestDistributeComparisonOverSwitch {
    @Test
    void test() {
        ((OptionalAssert) Assertions.assertThat(optimize(new Comparison(Comparison.Operator.LESS_THAN, new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(BigintType.BIGINT, "x")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "y"))), new Reference(BigintType.BIGINT, "z")), new Reference(BigintType.BIGINT, "m")))).describedAs("switch(...) < reference", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "x"), new Reference(BigintType.BIGINT, "m"))), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "y"), new Reference(BigintType.BIGINT, "m")))), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "z"), new Reference(BigintType.BIGINT, "m")))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Comparison(Comparison.Operator.LESS_THAN, new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(BigintType.BIGINT, "x")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "y"))), new Reference(BigintType.BIGINT, "z")), new Constant(BigintType.BIGINT, 1L)))).describedAs("switch(...) < constant", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "x"), new Constant(BigintType.BIGINT, 1L))), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "y"), new Constant(BigintType.BIGINT, 1L)))), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "z"), new Constant(BigintType.BIGINT, 1L)))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "m"), new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(BigintType.BIGINT, "x")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "y"))), new Reference(BigintType.BIGINT, "z"))))).describedAs("reference < switch(...)", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "x"), new Reference(BigintType.BIGINT, "m"))), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "y"), new Reference(BigintType.BIGINT, "m")))), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "z"), new Reference(BigintType.BIGINT, "m")))));
        ((OptionalAssert) Assertions.assertThat(optimize(new Comparison(Comparison.Operator.LESS_THAN, new Constant(BigintType.BIGINT, 1L), new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Reference(BigintType.BIGINT, "x")), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Reference(BigintType.BIGINT, "y"))), new Reference(BigintType.BIGINT, "z"))))).describedAs("constant < switch(...)", new Object[0])).isEqualTo(Optional.of(new Switch(new Reference(BigintType.BIGINT, "s"), ImmutableList.of(new WhenClause(new Reference(BigintType.BIGINT, "a"), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "x"), new Constant(BigintType.BIGINT, 1L))), new WhenClause(new Reference(BigintType.BIGINT, "b"), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "y"), new Constant(BigintType.BIGINT, 1L)))), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "z"), new Constant(BigintType.BIGINT, 1L)))));
    }

    private Optional<Expression> optimize(Expression expression) {
        return new DistributeComparisonOverSwitch().apply(expression, TestingSession.testSession(), ImmutableMap.of());
    }
}
