SortedValueSetImpl.java

package sprouts.impl;

import org.jspecify.annotations.Nullable;
import sprouts.Val;
import sprouts.ValueSet;

import java.util.*;
import java.util.stream.Stream;

import static sprouts.impl.ArrayUtil.*;

final class SortedValueSetImpl<E> implements ValueSet<E> {

    private static final boolean ALLOWS_NULL = false;
    private static final Node NULL_NODE = new Node(
            _createArray(Object.class, ALLOWS_NULL, 0)
    );

    private static int BASE_ENTRIES_PER_NODE(int depth) {
        return Math.max(1, depth * depth / 2);
    }

    static class Node {
        private final int _size;
        private final Object _elementsArray;
        private final @Nullable Node _left;
        private final @Nullable Node _right;

        Node(Object elementsArray) {
            this(elementsArray, null, null);
        }

        Node(Object elementsArray, @Nullable Node left, @Nullable Node right) {
            _size = _length(elementsArray) +
                    (left == null ? 0 : left.size()) +
                    (right == null ? 0 : right.size());
            _elementsArray = elementsArray;
            _left = left;
            _right = right;
        }

        Node(int size, Object elementsArray, @Nullable Node left, @Nullable Node right) {
            _size = size;
            _elementsArray = elementsArray;
            _left = left;
            _right = right;
        }

        public Object elementsArray() {
            return _elementsArray;
        }
        public @Nullable Node left() {
            return _left;
        }
        public @Nullable Node right() {
            return _right;
        }
        public int size() {
            return _size;
        }
        public Node withNewArrays(Object newElementsArray) {
            return new Node(newElementsArray, _left, _right);
        }
        public Node withNewLeft(@Nullable Node left) {
            return new Node(_elementsArray, left, _right);
        }
        public Node withNewRight(@Nullable Node right) {
            return new Node(_elementsArray, _left, right);
        }

        @Override
        public int hashCode() {
            int elementsHash = Val.hashCode(_elementsArray);
            return Objects.hash(_size, elementsHash, _left, _right);
        }

        @Override
        public boolean equals(@Nullable Object obj) {
            if (this == obj) return true;
            if (!(obj instanceof Node)) return false;
            Node other = (Node) obj;
            return _size == other._size &&
                    Val.equals(_elementsArray, other._elementsArray) &&
                    Objects.equals(_left, other._left) &&
                    Objects.equals(_right, other._right);
        }
    }

    private final Class<E> _type;
    private final Comparator<E> _comparator;
    private final Node _root;

    SortedValueSetImpl(
            final Class<E> type,
            final Comparator<E> comparator
    ) {
        this(type, comparator, NULL_NODE);
    }

    private SortedValueSetImpl(
            final Class<E> type,
            final Comparator<E> comparator,
            final Node root
    ) {
        _type = type;
        _comparator = comparator;
        _root = root;
    }

    @Override
    public int size() {
        return _root.size();
    }

    @Override
    public boolean isSorted() {
        return true;
    }

    @Override
    public Class<E> type() {
        return _type;
    }

    @Override
    public boolean contains(E element) {
        if (element == null) {
            throw new NullPointerException("Null element");
        }
        if (!_type.isAssignableFrom(element.getClass())) {
            throw new ClassCastException("Element type mismatch");
        }
        return _findElement(_root, _type, _comparator, element) != null;
    }

    private static <E> @Nullable E _findElement(
            Node node,
            Class<E> type,
            Comparator<E> comparator,
            E element
    ) {
        int numberOfElements = _length(node.elementsArray());
        int index = _binarySearch(node.elementsArray(), type, comparator, element);
        if (index < 0) {
            Node left = node.left();
            if (left != null) {
                E value = _findElement(left, type, comparator, element);
                if (value != null) return value;
            }
        }
        if (index >= numberOfElements) {
            Node right = node.right();
            if (right != null) {
                E value = _findElement(right, type, comparator, element);
                if (value != null) return value;
            }
        }
        if (index >= 0 && index < numberOfElements) {
            boolean elementExists = Objects.equals(element, _getAt(index, node.elementsArray(), type));
            if (elementExists) {
                return _getAt(index, node.elementsArray(), type);
            }
        }
        return null;
    }

    @Override
    public ValueSet<E> add(E element) {
        if (element == null) {
            throw new NullPointerException("Null element");
        }
        if (!_type.isAssignableFrom(element.getClass())) {
            throw new ClassCastException("Element type mismatch");
        }
        Node newRoot = _balance(_updateElement(_root, _type, _comparator, element, 0));
        if (Util.refEquals(newRoot, _root)) {
            return this;
        }
        return new SortedValueSetImpl<>(_type, _comparator, newRoot);
    }

    @Override
    public ValueSet<E> addAll(Stream<? extends E> elements) {
        Objects.requireNonNull(elements);
        // TODO: implement branching based bulk insert
        SortedValueSetImpl<E> result = this;
        // reduce the stream to a single association
        return elements.reduce(
                result,
                (acc,
                 entry) -> (SortedValueSetImpl<E>) acc.add(entry),
                (a, b) -> a);
    }

    private static <E> Node _updateElement(
            Node node,
            Class<E> keyType,
            Comparator<E> keyComparator,
            E key,
            int depth
    ) {
        int numberOfKeys = _length(node.elementsArray());
        int index = _binarySearch(node.elementsArray(), keyType, keyComparator, key);
        boolean foundInCurrentNode = index >= 0 && index < numberOfKeys;
        boolean leftAndRightAreNull = node.left() == null && node.right() == null;
        if ( leftAndRightAreNull && !foundInCurrentNode && numberOfKeys < BASE_ENTRIES_PER_NODE(depth) ) {
            // We add to the left
            Object newKeysArray = _createArray(keyType, ALLOWS_NULL, numberOfKeys+1);
            // arraycopy
            if ( index < 0 ) {
                if ( numberOfKeys > 0 ) {
                    System.arraycopy(node.elementsArray(), 0, newKeysArray, 1, numberOfKeys);
                }
                _setAt(0, key, newKeysArray);
            } else {
                if ( numberOfKeys > 0 ) {
                    System.arraycopy(node.elementsArray(), 0, newKeysArray, 0, numberOfKeys);
                }
                _setAt(numberOfKeys, key, newKeysArray);
            }
            return node.withNewArrays(newKeysArray);
        }
        if ( index < 0 ) {
            Node left = node.left();
            if ( left != null ) {
                Node newLeft = _balance(_updateElement(left, keyType, keyComparator, key, depth+1));
                if ( Util.refEquals(newLeft, left) ) {
                    return node; // No change in the left node
                }
                return node.withNewLeft(newLeft);
            } else { // Left is null, we create a new node
                Object newKeysArray = _createArray(keyType, ALLOWS_NULL, 1);
                _setAt(0, key, newKeysArray);
                return node.withNewLeft(new Node(newKeysArray));
            }
        }
        if ( index >= numberOfKeys ) {
            Node right = node.right();
            if ( right != null ) {
                Node newRight = _balance(_updateElement(right, keyType, keyComparator, key, depth+1));
                if ( Util.refEquals(newRight, right) ) {
                    return node; // No change in the right node
                }
                return node.withNewRight(newRight);
            } else { // Right is null, we create a new node
                Object newKeysArray = _createArray(keyType, ALLOWS_NULL, 1);
                _setAt(0, key, newKeysArray);
                return node.withNewRight(new Node(newKeysArray));
            }
        }

        boolean keyAlreadyExists = Objects.equals(key, _getAt(index, node.elementsArray(), keyType));
        if ( !keyAlreadyExists ) {
            if ( numberOfKeys < BASE_ENTRIES_PER_NODE(depth) ) {
                // We need to insert the key in the right place
                Object newKeysArray = _createArray(keyType, ALLOWS_NULL, numberOfKeys + 1);
                // arraycopy up to index, item, and then trailing item copy
                // First keys:
                System.arraycopy(node.elementsArray(), 0, newKeysArray, 0, index);
                _setAt(index, key, newKeysArray);
                System.arraycopy(node.elementsArray(), index, newKeysArray, index + 1, numberOfKeys - index);
                return node.withNewArrays(newKeysArray);
            } else {
                /*
                    Ok, so this is an interesting case. We have a full node, and we need to INSERT a new key
                    somewhere in the middle of the node. We do this by popping an excess entry from
                    one of the sides of the local arrays and then let this popped-off entry trickle down
                    to the left or right side of the tree.
                */
                Object newKeysArray = _createArray(keyType, ALLOWS_NULL, numberOfKeys);
                int numberOfEntriesLeft = node.left() == null ? 0 : _length(node.left().elementsArray());
                int numberOfEntriesRight = node.right() == null ? 0 : _length(node.right().elementsArray());
                if ( numberOfEntriesLeft < numberOfEntriesRight ) {
                    if ( index == 0 ) {
                        // we just update the left node
                        Node newLeft;
                        if ( node.left() != null ) {
                            // Re-add the popped key and value to the left node
                            newLeft = _balance(_updateElement(node.left(), keyType, keyComparator, key, depth+1));
                        } else {
                            newLeft = _createSingleEntryNode(keyType, key);
                        }
                        return node.withNewLeft(newLeft);
                    }
                    E poppedOffKey = _getNonNullAt(0, node.elementsArray(), keyType);
                    Node newLeft;
                    if ( node.left() != null ) {
                        // Re-add the popped key and value to the left node
                        newLeft = _balance(_updateElement(node.left(), keyType, keyComparator, poppedOffKey, depth+1));
                    } else {
                        newLeft = _createSingleEntryNode(keyType, poppedOffKey);
                    }
                    // We pop from the left
                    if ( numberOfKeys == 1 ) {
                        // We add the actual key and value to the current node as well as the new left node
                        _setAt(0, key, newKeysArray);
                    } else {
                        // First, insert the key and value at the index (adjust for the popped key)
                        _setAt(index-1, key, newKeysArray);
                        // Then, copy up to the index
                        System.arraycopy(node.elementsArray(), 1, newKeysArray, 0, index-1);
                        // Finally, copy the rest of the keys and values
                        System.arraycopy(node.elementsArray(), index, newKeysArray, index, numberOfKeys - index);
                    }
                    return new Node(newKeysArray, newLeft, node.right());
                } else {
                    if ( index == numberOfKeys ) {
                        // we just update the right node
                        Node newRight;
                        if ( node.right() != null ) {
                            // Re-add the popped key and value to the right node
                            newRight = _balance(_updateElement(node.right(), keyType, keyComparator, key, depth+1));
                        } else {
                            newRight = _createSingleEntryNode(keyType, key);
                        }
                        return node.withNewRight(newRight);
                    }
                    E poppedOffKey = _getNonNullAt(numberOfKeys-1, node.elementsArray(), keyType);
                    Node newRight;
                    if ( node.right() != null ) {
                        // Re-add the popped key and value to the right node
                        newRight = _balance(_updateElement(node.right(), keyType, keyComparator, poppedOffKey, depth+1));
                    } else {
                        newRight = _createSingleEntryNode(keyType, poppedOffKey);
                    }
                    // We pop from the right
                    if ( numberOfKeys == 1 ) {
                        // We add the actual key and value to the current node as well as the new right node
                        _setAt(0, key, newKeysArray);
                    } else {
                        // First, insert the key and value at the index (adjust for the popped key)
                        _setAt(index, key, newKeysArray);
                        // Then, copy up to the index
                        System.arraycopy(node.elementsArray(), 0, newKeysArray, 0, index);
                        // Finally, copy the rest of the keys and values
                        System.arraycopy(node.elementsArray(), index, newKeysArray, index+1, numberOfKeys - index - 1);
                    }
                    return new Node(newKeysArray, node.left(), newRight);
                }
            }
        }
        return node;
    }

    private static @Nullable Node _balanceNullable(@Nullable Node node){
        if (node == null)
            return null;
        return _balance(node);
    }

    private static Node _balance(Node node){
        
        final Node right = node.right();
        final Node left = node.left();
        final int leftSize = left == null ? 0 : left.size();
        final int rightSize = right == null ? 0 : right.size();
        if ( leftSize == rightSize ) {
            return node;
        }
        final int currentNodeArraySize = _length(node.elementsArray());
        if ( leftSize < rightSize && right != null ) {
            final int imbalance = rightSize - leftSize;
            final int rightArraySize = _length(right.elementsArray());
            final int rightLeftSize = right.left() == null ? 0 : right.left().size();
            final int newRightSize = rightSize - rightLeftSize - rightArraySize;
            final int newLeftSize = leftSize + rightLeftSize + currentNodeArraySize;
            final int newImbalance = Math.abs(newRightSize - newLeftSize);
            if ( newImbalance < imbalance ) { // We only re-balance if it is worth it!
                Node newLeft = new Node(newLeftSize, node.elementsArray(), left, right.left());
                return new Node(
                        node.size(), right.elementsArray(), newLeft, right.right()
                    );
            }
        }
        if ( rightSize < leftSize && left != null ) {
            final int imbalance = leftSize - rightSize;
            final int leftArraySize = _length(left.elementsArray());
            final int leftRightSize = left.right() == null ? 0 : left.right().size();
            final int newLeftSize = rightSize - leftRightSize - leftArraySize;
            final int newRightSize = leftSize + leftRightSize + currentNodeArraySize;
            final int newImbalance = Math.abs(newLeftSize - newRightSize);
            if ( newImbalance < imbalance ) { // We only re-balance if it is worth it!
                Node newRight = new Node(newRightSize, node.elementsArray(), left.right(), right);
                return new Node(
                        node.size(), left.elementsArray(), left.left(), newRight
                    );
            }
        }
        return node;
    }

    private static Node _createSingleEntryNode(
            Class<?> keyType, Object key
    ) {
        Object newKeysArray = _createArray(keyType, ALLOWS_NULL, 1);
        _setAt(0, key, newKeysArray);
        return new Node(newKeysArray);
    }

    @Override
    public ValueSet<E> remove(E element) {
        if (element == null) {
            throw new NullPointerException("Null element");
        }
        if (!_type.isAssignableFrom(element.getClass())) {
            throw new ClassCastException("Element type mismatch");
        }
        Node newRoot = _balanceNullable(_removeElement(_root, _type, _comparator, element));
        newRoot = newRoot == null ? NULL_NODE : newRoot;
        if ( Util.refEquals(newRoot, _root) ) {
            return this;
        }
        return new SortedValueSetImpl<>(_type, _comparator, newRoot);
    }

    @Override
    public ValueSet<E> removeAll(Stream<? extends E> elements) {
        if ( this.isEmpty() )
            return this;
        ValueSet<E> result = this;
        result = elements.reduce(result,
                (acc, entry) -> (ValueSet<E>) acc.remove(entry),
                (a, b) -> a);
        return result;
    }

    @Override
    public ValueSet<E> retainAll(Set<? extends E> elements) {
        if ( this.isEmpty() )
            return this;
        ValueSet<E> result = this;
        if ( elements.isEmpty() )
            return clear();
        for ( E currentElement : this ) {
            if ( !elements.contains(currentElement) ) {
                result = result.remove(currentElement);
            }
        }
        return result;
    }

    private static <E> @Nullable Node _removeElement(
            Node node,
            Class<E> keyType,
            Comparator<E> keyComparator,
            E key
    ) {
        int numberOfKeys = _length(node.elementsArray());
        int index = _binarySearch(node.elementsArray(), keyType, keyComparator, key);
        if ( index < 0 ) {
            Node left = node.left();
            if ( left != null ) {
                Node newLeft = _balanceNullable(_removeElement(left, keyType, keyComparator, key));
                return node.withNewLeft(newLeft);
            }
            return node; // Key not found
        }
        if ( index >= numberOfKeys ) {
            Node right = node.right();
            if ( right != null ) {
                Node newRight = _balanceNullable(_removeElement(right, keyType, keyComparator, key));
                return node.withNewRight(newRight);
            }
            return node; // Key not found
        }
        boolean keyAlreadyExists = Objects.equals(key, _getAt(index, node.elementsArray(), keyType));
        if ( keyAlreadyExists ) {
            if ( numberOfKeys == 1 ) {
                Node left = node.left();
                Node right = node.right();
                if ( left == null || right == null ) {
                    if ( left != null ) {
                        return left;
                    }
                    if ( right != null ) {
                        return right;
                    }
                    return null;
                }
                Object newKeysArray = _createArray(keyType, ALLOWS_NULL, 1);
                int leftSize = left.size();
                int rightSize = right.size();
                // Only the root node is allowed to be empty, so we rebalance here
                if ( leftSize > rightSize ) {
                    E rightMostKey = _findRightMostElement(left, keyType);
                    _setAt(0, rightMostKey, newKeysArray);
                    left = _balanceNullable(_removeElement(left, keyType, keyComparator, rightMostKey));
                } else {
                    E leftMostKey = _findLeftMostElement(right, keyType);
                    _setAt(0, leftMostKey, newKeysArray);
                    right = _balanceNullable(_removeElement(right, keyType, keyComparator, leftMostKey));
                }
                return new Node(node._size - 1, newKeysArray, left, right);
            }
            // We found the key, we need to remove it
            Object newKeysArray = _createArray(keyType, ALLOWS_NULL, numberOfKeys-1);
            // arraycopy
            System.arraycopy(node.elementsArray(), 0, newKeysArray, 0, index);
            System.arraycopy(node.elementsArray(), index+1, newKeysArray, index, numberOfKeys-index-1);
            return node.withNewArrays(newKeysArray);
        }
        return node;
    }

    private static <E> E _findRightMostElement(Node node, Class<E> type) {
        if (node.right() != null) {
            return _findRightMostElement(node.right(), type);
        }
        int numberOfElements = _length(node.elementsArray());
        return _getNonNullAt(numberOfElements - 1, node.elementsArray(), type);
    }

    private static <E> E _findLeftMostElement(Node node, Class<E> type) {
        if (node.left() != null) {
            return _findLeftMostElement(node.left(), type);
        }
        return _getNonNullAt(0, node.elementsArray(), type);
    }

    @Override
    public ValueSet<E> clear() {
        return Sprouts.factory().valueSetOfSorted(this.type(), _comparator);
    }

    @Override
    public Spliterator<E> spliterator() {
        return Spliterators.spliterator(iterator(), _root.size(),
                Spliterator.SORTED |
                        Spliterator.ORDERED |
                        Spliterator.DISTINCT |
                        Spliterator.SIZED |
                        Spliterator.SUBSIZED |
                        Spliterator.NONNULL |
                        Spliterator.IMMUTABLE
        );
    }

    @Override
    public Iterator<E> iterator() {
        return new Iterator<E>() {
            private @Nullable IteratorFrame currentFrame = null;
            {
                if (_root.size() > 0)
                    currentFrame = new IteratorFrame(null, _root);
            }

            @Override
            public boolean hasNext() {
                while (currentFrame != null) {
                    if (currentFrame.stage == 0) {
                        currentFrame.stage = 1;
                        if (currentFrame.node.left() != null)
                            this.currentFrame = new IteratorFrame(currentFrame, currentFrame.node.left());
                    } else if (currentFrame.stage == 1) {
                        if (currentFrame.index < _length(currentFrame.node.elementsArray())) return true;
                        currentFrame.stage = 2;
                    } else if (currentFrame.stage == 2) {
                        currentFrame.stage = 3;
                        if (currentFrame.node.right() != null)
                            this.currentFrame = new IteratorFrame(currentFrame, currentFrame.node.right());
                    } else {
                        this.currentFrame = currentFrame.parent;
                    }
                }
                return false;
            }

            @Override
            public E next() {
                if (!hasNext() || currentFrame == null)
                    throw new NoSuchElementException();
                E element = _getNonNullAt(currentFrame.index, currentFrame.node.elementsArray());
                currentFrame.index++;
                return element;
            }
        };
    }

    static class IteratorFrame {
        final @Nullable IteratorFrame parent;
        final Node node;
        byte stage = 0;  // 0=left, 1=elements, 2=right, 3=done
        int index = 0;

        IteratorFrame(@Nullable IteratorFrame parent, Node n) {
            this.parent = parent;
            this.node = n;
        }
    }

    @Override
    public String toString() {
        final int MAX_ITEMS = 8;
        StringBuilder sb = new StringBuilder();
        sb.append("SortedValueSet<").append(_type.getSimpleName()).append(">[");
        Iterator<E> iterator = iterator();
        int count = 0;
        while (iterator.hasNext()) {
            if (count >= MAX_ITEMS) {
                int itemsLeft = _root.size() - count;
                sb.append("... ").append(itemsLeft).append(" items left");
                break;
            }
            E element = iterator.next();
            sb.append(_toString(element, _type));
            if (iterator.hasNext()) {
                sb.append(", ");
            }
            count++;
        }
        return sb.append("]").toString();
    }

    private static String _toString(@Nullable Object singleItem, Class<?> type) {
        if (singleItem == null) {
            return "null";
        } else if (type == String.class) {
            return "\"" + singleItem + "\"";
        } else if (type == Character.class) {
            return "'" + singleItem + "'";
        } else if (type == Boolean.class) {
            return singleItem.toString();
        } else {
            return singleItem.toString();
        }
    }

    @Override
    public boolean equals(@Nullable Object obj) {
        if (this == obj) return true;
        if (obj == null || getClass() != obj.getClass()) return false;
        SortedValueSetImpl<?> other = (SortedValueSetImpl<?>) obj;
        boolean headersEqual = Objects.equals(_type, other._type) && Objects.equals(_comparator, other._comparator);
        if (!headersEqual)
            return false;

        Iterator<E> thisIterator = iterator();
        Iterator<E> otherIterator = (Iterator<E>) other.iterator();
        while (thisIterator.hasNext() && otherIterator.hasNext()) {
            if (!Objects.equals(thisIterator.next(), otherIterator.next())) {
                return false;
            }
        }
        return !thisIterator.hasNext() && !otherIterator.hasNext();
    }

    @Override
    public int hashCode() {
        int headerHash = Objects.hash(_type, _comparator);
        int contentHash = 31;
        for (E element : this) {
            contentHash = 31 * contentHash + Objects.hash(element);
        }
        return 31 * headerHash + contentHash;
    }
}