/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
    public static final AggregateExpandDistinctAggregatesRule INSTANCE = new AggregateExpandDistinctAggregatesRule(LogicalAggregate.class, RelFactories.DEFAULT_JOIN_FACTORY);
    private final RelFactories.JoinFactory joinFactory;

    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, RelFactories.JoinFactory joinFactory) {
        super(AggregateExpandDistinctAggregatesRule.operand(clazz, AggregateExpandDistinctAggregatesRule.any()));
        this.joinFactory = joinFactory;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (!aggregate.containsDistinctCall()) {
            return;
        }
        int nonDistinctCount = 0;
        LinkedHashSet argListSets = new LinkedHashSet();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                ++nonDistinctCount;
                continue;
            }
            ArrayList<Integer> argList = new ArrayList<Integer>();
            for (Integer arg : aggCall.getArgList()) {
                argList.add(arg);
            }
            argListSets.add(argList);
        }
        Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied");
        if (nonDistinctCount == 0 && argListSets.size() == 1) {
            RelNode converted = this.convertMonopole(aggregate, (List)argListSets.iterator().next());
            call.transformTo(converted);
            return;
        }
        List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
        ArrayList<RexInputRef> refs = new ArrayList<RexInputRef>();
        List<String> fieldNames = aggregate.getRowType().getFieldNames();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        for (int i2 : Util.range(groupAndIndicatorCount)) {
            refs.add(RexInputRef.of(i2, aggFields));
        }
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        int i = -1;
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            ++i;
            if (aggCall.isDistinct()) {
                refs.add(null);
                continue;
            }
            refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
            newAggCallList.add(aggCall);
        }
        RelNode rel = newAggCallList.isEmpty() ? null : LogicalAggregate.create(aggregate.getInput(), aggregate.indicator, groupSet, aggregate.getGroupSets(), newAggCallList);
        for (List list : argListSets) {
            rel = this.doRewrite(aggregate, rel, list, refs);
        }
        rel = RelOptUtil.createProject(rel, refs, fieldNames);
        call.transformTo(rel);
    }

    private RelNode convertMonopole(Aggregate aggregate, List<Integer> argList) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        Aggregate distinct = AggregateExpandDistinctAggregatesRule.createSelectDistinct(aggregate, argList, sourceOf);
        ArrayList newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
        AggregateExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        return aggregate.copy(aggregate.getTraitSet(), distinct, aggregate.indicator, ImmutableBitSet.range(cardinality), null, newAggCalls);
    }

    private RelNode doRewrite(Aggregate aggregate, RelNode left, List<Integer> argList, List<RexInputRef> refs) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<RelDataTypeField> leftFields = left == null ? null : left.getRowType().getFieldList();
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        Aggregate distinct = AggregateExpandDistinctAggregatesRule.createSelectDistinct(aggregate, argList, sourceOf);
        ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>();
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int i = groupAndIndicatorCount - 1;
        for (AggregateCall aggCall : aggCalls) {
            ++i;
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add((Integer)sourceOf.get(arg));
            }
            AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
            assert (refs.get(i) == null);
            if (left == null) {
                refs.set(i, new RexInputRef(groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
            } else {
                refs.set(i, new RexInputRef(leftFields.size() + groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
            }
            aggCallList.add(newAggCall);
        }
        Aggregate distinctAgg = aggregate.copy(aggregate.getTraitSet(), distinct, aggregate.indicator, ImmutableBitSet.range(aggregate.getGroupSet().cardinality()), (List<ImmutableBitSet>)aggregate.getGroupSets(), (List<AggregateCall>)aggCallList);
        if (left == null) {
            return distinctAgg;
        }
        List<RelDataTypeField> distinctFields = distinctAgg.getRowType().getFieldList();
        ArrayList conditions = Lists.newArrayList();
        for (i = 0; i < groupAndIndicatorCount; ++i) {
            conditions.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
        }
        return this.joinFactory.createJoin(left, distinctAgg, RexUtil.composeConjunction(rexBuilder, conditions, false), JoinRelType.INNER, (Set<String>)ImmutableSet.of(), false);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add(sourceOf.get(arg));
            }
            AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private static Aggregate createSelectDistinct(Aggregate aggregate, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        RelNode child = aggregate.getInput();
        List<RelDataTypeField> childFields = child.getRowType().getFieldList();
        for (int i : aggregate.getGroupSet()) {
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2(i, childFields));
        }
        for (Integer arg : argList) {
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2(arg, childFields));
        }
        RelNode project = RelOptUtil.createProject(child, projects, false);
        return aggregate.copy(aggregate.getTraitSet(), project, false, ImmutableBitSet.range(projects.size()), null, (List<AggregateCall>)ImmutableList.of());
    }
}

