package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.modality.nlp.NlpUtils;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.TextCleaner;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import ai.djl.util.Utils;
import com.google.gson.reflect.TypeToken;
import java.io.BufferedReader;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/modality/cv/translator/YoloWorldTranslator.class */
public class YoloWorldTranslator implements NoBatchifyTranslator<VisionLanguageInput, DetectedObjects> {
    private static final int MAX_DETECTION = 300;
    private static final int[] AXIS_0 = {0};
    private SimpleBpeTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private Predictor<NDList, NDList> predictor;
    private String clipModelPath;
    private float threshold;
    private float nmsThreshold;

    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloWorldTranslator$Builder.class */
    public static class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold = 0.25f;
        float nmsThreshold = 0.7f;
        String clipModelPath = "clip.pt";

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return self();
        }

        public Builder optNmsThreshold(float f) {
            this.nmsThreshold = f;
            return this;
        }

        public Builder optClipModelPath(String str) {
            this.clipModelPath = str;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public void configPostProcess(Map<String, ?> map) {
            super.configPostProcess(map);
            optThreshold(ArgumentsUtil.floatValue(map, "threshold", this.threshold));
            optNmsThreshold(ArgumentsUtil.floatValue(map, "nmsThreshold", this.nmsThreshold));
            optClipModelPath(ArgumentsUtil.stringValue(map, "clipModelPath", "clip.pt"));
        }

        public YoloWorldTranslator build() {
            return new YoloWorldTranslator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloWorldTranslator$SimpleBpeTokenizer.class */
    public static final class SimpleBpeTokenizer {
        private static final int MIN_CONTEXT_LENGTH = 77;
        private static final int MAX_CONTEXT_LENGTH = 512;
        private static final Type MAP_TYPE = new TypeToken<Map<String, Integer>>() { // from class: ai.djl.modality.cv.translator.YoloWorldTranslator.SimpleBpeTokenizer.1
        }.getType();
        private Map<String, Integer> vocabulary;
        private Map<Pair<String, String>, Integer> ranks;
        private int sot;
        private int eot;

        SimpleBpeTokenizer(Map<String, Integer> map, Map<Pair<String, String>, Integer> map2) {
            this.vocabulary = map;
            this.ranks = map2;
            this.sot = map.get("<|startoftext|>").intValue();
            this.eot = map.get("<|endoftext|>").intValue();
        }

        static SimpleBpeTokenizer newInstance(Path path) throws IOException {
            Path resolve = path.resolve("vocab.json");
            Path resolve2 = path.resolve("merges.txt");
            ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
            List<String> readLines = Utils.readLines(resolve2);
            int i = 0;
            Iterator<String> it = readLines.subList(1, readLines.size()).iterator();
            while (it.hasNext()) {
                String[] split = it.next().split(" ");
                int i2 = i;
                i++;
                concurrentHashMap.put(new Pair(split[0], split[1]), Integer.valueOf(i2));
            }
            BufferedReader newBufferedReader = Files.newBufferedReader(resolve);
            try {
                SimpleBpeTokenizer simpleBpeTokenizer = new SimpleBpeTokenizer((Map) JsonUtils.GSON.fromJson(newBufferedReader, MAP_TYPE), concurrentHashMap);
                if (newBufferedReader != null) {
                    newBufferedReader.close();
                }
                return simpleBpeTokenizer;
            } catch (Throwable th) {
                if (newBufferedReader != null) {
                    try {
                        newBufferedReader.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }

        int[][] batchEncode(String[] strArr) {
            ArrayList<List> arrayList = new ArrayList();
            int i = 0;
            for (String str : strArr) {
                List<Integer> encode = encode(str);
                int size = encode.size();
                if (size > MAX_CONTEXT_LENGTH) {
                    encode = encode.subList(0, MAX_CONTEXT_LENGTH);
                }
                i = Math.max(i, size);
                arrayList.add(encode);
            }
            int[][] iArr = new int[strArr.length][Math.max(i, MIN_CONTEXT_LENGTH)];
            int i2 = 0;
            for (List list : arrayList) {
                for (int i3 = 0; i3 < list.size(); i3++) {
                    iArr[i2][i3] = ((Integer) list.get(i3)).intValue();
                }
                i2++;
            }
            return iArr;
        }

        /* JADX WARN: Multi-variable type inference failed */
        List<Integer> encode(String str) {
            List arrayList = new ArrayList(Collections.singletonList(str));
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(new LowerCaseConvertor());
            arrayList2.add(new TextCleaner((v0) -> {
                return NlpUtils.isWhiteSpace(v0);
            }, ' '));
            arrayList2.add(new PunctuationSeparator());
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList = ((TextProcessor) it.next()).preprocess(arrayList);
            }
            ArrayList arrayList3 = new ArrayList();
            arrayList3.add(Integer.valueOf(this.sot));
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                arrayList3.add(this.vocabulary.get(bpe((String) it2.next())));
            }
            arrayList3.add(Integer.valueOf(this.eot));
            return arrayList3;
        }

        private String bpe(String str) {
            char[] charArray = str.toCharArray();
            ArrayList arrayList = new ArrayList(charArray.length);
            for (char c : charArray) {
                arrayList.add(String.valueOf(c));
            }
            arrayList.set(arrayList.size() - 1, arrayList.get(arrayList.size() - 1) + "</w>");
            Set<Pair<String, String>> pairs = getPairs(arrayList);
            if (pairs.isEmpty()) {
                return str + "</w>";
            }
            while (true) {
                Pair pair = (Pair) Collections.min(pairs, (pair2, pair3) -> {
                    return Integer.compare(this.ranks.getOrDefault(pair2, Integer.MAX_VALUE).intValue(), this.ranks.getOrDefault(pair3, Integer.MAX_VALUE).intValue());
                });
                if (!this.ranks.containsKey(pair)) {
                    break;
                }
                ArrayList arrayList2 = new ArrayList();
                String str2 = (String) pair.getKey();
                String str3 = (String) pair.getValue();
                int i = 0;
                while (true) {
                    if (i >= arrayList.size()) {
                        break;
                    }
                    int indexOf = arrayList.subList(i, arrayList.size()).indexOf(str2);
                    if (indexOf < 0) {
                        arrayList2.addAll(arrayList.subList(i, arrayList.size()));
                        break;
                    }
                    int i2 = indexOf + i;
                    arrayList2.addAll(arrayList.subList(i, i2));
                    if (arrayList.get(i2).equals(str2) && i2 < arrayList.size() - 1 && arrayList.get(i2 + 1).equals(str3)) {
                        arrayList2.add(str2 + str3);
                        i = i2 + 2;
                    } else {
                        arrayList2.add(arrayList.get(i2));
                        i = i2 + 1;
                    }
                }
                arrayList = arrayList2;
                if (arrayList.size() == 1) {
                    break;
                }
                pairs = getPairs(arrayList);
            }
            return String.join(" ", arrayList);
        }

        private Set<Pair<String, String>> getPairs(List<String> list) {
            if (list.size() < 2) {
                return Collections.emptySet();
            }
            HashSet hashSet = new HashSet();
            String str = list.get(0);
            for (int i = 1; i < list.size(); i++) {
                hashSet.add(new Pair(str, list.get(i)));
                str = list.get(i);
            }
            return hashSet;
        }
    }

    YoloWorldTranslator(Builder builder) {
        this.imageProcessor = new BaseImagePreProcessor(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
        this.clipModelPath = builder.clipModelPath;
    }

    @Override // ai.djl.translate.Translator
    public void prepare(TranslatorContext translatorContext) throws Exception {
        Model model = translatorContext.getModel();
        Path modelPath = model.getModelPath();
        Path path = Paths.get(this.clipModelPath, new String[0]);
        if (!path.isAbsolute() && Files.notExists(path, new LinkOption[0])) {
            path = modelPath.resolve(this.clipModelPath);
        }
        if (!Files.exists(path, new LinkOption[0])) {
            throw new IOException("clip model not found: " + this.clipModelPath);
        }
        NDManager nDManager = translatorContext.getNDManager();
        Model newModel = nDManager.getEngine().newModel("clip", nDManager.getDevice());
        newModel.load(path);
        this.predictor = newModel.newPredictor(new NoopTranslator(null));
        model.getNDManager().attachInternal(UUID.randomUUID().toString(), this.predictor);
        model.getNDManager().attachInternal(UUID.randomUUID().toString(), newModel);
        this.tokenizer = SimpleBpeTokenizer.newInstance(modelPath);
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, VisionLanguageInput visionLanguageInput) throws TranslateException {
        NDManager nDManager = translatorContext.getNDManager();
        String[] candidates = visionLanguageInput.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        NDArray nDArray = this.predictor.predict(new NDList(nDManager.create(this.tokenizer.batchEncode(candidates)))).get(0);
        NDArray expandDims = this.imageProcessor.processInput(translatorContext, visionLanguageInput.getImage()).get(0).expandDims(0);
        translatorContext.setAttachment("candidates", candidates);
        return new NDList(nDArray, expandDims);
    }

    @Override // ai.djl.translate.PostProcessor
    public DetectedObjects processOutput(TranslatorContext translatorContext, NDList nDList) {
        List asList = Arrays.asList((String[]) translatorContext.getAttachment("candidates"));
        int intValue = ((Integer) translatorContext.getAttachment("width")).intValue();
        int intValue2 = ((Integer) translatorContext.getAttachment("height")).intValue();
        NDArray squeeze = nDList.get(0).squeeze(0);
        int size = asList.size() + 4;
        NDArray gt = squeeze.get("4:" + size, new Object[0]).max(AXIS_0).gt(Float.valueOf(this.threshold));
        NDArray transpose = squeeze.transpose();
        NDList split = YoloTranslator.xywh2xyxy(transpose.get("..., :4", new Object[0])).concat(transpose.get("..., 4:", new Object[0]), -1).get(gt).split(new long[]{4, size}, 1);
        NDArray nDArray = split.get(0);
        int intExact = Math.toIntExact(nDArray.getShape().get(0));
        float[] floatArray = nDArray.toFloatArray();
        float[] floatArray2 = split.get(1).toFloatArray();
        long[] longArray = split.get(1).argMax(1).toLongArray();
        ArrayList arrayList = new ArrayList(intExact);
        ArrayList arrayList2 = new ArrayList(intExact);
        for (int i = 0; i < intExact; i++) {
            arrayList.add(new Rectangle(floatArray[i * 4], floatArray[(i * 4) + 1], floatArray[(i * 4) + 2] - r0, floatArray[(i * 4) + 3] - r0));
            arrayList2.add(Double.valueOf(floatArray2[i]));
        }
        List<Integer> nms = Rectangle.nms(arrayList, arrayList2, this.nmsThreshold);
        if (nms.size() > MAX_DETECTION) {
            nms = nms.subList(0, MAX_DETECTION);
        }
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        Iterator<Integer> it = nms.iterator();
        while (it.hasNext()) {
            int intValue3 = it.next().intValue();
            arrayList3.add((String) asList.get((int) longArray[intValue3]));
            arrayList4.add(Double.valueOf(floatArray2[r0]));
            Rectangle rectangle = (Rectangle) arrayList.get(intValue3);
            arrayList5.add(new Rectangle(rectangle.getX() / intValue, rectangle.getY() / intValue2, rectangle.getWidth() / intValue, rectangle.getHeight() / intValue2));
        }
        return new DetectedObjects(arrayList3, arrayList4, arrayList5);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
