package org.apache.beam.sdk.util.construction.graph;

import java.util.AbstractMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverride;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.ProjectionProducer;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/util/construction/graph/ProjectionPushdownOptimizer.class */
public class ProjectionPushdownOptimizer {
    private static final Logger LOG = LoggerFactory.getLogger(ProjectionPushdownOptimizer.class);

    /* loaded from: input_file:org/apache/beam/sdk/util/construction/graph/ProjectionPushdownOptimizer$PushdownOverrideFactory.class */
    private static class PushdownOverrideFactory<OutputT extends POutput, TransformT extends PTransform<PBegin, OutputT>> implements PTransformOverrideFactory<PBegin, OutputT, TransformT> {
        private final Map<TupleTag<?>, FieldAccessDescriptor> fields;

        PushdownOverrideFactory(Map<TupleTag<?>, FieldAccessDescriptor> map) {
            this.fields = map;
        }

        @Override // org.apache.beam.sdk.runners.PTransformOverrideFactory
        public PTransformOverrideFactory.PTransformReplacement<PBegin, OutputT> getReplacementTransform(AppliedPTransform<PBegin, OutputT, TransformT> appliedPTransform) {
            return PTransformOverrideFactory.PTransformReplacement.of(appliedPTransform.getPipeline().begin(), (PTransform) ((ProjectionProducer) appliedPTransform.getTransform()).actuateProjectionPushdown(this.fields));
        }

        @Override // org.apache.beam.sdk.runners.PTransformOverrideFactory
        public Map<PCollection<?>, PTransformOverrideFactory.ReplacementOutput> mapOutputs(Map<TupleTag<?>, PCollection<?>> map, OutputT outputt) {
            return (Map) map.entrySet().stream().map(entry -> {
                PCollection pCollection = outputt.expand().size() == 1 ? (PCollection) Iterables.getOnlyElement(outputt.expand().values()) : (PCollection) Preconditions.checkArgumentNotNull((PCollection) outputt.expand().get(entry.getKey()), "No PCollection found for output tag %s. Were output tags changed in actuateProjectionPushdown?", entry.getKey());
                return new AbstractMap.SimpleEntry(pCollection, PTransformOverrideFactory.ReplacementOutput.of(TaggedPValue.ofExpandedValue((PCollection) entry.getValue()), TaggedPValue.ofExpandedValue(pCollection)));
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
        }
    }

    public static void optimize(Pipeline pipeline) {
        FieldAccessVisitor fieldAccessVisitor = new FieldAccessVisitor();
        pipeline.traverseTopologically(fieldAccessVisitor);
        ProjectionProducerVisitor projectionProducerVisitor = new ProjectionProducerVisitor(fieldAccessVisitor.getPCollectionFieldAccess());
        pipeline.traverseTopologically(projectionProducerVisitor);
        PCollectionOutputTagVisitor pCollectionOutputTagVisitor = new PCollectionOutputTagVisitor(projectionProducerVisitor.getPushdownOpportunities());
        pipeline.traverseTopologically(pCollectionOutputTagVisitor);
        for (Map.Entry<ProjectionProducer<PTransform<?, ?>>, Map<TupleTag<?>, FieldAccessDescriptor>> entry : pCollectionOutputTagVisitor.getTaggedFieldAccess().entrySet()) {
            for (Map.Entry<TupleTag<?>, FieldAccessDescriptor> entry2 : entry.getValue().entrySet()) {
                LOG.info("Optimizing transform {}: output {} will contain reduced field set {}", new Object[]{entry.getKey(), entry2.getKey(), entry2.getValue().fieldNamesAccessed()});
            }
            pipeline.replaceAll(ImmutableList.of(PTransformOverride.of(appliedPTransform -> {
                return appliedPTransform.getTransform() == entry.getKey();
            }, new PushdownOverrideFactory(entry.getValue()))));
        }
    }
}
