package org.tribuo.protos;

import com.google.protobuf.Message;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.CategoricalInfo;
import org.tribuo.CategoricalInfoTest;
import org.tribuo.MutableFeatureMap;
import org.tribuo.RealIDInfo;
import org.tribuo.RealInfo;
import org.tribuo.hash.HashCodeHasher;
import org.tribuo.hash.HashedFeatureMap;
import org.tribuo.hash.MessageDigestHasher;
import org.tribuo.hash.ModHashCodeHasher;
import org.tribuo.protos.core.CategoricalIDInfoProto;
import org.tribuo.protos.core.FeatureDomainProto;
import org.tribuo.protos.core.HashedFeatureMapProto;
import org.tribuo.protos.core.HasherProto;
import org.tribuo.protos.core.MessageDigestHasherProto;
import org.tribuo.protos.core.ModHashCodeHasherProto;
import org.tribuo.protos.core.RealIDInfoProto;
import org.tribuo.protos.core.RealInfoProto;
import org.tribuo.protos.core.VariableInfoProto;

/* loaded from: input_file:org/tribuo/protos/ProtoUtilTest.class */
public class ProtoUtilTest {

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$IPS.class */
    public interface IPS<W, X, Y extends Message> extends ProtoSerializable<Y> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$IPS2.class */
    public interface IPS2<Y extends Message> extends ProtoSerializable<Y> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSA.class */
    public static class PSA<A, B extends Message> implements IPS<String, String, B> {
        public B serialize() {
            return (B) ProtoUtil.serialize(this);
        }
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSA2.class */
    public static class PSA2<A, B extends Message, Y extends Message> implements IPS2<B> {
        public B serialize() {
            return (B) ProtoUtil.serialize(this);
        }
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSB.class */
    public static class PSB<C extends Message> extends PSA<String, C> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSB2.class */
    public static class PSB2<Y extends Message> extends PSA2<String, Y, CategoricalIDInfoProto> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSC.class */
    public static class PSC extends PSB<CategoricalIDInfoProto> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSC2.class */
    public static class PSC2 extends PSB<RealIDInfoProto> {
    }

    /* loaded from: input_file:org/tribuo/protos/ProtoUtilTest$PSD2.class */
    public static class PSD2 extends PSC2 {
    }

    @Test
    void testHashedFeatureMap() throws Exception {
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap();
        mutableFeatureMap.add("goldrat", 1.618033988749d);
        mutableFeatureMap.add("e", 2.718281828459045d);
        mutableFeatureMap.add("pi", 3.141592653589793d);
        HashedFeatureMap generateHashedFeatureMap = HashedFeatureMap.generateHashedFeatureMap(mutableFeatureMap, new MessageDigestHasher("SHA-512", "abcdefghi"));
        FeatureDomainProto serialize = generateHashedFeatureMap.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.hash.HashedFeatureMap", serialize.getClassName());
        HasherProto hasher = serialize.getSerializedData().unpack(HashedFeatureMapProto.class).getHasher();
        Assertions.assertEquals(0, hasher.getVersion());
        Assertions.assertEquals("org.tribuo.hash.MessageDigestHasher", hasher.getClassName());
        Assertions.assertEquals("SHA-512", hasher.getSerializedData().unpack(MessageDigestHasherProto.class).getHashType());
        HashedFeatureMap deserialize = ProtoUtil.deserialize(serialize);
        deserialize.setSalt("abcdefghi");
        Assertions.assertEquals(generateHashedFeatureMap, deserialize);
    }

    @Test
    void testSerializeModHashCodeHasher() throws Exception {
        ModHashCodeHasher modHashCodeHasher = new ModHashCodeHasher(200, "abcdefghi");
        HasherProto serialize = modHashCodeHasher.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.hash.ModHashCodeHasher", serialize.getClassName());
        Assertions.assertEquals(200, serialize.getSerializedData().unpack(ModHashCodeHasherProto.class).getDimension());
        ModHashCodeHasher deserialize = ProtoUtil.deserialize(serialize);
        deserialize.setSalt("abcdefghi");
        Assertions.assertEquals(modHashCodeHasher, deserialize);
        Assertions.assertEquals(200, serialize.getSerializedData().unpack(ModHashCodeHasherProto.class).getDimension());
    }

    @Test
    void testMessageDigestHasher() throws Exception {
        MessageDigestHasher messageDigestHasher = new MessageDigestHasher("SHA-256", "abcdefghi");
        HasherProto serialize = messageDigestHasher.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.hash.MessageDigestHasher", serialize.getClassName());
        Assertions.assertEquals("SHA-256", serialize.getSerializedData().unpack(MessageDigestHasherProto.class).getHashType());
        MessageDigestHasher deserialize = ProtoUtil.deserialize(serialize);
        deserialize.setSalt("abcdefghi");
        Assertions.assertEquals(messageDigestHasher, deserialize);
    }

    @Test
    void testHashCodeHasher() throws Exception {
        HashCodeHasher hashCodeHasher = new HashCodeHasher("abcdefghi");
        HasherProto serialize = hashCodeHasher.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.hash.HashCodeHasher", serialize.getClassName());
        HashCodeHasher deserialize = ProtoUtil.deserialize(serialize);
        deserialize.setSalt("abcdefghi");
        Assertions.assertEquals(hashCodeHasher, deserialize);
    }

    @Test
    void testRealIDInfo() throws Exception {
        RealIDInfo realIDInfo = new RealIDInfo("bob", 100, 1000.0d, 0.0d, 25.0d, 125.0d, 12345);
        VariableInfoProto serialize = realIDInfo.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.RealIDInfo", serialize.getClassName());
        RealIDInfoProto unpack = serialize.getSerializedData().unpack(RealIDInfoProto.class);
        Assertions.assertEquals("bob", unpack.getName());
        Assertions.assertEquals(100, unpack.getCount());
        Assertions.assertEquals(1000.0d, unpack.getMax());
        Assertions.assertEquals(0.0d, unpack.getMin());
        Assertions.assertEquals(25.0d, unpack.getMean());
        Assertions.assertEquals(125.0d, unpack.getSumSquares());
        Assertions.assertEquals(12345, unpack.getId());
        Assertions.assertEquals(realIDInfo, ProtoUtil.deserialize(serialize));
    }

    @Test
    void testRealInfo() throws Exception {
        RealInfo realInfo = new RealInfo("bob", 100, 1000.0d, 0.0d, 25.0d, 125.0d);
        VariableInfoProto serialize = realInfo.serialize();
        Assertions.assertEquals(0, serialize.getVersion());
        Assertions.assertEquals("org.tribuo.RealInfo", serialize.getClassName());
        RealInfoProto unpack = serialize.getSerializedData().unpack(RealInfoProto.class);
        Assertions.assertEquals("bob", unpack.getName());
        Assertions.assertEquals(100, unpack.getCount());
        Assertions.assertEquals(1000.0d, unpack.getMax());
        Assertions.assertEquals(0.0d, unpack.getMin());
        Assertions.assertEquals(25.0d, unpack.getMean());
        Assertions.assertEquals(125.0d, unpack.getSumSquares());
        Assertions.assertEquals(realInfo, ProtoUtil.deserialize(serialize));
    }

    @Test
    void testGetSerializedClass() throws Exception {
        CategoricalInfo generateProtoTestInfo = CategoricalInfoTest.generateProtoTestInfo();
        Assertions.assertEquals(VariableInfoProto.class, ProtoUtil.getSerializedClass(generateProtoTestInfo));
        Assertions.assertEquals(VariableInfoProto.class, ProtoUtil.getSerializedClass(generateProtoTestInfo.makeIDInfo(12345)));
        Assertions.assertEquals(VariableInfoProto.class, ProtoUtil.getSerializedClass(new RealIDInfo("bob", 100, 1000.0d, 0.0d, 25.0d, 125.0d, 12345)));
        Assertions.assertEquals(VariableInfoProto.class, ProtoUtil.getSerializedClass(new RealInfo("bob", 100, 1000.0d, 0.0d, 25.0d, 125.0d)));
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap();
        mutableFeatureMap.add("goldrat", 1.618033988749d);
        mutableFeatureMap.add("e", 2.718281828459045d);
        mutableFeatureMap.add("pi", 3.141592653589793d);
        Assertions.assertEquals(FeatureDomainProto.class, ProtoUtil.getSerializedClass(HashedFeatureMap.generateHashedFeatureMap(mutableFeatureMap, new MessageDigestHasher("SHA-512", "abcdefghi"))));
        Assertions.assertEquals(HasherProto.class, ProtoUtil.getSerializedClass(new ModHashCodeHasher(200, "abcdefghi")));
        Assertions.assertEquals(HasherProto.class, ProtoUtil.getSerializedClass(new MessageDigestHasher("SHA-256", "abcdefghi")));
        Assertions.assertEquals(HasherProto.class, ProtoUtil.getSerializedClass(new HashCodeHasher("abcdefghi")));
        Assertions.assertEquals(CategoricalIDInfoProto.class, ProtoUtil.getSerializedClass(new PSC()));
        Assertions.assertEquals(RealIDInfoProto.class, ProtoUtil.getSerializedClass(new PSD2()));
        Assertions.assertEquals(RealIDInfoProto.class, ProtoUtil.getSerializedClass(new PSC2()));
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            ProtoUtil.getSerializedClass(new PSB2());
        });
    }
}
