ValueSetImpl.java

package sprouts.impl;

import org.jspecify.annotations.Nullable;
import sprouts.Tuple;
import sprouts.ValueSet;

import java.lang.reflect.Array;
import java.util.*;
import java.util.stream.Stream;

import static sprouts.impl.ArrayUtil.*;

/**
 * Immutable, hash-based set implementation using Hash Array Mapped Trie (HAMT) structure.
 * <p>
 * This class provides an efficient, persistent set implementation with near-constant time complexity
 * for core operations (add/remove/contains) under ideal conditions. The implementation features:
 * <ul>
 *   <li>Persistent structural sharing for memory efficiency</li>
 *   <li>Progressive branching based on node depth</li>
 *   <li>Linear hashing with collision resolution in leaf nodes</li>
 *   <li>Depth-dependent branching factor optimization</li>
 *   <li>Recursive tree traversal for set operations</li>
 * </ul>
 *
 * <h2>Structure Overview</h2>
 * <p>Each node contains:
 * <ul>
 *   <li><b>Elements Array:</b> Contiguous storage for elements (size ≤ depth²)</li>
 *   <li><b>Branches Array:</b> Child nodes (size = 32 + depth)</li>
 *   <li><b>Hash Codes:</b> Cached hashes for fast comparison</li>
 * </ul>
 *
 * <h2>Key Implementation Details</h2>
 * <ul>
 *   <li><b>Branching:</b> Branch count per node grows with depth (min 32 branches)</li>
 *   <li><b>Node Capacity:</b> Leaf nodes hold up to {@code depth²} elements before branching</li>
 *   <li><b>Hash Distribution:</b> Uses twin prime multiplication for branch distribution</li>
 *   <li><b>Collision Handling:</b> Linear probing within element arrays</li>
 *   <li><b>Immutability:</b> All modifications return new instances with structural sharing</li>
 * </ul>
 *
 * <h2>Performance Characteristics</h2>
 * <table border="1">
 *   <tr><th>Operation</th><th>Average</th><th>Worst Case</th></tr>
 *   <tr><td>{@code add()}</td><td>O(1)</td><td>O(log~32 n)</td></tr>
 *   <tr><td>{@code remove()}</td><td>O(1)</td><td>O(log~32 n)</td></tr>
 *   <tr><td>{@code contains()}</td><td>O(1)</td><td>O(log~32 n)</td></tr>
 *   <tr><td>{@code iterator()}</td><td>O(n)</td><td>O(n)</td></tr>
 * </table>
 *
 * <h2>Technical Details</h2>
 * <ul>
 *   <li><b>Hash Computation:</b> Runs key cache code through prime-based transformation ({@code PRIME_1}, {@code PRIME_2}) to improve hash distribution</li>
 *   <li><b>Structural Sharing:</b> Branches are reused when possible during modification, only the path to the modification is recreated</li>
 *   <li><b>No Branch Handling:</b> Uses static empty branch reference ({@code EMPTY_BRANCHES}), instead of null for better code quality</li>
 *   <li><b>Iteration:</b> Depth-first traversal with stack-based state management using a custom stack frame</li>
 * </ul>
 *
 * @param <E> Type of elements maintained by this set
 * @see AssociationImpl
 * @see sprouts.ValueSet
 * @see sprouts.Tuple
 */
final class ValueSetImpl<E> implements ValueSet<E> {

    private static final Node[] EMPTY_BRANCHES = new Node<?>[0];
    private static final boolean ALLOWS_NULL = false;
    private static final long PRIME_1 = 12055296811267L;
    private static final long PRIME_2 = 53982894593057L;

    private static final int BASE_BRANCHING_PER_NODE = 32;
    private static final int BASE_ENTRIES_PER_NODE = 0;


    private final Class<E> _type;
    private final ArrayItemAccess<E, Object> _itemGetter;
    private final Node<E> _root;

    private static class Node<E> {

        private final int _depth;
        private final int _size;
        private final Object _elementsArray;
        private final int[] _elementsHashes;
        private final Node<E>[] _branches;

        Node(
            final int depth,
            final Class<E> type,
            final Object newElementsArray,
            final int[] keyHashes,
            final Node<E>[] branches,
            final boolean rebuild
        ) {
            final ArrayItemAccess<?, Object> itemGetter = ArrayItemAccess.of(type,false);
            final int size = _length(newElementsArray);
            if ( rebuild && size > 1 ) {
                _elementsArray = _fillNodeArrays(size, type, itemGetter, newElementsArray);
            } else {
                _elementsArray = newElementsArray;
            }
            _depth = depth;
            _branches = branches;
            _size = size + _sumBranchSizes(_branches);
            if ( keyHashes.length != size || rebuild ) {
                _elementsHashes = new int[size];
                for (int i = 0; i < size; i++) {
                    _elementsHashes[i] = Objects.requireNonNull(_getAt(i,_elementsArray)).hashCode();
                }
            } else {
                _elementsHashes = keyHashes;
            }
        }
    }

    ValueSetImpl(
        final Class<E> type
    ) {
        this(
            0, type,
            _createArray(type, ALLOWS_NULL, 0),
            new int[0],
            EMPTY_BRANCHES, true
        );
    }

    private ValueSetImpl(
        final int depth,
        final Class<E> type,
        final Object newElementsArray,
        final int[] keyHashes,
        final Node<E>[] branches,
        final boolean rebuild
    ) {
        this(
            Objects.requireNonNull(type),
            ArrayItemAccess.of(type, false),
            new Node<>(depth, type, newElementsArray, keyHashes, branches, rebuild)
        );
    }

    private ValueSetImpl(
        final Class<E> type,
        final ArrayItemAccess<E, Object> itemGetter,
        final Node<E> root
    ) {
        _type = Objects.requireNonNull(type);
        _itemGetter = itemGetter;
        _root = root;
    }

    private ValueSetImpl<E> _withNewRoot(Node<E> newRoot) {
        if ( newRoot == _root )
            return this;
        else
            return new ValueSetImpl<>(_type, _itemGetter, newRoot);
    }

    private static <K> Object _fillNodeArrays(
        final int size,
        final Class<K> type,
        final ArrayItemAccess<?, Object> itemGetter,
        final Object newElementsArray
    ) {
        Object elementsArray = new Object[size];
        for (int i = 0; i < size; i++) {
            K key = _getAt(i, newElementsArray, type);
            Objects.requireNonNull(key);
            int index = _findValidIndexFor(itemGetter, key, key.hashCode(), elementsArray);
            _setAt(index, key, elementsArray);
        }
        return _tryFlatten(elementsArray, type, ALLOWS_NULL);
    }

    private static int _sumBranchSizes( Node<?>[] branches) {
        int sum = 0;
        for (Node<?> branch : branches) {
            if ( branch != null ) {
                sum += branch._size;
            }
        }
        return sum;
    }

    private static int _maxEntriesForThisNode(Node<?> node) {
        return BASE_ENTRIES_PER_NODE + (node._depth * node._depth);
    }

    private static int _minBranchingPerNode(Node<?> node) {
        return BASE_BRANCHING_PER_NODE + node._depth;
    }

    private static <E> Node<E> _withBranchAt(
            Node<E> node,
            Class<E> type,
            int index,
            @Nullable Node<E> branch
    ) {
        Node<E>[] newBranches = node._branches.clone();
        newBranches[index] = branch;
        return new Node<>(node._depth, type, node._elementsArray, node._elementsHashes, newBranches, false);
    }

    private static <E> int _findValidIndexFor(
        final Node<E> node,
        final ArrayItemAccess<?, Object> itemGetter,
        final E key,
        final int hash
    ) {
        int length = node._elementsHashes.length;
        if ( length < 1 ) {
            return -1;
        }
        int index = Math.abs(hash) % length;
        int tries = 0;
        while (!_isEqual(node, itemGetter, node._elementsArray, index, key, hash) && tries < length) {
            index = ( index + 1 ) % length;
            tries++;
        }
        if ( tries >= length ) {
            return -1;
        }
        return index;
    }

    private static boolean _isEqual(
        final Node<?> node,
        final ArrayItemAccess<?, Object> itemGetter,
        final Object items,
        final int index,
        final Object key,
        final int keyHash
    ) {
        if ( node._elementsHashes[index] != keyHash ) {
            return false;
        }
        return key.equals(itemGetter.get(index, items));
    }

    private static <K> int _findValidIndexFor(
        final ArrayItemAccess<?, Object> itemGetter,
        final K key,
        final int hash,
        final Object elements
    ) {
        int length = _length(elements);
        if ( length < 1 ) {
            return -1;
        }
        int index = Math.abs(hash) % length;
        int tries = 0;
        while (itemGetter.get(index, elements) != null && !Objects.equals(itemGetter.get(index, elements), key) && tries < length) {
            index = ( index + 1 ) % length;
            tries++;
        }
        return index;
    }

    private static <E> int _computeBranchIndex(
        final Node<E> node,
        final int hash,
        final int numberOfBranches
    ) {
        int localHash = Long.hashCode(PRIME_1 * (hash - PRIME_2 * (hash+node._depth)));
        return Math.abs(localHash) % numberOfBranches;
    }

    @Override
    public int size() {
        return _root._size;
    }

    @Override
    public boolean isLinked() {
        return false;
    }

    @Override
    public boolean isSorted() {
        return false;
    }

    @Override
    public Class<E> type() {
        return _type;
    }

    @Override
    public Tuple<E> toTuple() {
        return _toTuple(_root, _type);
    }

    private static <E> Tuple<E> _toTuple(Node<E> node, Class<E> type) {
        if ( node._branches.length == 0 ) {
            return new TupleWithDiff<>(TupleTree.ofRaw(false, type, node._elementsArray), null);
        } else {
            List<E> values = new ArrayList<>(_length(node._elementsArray));
            _each(node._elementsArray, type, value -> {
                if ( value != null ) {
                    values.add(value);
                }
            });
            for (@Nullable Node<E> branch : node._branches) {
                if ( branch != null ) {
                    values.addAll(_toTuple(branch, type).toList());
                }
            }
            return Tuple.of(type, values);
        }
    }

    @Override
    public boolean contains( final E element ) {
        if ( !_type.isAssignableFrom(element.getClass()) ) {
            throw new IllegalArgumentException(
                    "The provided element '" + element + "' is of type '" + element.getClass().getSimpleName() + "', " +
                    "instead of the expected type '" + _type + "'."
                );
        }
        return _contains(_root, _itemGetter, element, element.hashCode());
    }

    private static <E> boolean _contains(
        final Node<E> node,
        final ArrayItemAccess<?, Object> itemGetter,
        final E element,
        final int elementHash
    ) {
        Node<E>[] branches = node._branches;
        int index = _findValidIndexFor(node, itemGetter, element, elementHash);
        if ( index < 0 ) {
            if ( branches.length > 0 ) {
                int branchIndex = _computeBranchIndex(node, elementHash, branches.length);
                @Nullable Node<E> branch = branches[branchIndex];
                if ( branch != null ) {
                    return _contains(branch, itemGetter, element, elementHash);
                } else {
                    return false;
                }
            } else {
                return false;
            }
        }
        return true;
    }

    @Override
    public ValueSet<E> add( final E element ) {
        if ( !_type.isAssignableFrom(element.getClass()) ) {
            throw new IllegalArgumentException(
                    "The supplied element '" + element + "' is of type '" + element.getClass().getSimpleName() + "', " +
                    "instead of the expected type '" + _type + "'."
                );
        }
        return _withNewRoot(_with(_root, _type, _itemGetter, element, element.hashCode()));
    }

    private static <E> Node<E> _with(
        final Node<E> node,
        final Class<E> type,
        final ArrayItemAccess<?, Object> itemGetter,
        final E key,
        final int keyHash
    ) {
        int depth = node._depth;
        Object elementsArray = node._elementsArray;
        int[] elementsHashes = node._elementsHashes;
        Node<E>[] branches = node._branches;
        int index = _findValidIndexFor(node, itemGetter, key, keyHash);
        if ( index < 0 || index >= _length(elementsArray) ) {
            if ( branches.length == 0 && _length(elementsArray) < _maxEntriesForThisNode(node) ) {
                return new Node<>(
                        depth,
                        type,
                        _withAddAt(_length(elementsArray), key, elementsArray, type, ALLOWS_NULL),
                        elementsHashes,
                        branches,
                        true
                );
            } else {
                if ( branches.length > 0 ) {
                    int branchIndex = _computeBranchIndex(node, keyHash, branches.length);
                    @Nullable Node<E> branch = branches[branchIndex];
                    if (branch == null) {
                        Object newElementsArray = _createArray(type, ALLOWS_NULL, 1);
                        _setAt(0, key, newElementsArray);
                        return _withBranchAt(node, type, branchIndex, new Node<>(depth + 1, type, newElementsArray, elementsHashes, EMPTY_BRANCHES, true));
                    } else {
                        Node<E> newBranch = _with(branch, type, itemGetter, key, keyHash);
                        if ( Util.refEquals(newBranch, branch) ) {
                            return node;
                        } else {
                            return _withBranchAt(node, type, branchIndex, newBranch);
                        }
                    }
                } else {
                    // We create two new branches for this node, this is where the tree grows
                    int newBranchSize = _minBranchingPerNode(node);
                    Node<E>[] newBranches = new Node[newBranchSize];
                    Object newElementsArray = _createArray(type, ALLOWS_NULL, 1);
                    _setAt(0, key, newElementsArray);
                    newBranches[_computeBranchIndex(node, keyHash, newBranchSize)] = new Node<>(
                            depth + 1, type, newElementsArray, elementsHashes, EMPTY_BRANCHES, true
                    );
                    return new Node<>(depth, type, elementsArray, elementsHashes, newBranches, false);
                }
            }
        }
        return node;
    }

    @Override
    public ValueSet<E> remove( final E element ) {
        if ( !_type.isAssignableFrom(element.getClass()) ) {
            throw new IllegalArgumentException(
                    "The supplied element '" + element + "' is of type '" + element.getClass().getSimpleName() + "', " +
                    "instead of the expected type '" + _type + "'."
                );
        }
        return _withNewRoot(_without(_root, _type, _itemGetter, element, element.hashCode()));
    }

    private static <E> Node<E> _without(
        final Node<E> node,
        final Class<E> type,
        final ArrayItemAccess<?, Object> itemGetter,
        final E key,
        final int keyHash
    ) {
        int depth = node._depth;
        Object elementsArray = node._elementsArray;
        int[] elementsHashes = node._elementsHashes;
        Node<E>[] branches = node._branches;
        int index = _findValidIndexFor(node, itemGetter, key, keyHash);
        if ( index < 0 ) {
            if ( branches.length == 0 ) {
                return node;
            } else {
                int branchIndex = _computeBranchIndex(node, keyHash, branches.length);
                @Nullable Node<E> branch = branches[branchIndex];
                if ( branch == null ) {
                    return node;
                } else {
                    Node<E> newBranch = _without(branch, type, itemGetter, key, keyHash);
                    if ( Util.refEquals(newBranch, branch) ) {
                        return node;
                    } else if ( newBranch._size == 0 ) {
                        // Maybe we can remove all branches now
                        int numberOfNonNullBranches = 0;
                        for (int i = 0; i < branches.length; i++) {
                            if (branches[i] != null && i != branchIndex) {
                                numberOfNonNullBranches++;
                            }
                        }
                        if ( numberOfNonNullBranches == 0 ) {
                            return new Node<>(depth, type, elementsArray, elementsHashes, EMPTY_BRANCHES, false);
                        }
                        newBranch = null;
                    }
                    return _withBranchAt(node, type, branchIndex, newBranch);
                }
            }
        } else {
            Object newElementsArray = _withRemoveRange(index, index+1, elementsArray, type, ALLOWS_NULL);
            return new Node<>(depth, type, newElementsArray, elementsHashes, branches, true);
        }
    }

    @Override
    public ValueSet<E> addAll( Stream<? extends E> entries ) {
        Objects.requireNonNull(entries);
        // TODO: implement branching based bulk insert
        ValueSetImpl<E> result = this;
        // reduce the stream to a single association
        return entries.reduce(
                result,
                (acc,
                 entry) -> (ValueSetImpl<E>) acc.add(entry),
                (a, b) -> a);
    }

    @Override
    public ValueSet<E> removeAll( Stream<? extends E> elements ) {
        if ( this.isEmpty() )
            return this;
         ValueSet<E> result = this;
         result = elements.reduce(result,
                                    (acc, entry) -> (ValueSet<E>) acc.remove(entry),
                                    (a, b) -> a);
        return result;
    }

    @Override
    public ValueSet<E> retainAll( Set<? extends E> elements ) {
        if ( this.isEmpty() )
            return this;
        ValueSet<E> result = this;
        if ( elements.isEmpty() )
            return clear();
        for ( E currentElement : this ) {
            if ( !elements.contains(currentElement) ) {
                result = result.remove(currentElement);
            }
        }
        return result;
    }

    @Override
    public ValueSet<E> clear() {
        return Sprouts.factory().valueSetOf(this.type());
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("ValueSet<").append(_type.getSimpleName()).append(">[");
        final int howMany = 8;
        sb = _appendRecursivelyUpTo(_root, _type, _itemGetter, sb, howMany);
        int numberOfElementsLeft = _root._size - howMany;
        if ( numberOfElementsLeft > 0 ) {
            sb.append(", ... ").append(numberOfElementsLeft).append(" items left");
        }
        sb.append("]");
        return sb.toString();
    }

    private static <E> StringBuilder _appendRecursivelyUpTo(
        final Node<E> node,
        final Class<E>  type,
        final ArrayItemAccess<E, Object> access,
        StringBuilder sb,
        final int size
    ) {
        int howMany = Math.min(size, _length(node._elementsArray));
        for (int i = 0; i < howMany; i++) {
            E key = access.get(i, node._elementsArray);
            sb.append(Util._toString(key, type));
            if ( i < howMany - 1 ) {
                sb.append(", ");
            }
        }
        int deltaLeft = size - howMany;
        if ( deltaLeft > 0 ) {
            for (Node<E> branch : node._branches) {
                if ( branch != null ) {
                    if ( deltaLeft < size - howMany || howMany > 0 )
                        sb.append(", ");
                    sb = _appendRecursivelyUpTo(branch, type, access, sb, deltaLeft);
                    deltaLeft -= branch._size;
                    if ( deltaLeft <= 0 ) {
                        break;
                    }
                }
            }
        }
        return sb;
    }

    @Override
    public boolean equals(Object obj) {
        if ( obj == this ) {
            return true;
        }
        if ( obj instanceof ValueSetImpl) {
             ValueSetImpl<E> other = (ValueSetImpl) obj;
            if ( other.size() != this.size() ) {
                return false;
            }
            for ( E key : this ) {
                int keyHash = key.hashCode();
                Object value = _contains(_root, _itemGetter, key, keyHash);
                if ( !Objects.equals(value, _contains(other._root, other._itemGetter, key, keyHash)) ) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    @Override
    public int hashCode() {
        return Long.hashCode(_recursiveHashCode(_root));
    }

    private static <E> long _recursiveHashCode(Node<E> node) {
        long baseHash = 0; // -> full 64 bit improve hash distribution
        for ( int elementsHash : node._elementsHashes ) {
            baseHash += elementsHash * PRIME_1; // -> we try to expand to all 64 bits in the long
        }
        for (Node<E> branch : node._branches) {
            if ( branch != null ) {
                baseHash += _recursiveHashCode(branch);
            }
        }
        return baseHash;
    }

    // A helper class to keep track of our position in a node.
    static final class IteratorFrame<E> {
        final @Nullable IteratorFrame<E> parent;
        final Node<E> node;
        final int arrayLength;   // Total entries in the node's arrays
        final int branchesLength; // Total branches in the node
        int arrayIndex;    // Next index in the elements/values arrays
        int branchIndex;   // Next branch index to check

        IteratorFrame(@Nullable IteratorFrame<E> parent, Node<E> node) {
            this.parent = parent;
            this.node = node;
            this.arrayLength = _length(node._elementsArray);
            this.branchesLength = node._branches.length;
            this.arrayIndex = 0;
            this.branchIndex = 0;
        }
    }

    @Override
    public Spliterator<E> spliterator() {
        return Spliterators.spliterator(iterator(), _root._size,
                Spliterator.DISTINCT |
                Spliterator.SIZED    |
                Spliterator.SUBSIZED |
                Spliterator.NONNULL  |
                Spliterator.IMMUTABLE
        );
    }

    @Override
    public Iterator<E> iterator() {
        return new ValueSetIterator<>(_root, ArrayItemAccess.of(_type, false));
    }


    private static final class ValueSetIterator<E> implements Iterator<E>
    {
        private final ArrayItemAccess<E,Object> _elementGetter;
        // Use a stack to perform depth-first traversal.
        private @Nullable IteratorFrame<E> currentFrame = null;

        ValueSetIterator(Node<E> node, ArrayItemAccess<E, Object> elementGetter) {
            _elementGetter = elementGetter;
            // Initialize with this node if there is at least one element.
            if (node._size > 0) {
                currentFrame = new IteratorFrame<>(null, node);
            }
        }

        @Override
        public boolean hasNext() {
            // Loop until we find a node state with an unvisited entry or the stack is empty.
            while ( currentFrame != null ) {
                // If there is a key-value pair left in the current node, we're done.
                if (currentFrame.arrayIndex < currentFrame.arrayLength) {
                    return true;
                }

                // Otherwise, check for non-null branches to traverse.
                if (currentFrame.branchIndex < currentFrame.branchesLength) {
                    // Look for the next branch.
                    while (currentFrame.branchIndex < currentFrame.branchesLength) {
                        Node<E> branch = currentFrame.node._branches[currentFrame.branchIndex];
                        currentFrame.branchIndex++;
                        if (branch != null && branch._size > 0) {
                            // Found a non-empty branch: push its state on the stack.
                            currentFrame = new IteratorFrame(currentFrame, branch);
                            break;
                        }
                    }
                    // Continue the while loop: now the top of the stack may have entries.
                    continue;
                }

                // If no more entries or branches are left in the current node, pop it.
                currentFrame = currentFrame.parent;
            }
            return false;
        }

        @Override
        public E next() {
            if (!hasNext() || currentFrame == null) {
                throw new NoSuchElementException();
            }
            // Retrieve the key and value at the current position.
            E key = _elementGetter.get(currentFrame.arrayIndex, currentFrame.node._elementsArray);
            currentFrame.arrayIndex++;
            return Objects.requireNonNull(key);
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

}