/*
 * Decompiled with CFR 0.152.
 */
package es.uvigo.darwin.prottest.consensus;

import es.uvigo.darwin.prottest.selection.InformationCriterion;
import es.uvigo.darwin.prottest.selection.model.SelectionModel;
import es.uvigo.darwin.prottest.tree.TreeUtils;
import es.uvigo.darwin.prottest.tree.WeightedTree;
import es.uvigo.darwin.prottest.util.FixedBitSet;
import es.uvigo.darwin.prottest.util.Utilities;
import es.uvigo.darwin.prottest.util.exception.ImportException;
import es.uvigo.darwin.prottest.util.exception.ProtTestInternalException;
import es.uvigo.darwin.prottest.util.fileio.NexusExporter;
import es.uvigo.darwin.prottest.util.fileio.NexusTreeReader;
import es.uvigo.darwin.prottest.util.fileio.SimpleNewickTreeReader;
import es.uvigo.darwin.prottest.util.fileio.TreeReader;
import es.uvigo.darwin.prottest.util.printer.ProtTestFormattedOutput;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import pal.misc.IdGroup;
import pal.misc.Identifier;
import pal.tree.Node;
import pal.tree.NodeFactory;
import pal.tree.SimpleTree;
import pal.tree.Tree;

public class Consensus {
    public static final boolean SUPPORT_AS_PERCENT = false;
    public static final int BRANCH_LENGTHS_AVERAGE = 1;
    public static final int BRANCH_LENGTHS_MEDIAN = 2;
    private static final BranchDistances DEFAULT_BRANCH_DISTANCES = BranchDistances.WeightedMedian;
    private static final int FIRST = 0;
    private List<WeightedTree> trees;
    private double cumWeight = 0.0;
    private int numTaxa;
    private IdGroup idGroup;
    private Map<FixedBitSet, Support> support = new HashMap<FixedBitSet, Support>();
    private Map<FixedBitSet, Double> cladeSupport;
    private Tree consensusTree;
    private List<FixedBitSet> splitsInConsensus = new ArrayList<FixedBitSet>();
    private List<FixedBitSet> splitsOutFromConsensus = new ArrayList<FixedBitSet>();

    private Map<FixedBitSet, Support> getSupport() {
        return this.support;
    }

    public Map<FixedBitSet, Double> getCladeSupport() {
        if (this.cladeSupport == null) {
            this.cladeSupport = new HashMap<FixedBitSet, Double>(this.support.size());
            Object[] keys = this.support.keySet().toArray(new FixedBitSet[0]);
            Arrays.sort(keys);
            for (Object fbs : keys) {
                this.cladeSupport.put((FixedBitSet)fbs, this.support.get(fbs).treesWeightWithClade / this.cumWeight);
            }
        }
        return this.cladeSupport;
    }

    public IdGroup getIdGroup() {
        return this.idGroup;
    }

    public Tree getConsensusTree() {
        return this.consensusTree;
    }

    public Collection<WeightedTree> getTrees() {
        return this.trees;
    }

    private boolean addTree(WeightedTree wTree) {
        if (wTree.getTree() == null || wTree.getWeight() < 0.0) {
            throw new ProtTestInternalException();
        }
        if (this.trees.isEmpty()) {
            this.trees.add(wTree);
            this.numTaxa = wTree.getTree().getIdCount();
            this.idGroup = pal.tree.TreeUtils.getLeafIdGroup((Tree)wTree.getTree());
        } else {
            if (wTree.getTree().getIdCount() != this.numTaxa) {
                return false;
            }
            Tree pTree = this.trees.get(0).getTree();
            for (int i = 0; i < this.numTaxa; ++i) {
                boolean found = false;
                for (int j = 0; j < this.numTaxa; ++j) {
                    if (!wTree.getTree().getIdentifier(i).equals((Object)pTree.getIdentifier(j))) continue;
                    found = true;
                    break;
                }
                if (found) continue;
                System.out.println("NOT COMPATIBLE TREES");
                return false;
            }
            this.trees.add(wTree);
        }
        this.cumWeight += wTree.getWeight();
        return true;
    }

    public Consensus(InformationCriterion ic, double supportThreshold) {
        this(ic, supportThreshold, 0);
    }

    public Consensus(InformationCriterion ic, double supportThreshold, int branchDistances) {
        this.trees = new ArrayList<WeightedTree>();
        for (SelectionModel model : ic.getConfidenceModels()) {
            WeightedTree wTree = new WeightedTree(model.getModel().getTree(), model.getWeightValue());
            this.addTree(wTree);
        }
        this.consensusTree = this.buildTree(supportThreshold, this.getBranchDistances(branchDistances));
    }

    public Consensus(List<WeightedTree> trees, double supportThreshold, int branchDistances) {
        this.trees = new ArrayList<WeightedTree>();
        for (WeightedTree tree : trees) {
            this.addTree(tree);
        }
        this.consensusTree = this.buildTree(supportThreshold, this.getBranchDistances(branchDistances));
    }

    public Consensus(File treesFile, double supportThreshold) throws ProtTestInternalException, IOException {
        this(treesFile, supportThreshold, 0);
    }

    public Consensus(File treesFile, double supportThreshold, int branchDistances) throws ProtTestInternalException, IOException {
        File resultDir;
        TreeReader treeReader;
        try {
            treeReader = new NexusTreeReader(treesFile);
        }
        catch (ImportException ex) {
            treeReader = new SimpleNewickTreeReader(treesFile);
        }
        this.trees = treeReader.getWeightedTreeList();
        this.cumWeight = treeReader.getCumWeight();
        this.numTaxa = treeReader.getNumTaxa();
        this.idGroup = treeReader.getIdGroup();
        this.consensusTree = this.buildTree(supportThreshold, this.getBranchDistances(branchDistances));
        String fileName = treesFile.getName();
        if (fileName.contains(".")) {
            fileName = fileName.substring(0, fileName.lastIndexOf("."));
        }
        if (!(resultDir = new File("results")).exists()) {
            resultDir.mkdir();
        }
        File outFile = File.createTempFile(fileName, ".con", resultDir);
        NexusExporter nw = new NexusExporter(outFile);
        nw.printConsensusBlock(this.consensusTree, ((NexusTreeReader)treeReader).getNexusId());
        nw.close();
    }

    private FixedBitSet rootedSupport(WeightedTree wTree, Node node, Map<FixedBitSet, Support> support) {
        FixedBitSet clade = new FixedBitSet(this.numTaxa);
        if (node.isLeaf()) {
            clade.set(this.idGroup.whichIdNumber(node.getIdentifier().getName()));
        } else {
            for (int i = 0; i < node.getChildCount(); ++i) {
                Node n = node.getChild(i);
                FixedBitSet childClade = this.rootedSupport(wTree, n, support);
                clade.union(childClade);
            }
        }
        Support s = support.get(clade);
        if (s == null) {
            s = new Support();
            support.put(clade, s);
        }
        s.add(wTree.getWeight(), TreeUtils.safeNodeHeight(wTree.getTree(), node), node.getBranchLength());
        return clade;
    }

    public Node detachChildren(Tree tree, Node node, List<Integer> split) {
        assert (split.size() > 1);
        ArrayList<Node> detached = new ArrayList<Node>();
        for (int n : split) {
            detached.add(node.getChild(n));
        }
        Node saveRoot = tree.getRoot();
        ArrayList<Integer> toRemove = new ArrayList<Integer>();
        for (int i = 0; i < node.getChildCount(); ++i) {
            Node n = node.getChild(i);
            if (!detached.contains(n)) continue;
            toRemove.add(0, i);
        }
        Iterator i$ = toRemove.iterator();
        while (i$.hasNext()) {
            int i = (Integer)i$.next();
            node.removeChild(i);
        }
        Node dnode = NodeFactory.createNode((Node[])detached.toArray(new Node[0]));
        node.addChild(dnode);
        tree.setRoot(saveRoot);
        return dnode;
    }

    private Tree buildTree(double supportThreshold, BranchDistances branchDistances) {
        Map.Entry<FixedBitSet, Support> e;
        Support s;
        double psupport;
        if (this.trees.isEmpty()) {
            throw new ProtTestInternalException("There are no trees to consense");
        }
        if (supportThreshold < 0.5 || supportThreshold > 1.0) {
            throw new ProtTestInternalException("Invalid threshold value: " + supportThreshold);
        }
        this.support = new HashMap<FixedBitSet, Support>();
        int k = 0;
        for (WeightedTree wTree : this.trees) {
            this.rootedSupport(wTree, wTree.getTree().getRoot(), this.support);
            ++k;
        }
        SimpleTree cons = new SimpleTree();
        ArrayList<Node> internalNodes = new ArrayList<Node>(this.numTaxa);
        ArrayList<FixedBitSet> internalNodesTips = new ArrayList<FixedBitSet>(this.numTaxa);
        assert (this.idGroup.getIdCount() == this.numTaxa);
        internalNodesTips.add(new FixedBitSet(this.numTaxa));
        FixedBitSet rooNode = (FixedBitSet)internalNodesTips.get(0);
        Node[] nodes = new Node[this.numTaxa];
        for (int nt = 0; nt < this.numTaxa; ++nt) {
            nodes[nt] = NodeFactory.createNode((Identifier)this.idGroup.getIdentifier(nt));
            rooNode.set(nt);
        }
        Node rootNode = NodeFactory.createNode((Node[])nodes);
        internalNodes.add(rootNode);
        cons.setRoot(rootNode);
        Comparator<Map.Entry<FixedBitSet, Support>> comparator = new Comparator<Map.Entry<FixedBitSet, Support>>(){

            @Override
            public int compare(Map.Entry<FixedBitSet, Support> o1, Map.Entry<FixedBitSet, Support> o2) {
                double diff = o2.getValue().treesWeightWithClade - o1.getValue().treesWeightWithClade;
                if (diff > 0.0) {
                    return 1;
                }
                if (diff < 0.0) {
                    return -1;
                }
                return 0;
            }
        };
        PriorityQueue<Map.Entry<FixedBitSet, Support>> queue = new PriorityQueue<Map.Entry<FixedBitSet, Support>>(this.support.size(), comparator);
        for (Map.Entry<FixedBitSet, Support> se : this.support.entrySet()) {
            Support s2 = se.getValue();
            FixedBitSet clade = se.getKey();
            int cladeSize = clade.cardinality();
            if (cladeSize == this.numTaxa) {
                cons.getRoot().setNodeHeight(s2.sumBranches / (double)this.trees.size());
                cons.getRoot().setBranchLength(branchDistances.build(s2.branchLengths));
                continue;
            }
            if (Math.abs(s2.treesWeightWithClade - this.cumWeight) < 1.0E-5 && cladeSize == 1) {
                int nt = clade.nextOnBit(0);
                Node leaf = cons.getExternalNode(nt);
                leaf.setNodeHeight(s2.sumBranches / (double)this.trees.size());
                leaf.setBranchLength(branchDistances.build(s2.branchLengths));
                continue;
            }
            queue.add(se);
        }
        block3: while (queue.peek() != null && !((psupport = 1.0 * (s = (e = queue.poll()).getValue()).treesWeightWithClade / this.cumWeight) < supportThreshold)) {
            FixedBitSet cladeTips = e.getKey();
            boolean found = false;
            for (int nsub = internalNodesTips.size() - 1; nsub >= 0; --nsub) {
                FixedBitSet allNodeTips = (FixedBitSet)internalNodesTips.get(nsub);
                int nSplit = allNodeTips.intersectCardinality(cladeTips);
                if (nSplit != cladeTips.cardinality()) continue;
                found = true;
                ArrayList<Integer> split = new ArrayList<Integer>();
                Node n = (Node)internalNodes.get(nsub);
                int l = 0;
                for (int j = 0; j < n.getChildCount(); ++j) {
                    Node ch = n.getChild(j);
                    if (ch.isLeaf()) {
                        if (cladeTips.contains(this.idGroup.whichIdNumber(ch.getIdentifier().getName()))) {
                            split.add(l);
                        }
                    } else {
                        int o = internalNodes.indexOf(ch);
                        int i = ((FixedBitSet)internalNodesTips.get(o)).intersectCardinality(cladeTips);
                        if (i == ((FixedBitSet)internalNodesTips.get(o)).cardinality()) {
                            split.add(l);
                        } else if (i > 0) {
                            found = false;
                            break;
                        }
                    }
                    ++l;
                }
                if (!found || split.size() >= n.getChildCount()) {
                    found = false;
                    continue block3;
                }
                if (split.isEmpty()) {
                    System.err.println("Bug??");
                    assert (false);
                }
                Node detached = this.detachChildren((Tree)cons, n, split);
                double height = s.sumBranches / (double)s.nTreesWithClade;
                detached.setNodeHeight(height);
                detached.setBranchLength(branchDistances.build(s.branchLengths));
                cons.setAttribute(detached, "support", (Object)psupport);
                internalNodes.add(nsub + 1, detached);
                internalNodesTips.add(nsub + 1, new FixedBitSet(cladeTips));
                continue block3;
            }
        }
        TreeUtils.insureConsistency((Tree)cons, cons.getRoot());
        String thresholdAsPercent = String.valueOf(supportThreshold * 100.0);
        cons.setAttribute(cons.getRoot(), "treeName", (Object)("cons_" + thresholdAsPercent + "_majRule"));
        Set<FixedBitSet> keySet = this.getSupport().keySet();
        Object[] keys = keySet.toArray(new FixedBitSet[0]);
        Arrays.sort(keys);
        for (Object fbs : keys) {
            if (((FixedBitSet)fbs).cardinality() <= 1) continue;
            double psupport2 = 1.0 * this.getSupport().get(fbs).getTreesWeightWithClade() / this.cumWeight;
            if (psupport2 < supportThreshold) {
                this.splitsOutFromConsensus.add((FixedBitSet)fbs);
                continue;
            }
            this.splitsInConsensus.add((FixedBitSet)fbs);
        }
        return cons;
    }

    /*
     * WARNING - void declaration
     */
    public static void main(String[] args) {
        PrintWriter out = new PrintWriter(System.out);
        if (args.length < 2 || args[0].contains("help")) {
            out.println("This class requires at least 2 arguments: Tree set filename, Threshold value and optionally the branch length calculation method [avg (default), median]");
            out.println("The file format should be:");
            out.println("    \u00b7 (Newick's tree)[Weight];");
            out.println("    \u00b7 Nexus format tree set");
            out.flush();
            System.exit(-1);
        }
        String filename = args[0];
        Double threshold = Double.parseDouble(args[1]);
        File f = new File(filename);
        try {
            int i;
            void var13_21;
            Consensus consensus = args.length >= 3 ? (args[2].equalsIgnoreCase("median") ? new Consensus(f, (double)threshold, 2) : (args[2].equalsIgnoreCase("avg") ? new Consensus(f, (double)threshold, 1) : new Consensus(f, (double)threshold))) : new Consensus(f, (double)threshold);
            Tree consensusTree = consensus.getConsensusTree();
            out.println("");
            Set<FixedBitSet> keySet = consensus.getSupport().keySet();
            Object[] keys = keySet.toArray(new FixedBitSet[0]);
            ArrayList<Object> splitsInConsensus = new ArrayList<Object>();
            ArrayList<Object> splitsOutFromConsensus = new ArrayList<Object>();
            Arrays.sort(keys);
            Object[] arr$ = keys;
            int len$ = arr$.length;
            boolean bl = false;
            while (var13_21 < len$) {
                Object fbs = arr$[var13_21];
                if (((FixedBitSet)fbs).cardinality() > 1) {
                    double psupport = 1.0 * consensus.getSupport().get(fbs).getTreesWeightWithClade() / consensus.cumWeight;
                    if (psupport < threshold) {
                        splitsOutFromConsensus.add(fbs);
                    } else {
                        splitsInConsensus.add(fbs);
                    }
                }
                ++var13_21;
            }
            out.println("# # # # # # # # # # # # # # # #");
            out.println(" ");
            out.println("Species in order:");
            out.println(" ");
            for (int i2 = 0; i2 < consensus.getIdGroup().getIdCount(); ++i2) {
                Identifier id = consensus.getIdGroup().getIdentifier(i2);
                out.println("    " + (i2 + 1) + ". " + id.getName());
            }
            out.println(" ");
            out.println("# # # # # # # # # # # # # # # #");
            out.println(" ");
            out.println("Sets included in the consensus tree");
            out.println(" ");
            out.print("    ");
            int numTaxa = consensus.getIdGroup().getIdCount();
            for (i = 0; i < consensus.getIdGroup().getIdCount(); ++i) {
                out.print(String.valueOf(i + 1).charAt(0));
            }
            out.println(" ");
            if (numTaxa >= 10) {
                ProtTestFormattedOutput.space(13, ' ');
                for (i = 9; i < consensus.getIdGroup().getIdCount(); ++i) {
                    out.print(String.valueOf(i + 1).charAt(1));
                }
            }
            out.println(" ");
            for (FixedBitSet fixedBitSet : splitsInConsensus) {
                out.println("    " + fixedBitSet.splitRepresentation() + " ( " + Utilities.round(consensus.getSupport().get(fixedBitSet).getTreesWeightWithClade(), 3) + " )");
            }
            out.println(" ");
            out.println("Sets NOT included in consensus tree");
            out.println(" ");
            out.print("    ");
            for (int i3 = 0; i3 < consensus.getIdGroup().getIdCount(); ++i3) {
                out.print(i3 + 1);
            }
            out.println(" ");
            for (FixedBitSet fixedBitSet : splitsOutFromConsensus) {
                out.println("    " + fixedBitSet.splitRepresentation() + " ( " + Utilities.round(consensus.getSupport().get(fixedBitSet).getTreesWeightWithClade(), 3) + " )");
            }
            out.println(" ");
            out.println("# # # # # # # # # # # # # # # #");
            TreeUtils.printASCII(consensusTree, out);
            out.println(" ");
            TreeUtils.printBranchInfo(consensusTree, out);
            out.println(" ");
            TreeUtils.heightInfo(consensusTree, out);
            out.println(" ");
            out.println("# # # # # # # # # # # # # # # #");
            out.println(" ");
            out.println(TreeUtils.toNewick(consensusTree, true, true, true));
            out.println(" ");
            out.println("# # # # # # # # # # # # # # # #");
            out.println(" ");
        }
        catch (FileNotFoundException e) {
            out.println("File not found: " + filename);
        }
        catch (IOException e1) {
            out.println("IO Error: " + e1.getMessage());
        }
        out.flush();
    }

    public String getTaxaHeader() {
        int i;
        StringBuilder taxaHeader = new StringBuilder();
        for (i = 0; i < this.numTaxa; ++i) {
            taxaHeader.append(String.valueOf(i + 1).charAt(0));
        }
        if (this.numTaxa >= 10) {
            taxaHeader.append('\n');
            taxaHeader.append(ProtTestFormattedOutput.space(13, ' '));
            for (i = 9; i < this.numTaxa; ++i) {
                taxaHeader.append(String.valueOf(i + 1).charAt(1));
            }
        }
        if (this.numTaxa >= 100) {
            taxaHeader.append('\n');
            taxaHeader.append(ProtTestFormattedOutput.space(103, ' '));
            for (i = 99; i < this.numTaxa; ++i) {
                taxaHeader.append(String.valueOf(i + 1).charAt(2));
            }
        }
        if (this.numTaxa >= 1000) {
            taxaHeader.append('\n');
            taxaHeader.append(ProtTestFormattedOutput.space(1003, ' '));
            for (i = 999; i < this.numTaxa; ++i) {
                taxaHeader.append(String.valueOf(i + 1).charAt(3));
            }
        }
        return taxaHeader.toString();
    }

    public String getSetsIncluded() {
        StringBuilder setsIncluded = new StringBuilder();
        setsIncluded.append("    ");
        setsIncluded.append(this.getTaxaHeader());
        setsIncluded.append('\n');
        for (FixedBitSet fbs : this.splitsInConsensus) {
            setsIncluded.append("    ").append(fbs.splitRepresentation()).append(" ( ").append(Utilities.round(this.getCladeSupport().get(fbs), 5)).append(" )").append('\n');
        }
        return setsIncluded.toString();
    }

    public String getSetsNotIncluded() {
        StringBuilder setsIncluded = new StringBuilder();
        setsIncluded.append("    ");
        setsIncluded.append(this.getTaxaHeader());
        setsIncluded.append('\n');
        for (FixedBitSet fbs : this.splitsOutFromConsensus) {
            setsIncluded.append("    ").append(fbs.splitRepresentation()).append(" ( ").append(Utilities.round(this.getCladeSupport().get(fbs), 5)).append(" )").append('\n');
        }
        return setsIncluded.toString();
    }

    private BranchDistances getBranchDistances(int value) {
        BranchDistances bd;
        switch (value) {
            case 1: {
                bd = BranchDistances.WeightedAverage;
                break;
            }
            case 2: {
                bd = BranchDistances.WeightedMedian;
                break;
            }
            default: {
                bd = DEFAULT_BRANCH_DISTANCES;
            }
        }
        return bd;
    }

    static class UnweightedTree
    extends WeightedTree {
        UnweightedTree(Tree tree) {
            super(tree, 1.0);
        }
    }

    static class WeightLengthPair
    implements Comparable<WeightLengthPair> {
        private double weight;
        private double branchLength;

        WeightLengthPair(double weight, double branchLength) {
            this.weight = weight;
            this.branchLength = branchLength;
        }

        @Override
        public int compareTo(WeightLengthPair o) {
            if (this.branchLength < o.branchLength) {
                return -1;
            }
            if (this.branchLength > o.branchLength) {
                return 1;
            }
            return 0;
        }
    }

    static final class Support {
        private int nTreesWithClade = 0;
        private double treesWeightWithClade = 0.0;
        private ArrayList<WeightLengthPair> branchLengths = new ArrayList();
        private double sumBranches = 0.0;

        public double getTreesWeightWithClade() {
            return this.treesWeightWithClade;
        }

        Support() {
        }

        public final void add(double weight, double height, double branchLength) {
            this.sumBranches += height;
            this.branchLengths.add(new WeightLengthPair(weight, branchLength));
            this.treesWeightWithClade += weight;
            ++this.nTreesWithClade;
            double testW = 0.0;
            for (WeightLengthPair wlp : this.branchLengths) {
                testW += wlp.weight;
            }
        }
    }

    private static enum BranchDistances {
        WeightedAverage{

            @Override
            public double build(List<WeightLengthPair> values) {
                double avg = 0.0;
                double cumWeight = 0.0;
                for (WeightLengthPair pair : values) {
                    avg += pair.branchLength * pair.weight;
                    cumWeight += pair.weight;
                }
                return avg /= cumWeight;
            }
        }
        ,
        WeightedMedian{

            @Override
            public double build(List<WeightLengthPair> values) {
                Collections.sort(values);
                double median = -1.0;
                double cumWeight = 0.0;
                for (WeightLengthPair pair : values) {
                    cumWeight += pair.weight;
                }
                double halfWeight = cumWeight / 2.0;
                double cumValue = 0.0;
                for (WeightLengthPair pair : values) {
                    if (!((cumValue += pair.weight) >= halfWeight)) continue;
                    median = pair.branchLength;
                    break;
                }
                return median;
            }
        };


        public abstract double build(List<WeightLengthPair> var1);
    }
}

