package org.tribuo.transform;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.test.Helpers;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.test.MockTrainer;
import org.tribuo.transform.transformations.BinningTransformation;
import org.tribuo.transform.transformations.IDFTransformation;
import org.tribuo.transform.transformations.LinearScalingTransformation;
import org.tribuo.transform.transformations.MeanStdDevTransformation;
import org.tribuo.transform.transformations.SimpleTransform;

/* loaded from: input_file:org/tribuo/transform/TransformedModelTest.class */
public class TransformedModelTest {
    public static MutableDataset<MockOutput> generateDenseDataset() {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1", "F2", "F3", "F4"};
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 9.0d, 8.0d, 7.0d, 6.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{2.0d, 2.0d, 2.0d, 2.0d, 2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 10.0d, 10.0d, 10.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 5.0d, 1.0d, 5.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{5.0d, 1.0d, 5.0d, 1.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}));
        return mutableDataset;
    }

    @Test
    public void serializationTest() {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        Helpers.testModelProtoSerialization(new TransformTrainer(new MockTrainer("UNK"), new TransformationMap(Collections.singletonList(new LinearScalingTransformation()))).train(generateDenseDataset), MockOutput.class, generateDenseDataset);
    }

    @Test
    public void test431Protobufs() throws IOException, URISyntaxException {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Collections.singletonList(BinningTransformation.equalWidth(3)));
        hashMap.put("F1", Collections.singletonList(new IDFTransformation()));
        hashMap.put("F2", Collections.singletonList(new LinearScalingTransformation(1.0d, 10.0d)));
        hashMap.put("F3", Collections.singletonList(new MeanStdDevTransformation(1.0d, 5.0d)));
        hashMap.put("F4", Collections.singletonList(SimpleTransform.mul(5.0d)));
        TransformedModel train = new TransformTrainer(new MockTrainer("UNK"), new TransformationMap(hashMap)).train(generateDenseDataset);
        InputStream newInputStream = Files.newInputStream(Paths.get(TransformedModelTest.class.getResource("transformed-model-431.tribuo").toURI()), new OpenOption[0]);
        try {
            TransformedModel deserialize = Model.deserialize(ModelProto.parseFrom(newInputStream));
            Assertions.assertEquals("4.3.1", deserialize.getProvenance().getTribuoVersion());
            List predict = train.predict(generateDenseDataset);
            List predict2 = deserialize.predict(generateDenseDataset);
            Assertions.assertEquals(predict.size(), predict2.size());
            Assertions.assertEquals(generateDenseDataset.size(), predict2.size());
            if (newInputStream != null) {
                newInputStream.close();
            }
        } catch (Throwable th) {
            if (newInputStream != null) {
                try {
                    newInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void generateProtobufs() throws IOException {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Collections.singletonList(BinningTransformation.equalWidth(3)));
        hashMap.put("F1", Collections.singletonList(new IDFTransformation()));
        hashMap.put("F2", Collections.singletonList(new LinearScalingTransformation(1.0d, 10.0d)));
        hashMap.put("F3", Collections.singletonList(new MeanStdDevTransformation(1.0d, 5.0d)));
        hashMap.put("F4", Collections.singletonList(SimpleTransform.mul(5.0d)));
        Helpers.writeProtobuf(new TransformTrainer(new MockTrainer("UNK"), new TransformationMap(hashMap)).train(generateDenseDataset), Paths.get("src", "test", "resources", "org", "tribuo", "transform", "transformed-model-431.tribuo"));
    }
}
