package org.tribuo.transform.transformations;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.LinearScalingTransformerProto;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.transform.TransformationMap;
import org.tribuo.transform.Transformer;
import org.tribuo.transform.TransformerMap;
import org.tribuo.transform.transformations.LinearScalingTransformation;

/* loaded from: input_file:org/tribuo/transform/transformations/LinearScalingTest.class */
public class LinearScalingTest {
    public static MutableDataset<MockOutput> generateDenseDataset() {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9"};
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 9.0d, 8.0d, 7.0d, 6.0d, 5.0d, 4.0d, 3.0d, 2.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d, 2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 10.0d, 10.0d, 10.0d, 10.0d, 10.0d, 10.0d, 10.0d, 10.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d, 5.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d}));
        return mutableDataset;
    }

    public void testGlobalLinearScaling(TransformationMap transformationMap, double d, double d2) {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        TransformerMap createTransformers = generateDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateDenseDataset);
        for (int i = 0; i < generateDenseDataset.size(); i++) {
            Example example = (Example) generateDenseDataset.getData().get(i);
            Example example2 = (Example) transformDataset.getData().get(i);
            Assertions.assertEquals(example.size(), example2.size(), "Transformed not the same size as original.");
            Iterator it = example.iterator();
            Iterator it2 = example2.iterator();
            while (it.hasNext() && it2.hasNext()) {
                Feature feature = (Feature) it.next();
                Feature feature2 = (Feature) it2.next();
                Assertions.assertEquals(feature.getName(), feature2.getName());
                Assertions.assertTrue(feature2.getValue() > d - 1.0E-12d);
                Assertions.assertTrue(feature2.getValue() < d2 + 1.0E-12d);
            }
        }
        List list = createTransformers.get("F0");
        if (list.size() == 1) {
            Transformer deserialize = Transformer.deserialize(((Transformer) list.get(0)).serialize());
            Assertions.assertEquals(list.get(0), deserialize);
            Assertions.assertNotSame(list.get(0), deserialize);
        }
    }

    @Test
    public void testGlobalLinearScalingSimple() {
        testGlobalLinearScaling(new TransformationMap(Collections.singletonList(new LinearScalingTransformation()), new HashMap()), 0.0d, 1.0d);
    }

    @Test
    public void testGlobalLinearScalingAdd() {
        testGlobalLinearScaling(new TransformationMap(Arrays.asList(SimpleTransform.add(5.0d), new LinearScalingTransformation()), new HashMap()), 0.0d, 1.0d);
    }

    @Test
    public void testGlobalLinearScalingSub() {
        testGlobalLinearScaling(new TransformationMap(Arrays.asList(SimpleTransform.sub(5.0d), new LinearScalingTransformation()), new HashMap()), 0.0d, 1.0d);
    }

    @Test
    public void testGlobalLinearScalingChain() {
        testGlobalLinearScaling(new TransformationMap(Arrays.asList(SimpleTransform.add(5.0d), SimpleTransform.sub(10.0d), new LinearScalingTransformation()), new HashMap()), 0.0d, 1.0d);
    }

    @Test
    public void testGlobalLinearScalingInvertedChain() {
        testGlobalLinearScaling(new TransformationMap(Arrays.asList(new LinearScalingTransformation(), SimpleTransform.mul(5.0d), SimpleTransform.sub(2.5d)), new HashMap()), -2.5d, 2.5d);
    }

    @Test
    public void testGlobalLinearScalingRange() {
        testGlobalLinearScaling(new TransformationMap(Collections.singletonList(new LinearScalingTransformation(-5.0d, 5.0d)), new HashMap()), -5.0d, 5.0d);
    }

    @Test
    public void testGlobalLinearScalingFeatureSpecific() {
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Collections.singletonList(SimpleTransform.add(5.0d)));
        hashMap.put("F9", Collections.singletonList(SimpleTransform.div(5.0d)));
        testGlobalLinearScaling(new TransformationMap(Collections.singletonList(new LinearScalingTransformation()), hashMap), 0.0d, 1.0d);
    }

    @Test
    void testSerialize() throws Exception {
        LinearScalingTransformation.LinearScalingTransformer linearScalingTransformer = new LinearScalingTransformation.LinearScalingTransformer(2.718281828459045d, 3.141592653589793d, 0.618033988749d, 1.059463094359d);
        TransformerProto serialize = linearScalingTransformer.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.transform.transformations.LinearScalingTransformation$LinearScalingTransformer", serialize.getClassName());
        LinearScalingTransformerProto unpack = serialize.getSerializedData().unpack(LinearScalingTransformerProto.class);
        Assertions.assertEquals(2.718281828459045d, unpack.getObservedMin());
        Assertions.assertEquals(3.141592653589793d, unpack.getObservedMax());
        Assertions.assertEquals(0.618033988749d, unpack.getTargetMin());
        Assertions.assertEquals(1.059463094359d, unpack.getTargetMax());
        Assertions.assertEquals(linearScalingTransformer, ProtoUtil.deserialize(serialize));
    }
}
