package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Map;

/* loaded from: input_file:ai/djl/modality/cv/translator/YoloTranslator.class */
public class YoloTranslator extends ObjectDetectionTranslator {

    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloTranslator$Builder.class */
    public static class Builder extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        public YoloTranslator build() {
            validate();
            return new YoloTranslator(this);
        }
    }

    public YoloTranslator(Builder builder) {
        super(builder);
    }

    @Override // ai.djl.translate.PostProcessor
    public DetectedObjects processOutput(TranslatorContext translatorContext, NDList nDList) {
        int[] intArray = nDList.get(0).toType(DataType.INT32, true).flatten().toIntArray();
        double[] doubleArray = nDList.get(1).toType(DataType.FLOAT64, true).flatten().toDoubleArray();
        NDArray nDArray = nDList.get(2);
        int intExact = Math.toIntExact(doubleArray.length);
        NDArray div = nDArray.get(":, 0", new Object[0]).clip(0, Integer.valueOf(this.width)).div(Integer.valueOf(this.width));
        NDArray div2 = nDArray.get(":, 1", new Object[0]).clip(0, Integer.valueOf(this.height)).div(Integer.valueOf(this.height));
        NDArray div3 = nDArray.get(":, 2", new Object[0]).clip(0, Integer.valueOf(this.width)).div(Integer.valueOf(this.width));
        NDArray div4 = nDArray.get(":, 3", new Object[0]).clip(0, Integer.valueOf(this.height)).div(Integer.valueOf(this.height));
        float[] floatArray = div.toFloatArray();
        float[] floatArray2 = div2.toFloatArray();
        float[] floatArray3 = div3.sub(div).toFloatArray();
        float[] floatArray4 = div4.sub(div2).toFloatArray();
        ArrayList arrayList = new ArrayList(intExact);
        ArrayList arrayList2 = new ArrayList(intExact);
        ArrayList arrayList3 = new ArrayList(intExact);
        for (int i = 0; i < intExact; i++) {
            if (intArray[i] >= 0 && doubleArray[i] >= this.threshold) {
                arrayList.add(this.classes.get(intArray[i]));
                arrayList2.add(Double.valueOf(doubleArray[i]));
                arrayList3.add(this.applyRatio ? new Rectangle(floatArray[i] / this.width, floatArray2[i] / this.height, floatArray3[i] / this.width, floatArray4[i] / this.height) : new Rectangle(floatArray[i], floatArray2[i], floatArray3[i], floatArray4[i]));
            }
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray xywh2xyxy(NDArray nDArray) {
        NDArray nDArray2 = nDArray.get("..., :2", new Object[0]);
        NDArray div = nDArray.get("..., 2:", new Object[0]).div((Number) 2);
        return nDArray2.sub(div).concat(nDArray2.add(div), -1);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
