package org.apache.tez.runtime.library.cartesianproduct;

import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.TaskLocationHint;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.apache.tez.runtime.library.utils.Grouper;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/tez/runtime/library/cartesianproduct/FairCartesianProductVertexManager.class */
public class FairCartesianProductVertexManager extends CartesianProductVertexManagerReal {
    private static final Logger LOG = LoggerFactory.getLogger(FairCartesianProductVertexManager.class);
    private CartesianProductUserPayload.CartesianProductConfigProto config;
    private List<String> sourceList;
    private Map<String, Source> sourcesByName;
    private Map<String, SrcVertex> srcVerticesByName;
    private boolean enableGrouping;
    private int maxParallelism;
    private int numPartitions;
    private long minOpsPerWorker;
    private long minNumRecordForEstimation;
    private boolean vertexReconfigured;
    private boolean vertexStarted;
    private boolean vertexStartSchedule;
    private int numCPSrcNotInConfigureState;
    private int numBroadcastSrcNotInRunningState;
    private Queue<TaskAttemptIdentifier> completedSrcTaskToProcess;
    private RoaringBitmap scheduledTasks;
    private int parallelism;
    private int[] numChunksPerSrc;
    private Grouper grouper;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/tez/runtime/library/cartesianproduct/FairCartesianProductVertexManager$Source.class */
    public static class Source {
        List<SrcVertex> srcVertices = new ArrayList();
        int position;
        String name;
        int numChunk;
        long numRecord;

        Source() {
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Source at position ");
            sb.append(this.position);
            if (this.name != null) {
                sb.append(", ");
                sb.append("name ");
                sb.append(this.name);
            }
            sb.append(", num chunk ").append(this.numChunk);
            sb.append(": {");
            for (SrcVertex srcVertex : this.srcVertices) {
                sb.append("[");
                sb.append(srcVertex.toString());
                sb.append("], ");
            }
            sb.deleteCharAt(sb.length() - 1);
            sb.setCharAt(sb.length() - 1, '}');
            return sb.toString();
        }

        public long estimateNumRecord() {
            long j = 0;
            Iterator<SrcVertex> it = this.srcVertices.iterator();
            while (it.hasNext()) {
                j += it.next().estimateNumRecord();
            }
            return j;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isChunkCompleted(int i) {
            Iterator<SrcVertex> it = this.srcVertices.iterator();
            while (it.hasNext()) {
                if (!it.next().isChunkCompleted(i)) {
                    return false;
                }
            }
            return true;
        }

        public int getNumTask() {
            int i = 0;
            Iterator<SrcVertex> it = this.srcVertices.iterator();
            while (it.hasNext()) {
                i += it.next().numTask;
            }
            return i;
        }

        public SrcVertex getSrcVertexWithMostOutput() {
            SrcVertex srcVertex = null;
            for (SrcVertex srcVertex2 : this.srcVertices) {
                if (srcVertex == null || srcVertex2.numRecord > srcVertex.numRecord) {
                    srcVertex = srcVertex2;
                }
            }
            return srcVertex;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/tez/runtime/library/cartesianproduct/FairCartesianProductVertexManager$SrcVertex.class */
    public class SrcVertex {
        Source source;
        String name;
        int numTask;
        RoaringBitmap taskCompleted = new RoaringBitmap();
        RoaringBitmap taskWithVMEvent = new RoaringBitmap();
        long numRecord;

        SrcVertex() {
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("vertex ").append(this.name).append(", ");
            sb.append(this.numTask).append(" tasks, ");
            sb.append(this.taskWithVMEvent.getCardinality()).append(" VMEvents, ");
            sb.append("numRecord ").append(this.numRecord).append(", ");
            sb.append("estimated # output records ").append(estimateNumRecord());
            return sb.toString();
        }

        public long estimateNumRecord() {
            if (this.taskWithVMEvent.isEmpty()) {
                return 0L;
            }
            return (this.numRecord * this.numTask) / this.taskWithVMEvent.getCardinality();
        }

        public boolean isChunkCompleted(int i) {
            FairCartesianProductVertexManager.this.grouper.init(this.numTask * FairCartesianProductVertexManager.this.numPartitions, this.source.numChunk);
            int firstItemInGroup = FairCartesianProductVertexManager.this.grouper.getFirstItemInGroup(i) / FairCartesianProductVertexManager.this.numPartitions;
            int lastItemInGroup = FairCartesianProductVertexManager.this.grouper.getLastItemInGroup(i) / FairCartesianProductVertexManager.this.numPartitions;
            for (int i2 = firstItemInGroup; i2 <= lastItemInGroup; i2++) {
                if (!this.taskCompleted.contains(i2)) {
                    return false;
                }
            }
            return true;
        }
    }

    public FairCartesianProductVertexManager(VertexManagerPluginContext vertexManagerPluginContext) {
        super(vertexManagerPluginContext);
        this.sourcesByName = new HashMap();
        this.srcVerticesByName = new HashMap();
        this.vertexReconfigured = false;
        this.vertexStarted = false;
        this.vertexStartSchedule = false;
        this.numCPSrcNotInConfigureState = 0;
        this.numBroadcastSrcNotInRunningState = 0;
        this.completedSrcTaskToProcess = new LinkedList();
        this.scheduledTasks = new RoaringBitmap();
        this.grouper = new Grouper();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal
    public void initialize(CartesianProductUserPayload.CartesianProductConfigProto cartesianProductConfigProto) throws Exception {
        this.config = cartesianProductConfigProto;
        this.maxParallelism = cartesianProductConfigProto.hasMaxParallelism() ? cartesianProductConfigProto.getMaxParallelism() : CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_MAX_PARALLELISM_DEFAULT;
        this.enableGrouping = cartesianProductConfigProto.hasEnableGrouping() ? cartesianProductConfigProto.getEnableGrouping() : true;
        this.minOpsPerWorker = cartesianProductConfigProto.hasMinOpsPerWorker() ? cartesianProductConfigProto.getMinOpsPerWorker() : CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_MIN_OPS_PER_WORKER_DEFAULT;
        this.sourceList = cartesianProductConfigProto.mo166getSourcesList();
        if (cartesianProductConfigProto.hasNumPartitionsForFairCase()) {
            this.numPartitions = cartesianProductConfigProto.getNumPartitionsForFairCase();
        } else {
            this.numPartitions = (int) Math.pow(this.maxParallelism, 1.0d / this.sourceList.size());
        }
        for (Map.Entry entry : getContext().getInputVertexEdgeProperties().entrySet()) {
            if (((EdgeProperty) entry.getValue()).getDataMovementType() == EdgeProperty.DataMovementType.CUSTOM && ((EdgeProperty) entry.getValue()).getEdgeManagerDescriptor().getClassName().equals(CartesianProductEdgeManager.class.getName())) {
                this.srcVerticesByName.put(entry.getKey(), new SrcVertex());
                this.srcVerticesByName.get(entry.getKey()).name = (String) entry.getKey();
                getContext().registerForVertexStateUpdates((String) entry.getKey(), EnumSet.of(VertexState.CONFIGURED));
                this.numCPSrcNotInConfigureState++;
            } else {
                getContext().registerForVertexStateUpdates((String) entry.getKey(), EnumSet.of(VertexState.RUNNING));
                this.numBroadcastSrcNotInRunningState++;
            }
        }
        Map inputVertexGroups = getContext().getInputVertexGroups();
        for (int i = 0; i < this.sourceList.size(); i++) {
            String str = this.sourceList.get(i);
            Source source = new Source();
            source.position = i;
            if (inputVertexGroups.containsKey(str)) {
                source.name = str;
                for (String str2 : (List) inputVertexGroups.get(str)) {
                    source.srcVertices.add(this.srcVerticesByName.get(str2));
                    this.srcVerticesByName.get(str2).source = source;
                }
            } else {
                source.name = str;
                source.srcVertices.add(this.srcVerticesByName.get(str));
                this.srcVerticesByName.get(str).source = source;
            }
            this.sourcesByName.put(str, source);
        }
        this.minNumRecordForEstimation = (long) Math.pow(this.minOpsPerWorker * this.maxParallelism, 1.0d / this.sourceList.size());
        this.numChunksPerSrc = new int[this.sourcesByName.size()];
        getContext().vertexReconfigurationPlanned();
    }

    @Override // org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal
    public synchronized void onVertexStarted(List<TaskAttemptIdentifier> list) throws Exception {
        this.vertexStarted = true;
        if (list != null) {
            LOG.info("OnVertexStarted with " + list.size() + " completed source task");
            Iterator<TaskAttemptIdentifier> it = list.iterator();
            while (it.hasNext()) {
                addCompletedSrcTaskToProcess(it.next());
            }
        }
        tryScheduleTasks();
    }

    @Override // org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal
    public synchronized void onVertexStateUpdated(VertexStateUpdate vertexStateUpdate) throws IOException {
        String vertexName = vertexStateUpdate.getVertexName();
        VertexState vertexState = vertexStateUpdate.getVertexState();
        if (vertexState == VertexState.CONFIGURED) {
            this.srcVerticesByName.get(vertexName).numTask = getContext().getVertexNumTasks(vertexName);
            this.numCPSrcNotInConfigureState--;
        } else if (vertexState == VertexState.RUNNING) {
            this.numBroadcastSrcNotInRunningState--;
        }
        tryScheduleTasks();
    }

    @Override // org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal
    public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier taskAttemptIdentifier) throws Exception {
        addCompletedSrcTaskToProcess(taskAttemptIdentifier);
        tryScheduleTasks();
    }

    private void addCompletedSrcTaskToProcess(TaskAttemptIdentifier taskAttemptIdentifier) {
        int identifier = taskAttemptIdentifier.getTaskIdentifier().getIdentifier();
        SrcVertex srcVertex = this.srcVerticesByName.get(taskAttemptIdentifier.getTaskIdentifier().getVertexIdentifier().getName());
        if (srcVertex == null || srcVertex.taskCompleted.contains(identifier)) {
            return;
        }
        srcVertex.taskCompleted.add(identifier);
        this.completedSrcTaskToProcess.add(taskAttemptIdentifier);
    }

    private boolean tryStartSchedule() {
        this.vertexStartSchedule = this.vertexReconfigured && this.vertexStarted && this.numBroadcastSrcNotInRunningState == 0;
        return this.vertexStartSchedule;
    }

    @Override // org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal
    public synchronized void onVertexManagerEventReceived(VertexManagerEvent vertexManagerEvent) throws IOException {
        if (this.vertexReconfigured) {
            return;
        }
        if (vertexManagerEvent.getUserPayload() != null) {
            SrcVertex srcVertex = this.srcVerticesByName.get(vertexManagerEvent.getProducerAttemptIdentifier().getTaskIdentifier().getVertexIdentifier().getName());
            if (srcVertex == null) {
                return;
            }
            srcVertex.numRecord += ShuffleUserPayloads.VertexManagerEventPayloadProto.parseFrom(ByteString.copyFrom(vertexManagerEvent.getUserPayload())).getNumRecord();
            srcVertex.taskWithVMEvent.add(vertexManagerEvent.getProducerAttemptIdentifier().getTaskIdentifier().getIdentifier());
        }
        tryScheduleTasks();
    }

    private void reconfigureWithZeroTask() {
        getContext().reconfigureVertex(0, (VertexLocationHint) null, (Map) null);
        this.vertexReconfigured = true;
        getContext().doneReconfiguringVertex();
    }

    private boolean tryReconfigure() throws IOException {
        if (this.numCPSrcNotInConfigureState > 0) {
            return false;
        }
        Iterator<Source> it = this.sourcesByName.values().iterator();
        while (it.hasNext()) {
            if (it.next().getNumTask() == 0) {
                this.parallelism = 0;
                reconfigureWithZeroTask();
                return true;
            }
        }
        if (!this.config.hasGroupingFraction() || this.config.getGroupingFraction() <= TezRuntimeConfiguration.TEZ_RUNTIME_INPUT_BUFFER_PERCENT_DEFAULT) {
            for (SrcVertex srcVertex : this.srcVerticesByName.values()) {
                if (srcVertex.numRecord < this.minNumRecordForEstimation && srcVertex.taskWithVMEvent.getCardinality() < srcVertex.numTask) {
                    return false;
                }
            }
        } else {
            for (SrcVertex srcVertex2 : this.srcVerticesByName.values()) {
                if (srcVertex2.taskCompleted.getCardinality() < srcVertex2.numTask && (srcVertex2.numTask * this.config.getGroupingFraction() > srcVertex2.taskCompleted.getCardinality() || srcVertex2.numRecord == 0)) {
                    return false;
                }
            }
        }
        LOG.info("Start reconfiguring vertex " + getContext().getVertexName() + ", max parallelism: " + this.maxParallelism + ", min-ops-per-worker: " + this.minOpsPerWorker + ", num partition: " + this.numPartitions);
        Iterator<Source> it2 = this.sourcesByName.values().iterator();
        while (it2.hasNext()) {
            LOG.info(it2.next().toString());
        }
        long j = 1;
        for (Source source : this.sourcesByName.values()) {
            source.numRecord = source.estimateNumRecord();
            if (source.numRecord == 0) {
                LOG.info("Set parallelism to 0 because source " + source.name + " has 0 output recorc");
                reconfigureWithZeroTask();
                return true;
            }
            try {
                j = LongMath.checkedMultiply(j, source.numRecord);
            } catch (ArithmeticException e) {
                LOG.info("totalOps exceeds 9223372036854775807, capping to 9223372036854775807");
                j = Long.MAX_VALUE;
            }
        }
        if (j / this.minOpsPerWorker >= this.maxParallelism) {
            this.parallelism = this.maxParallelism;
        } else {
            this.parallelism = (int) (((j + this.minOpsPerWorker) - 1) / this.minOpsPerWorker);
        }
        LOG.info("Total ops " + j + ", initial parallelism " + this.parallelism);
        if (this.enableGrouping) {
            determineNumChunks(this.sourcesByName, this.parallelism);
        } else {
            for (Source source2 : this.sourcesByName.values()) {
                source2.numChunk = source2.getSrcVertexWithMostOutput().numTask;
            }
        }
        this.parallelism = 1;
        Iterator<Source> it3 = this.sourcesByName.values().iterator();
        while (it3.hasNext()) {
            this.parallelism *= it3.next().numChunk;
        }
        LOG.info("After reconfigure, final parallelism " + this.parallelism);
        Iterator<Source> it4 = this.sourcesByName.values().iterator();
        while (it4.hasNext()) {
            LOG.info(it4.next().toString());
        }
        for (int i = 0; i < this.numChunksPerSrc.length; i++) {
            this.numChunksPerSrc[i] = this.sourcesByName.get(this.sourceList.get(i)).numChunk;
        }
        CartesianProductUserPayload.CartesianProductConfigProto.Builder newBuilder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder(this.config);
        newBuilder.addAllNumChunks(Ints.asList(this.numChunksPerSrc));
        Map inputVertexEdgeProperties = getContext().getInputVertexEdgeProperties();
        Iterator it5 = inputVertexEdgeProperties.entrySet().iterator();
        while (it5.hasNext()) {
            if (((EdgeProperty) ((Map.Entry) it5.next()).getValue()).getDataMovementType() != EdgeProperty.DataMovementType.CUSTOM) {
                it5.remove();
            }
        }
        for (Source source3 : this.sourcesByName.values()) {
            newBuilder.clearNumTaskPerVertexInGroup();
            for (int i2 = 0; i2 < source3.srcVertices.size(); i2++) {
                SrcVertex srcVertex3 = source3.srcVertices.get(i2);
                newBuilder.setPositionInGroup(i2);
                ((EdgeProperty) inputVertexEdgeProperties.get(srcVertex3.name)).getEdgeManagerDescriptor().setUserPayload(UserPayload.create(ByteBuffer.wrap(newBuilder.m199build().toByteArray())));
                newBuilder.addNumTaskPerVertexInGroup(srcVertex3.numTask);
            }
        }
        getContext().reconfigureVertex(this.parallelism, (VertexLocationHint) null, inputVertexEdgeProperties);
        this.vertexReconfigured = true;
        getContext().doneReconfiguringVertex();
        return true;
    }

    private void determineNumChunks(Map<String, Source> map, int i) {
        double log10 = Math.log10(i);
        Iterator<Source> it = map.values().iterator();
        while (it.hasNext()) {
            log10 -= Math.log10(it.next().numRecord);
        }
        double pow = Math.pow(10.0d, log10 / map.size());
        for (Source source : map.values()) {
            if (source.numRecord * pow < 2.0d) {
                source.numChunk = 1;
            }
        }
        double log102 = Math.log10(i);
        int i2 = 0;
        Iterator<Source> it2 = map.values().iterator();
        while (it2.hasNext()) {
            if (it2.next().numChunk != 1) {
                log102 -= Math.log10(r0.numRecord);
                i2++;
            }
        }
        double pow2 = Math.pow(10.0d, log102 / i2);
        for (Source source2 : map.values()) {
            if (source2.numChunk != 1) {
                source2.numChunk = Math.min(this.maxParallelism, Math.min(source2.getSrcVertexWithMostOutput().numTask * this.numPartitions, Math.max(1, (int) (source2.numRecord * pow2))));
            }
        }
    }

    private void tryScheduleTasks() throws IOException {
        if (this.vertexReconfigured || tryReconfigure()) {
            if (this.vertexStartSchedule || tryStartSchedule()) {
                while (!this.completedSrcTaskToProcess.isEmpty()) {
                    scheduleTasksDependOnCompletion(this.completedSrcTaskToProcess.poll());
                }
            }
        }
    }

    private void scheduleTasksDependOnCompletion(TaskAttemptIdentifier taskAttemptIdentifier) {
        if (this.parallelism == 0) {
            return;
        }
        int identifier = taskAttemptIdentifier.getTaskIdentifier().getIdentifier();
        SrcVertex srcVertex = this.srcVerticesByName.get(taskAttemptIdentifier.getTaskIdentifier().getVertexIdentifier().getName());
        Source source = srcVertex.source;
        ArrayList arrayList = new ArrayList();
        CartesianProductCombination cartesianProductCombination = new CartesianProductCombination(this.numChunksPerSrc, source.position);
        this.grouper.init(srcVertex.numTask * this.numPartitions, source.numChunk);
        int groupId = this.grouper.getGroupId(identifier * this.numPartitions);
        int groupId2 = this.grouper.getGroupId(((identifier * this.numPartitions) + this.numPartitions) - 1);
        for (int i = groupId; i <= groupId2; i++) {
            cartesianProductCombination.firstTaskWithFixedChunk(i);
            do {
                List<Integer> combination = cartesianProductCombination.getCombination();
                if (!this.scheduledTasks.contains(cartesianProductCombination.getTaskId())) {
                    boolean isChunkCompleted = source.isChunkCompleted(combination.get(source.position).intValue());
                    for (int i2 = 0; isChunkCompleted && i2 < combination.size(); i2++) {
                        if (i2 != source.position) {
                            isChunkCompleted = this.sourcesByName.get(this.sourceList.get(i2)).isChunkCompleted(combination.get(i2).intValue());
                        }
                    }
                    if (isChunkCompleted) {
                        arrayList.add(VertexManagerPluginContext.ScheduleTaskRequest.create(cartesianProductCombination.getTaskId(), (TaskLocationHint) null));
                        this.scheduledTasks.add(cartesianProductCombination.getTaskId());
                    }
                }
            } while (cartesianProductCombination.nextTaskWithFixedChunk());
        }
        if (arrayList.isEmpty()) {
            return;
        }
        getContext().scheduleTasks(arrayList);
    }
}
