package org.tribuo.transform.transformations;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.FeatureMap;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.MeanStdDevTransformerProto;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.transform.TransformationMap;
import org.tribuo.transform.Transformer;
import org.tribuo.transform.TransformerMap;
import org.tribuo.transform.transformations.MeanStdDevTransformation;

/* loaded from: input_file:org/tribuo/transform/transformations/MeanStdDevTest.class */
public class MeanStdDevTest {
    public static MutableDataset<MockOutput> generateDenseDataset(int i) {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        Random random = new Random(i);
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1"};
        for (int i2 = 0; i2 < 10000; i2++) {
            mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{(random.nextGaussian() * 5.0d) + 10.0d, random.nextDouble() * (-20.0d)}));
        }
        return mutableDataset;
    }

    @Test
    public void testMeanZeroStdDevOne() {
        testMeanStdDev(new TransformationMap(Collections.singletonList(new MeanStdDevTransformation()), new HashMap()), 0.0d, 1.0d);
    }

    @Test
    public void testMeanZeroStdDevFive() {
        testMeanStdDev(new TransformationMap(Collections.singletonList(new MeanStdDevTransformation(0.0d, 5.0d)), new HashMap()), 0.0d, 5.0d);
    }

    @Test
    public void testMeanFiveStdDevOne() {
        testMeanStdDev(new TransformationMap(Collections.singletonList(new MeanStdDevTransformation(5.0d, 1.0d)), new HashMap()), 5.0d, 1.0d);
    }

    @Test
    public void testMeanMinusFiveStdDevFive() {
        testMeanStdDev(new TransformationMap(Collections.singletonList(new MeanStdDevTransformation(-5.0d, 5.0d)), new HashMap()), -5.0d, 5.0d);
    }

    @Test
    public void testInvalidTransformation() {
        try {
            new MeanStdDevTransformation(0.0d, 0.0d);
            Assertions.fail("Should have thrown exception");
        } catch (IllegalArgumentException e) {
        } catch (Exception e2) {
            Assertions.fail("Threw incorrect exception, should have been IllegalArgumentException, found " + e2);
        }
        try {
            new MeanStdDevTransformation(0.0d, -1.0d);
            Assertions.fail("Should have thrown exception");
        } catch (IllegalArgumentException e3) {
        } catch (Exception e4) {
            Assertions.fail("Threw incorrect exception, should have been IllegalArgumentException, found " + e4);
        }
    }

    public void testMeanStdDev(TransformationMap transformationMap, double d, double d2) {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset(1);
        MutableDataset<MockOutput> generateDenseDataset2 = generateDenseDataset(2);
        TransformerMap createTransformers = generateDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateDenseDataset);
        MutableDataset transformDataset2 = createTransformers.transformDataset(generateDenseDataset2);
        FeatureMap featureMap = transformDataset.getFeatureMap();
        FeatureMap featureMap2 = transformDataset2.getFeatureMap();
        Assertions.assertEquals(d, featureMap.get("F0").getMean(), 1.0E-5d);
        Assertions.assertEquals(d2, Math.sqrt(featureMap.get("F0").getVariance()), 1.0E-5d);
        Assertions.assertEquals(d, featureMap.get("F1").getMean(), 1.0E-5d);
        Assertions.assertEquals(d2, Math.sqrt(featureMap.get("F1").getVariance()), 1.0E-5d);
        Assertions.assertEquals(d, featureMap2.get("F0").getMean(), 0.1d);
        Assertions.assertEquals(d2, Math.sqrt(featureMap2.get("F0").getVariance()), 0.1d);
        Assertions.assertEquals(d, featureMap2.get("F1").getMean(), 0.01d);
        Assertions.assertEquals(d2, Math.sqrt(featureMap2.get("F1").getVariance()), 0.01d);
        List list = createTransformers.get("F0");
        Assertions.assertEquals(1, list.size());
        Transformer deserialize = Transformer.deserialize(((Transformer) list.get(0)).serialize());
        Assertions.assertEquals(list.get(0), deserialize);
        Assertions.assertNotSame(list.get(0), deserialize);
    }

    @Test
    void testSerializeMeanStdDevTransformer() throws Exception {
        MeanStdDevTransformation.MeanStdDevTransformer meanStdDevTransformer = new MeanStdDevTransformation.MeanStdDevTransformer(2.718281828459045d, 3.141592653589793d, 0.618033988749d, 1.059463094359d);
        TransformerProto serialize = meanStdDevTransformer.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.transform.transformations.MeanStdDevTransformation$MeanStdDevTransformer", serialize.getClassName());
        MeanStdDevTransformerProto unpack = serialize.getSerializedData().unpack(MeanStdDevTransformerProto.class);
        Assertions.assertEquals(2.718281828459045d, unpack.getObservedMean());
        Assertions.assertEquals(3.141592653589793d, unpack.getObservedStdDev());
        Assertions.assertEquals(0.618033988749d, unpack.getTargetMean());
        Assertions.assertEquals(1.059463094359d, unpack.getTargetStdDev());
        Assertions.assertEquals(meanStdDevTransformer, ProtoUtil.deserialize(serialize));
    }
}
