package com.wjholden.cmsc451;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.Point;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.BiFunction;
import javax.swing.BorderFactory;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;

/**
 *
 * @author William John Holden
 */
public class MinimumSpanningTree {

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args) {
        SwingUtilities.invokeLater(new Runnable() {
            public void run() {
                createAndShowGUI();
            }
        });
    }

    private static void createAndShowGUI() {
        JFrame f = new JFrame("Minimum Spanning Tree with Kruskal's Algorithm");
        f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        f.add(new MSTPanel());
        f.pack();
        f.setVisible(true);
    }

}

class MSTPanel extends JPanel {

    final Set<Vertex> v = new TreeSet<>();
    final Map<Vertex, Set<Vertex>> e = new TreeMap<>();
    Vertex blue = null;
    MST kruskal = new Kruskal();

    class Mouse extends MouseAdapter {

        Point drag = null;

        @Override
        public void mouseClicked(MouseEvent event) {
            Point location = event.getPoint();
            if (SwingUtilities.isLeftMouseButton(event)) {
                Vertex vertex = new Vertex(location);
                v.add(vertex);
                e.put(vertex, new TreeSet<>());
            }

            if (SwingUtilities.isRightMouseButton(event)) {
                if (blue == null) {
                    blue = Vertex.getNearest(v, location);
                } else {
                    Vertex end = Vertex.getNearest(v, location);
                    if (end != null && blue != end) {
                        e.get(blue).add(end);
                        blue = null;
                    }
                }
            }

            repaint();
        }

        @Override
        public void mousePressed(MouseEvent e) {
            if (SwingUtilities.isLeftMouseButton(e) && !v.isEmpty()) {
                drag = Vertex.getNearest(v, e.getPoint());
            }
        }

        @Override
        public void mouseReleased(MouseEvent e) {
            drag = null;
        }

        @Override
        public void mouseEntered(MouseEvent e) {
        }

        @Override
        public void mouseExited(MouseEvent e) {
        }

        @Override
        public void mouseDragged(MouseEvent e) {
            if (drag != null && SwingUtilities.isLeftMouseButton(e)) {
                drag.x = e.getX();
                drag.y = e.getY();
            }
            repaint();
        }

    }

    public MSTPanel() {
        setBorder(BorderFactory.createLineBorder(Color.black));

        Mouse mouse = new Mouse();
        addMouseListener(mouse);
        addMouseMotionListener(mouse);
    }

    @Override
    public Dimension getPreferredSize() {
        return new Dimension(500, 500);
    }

    @Override
    protected void paintComponent(Graphics g) {
        super.paintComponent(g);

        g.drawString("Left click to draw a point. Click and drag to move them.", 10, 20);
        g.drawString("Right click on two points to connect them.", 10, 40);

        v.forEach(vertex -> g.drawOval(vertex.x, vertex.y, 16, 16));

        e.forEach((start, edgeSet) -> {
            edgeSet.forEach(end -> {
                g.drawLine(start.x + 8, start.y + 8, end.x + 8, end.y + 8);
            });
        });

        if (blue != null) {
            g.setColor(Color.blue);
            g.fillOval(blue.x, blue.y, 16, 16);
        }

        g.setColor(Color.orange);
        kruskal.apply(v, e).forEach(edge -> {
            g.drawLine(edge.start.x + 8, edge.start.y + 8, edge.finish.x + 8, edge.finish.y + 8);
        });
    }

}

abstract class MST implements BiFunction<Set<Vertex>,Map<Vertex,Set<Vertex>>,Set<Edge>> {
    
}

class Kruskal extends MST {

    @Override
    public Set<Edge> apply(Set<Vertex> v, Map<Vertex, Set<Vertex>> e) {
        final SortedSet<Edge> mst = new TreeSet<>();
        final SortedSet<Edge> all = new TreeSet<>();

        e.forEach((start, edges) -> {
            edges.forEach(finish -> {
                all.add(new Edge(start, finish));
            });
        });

        List<Vertex> list = new ArrayList<>(v);
        UnionFind uf = new UnionFind(list.size());

        for (Edge edge : all) {
            if (mst.size() == v.size() - 1) {
                break;
            }
            int p = list.indexOf(edge.start);
            int q = list.indexOf(edge.finish);
            if (!uf.connected(p, q)) {
                uf.union(p, q);
                mst.add(edge);
            }
        }

        return mst;
    }

}

class Edge implements Comparable<Edge> {

    Vertex start, finish;

    public Edge(Vertex start, Vertex finish) {
        this.start = start;
        this.finish = finish;
    }

    @Override
    public int compareTo(Edge o) {
        return Double.valueOf(start.distance(finish)).compareTo(o.start.distance(o.finish));
    }
}

class Vertex extends Point implements Comparable<Vertex> {

    /* Yes, this is a terrible O(n) stupid search. Could be smarter. */
    static Vertex getNearest(Collection<Vertex> points, Point p) {
        Vertex nearest = null;
        for (Vertex vertex : points) {
            if (nearest == null || p.distance(nearest) > p.distance(vertex)) {
                nearest = vertex;
            }
        }
        if (nearest == null || p.distance(nearest) > 20) {
            return null;
        } else {
            return nearest;
        }
    }

    final int id;
    static int counter = 0;

    public Vertex(Point p) {
        super(p);
        id = counter;
        counter++;
    }

    @Override
    public int compareTo(Vertex o) {
        return Integer.valueOf(id).compareTo(o.id);
    }
}

/* I don't really understand this. I mostly copied it from
 http://algs4.cs.princeton.edu/15uf/QuickFindUF.java.html

 Hopefully the code works. UCSD has a more approachable
 explanation of this concept than Sedgewick & Wayne.

 http://www.math.ucsd.edu/~fan/teach/202/notes/01UnionFind.pdf

 All this really does is give an efficient but confusing  "is cyclic"
 operation for knowing whether we insert an edge into the MinimumSpanningTree or not.
 */
class UnionFind {

    int[] id;
    int count;

    public UnionFind(int n) {
        count = n;
        id = new int[n];
        for (int i = 0; i < n; i++) {
            id[i] = i;
        }
    }

    void union(int p, int q) {
        int pID = id[p];
        int qID = id[q];

        if (pID == qID) {
            return;
        }

        for (int i = 0; i < id.length; i++) {
            if (id[i] == pID) {
                id[i] = qID;
            }
        }
        count--;
    }

    int find(int p) {
        return id[p];
    }

    boolean connected(int p, int q) {
        return id[p] == id[q];
    }

    int count() {
        return count;
    }
}
