package org.tribuo.transform;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockModel;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.test.protos.TestCountTransformerProto;
import org.tribuo.transform.transformations.LinearScalingTransformation;
import org.tribuo.transform.transformations.SimpleTransform;

/* loaded from: input_file:org/tribuo/transform/TransformationMapTest.class */
public class TransformationMapTest {

    /* loaded from: input_file:org/tribuo/transform/TransformationMapTest$CountStatistics.class */
    private static class CountStatistics implements TransformStatistics {
        public int sparseCount;
        public Map<Double, MutableLong> countMap;
        public int count;

        private CountStatistics() {
            this.countMap = new HashMap();
        }

        public void observeValue(double d) {
            this.countMap.computeIfAbsent(Double.valueOf(d), d2 -> {
                return new MutableLong();
            }).increment();
            this.count++;
        }

        @Deprecated
        public void observeSparse() {
            this.sparseCount++;
        }

        public void observeSparse(int i) {
            this.sparseCount += i;
        }

        public Transformer generateTransformer() {
            return new CountTransformer(this.sparseCount, this.count, this.countMap);
        }
    }

    /* loaded from: input_file:org/tribuo/transform/TransformationMapTest$CountTransformation.class */
    public static class CountTransformation implements Transformation {
        public TransformStatistics createStats() {
            return new CountStatistics();
        }

        /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
        public TransformationProvenance m317getProvenance() {
            return new CountTransformationProvenance();
        }
    }

    /* loaded from: input_file:org/tribuo/transform/TransformationMapTest$CountTransformationProvenance.class */
    public static class CountTransformationProvenance implements TransformationProvenance {
        public CountTransformationProvenance() {
        }

        public CountTransformationProvenance(Map<String, Provenance> map) {
        }

        public String getClassName() {
            return CountTransformation.class.getName();
        }

        public Map<String, Provenance> getConfiguredParameters() {
            return Collections.emptyMap();
        }

        public boolean equals(Object obj) {
            return obj instanceof CountTransformationProvenance;
        }
    }

    /* loaded from: input_file:org/tribuo/transform/TransformationMapTest$CountTransformer.class */
    private static class CountTransformer implements Transformer {
        public final int count;
        public final int sparseCount;
        public final Map<Double, MutableLong> countMap;

        public CountTransformer(int i, int i2, Map<Double, MutableLong> map) {
            this.count = i2;
            this.sparseCount = i;
            this.countMap = map;
        }

        static CountTransformer deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
            TestCountTransformerProto unpack = any.unpack(TestCountTransformerProto.class);
            if (i != 0) {
                throw new IllegalArgumentException("Unknown version " + i + " expected {0}");
            }
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < unpack.getCountMapKeysCount(); i2++) {
                hashMap.put(Double.valueOf(unpack.getCountMapKeys(i2)), new MutableLong(unpack.getCountMapValues(i2)));
            }
            return new CountTransformer(unpack.getSparseCount(), unpack.getCount(), hashMap);
        }

        public double transform(double d) {
            return d;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            CountTransformer countTransformer = (CountTransformer) obj;
            if ((this.countMap != null) ^ (countTransformer.countMap != null)) {
                return false;
            }
            if (this.countMap != null && countTransformer.countMap != null) {
                if (this.countMap.size() != countTransformer.countMap.size()) {
                    return false;
                }
                for (Map.Entry<Double, MutableLong> entry : this.countMap.entrySet()) {
                    MutableLong mutableLong = countTransformer.countMap.get(entry.getKey());
                    if (mutableLong == null || entry.getValue().longValue() != mutableLong.longValue()) {
                        return false;
                    }
                }
            }
            return this.count == countTransformer.count && this.sparseCount == countTransformer.sparseCount;
        }

        public int hashCode() {
            return Objects.hash(Integer.valueOf(this.count), Integer.valueOf(this.sparseCount), this.countMap);
        }

        /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
        public TransformerProto m318serialize() {
            TransformerProto.Builder newBuilder = TransformerProto.newBuilder();
            newBuilder.setVersion(0);
            newBuilder.setClassName(getClass().getName());
            TestCountTransformerProto.Builder sparseCount = TestCountTransformerProto.newBuilder().setCount(this.count).setSparseCount(this.sparseCount);
            for (Map.Entry<Double, MutableLong> entry : this.countMap.entrySet()) {
                sparseCount.addCountMapKeys(entry.getKey().doubleValue());
                sparseCount.addCountMapValues(entry.getValue().longValue());
            }
            newBuilder.setSerializedData(Any.pack(sparseCount.m309build()));
            return newBuilder.build();
        }
    }

    public static MutableDataset<MockOutput> generateDenseDataset() {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1", "F2", "F3", "F4"};
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 9.0d, 8.0d, 7.0d, 6.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{2.0d, 2.0d, 2.0d, 2.0d, 2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{10.0d, 10.0d, 10.0d, 10.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 5.0d, 1.0d, 5.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{5.0d, 1.0d, 5.0d, 1.0d, 5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}));
        return mutableDataset;
    }

    public static MutableDataset<MockOutput> generateSparseDataset() {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        MockOutput mockOutput = new MockOutput("UNK");
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F0", "F1", "F2", "F3", "F4"}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F0", "F1", "F2", "F3"}, new double[]{1.0d, 2.0d, 3.0d, 4.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F0"}, new double[]{10.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F0", "F2"}, new double[]{1.0d, 1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F1"}, new double[]{1.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F2"}, new double[]{5.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F1", "F3"}, new double[]{2.0d, 2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F3"}, new double[]{2.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F3"}, new double[]{4.0d}));
        mutableDataset.add(new ArrayExample(mockOutput, new String[]{"F1", "F2", "F4"}, new double[]{1.0d, 1.0d, 1.0d}));
        return mutableDataset;
    }

    @Test
    public void testRegex() {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Collections.singletonList(SimpleTransform.add(0.0d)));
        hashMap.put("F\\d", Collections.singletonList(SimpleTransform.add(0.0d)));
        try {
            generateDenseDataset.createTransformers(new TransformationMap(new ArrayList(), hashMap));
            Assertions.fail("Should have thrown IllegalArgumentException");
        } catch (IllegalArgumentException e) {
        } catch (Exception e2) {
            Assertions.fail("Unexpected exception " + e2);
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x00ab. Please report as an issue. */
    @Test
    public void testNumTransformers() {
        MutableDataset<MockOutput> generateDenseDataset = generateDenseDataset();
        HashMap hashMap = new HashMap();
        hashMap.put("F0", Collections.singletonList(SimpleTransform.add(1.0d)));
        hashMap.put("F[34]", Arrays.asList(SimpleTransform.add(5.0d), SimpleTransform.log()));
        hashMap.put("F1", Arrays.asList(SimpleTransform.add(1.0d), SimpleTransform.mul(5.0d), SimpleTransform.exp()));
        TransformerMap createTransformers = generateDenseDataset.createTransformers(new TransformationMap(Collections.singletonList(new LinearScalingTransformation()), hashMap));
        for (Map.Entry entry : createTransformers.entrySet()) {
            String str = (String) entry.getKey();
            boolean z = -1;
            switch (str.hashCode()) {
                case 2218:
                    if (str.equals("F0")) {
                        z = false;
                        break;
                    }
                    break;
                case 2219:
                    if (str.equals("F1")) {
                        z = true;
                        break;
                    }
                    break;
                case 2220:
                    if (str.equals("F2")) {
                        z = 2;
                        break;
                    }
                    break;
                case 2221:
                    if (str.equals("F3")) {
                        z = 3;
                        break;
                    }
                    break;
                case 2222:
                    if (str.equals("F4")) {
                        z = 4;
                        break;
                    }
                    break;
            }
            switch (z) {
                case MockModel.CURRENT_VERSION /* 0 */:
                    Assertions.assertEquals(2, ((List) entry.getValue()).size());
                    break;
                case true:
                    Assertions.assertEquals(4, ((List) entry.getValue()).size());
                    break;
                case true:
                    Assertions.assertEquals(1, ((List) entry.getValue()).size());
                    break;
                case true:
                    Assertions.assertEquals(3, ((List) entry.getValue()).size());
                    break;
                case true:
                    Assertions.assertEquals(3, ((List) entry.getValue()).size());
                    break;
                default:
                    Assertions.fail("Unknown feature named " + ((String) entry.getKey()));
                    break;
            }
        }
        TransformerMap deserialize = TransformerMap.deserialize(createTransformers.serialize());
        Assertions.assertEquals(createTransformers, deserialize);
        Assertions.assertNotSame(createTransformers, deserialize);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x006e. Please report as an issue. */
    @Test
    public void testSparseObservations() {
        TransformerMap createTransformers = generateSparseDataset().createTransformers(new TransformationMap(Collections.singletonList(new CountTransformation()), new HashMap()), true);
        for (Map.Entry entry : createTransformers.entrySet()) {
            CountTransformer countTransformer = (CountTransformer) ((List) entry.getValue()).get(0);
            String str = (String) entry.getKey();
            boolean z = -1;
            switch (str.hashCode()) {
                case 2218:
                    if (str.equals("F0")) {
                        z = false;
                        break;
                    }
                    break;
                case 2219:
                    if (str.equals("F1")) {
                        z = true;
                        break;
                    }
                    break;
                case 2220:
                    if (str.equals("F2")) {
                        z = 2;
                        break;
                    }
                    break;
                case 2221:
                    if (str.equals("F3")) {
                        z = 3;
                        break;
                    }
                    break;
                case 2222:
                    if (str.equals("F4")) {
                        z = 4;
                        break;
                    }
                    break;
            }
            switch (z) {
                case MockModel.CURRENT_VERSION /* 0 */:
                    Assertions.assertEquals(6, countTransformer.sparseCount);
                    Assertions.assertEquals(4, countTransformer.count);
                    break;
                case true:
                    Assertions.assertEquals(5, countTransformer.sparseCount);
                    Assertions.assertEquals(5, countTransformer.count);
                    break;
                case true:
                    Assertions.assertEquals(5, countTransformer.sparseCount);
                    Assertions.assertEquals(5, countTransformer.count);
                    break;
                case true:
                    Assertions.assertEquals(5, countTransformer.sparseCount);
                    Assertions.assertEquals(5, countTransformer.count);
                    break;
                case true:
                    Assertions.assertEquals(8, countTransformer.sparseCount);
                    Assertions.assertEquals(2, countTransformer.count);
                    break;
                default:
                    Assertions.fail("Unknown feature named " + ((String) entry.getKey()));
                    break;
            }
        }
        TransformerMap deserialize = TransformerMap.deserialize(createTransformers.serialize());
        Assertions.assertEquals(createTransformers, deserialize);
        Assertions.assertNotSame(createTransformers, deserialize);
    }
}
