import javax.swing.*;
import javax.swing.event.*;
import javax.swing.tree.*;
public class BinSortTree<E extends Comparable<E>> implements SortTree<E>, TreeModel {
    private BinTreeNode<E> root = null;

    public int size() { return size(this.root);  }
    private static <E extends Comparable<E>> int size(BinTreeNode<E> node) {
        if (node == null) {  return 0;  }
        return 1 + size(node.left) + size(node.right);
    }

    public boolean contains(E elem) {  return contains(this.root, elem);  }
    private static <E extends Comparable<E>> boolean contains(BinTreeNode<E> node, E elem) {
        if (node == null) {  return false;  }
        switch (signum(elem.compareTo(node.elem))) {
        case -1:  return contains(node.left, elem);
        case 0:   return true;
        case 1:   return contains(node.right, elem);
        }
        throw new RuntimeException("compareTo error");
    }

    public void add(E elem) {  this.root = add(this.root, elem);  }
    private static <E extends Comparable<E>> BinTreeNode<E> add(BinTreeNode<E> node, E elem) {
        if (node == null) {  return new BinTreeNode<E>(elem, null, null);  }
        switch (signum(elem.compareTo(node.elem))) {
        case -1:
            node.left = add(node.left, elem);
            return node;
        case 0:   return node;
        case 1:
            node.right = add(node.right, elem);
            return node;
        }
        throw new RuntimeException("compareTo error");
    }
    private static int signum(int x) {
        return x == 0 ? 0 : (x > 0 ? 1 : -1);
    }

    public void remove(E elem) {  this.root = remove(this.root, elem);  }
    private static <E extends Comparable<E>> BinTreeNode<E> remove(BinTreeNode<E> node, E elem) {
        if (node == null) {  return null;  }
        switch (signum(elem.compareTo(node.elem))) {
        case -1:
            return new BinTreeNode<E>(node.elem, remove(node.left, elem), node.right);
        case 0:
            if (node.left == null) {  return node.right;  }
            if (node.right == null) {  return node.left;  }
            E max = max(node.left);
            return new BinTreeNode<E>(max, remove(node.left, max), node.right);
        case 1:
            return new BinTreeNode<E>(node.elem, node.left, remove(node.right, elem));
        }
        throw new RuntimeException("compareTo error");
    }
    private static <E extends Comparable<E>> E max(BinTreeNode<E> node) {
        return node.right == null ? node.elem : max(node.right);
    }
    
    public java.util.List<E> traverse() {
        return traverse(this.root, new java.util.ArrayList<E>());
    }
    private static <E extends Comparable<E>> java.util.List<E> traverse(BinTreeNode<E> node, java.util.List<E> result) {
        if (node != null) {
            traverse(node.left, result);
            result.add(node.elem);
            traverse(node.right, result);
        }
        return result;
    }

    public void balance() {  this.root = balance(this.root);  }
    private static <E extends Comparable<E>> BinTreeNode<E> balance(BinTreeNode<E> node) {
        if (node == null) {  return null;  }
        int lSize = size(node.left);
        int rSize = size(node.right);
        
        while (lSize > rSize+1) {
            E max = max(node.left);  lSize--;  rSize++;
            node = new BinTreeNode<E>(max, remove(node.left, max), add(node.right, node.elem));
        }
        while (rSize > lSize+1) {
            E min = min(node.right);  rSize--;  lSize++;
            node = new BinTreeNode<E>(min, add(node.left, node.elem), remove(node.right, min));
        }
        return new BinTreeNode<E>(node.elem, balance(node.left), balance(node.right));
    }
    private static <E extends Comparable<E>> E min(BinTreeNode<E> node) {
        return node.left == null ? node.elem : min(node.left);
    }
    
    public static void main(String[] args) {
        BinSortTree<Character> tree = new BinSortTree<Character>();
        for (Character arg : new Character[] {'D','B','C','A','I','E','G','H','F'}) {
            tree.add(arg);
        }
        for (Character arg : new Character[] {'J','K','L','M','Z','Y','X','W','V','N','O','U','T','S','R','P','Q'}) {
            tree.add(arg);
        }
        tree.balance();
        JScrollPane content = new JScrollPane(new JTree(tree));
        JFrame window = new JFrame("BinSortTree");
        window.setContentPane(content);
        window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        window.pack();
        window.setVisible(true);
    }
    @Override
    public void addTreeModelListener(TreeModelListener listener) {}
    @Override
    public Object getChild(Object obj, int index) {
        BinTreeNode<E> node = (BinTreeNode<E>) obj;
        return index == 0 ? (node.left == null ? node.right : node.left) : node.right;
    }
    @Override
    public int getChildCount(Object obj) {
        BinTreeNode<E> node = (BinTreeNode<E>) obj;
        return (node.left == null ? 0 : 1) + (node.right == null ? 0 : 1);
    }
    @Override
    public int getIndexOfChild(Object obj1, Object obj2) {
        BinTreeNode<E> node = (BinTreeNode<E>) obj1;
        BinTreeNode<E> child = (BinTreeNode<E>) obj2;
        return node.left.equals(child) ? 0 : (node.left == null ? 0 : 1);
    }
    @Override
    public Object getRoot() {
        return this.root;
    }
    @Override
    public boolean isLeaf(Object obj) {
        return this.getChildCount(obj) == 0;
    }
    @Override
    public void removeTreeModelListener(TreeModelListener arg0) {}
    @Override
    public void valueForPathChanged(TreePath arg0, Object arg1) {}
}
