package org.tribuo.util.onnx;

import ai.onnx.proto.OnnxMl;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:org/tribuo/util/onnx/ONNXContext.class */
public final class ONNXContext {
    private final Map<String, Long> nameMap = new HashMap();
    private final OnnxMl.GraphProto.Builder protoBuilder = OnnxMl.GraphProto.newBuilder();

    public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperator oNNXOperator, List<T> list, List<String> list2, Map<String, Object> map) {
        if (!list.stream().allMatch(oNNXRef -> {
            return oNNXRef.context == this;
        })) {
            throw new IllegalArgumentException("All input nodes must belong to this ONNXContext");
        }
        OnnxMl.NodeProto build = oNNXOperator.build(this, (String[]) list.stream().map((v0) -> {
            return v0.getReference();
        }).toArray(i -> {
            return new String[i];
        }), (String[]) list2.stream().map(this::generateUniqueName).toArray(i2 -> {
            return new String[i2];
        }), map);
        this.protoBuilder.addNode(build);
        return (List) IntStream.range(0, list2.size()).mapToObj(i3 -> {
            return new ONNXNode(this, build, (String) list2.get(i3), i3);
        }).collect(Collectors.toList());
    }

    public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator oNNXOperator, List<T> list, String str, Map<String, Object> map) {
        List<ONNXNode> operation = operation(oNNXOperator, list, Collections.singletonList(str), map);
        if (((OnnxMl.NodeProto) operation.get(0).backRef).getOutputList().size() > 1) {
            throw new IllegalStateException("Requested a single output from operation " + oNNXOperator.getOpName() + " which produced " + ((OnnxMl.NodeProto) operation.get(0).backRef).getOutputList().size() + " outputs");
        }
        return operation.get(0);
    }

    public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator oNNXOperator, List<T> list, String str) {
        return operation(oNNXOperator, list, str, Collections.emptyMap());
    }

    public <LHS extends ONNXRef<?>, RHS extends ONNXRef<?>> LHS assignTo(RHS rhs, LHS lhs) {
        if (rhs.context != lhs.context || rhs.context != this) {
            throw new IllegalArgumentException("both input and output must both belong to this ONNXContext");
        }
        this.protoBuilder.addNode(ONNXOperators.IDENTITY.build(this, rhs.getReference(), lhs.getReference()));
        return lhs;
    }

    public ONNXPlaceholder floatInput(String str, int i) {
        OnnxMl.ValueInfoProto build = OnnxMl.ValueInfoProto.newBuilder().setType(ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, i}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT)).setName(str).build();
        this.protoBuilder.addInput(build);
        return new ONNXPlaceholder(this, build, str);
    }

    public ONNXPlaceholder floatInput(int i) {
        return floatInput("input", i);
    }

    public ONNXPlaceholder floatOutput(String str, int i) {
        OnnxMl.ValueInfoProto build = OnnxMl.ValueInfoProto.newBuilder().setType(ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, i}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT)).setName(str).build();
        this.protoBuilder.addOutput(build);
        return new ONNXPlaceholder(this, build, str);
    }

    public ONNXPlaceholder floatOutput(int i) {
        return floatOutput("output", i);
    }

    public ONNXInitializer floatTensor(String str, List<Integer> list, Consumer<FloatBuffer> consumer) {
        OnnxMl.TensorProto floatTensorBuilder = ONNXUtils.floatTensorBuilder(this, str, list, consumer);
        this.protoBuilder.addInitializer(floatTensorBuilder);
        return new ONNXInitializer(this, floatTensorBuilder, str);
    }

    public ONNXInitializer array(String str, long[] jArr) {
        OnnxMl.TensorProto arrayBuilder = ONNXUtils.arrayBuilder(this, str, jArr);
        this.protoBuilder.addInitializer(arrayBuilder);
        return new ONNXInitializer(this, arrayBuilder, str);
    }

    public ONNXInitializer array(String str, int[] iArr) {
        OnnxMl.TensorProto arrayBuilder = ONNXUtils.arrayBuilder(this, str, iArr);
        this.protoBuilder.addInitializer(arrayBuilder);
        return new ONNXInitializer(this, arrayBuilder, str);
    }

    public ONNXInitializer array(String str, float[] fArr) {
        OnnxMl.TensorProto arrayBuilder = ONNXUtils.arrayBuilder(this, str, fArr);
        this.protoBuilder.addInitializer(arrayBuilder);
        return new ONNXInitializer(this, arrayBuilder, str);
    }

    public ONNXInitializer array(String str, double[] dArr, boolean z) {
        OnnxMl.TensorProto arrayBuilder = ONNXUtils.arrayBuilder(this, str, dArr, z);
        this.protoBuilder.addInitializer(arrayBuilder);
        return new ONNXInitializer(this, arrayBuilder, str);
    }

    public ONNXInitializer array(String str, double[] dArr) {
        return array(str, dArr, true);
    }

    public ONNXInitializer constant(String str, float f) {
        OnnxMl.TensorProto build = OnnxMl.TensorProto.newBuilder().setName(generateUniqueName(str)).setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()).addFloatData(f).build();
        this.protoBuilder.addInitializer(build);
        return new ONNXInitializer(this, build, str);
    }

    public ONNXInitializer constant(String str, long j) {
        OnnxMl.TensorProto build = OnnxMl.TensorProto.newBuilder().setName(generateUniqueName(str)).setDataType(OnnxMl.TensorProto.DataType.INT64.getNumber()).addInt64Data(j).build();
        this.protoBuilder.addInitializer(build);
        return new ONNXInitializer(this, build, str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String generateUniqueName(String str) {
        long longValue = this.nameMap.computeIfAbsent(str, str2 -> {
            return 0L;
        }).longValue();
        String str3 = str + "_" + longValue;
        this.nameMap.put(str, Long.valueOf(longValue + 1));
        return str3;
    }

    public void setName(String str) {
        this.protoBuilder.setName(str);
    }

    public OnnxMl.GraphProto buildGraph() {
        return this.protoBuilder.build();
    }
}
