package org.tribuo.sequence;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Output;
import org.tribuo.hash.HashedFeatureMap;
import org.tribuo.hash.Hasher;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;

/* loaded from: input_file:org/tribuo/sequence/HashingSequenceTrainer.class */
public final class HashingSequenceTrainer<T extends Output<T>> implements SequenceTrainer<T> {
    private static final Logger logger = Logger.getLogger(HashingSequenceTrainer.class.getName());

    @Config(mandatory = true, description = "Trainer to use.")
    private SequenceTrainer<T> innerTrainer;

    @Config(mandatory = true, description = "Feature hashing function to use.")
    private Hasher hasher;

    /* loaded from: input_file:org/tribuo/sequence/HashingSequenceTrainer$HashingSequenceTrainerProvenance.class */
    public static class HashingSequenceTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;

        <T extends Output<T>> HashingSequenceTrainerProvenance(HashingSequenceTrainer<T> hashingSequenceTrainer) {
            super(hashingSequenceTrainer);
        }

        public HashingSequenceTrainerProvenance(Map<String, Provenance> map) {
            super(extractProvenanceInfo(map));
        }
    }

    private HashingSequenceTrainer() {
    }

    public HashingSequenceTrainer(SequenceTrainer<T> sequenceTrainer, Hasher hasher) {
        this.innerTrainer = sequenceTrainer;
        this.hasher = hasher;
    }

    @Override // org.tribuo.sequence.SequenceTrainer
    public SequenceModel<T> train(SequenceDataset<T> sequenceDataset, Map<String, Provenance> map) {
        logger.log(Level.INFO, "Before hashing, had " + sequenceDataset.getFeatureIDMap().size() + " features.");
        ImmutableSequenceDataset changeFeatureMap = ImmutableSequenceDataset.changeFeatureMap(sequenceDataset, HashedFeatureMap.generateHashedFeatureMap(sequenceDataset.getFeatureIDMap(), this.hasher));
        logger.log(Level.INFO, "After hashing, had " + changeFeatureMap.getFeatureIDMap().size() + " features.");
        SequenceModel<T> train = this.innerTrainer.train(changeFeatureMap, map);
        if (train.featureIDMap instanceof HashedFeatureMap) {
            return train;
        }
        throw new IllegalStateException("Trainer " + this.innerTrainer.getClass().getName() + " does not support hashing.");
    }

    @Override // org.tribuo.sequence.SequenceTrainer
    public int getInvocationCount() {
        return this.innerTrainer.getInvocationCount();
    }

    public String toString() {
        return "HashingSequenceTrainer(trainer=" + this.innerTrainer.toString() + ",hasher=" + this.hasher.toString() + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m2511getProvenance() {
        return new HashingSequenceTrainerProvenance(this);
    }
}
