/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator;
import org.apache.hadoop.hive.ql.exec.ConditionalTask;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.physical.CrossProductCheck;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hadoop.hive.ql.session.SessionState;

public class SparkCrossProductCheck
implements PhysicalPlanResolver,
Dispatcher {
    @Override
    public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
        Task currTask = (Task)nd;
        if (currTask instanceof SparkTask) {
            SparkWork sparkWork = (SparkWork)((SparkTask)currTask).getWork();
            this.checkShuffleJoin(sparkWork);
            this.checkMapJoin((SparkTask)currTask);
        } else if (currTask instanceof ConditionalTask) {
            List<Task<? extends Serializable>> taskList = ((ConditionalTask)currTask).getListTasks();
            for (Task<? extends Serializable> task : taskList) {
                this.dispatch(task, stack, nodeOutputs);
            }
        }
        return null;
    }

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        TaskGraphWalker ogw = new TaskGraphWalker(this);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private void warn(String msg) {
        SessionState.getConsole().printInfo("Warning: " + msg, false);
    }

    private void checkShuffleJoin(SparkWork sparkWork) throws SemanticException {
        for (ReduceWork reduceWork : sparkWork.getAllReduceWork()) {
            Operator<?> reducer = reduceWork.getReducer();
            if (!(reducer instanceof JoinOperator) && !(reducer instanceof CommonMergeJoinOperator)) continue;
            HashMap<Integer, CrossProductCheck.ExtractReduceSinkInfo.Info> rsInfo = new HashMap<Integer, CrossProductCheck.ExtractReduceSinkInfo.Info>();
            for (BaseWork parent : sparkWork.getParents(reduceWork)) {
                rsInfo.putAll(new CrossProductCheck.ExtractReduceSinkInfo(null).analyze(parent));
            }
            this.checkForCrossProduct(reduceWork.getName(), reducer, rsInfo);
        }
    }

    private void checkMapJoin(SparkTask sparkTask) throws SemanticException {
        SparkWork sparkWork = (SparkWork)sparkTask.getWork();
        for (BaseWork baseWork : sparkWork.getAllWorkUnsorted()) {
            List<String> warnings = new CrossProductCheck.MapJoinCheck(sparkTask.toString()).analyze(baseWork);
            for (String w : warnings) {
                this.warn(w);
            }
        }
    }

    private void checkForCrossProduct(String workName, Operator<? extends OperatorDesc> reducer, Map<Integer, CrossProductCheck.ExtractReduceSinkInfo.Info> rsInfo) {
        if (rsInfo.isEmpty()) {
            return;
        }
        Iterator<CrossProductCheck.ExtractReduceSinkInfo.Info> it = rsInfo.values().iterator();
        CrossProductCheck.ExtractReduceSinkInfo.Info info = it.next();
        if (info.keyCols.size() == 0) {
            ArrayList<String> iAliases = new ArrayList<String>();
            iAliases.addAll(info.inputAliases);
            while (it.hasNext()) {
                info = it.next();
                iAliases.addAll(info.inputAliases);
            }
            String warning = String.format("Shuffle Join %s[tables = %s] in Work '%s' is a cross product", reducer.toString(), iAliases, workName);
            this.warn(warning);
        }
    }
}

