package org.tribuo.test;

import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.ConfigurationData;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Assertions;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ListExample;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;

/* loaded from: input_file:org/tribuo/test/Helpers.class */
public final class Helpers {
    private static final Logger logger = Logger.getLogger(Helpers.class.getName());

    private Helpers() {
    }

    public static ImmutableFeatureMap mkFeatureMap(String... strArr) {
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap();
        for (String str : strArr) {
            mutableFeatureMap.add(str, 1.0d);
        }
        return new ImmutableFeatureMap(mutableFeatureMap);
    }

    public static Example<MockOutput> mkExample(MockOutput mockOutput, String... strArr) {
        ListExample listExample = new ListExample(mockOutput);
        HashMap hashMap = new HashMap();
        for (String str : strArr) {
            hashMap.put(str, Integer.valueOf(((Integer) hashMap.getOrDefault(str, 0)).intValue() + 1));
        }
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            listExample.add(new Feature((String) ((Map.Entry) it.next()).getKey(), 1.0d * ((Integer) r0.getValue()).intValue()));
        }
        return listExample;
    }

    public static <T extends Output<T>> boolean sequenceDatasetEquals(SequenceDataset<T> sequenceDataset, SequenceDataset<T> sequenceDataset2) {
        if (sequenceDataset.size() != sequenceDataset2.size()) {
            return false;
        }
        for (int i = 0; i < sequenceDataset.size(); i++) {
            if (!sequenceDataset.getExample(i).equals(sequenceDataset2.getExample(i))) {
                return false;
            }
        }
        if (sequenceDataset.getOutputFactory().equals(sequenceDataset2.getOutputFactory()) && sequenceDataset.getFeatureMap().equals(sequenceDataset2.getFeatureMap())) {
            return sequenceDataset.getOutputInfo().equals(sequenceDataset2.getOutputInfo());
        }
        return false;
    }

    public static <T extends Output<T>> boolean datasetEquals(Dataset<T> dataset, Dataset<T> dataset2) {
        if (dataset.size() != dataset2.size()) {
            return false;
        }
        for (int i = 0; i < dataset.size(); i++) {
            if (!dataset.getExample(i).equals(dataset2.getExample(i))) {
                return false;
            }
        }
        if (dataset.getOutputFactory().equals(dataset2.getOutputFactory()) && dataset.getFeatureMap().equals(dataset2.getFeatureMap())) {
            return dataset.getOutputInfo().equals(dataset2.getOutputInfo());
        }
        return false;
    }

    public static <P extends ConfiguredObjectProvenance, C extends Configurable & Provenancable<P>> void testConfigurableRoundtrip(C c) {
        ConfigurationManager configurationManager = new ConfigurationManager();
        String importConfigurable = configurationManager.importConfigurable(c, "item");
        Stream stream = configurationManager.getComponentNames().stream();
        Objects.requireNonNull(configurationManager);
        List list = (List) stream.map(configurationManager::getConfigurationData).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        }).collect(Collectors.toList());
        List extractConfiguration = ProvenanceUtil.extractConfiguration(((Provenancable) c).getProvenance());
        Assertions.assertTrue(ConfigurationData.structuralEquals(list, extractConfiguration, importConfigurable, ((ConfigurationData) extractConfiguration.get(0)).getName()));
    }

    public static void testProvenanceMarshalling(ObjectProvenance objectProvenance) {
        Assertions.assertEquals(ProvenanceUtil.unmarshalProvenance(ProvenanceUtil.marshalProvenance(objectProvenance)), objectProvenance);
    }

    public static <T extends Output<T>> SequenceDataset<T> testSequenceDatasetSerialization(SequenceDataset<T> sequenceDataset) {
        SequenceDataset<T> sequenceDataset2 = (SequenceDataset) ProtoUtil.deserialize(sequenceDataset.serialize());
        Assertions.assertEquals(sequenceDataset.getClass(), sequenceDataset2.getClass());
        Assertions.assertFalse(sequenceDataset == sequenceDataset2);
        Assertions.assertTrue(sequenceDatasetEquals(sequenceDataset, sequenceDataset2));
        return sequenceDataset2;
    }

    public static <T extends Output<T>> Dataset<T> testDatasetSerialization(Dataset<T> dataset) {
        Dataset<T> dataset2 = (Dataset) ProtoUtil.deserialize(dataset.serialize());
        Assertions.assertEquals(dataset.getClass(), dataset2.getClass());
        Assertions.assertFalse(dataset == dataset2);
        Assertions.assertTrue(datasetEquals(dataset, dataset2));
        return dataset2;
    }

    public static <U extends Message, T extends ProtoSerializable<U>> T testProtoSerialization(T t) {
        T t2 = (T) ProtoUtil.deserialize(t.serialize());
        Assertions.assertEquals(t, t2);
        return t2;
    }

    public static <T extends Output<T>> Model<T> testModelProtoSerialization(Model<T> model, Class<T> cls, Iterable<Example<T>> iterable) {
        return testModelProtoSerialization(model, cls, iterable, 1.0E-15d);
    }

    public static <T extends Output<T>> Model<T> testModelProtoSerialization(Model<T> model, Class<T> cls, Iterable<Example<T>> iterable, double d) {
        testProvenanceMarshalling(model.getProvenance());
        Model deserialize = Model.deserialize(model.serialize());
        Assertions.assertEquals(model.getProvenance(), deserialize.getProvenance());
        Assertions.assertTrue(deserialize.validate(cls));
        Model<T> castModel = deserialize.castModel(cls);
        List predict = model.predict(iterable);
        List predict2 = castModel.predict(iterable);
        Assertions.assertEquals(predict.size(), predict2.size());
        for (int i = 0; i < predict.size(); i++) {
            Assertions.assertTrue(((Prediction) predict.get(i)).distributionEquals((Prediction) predict2.get(i), d));
        }
        return castModel;
    }

    public static <T extends Output<T>> SequenceModel<T> testSequenceModelProtoSerialization(SequenceModel<T> sequenceModel, Class<T> cls, SequenceDataset<T> sequenceDataset) {
        testProvenanceMarshalling(sequenceModel.getProvenance());
        SequenceModel deserialize = SequenceModel.deserialize(sequenceModel.serialize());
        Assertions.assertEquals(sequenceModel.getProvenance(), deserialize.getProvenance());
        Assertions.assertTrue(deserialize.validate(cls));
        SequenceModel<T> castModel = deserialize.castModel(cls);
        List predict = sequenceModel.predict(sequenceDataset);
        List predict2 = castModel.predict(sequenceDataset);
        Assertions.assertEquals(predict.size(), predict2.size());
        for (int i = 0; i < predict.size(); i++) {
            List list = (List) predict.get(i);
            List list2 = (List) predict2.get(i);
            Assertions.assertEquals(list.size(), list2.size());
            for (int i2 = 0; i2 < list.size(); i2++) {
                Assertions.assertTrue(((Prediction) list.get(i2)).distributionEquals((Prediction) list2.get(i2)));
            }
        }
        return castModel;
    }

    public static <T extends Output<T>> void testModelSerialization(Model<T> model, Class<T> cls) {
        testProvenanceMarshalling(model.getProvenance());
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(byteArrayOutputStream));
            try {
                objectOutputStream.writeObject(model);
                objectOutputStream.close();
            } finally {
            }
        } catch (IOException e) {
            logger.severe("IOException when writing out model");
            Assertions.fail("Failed to serialize model class " + model.getClass().toString(), e);
        }
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray())));
            try {
                AutoCloseable autoCloseable = (Model) objectInputStream.readObject();
                Assertions.assertEquals(model.getProvenance(), autoCloseable.getProvenance());
                Assertions.assertTrue(autoCloseable.validate(cls));
                if (autoCloseable instanceof AutoCloseable) {
                    try {
                        autoCloseable.close();
                    } catch (Exception e2) {
                        logger.severe("Exception thrown when closing model");
                        Assertions.fail("Failed to close deserialized model " + model.getClass().toString(), e2);
                    }
                }
                objectInputStream.close();
            } finally {
            }
        } catch (IOException e3) {
            logger.severe("IOException when reading in model");
            Assertions.fail("Failed to deserialize model class " + model.getClass().toString(), e3);
        } catch (ClassNotFoundException e4) {
            logger.severe("ClassNotFoundException when reading in model");
            Assertions.fail("Failed to deserialize model class " + model.getClass().toString(), e4);
        }
    }

    public static <T extends Output<T>> void testSequenceModelSerialization(SequenceModel<T> sequenceModel, Class<T> cls) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(byteArrayOutputStream));
            try {
                objectOutputStream.writeObject(sequenceModel);
                objectOutputStream.close();
            } finally {
            }
        } catch (IOException e) {
            logger.severe("IOException when writing out model");
            Assertions.fail("Failed to serialize sequence model class " + sequenceModel.getClass().toString(), e);
        }
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray())));
            try {
                AutoCloseable autoCloseable = (SequenceModel) objectInputStream.readObject();
                Assertions.assertEquals(sequenceModel.getProvenance(), autoCloseable.getProvenance());
                Assertions.assertTrue(autoCloseable.validate(cls));
                if (autoCloseable instanceof AutoCloseable) {
                    try {
                        autoCloseable.close();
                    } catch (Exception e2) {
                        logger.severe("Exception thrown when closing model");
                        Assertions.fail("Failed to close deserialized model " + sequenceModel.getClass().toString(), e2);
                    }
                }
                objectInputStream.close();
            } finally {
            }
        } catch (IOException e3) {
            logger.severe("IOException when reading in model");
            Assertions.fail("Failed to deserialize sequence model class " + sequenceModel.getClass().toString(), e3);
        } catch (ClassNotFoundException e4) {
            logger.severe("ClassNotFoundException when reading in model");
            Assertions.fail("Failed to deserialize sequence model class " + sequenceModel.getClass().toString(), e4);
        }
    }

    public static boolean topFeaturesEqual(Map<String, List<Pair<String, Double>>> map, Map<String, List<Pair<String, Double>>> map2, double d) {
        if (map.size() != map2.size() || !map.keySet().containsAll(map2.keySet())) {
            return false;
        }
        for (Map.Entry<String, List<Pair<String, Double>>> entry : map.entrySet()) {
            List<Pair<String, Double>> value = entry.getValue();
            List<Pair<String, Double>> list = map2.get(entry.getKey());
            if (value.size() != list.size()) {
                return false;
            }
            for (int i = 0; i < value.size(); i++) {
                Pair<String, Double> pair = value.get(i);
                Pair<String, Double> pair2 = list.get(i);
                if (!((String) pair.getA()).equals(pair2.getA()) || Math.abs(((Double) pair.getB()).doubleValue() - ((Double) pair2.getB()).doubleValue()) > d) {
                    return false;
                }
            }
        }
        return true;
    }

    public static <T extends Output<T>> boolean predictionListDistributionEquals(List<Prediction<T>> list, List<Prediction<T>> list2) {
        return predictionListDistributionEquals(list, list2, 1.0E-14d);
    }

    public static <T extends Output<T>> boolean predictionListDistributionEquals(List<Prediction<T>> list, List<Prediction<T>> list2, double d) {
        if (list.size() != list2.size()) {
            return false;
        }
        boolean z = true;
        for (int i = 0; i < list.size(); i++) {
            z &= list.get(i).distributionEquals(list2.get(i), d);
        }
        return z;
    }

    public static void writeProtobuf(ProtoSerializable<?> protoSerializable, Path path) throws IOException {
        Message serialize = protoSerializable.serialize();
        Files.createDirectories(path.getParent(), new FileAttribute[0]);
        OutputStream newOutputStream = Files.newOutputStream(path, new OpenOption[0]);
        try {
            serialize.writeTo(newOutputStream);
            if (newOutputStream != null) {
                newOutputStream.close();
            }
        } catch (Throwable th) {
            if (newOutputStream != null) {
                try {
                    newOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
