package org.tribuo.transform.transformations;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.function.DoubleUnaryOperator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.SimpleTransformProto;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockModel;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.transform.TransformationMap;
import org.tribuo.transform.Transformer;
import org.tribuo.transform.TransformerMap;
import org.tribuo.transform.transformations.SimpleTransform;

/* loaded from: input_file:org/tribuo/transform/transformations/SimpleTransformTest.class */
public class SimpleTransformTest {
    public static MutableDataset<MockOutput> generateDenseDataset() {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9"};
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{-1.0d, -1.0d, -1.0d, -1.0d, -1.0d, -1.0d, -1.0d, -1.0d, -1.0d, -1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 9.0d, 8.0d, 7.0d, 6.0d, 5.0d, 4.0d, 3.0d, 2.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{2.0d, 4.0d, 6.0d, 8.0d, 10.0d, 12.0d, 14.0d, 16.0d, 18.0d, 20.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, -2.0d, 3.0d, -4.0d, 5.0d, -6.0d, 7.0d, -8.0d, 9.0d, -10.0d}));
        return mutableDataset;
    }

    @Test
    public void testAddSub() {
        testSimple(new TransformationMap(Arrays.asList(SimpleTransform.add(10.0d), SimpleTransform.sub(10.0d)), new HashMap()), d -> {
            return d;
        });
    }

    @Test
    public void testMulDiv() {
        testSimple(new TransformationMap(Arrays.asList(SimpleTransform.mul(10.0d), SimpleTransform.div(10.0d)), new HashMap()), d -> {
            return d;
        });
    }

    @Test
    public void testExpLog() {
        testSimple(new TransformationMap(Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()), new HashMap()), d -> {
            return d;
        });
    }

    @Test
    public void testExp() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.exp()), new HashMap()), Math::exp);
    }

    @Test
    public void testLog() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.log()), new HashMap()), Math::log);
    }

    @Test
    public void testAdd() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.add(5.0d)), new HashMap()), d -> {
            return d + 5.0d;
        });
    }

    @Test
    public void testSub() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.sub(100.0d)), new HashMap()), d -> {
            return d - 100.0d;
        });
    }

    @Test
    public void testMul() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.mul(-2.0d)), new HashMap()), d -> {
            return d * (-2.0d);
        });
    }

    @Test
    public void testDiv() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.div(45.0d)), new HashMap()), d -> {
            return d / 45.0d;
        });
    }

    @Test
    public void testBinarise() {
        testSimple(new TransformationMap(Collections.singletonList(SimpleTransform.binarise()), new HashMap()), d -> {
            return d < 1.0E-12d ? 0.0d : 1.0d;
        });
    }

    public void testSimple(TransformationMap transformationMap, DoubleUnaryOperator doubleUnaryOperator) {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        TransformerMap createTransformers = generateDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateDenseDataset);
        for (int i = 0; i < generateDenseDataset.size(); i++) {
            Example example = (Example) generateDenseDataset.getData().get(i);
            Example example2 = (Example) transformDataset.getData().get(i);
            Assertions.assertEquals(example.size(), example2.size(), "Transformed not the same size as original.");
            Iterator it = example.iterator();
            Iterator it2 = example2.iterator();
            while (it.hasNext() && it2.hasNext()) {
                Feature feature = (Feature) it.next();
                Feature feature2 = (Feature) it2.next();
                Assertions.assertEquals(feature.getName(), feature2.getName());
                Assertions.assertEquals(doubleUnaryOperator.applyAsDouble(feature.getValue()), feature2.getValue(), 1.0E-12d);
            }
        }
        List list = createTransformers.get("F0");
        if (list.size() == 1) {
            Transformer deserialize = Transformer.deserialize(((Transformer) list.get(0)).serialize());
            Assertions.assertEquals(list.get(0), deserialize);
            Assertions.assertNotSame(list.get(0), deserialize);
        }
    }

    @Test
    public void testThresholdBelow() {
        testThresholding(new TransformationMap(Collections.singletonList(SimpleTransform.threshold(0.0d, Double.POSITIVE_INFINITY)), new HashMap()), 0.0d, Double.POSITIVE_INFINITY);
    }

    @Test
    public void testThresholdAbove() {
        testThresholding(new TransformationMap(Collections.singletonList(SimpleTransform.threshold(Double.NEGATIVE_INFINITY, 1.0d)), new HashMap()), Double.NEGATIVE_INFINITY, 1.0d);
    }

    @Test
    public void testThreshold() {
        testThresholding(new TransformationMap(Collections.singletonList(SimpleTransform.threshold(0.0d, 1.0d)), new HashMap()), 0.0d, 1.0d);
    }

    public void testThresholding(TransformationMap transformationMap, double d, double d2) {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        MutableDataset transformDataset = generateDenseDataset.createTransformers(transformationMap).transformDataset(generateDenseDataset);
        for (int i = 0; i < generateDenseDataset.size(); i++) {
            Example example = (Example) generateDenseDataset.getData().get(i);
            Example example2 = (Example) transformDataset.getData().get(i);
            Assertions.assertEquals(example.size(), example2.size(), "Transformed not the same size as original.");
            Iterator it = example.iterator();
            Iterator it2 = example2.iterator();
            while (it.hasNext() && it2.hasNext()) {
                Feature feature = (Feature) it.next();
                Feature feature2 = (Feature) it2.next();
                Assertions.assertEquals(feature.getName(), feature2.getName());
                if (feature.getValue() < d) {
                    Assertions.assertEquals(d, feature2.getValue(), 1.0E-12d);
                } else if (feature.getValue() > d2) {
                    Assertions.assertEquals(d2, feature2.getValue(), 1.0E-12d);
                } else {
                    Assertions.assertEquals(feature.getValue(), feature2.getValue(), 1.0E-12d);
                }
            }
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:10:0x0158. Please report as an issue. */
    @Test
    public void testFeatureSpecific() {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()));
        hashMap.put("F2", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log(), SimpleTransform.exp()));
        hashMap.put("F4", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()));
        hashMap.put("F6", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log(), SimpleTransform.add(5.0d)));
        hashMap.put("F8", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()));
        MutableDataset transformDataset = generateDenseDataset.createTransformers(new TransformationMap(new ArrayList(), hashMap)).transformDataset(generateDenseDataset);
        for (int i = 0; i < generateDenseDataset.size(); i++) {
            Example example = (Example) generateDenseDataset.getData().get(i);
            Example example2 = (Example) transformDataset.getData().get(i);
            Assertions.assertEquals(example.size(), example2.size(), "Transformed not the same size as original.");
            Iterator it = example.iterator();
            Iterator it2 = example2.iterator();
            while (it.hasNext() && it2.hasNext()) {
                Feature feature = (Feature) it.next();
                Feature feature2 = (Feature) it2.next();
                Assertions.assertEquals(feature.getName(), feature2.getName());
                String name = feature.getName();
                boolean z = -1;
                switch (name.hashCode()) {
                    case 2218:
                        if (name.equals("F0")) {
                            z = false;
                            break;
                        }
                        break;
                    case 2220:
                        if (name.equals("F2")) {
                            z = true;
                            break;
                        }
                        break;
                    case 2222:
                        if (name.equals("F4")) {
                            z = 2;
                            break;
                        }
                        break;
                    case 2224:
                        if (name.equals("F6")) {
                            z = 3;
                            break;
                        }
                        break;
                    case 2226:
                        if (name.equals("F8")) {
                            z = 4;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case MockModel.CURRENT_VERSION /* 0 */:
                        Assertions.assertEquals(feature.getValue(), feature2.getValue(), 1.0E-12d);
                        break;
                    case true:
                        Assertions.assertEquals(Math.exp(feature.getValue()), feature2.getValue(), 1.0E-12d);
                        break;
                    case true:
                        Assertions.assertEquals(feature.getValue(), feature2.getValue(), 1.0E-12d);
                        break;
                    case true:
                        Assertions.assertEquals(feature.getValue() + 5.0d, feature2.getValue(), 1.0E-12d);
                        break;
                    case true:
                        Assertions.assertEquals(feature.getValue(), feature2.getValue(), 1.0E-12d);
                        break;
                    default:
                        Assertions.assertEquals(feature.getValue(), feature2.getValue(), 1.0E-12d);
                        break;
                }
            }
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:10:0x016c. Please report as an issue. */
    @Test
    public void testSpecificAndGlobal() {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()));
        hashMap.put("F2", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log(), SimpleTransform.exp()));
        hashMap.put("F4", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log(), SimpleTransform.add(5.0d), SimpleTransform.sub(5.0d)));
        hashMap.put("F6", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log(), SimpleTransform.add(5.0d)));
        hashMap.put("F8", Arrays.asList(SimpleTransform.exp(), SimpleTransform.log()));
        MutableDataset transformDataset = generateDenseDataset.createTransformers(new TransformationMap(Collections.singletonList(SimpleTransform.mul(-1.0d)), hashMap)).transformDataset(generateDenseDataset);
        for (int i = 0; i < generateDenseDataset.size(); i++) {
            Example example = (Example) generateDenseDataset.getData().get(i);
            Example example2 = (Example) transformDataset.getData().get(i);
            Assertions.assertEquals(example.size(), example2.size(), "Transformed not the same size as original.");
            Iterator it = example.iterator();
            Iterator it2 = example2.iterator();
            while (it.hasNext() && it2.hasNext()) {
                Feature feature = (Feature) it.next();
                Feature feature2 = (Feature) it2.next();
                Assertions.assertEquals(feature.getName(), feature2.getName());
                String name = feature.getName();
                boolean z = -1;
                switch (name.hashCode()) {
                    case 2220:
                        if (name.equals("F2")) {
                            z = false;
                            break;
                        }
                        break;
                    case 2224:
                        if (name.equals("F6")) {
                            z = true;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case MockModel.CURRENT_VERSION /* 0 */:
                        Assertions.assertEquals(-Math.exp(feature.getValue()), feature2.getValue(), 1.0E-12d);
                        break;
                    case true:
                        Assertions.assertEquals(-(feature.getValue() + 5.0d), feature2.getValue(), 1.0E-12d);
                        break;
                    default:
                        Assertions.assertEquals(-feature.getValue(), feature2.getValue(), 1.0E-12d);
                        break;
                }
            }
        }
    }

    @Test
    void testSerializeSimpleTransform() throws Exception {
        SimpleTransform simpleTransform = new SimpleTransform(SimpleTransform.Operation.exp, 2.718281828459045d, 3.141592653589793d);
        TransformerProto serialize = simpleTransform.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.transform.transformations.SimpleTransform", serialize.getClassName());
        SimpleTransformProto unpack = serialize.getSerializedData().unpack(SimpleTransformProto.class);
        Assertions.assertEquals("exp", unpack.getOp());
        Assertions.assertEquals(2.718281828459045d, unpack.getFirstOperand());
        Assertions.assertEquals(3.141592653589793d, unpack.getSecondOperand());
        Assertions.assertEquals(simpleTransform, ProtoUtil.deserialize(serialize));
    }
}
