AssociationImpl.java
package sprouts.impl;
import org.jspecify.annotations.Nullable;
import sprouts.Association;
import sprouts.Pair;
import sprouts.Tuple;
import java.lang.reflect.Array;
import java.util.*;
import java.util.stream.Stream;
import static sprouts.impl.ArrayUtil.*;
final class AssociationImpl<K, V> implements Association<K, V> {
private static final AssociationImpl[] EMPTY_BRANCHES = new AssociationImpl<?, ?>[0];
private static final boolean ALLOWS_NULL = false;
private static final long PRIME_1 = 12055296811267L;
private static final long PRIME_2 = 53982894593057L;
private static final int BASE_BRANCHING_PER_NODE = 32;
private static final int BASE_ENTRIES_PER_NODE = 0;
private final int _depth;
private final int _size;
private final Class<K> _keyType;
private final Object _keysArray;
private final Class<V> _valueType;
private final Object _valuesArray;
private final int[] _keyHashes;
private final AssociationImpl<K, V>[] _branches;
AssociationImpl(
final Class<K> keyType,
final Class<V> valueType
) {
this(
0, keyType,
_createArray(keyType, ALLOWS_NULL, 0),
valueType,
_createArray(valueType, ALLOWS_NULL, 0),
new int[0],
EMPTY_BRANCHES, true
);
}
public AssociationImpl(
final Class<K> keyType,
final Class<V> valueType,
final Stream<Pair<? extends K, ? extends V>> entries
) {
Map<K, V> uniqueEntries = new java.util.HashMap<>();
entries.forEach(entry -> {
if ( entry.first() == null || entry.second() == null ) {
throw new IllegalArgumentException("The given association may not contain null keys or values.");
}
// If the map already contains the key, we do not overwrite it
uniqueEntries.putIfAbsent(entry.first(), entry.second());
});
final Object[] keys = uniqueEntries.keySet().toArray();
final Object[] values = uniqueEntries.values().toArray();
final int size = keys.length;
Pair<Object,Object> localData = _fillNodeArrays(size, keyType, valueType, keys, values);
_keysArray = localData.first();
_valuesArray = localData.second();
_depth = 0;
_keyType = Objects.requireNonNull(keyType);
_valueType = Objects.requireNonNull(valueType);
_keyHashes = new int[size];
for (int i = 0; i < size; i++) {
_keyHashes[i] = Objects.requireNonNull(Array.get(_keysArray, i)).hashCode();
}
_branches = EMPTY_BRANCHES;
_size = size + _sumBranchSizes(_branches);
}
private AssociationImpl(
final int depth,
final Class<K> keyType,
final Object newKeysArray,
final Class<V> valueType,
final Object newValuesArray,
final int[] keyHashes,
final AssociationImpl<K, V>[] branches,
final boolean rebuild
) {
final int size = _length(newKeysArray);
if ( rebuild && size > 1 ) {
Pair<Object,Object> localData = _fillNodeArrays(size, keyType, valueType, newKeysArray, newValuesArray);
_keysArray = localData.first();
_valuesArray = localData.second();
} else {
_keysArray = newKeysArray;
_valuesArray = newValuesArray;
}
_depth = depth;
_keyType = Objects.requireNonNull(keyType);
_valueType = Objects.requireNonNull(valueType);
_branches = branches;
_size = size + _sumBranchSizes(_branches);
if ( keyHashes.length != size || rebuild ) {
_keyHashes = new int[size];
for (int i = 0; i < size; i++) {
_keyHashes[i] = Objects.requireNonNull(Array.get(_keysArray, i)).hashCode();
}
} else {
_keyHashes = keyHashes;
}
}
private static <K,V> Pair<Object,Object> _fillNodeArrays(
final int size,
final Class<K> keyType,
final Class<V> valueType,
final Object newKeysArray,
final Object newValuesArray
) {
Object keysArray = new Object[size];
Object valuesArray = new Object[size];
for (int i = 0; i < size; i++) {
K key = _getAt(i, newKeysArray, keyType);
V value = _getAt(i, newValuesArray, valueType);
Objects.requireNonNull(key);
Objects.requireNonNull(value);
int index = _findValidIndexFor(key, key.hashCode(), keysArray);
_setAt(index, key, keysArray);
_setAt(index, value, valuesArray);
}
return Pair.of(
_tryFlatten(keysArray, keyType, ALLOWS_NULL),
_tryFlatten(valuesArray, valueType, ALLOWS_NULL)
);
}
private static int _sumBranchSizes(AssociationImpl<?, ?>[] branches) {
int sum = 0;
for (AssociationImpl<?, ?> branch : branches) {
if ( branch != null ) {
sum += branch.size();
}
}
return sum;
}
private int _maxEntriesForThisNode() {
return BASE_ENTRIES_PER_NODE + (_depth * _depth);
}
private int _minBranchingPerNode() {
return BASE_BRANCHING_PER_NODE + _depth;
}
private AssociationImpl<K, V> _withBranchAt(
int index,
@Nullable AssociationImpl<K, V> branch
) {
AssociationImpl<K, V>[] newBranches = _branches.clone();
newBranches[index] = branch;
return new AssociationImpl<>(_depth, _keyType, _keysArray, _valueType, _valuesArray, _keyHashes, newBranches, false);
}
private int _findValidIndexFor(final K key, final int hash) {
int length = _keyHashes.length;
if ( length < 1 ) {
return -1;
}
int index = Math.abs(hash) % length;
int tries = 0;
while (!_isEqual(_keysArray, index, key, hash) && tries < length) {
index = ( index + 1 ) % length;
tries++;
}
if ( tries >= length ) {
return -1;
}
return index;
}
private boolean _isEqual(Object items, int index, Object key, int keyHash) {
if ( _keyHashes[index] != keyHash ) {
return false;
}
return key.equals(Array.get(items, index));
}
private static <K> int _findValidIndexFor(final K key, final int hash, final Object keys) {
int length = _length(keys);
if ( length < 1 ) {
return -1;
}
int index = Math.abs(hash) % length;
int tries = 0;
while (Array.get(keys, index) != null && !Objects.equals(Array.get(keys, index), key) && tries < length) {
index = ( index + 1 ) % length;
tries++;
}
return index;
}
private int _computeBranchIndex(int hash, int numberOfBranches) {
int localHash = Long.hashCode(PRIME_1 * (hash - PRIME_2 * (hash+_depth)));
return Math.abs(localHash) % numberOfBranches;
}
@Override
public int size() {
return _size;
}
@Override
public Class<K> keyType() {
return _keyType;
}
@Override
public Class<V> valueType() {
return _valueType;
}
@Override
public Set<K> keySet() {
Set<K> setOfKeys = new java.util.HashSet<>(_size);
populateKeySetRecursively(setOfKeys);
return java.util.Collections.unmodifiableSet(setOfKeys);
}
public void populateKeySetRecursively(Set<K> setOfKeys) {
for (int i = 0; i < _length(_keysArray); i++) {
K key = _getAt(i, _keysArray, _keyType);
setOfKeys.add(key);
}
for (AssociationImpl<K, V> branch : _branches) {
if (branch != null) {
branch.populateKeySetRecursively(setOfKeys);
}
}
}
@Override
public Tuple<V> values() {
if ( _branches.length == 0 ) {
return new TupleImpl<>(false, _valueType, _valuesArray, null);
} else {
List<V> values = new java.util.ArrayList<>(_length(_valuesArray));
_each(_valuesArray, _valueType, value -> {
if ( value != null ) {
values.add(value);
}
});
for (@Nullable AssociationImpl<K, V> branch : _branches) {
if ( branch != null ) {
values.addAll(branch.values().toList());
}
}
return Tuple.of(_valueType, values);
}
}
@Override
public Set<Pair<K, V>> entrySet() {
return new AbstractSet<Pair<K, V>>() {
@Override
public Iterator<Pair<K, V>> iterator() {
return AssociationImpl.this.iterator();
}
@Override
public int size() {
return _size;
}
@Override
public boolean contains(Object o) {
if (o instanceof Pair) {
Pair<?, ?> pair = (Pair<?, ?>) o;
K key = _keyType.cast(pair.first());
return AssociationImpl.this.containsKey(key);
}
return false;
}
};
}
@Override
public boolean containsKey(K key) {
if ( !_keyType.isAssignableFrom(key.getClass()) ) {
throw new IllegalArgumentException(
"The given key '" + key + "' is of type '" + key.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _keyType + "'."
);
}
return _get(key, key.hashCode()) != null;
}
@Override
public Optional<V> get( final K key ) {
if ( !_keyType.isAssignableFrom(key.getClass()) ) {
throw new IllegalArgumentException(
"The given key '" + key + "' is of type '" + key.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _keyType + "'."
);
}
return Optional.ofNullable(_get(key, key.hashCode()));
}
private @Nullable V _get( final K key, final int keyHash ) {
int index = _findValidIndexFor(key, keyHash);
if ( index < 0 ) {
if ( _branches.length > 0 ) {
int branchIndex = _computeBranchIndex(keyHash, _branches.length);
@Nullable AssociationImpl<K, V> branch = _branches[branchIndex];
if ( branch != null ) {
return branch._get(key, keyHash);
} else {
return null;
}
} else {
return null;
}
}
return _getAt(index, _valuesArray, _valueType);
}
@Override
public Association<K, V> put(final K key, final V value) {
if ( !_keyType.isAssignableFrom(key.getClass()) ) {
throw new IllegalArgumentException(
"The given key '" + key + "' is of type '" + key.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _keyType + "'."
);
}
if ( !_valueType.isAssignableFrom(value.getClass()) ) {
throw new IllegalArgumentException(
"The given value '" + value + "' is of type '" + value.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _valueType + "'."
);
}
return _with(key, key.hashCode(), value, false);
}
@Override
public Association<K, V> putIfAbsent(K key, V value) {
if ( !_keyType.isAssignableFrom(key.getClass()) ) {
throw new IllegalArgumentException(
"The given key '" + key + "' is of type '" + key.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _keyType + "'."
);
}
if ( !_valueType.isAssignableFrom(value.getClass()) ) {
throw new IllegalArgumentException(
"The given value '" + value + "' is of type '" + value.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _valueType + "'."
);
}
return _with(key, key.hashCode(), value, true);
}
public AssociationImpl<K, V> _with(final K key, final int keyHash, final V value, boolean putIfAbsent) {
int index = _findValidIndexFor(key, keyHash);
if ( index < 0 || index >= _length(_keysArray) ) {
if ( _length(_keysArray) < _maxEntriesForThisNode() ) {
return new AssociationImpl<>(
_depth,
_keyType,
_withAddAt(_length(_keysArray), key, _keysArray, _keyType, ALLOWS_NULL),
_valueType,
_withAddAt(_length(_valuesArray), value, _valuesArray, _valueType, ALLOWS_NULL),
_keyHashes,
_branches,
true
);
} else {
if ( _branches.length > 0 ) {
int branchIndex = _computeBranchIndex(keyHash, _branches.length);
@Nullable AssociationImpl<K, V> branch = _branches[branchIndex];
if (branch == null) {
Object newKeysArray = _createArray(_keyType, ALLOWS_NULL, 1);
_setAt(0, key, newKeysArray);
Object newValuesArray = _createArray(_valueType, ALLOWS_NULL, 1);
_setAt(0, value, newValuesArray);
return _withBranchAt(branchIndex, new AssociationImpl<>(_depth + 1, _keyType, newKeysArray, _valueType, newValuesArray, _keyHashes, EMPTY_BRANCHES, true));
} else {
AssociationImpl<K, V> newBranch = branch._with(key, keyHash, value, putIfAbsent);
if ( newBranch == branch ) {
return this;
} else {
return _withBranchAt(branchIndex, newBranch);
}
}
} else {
// We create two new branches for this node, this is where the tree grows
int newBranchSize = _minBranchingPerNode();
AssociationImpl<K, V>[] newBranches = new AssociationImpl[newBranchSize];
Object newKeysArray = _createArray(_keyType, ALLOWS_NULL, 1);
_setAt(0, key, newKeysArray);
Object newValuesArray = _createArray(_valueType, ALLOWS_NULL, 1);
_setAt(0, value, newValuesArray);
newBranches[_computeBranchIndex(keyHash, newBranchSize)] = new AssociationImpl<>(
_depth + 1, _keyType, newKeysArray, _valueType, newValuesArray, _keyHashes, EMPTY_BRANCHES, true
);
return new AssociationImpl<>(_depth, _keyType, _keysArray, _valueType, _valuesArray, _keyHashes, newBranches, false);
}
}
} else if ( Objects.equals(_getAt(index, _valuesArray, _valueType), value) ) {
return this;
} else if ( !putIfAbsent ) {
Object newValuesArray = _withSetAt(index, value, _valuesArray, _valueType, ALLOWS_NULL);
return new AssociationImpl<>(_depth, _keyType, _keysArray, _valueType, newValuesArray, _keyHashes, _branches, false);
}
return this;
}
@Override
public Association<K, V> remove( K key ) {
if ( !_keyType.isAssignableFrom(key.getClass()) ) {
throw new IllegalArgumentException(
"The given key '" + key + "' is of type '" + key.getClass().getSimpleName() + "', " +
"instead of the expected type '" + _keyType + "'."
);
}
return _without(key, key.hashCode());
}
private AssociationImpl<K, V> _without(final K key, final int keyHash) {
int index = _findValidIndexFor(key, keyHash);
if ( index < 0 ) {
if ( _branches.length == 0 ) {
return this;
} else {
int branchIndex = _computeBranchIndex(keyHash, _branches.length);
@Nullable AssociationImpl<K, V> branch = _branches[branchIndex];
if ( branch == null ) {
return this;
} else {
AssociationImpl<K, V> newBranch = branch._without(key, keyHash);
if ( newBranch == branch ) {
return this;
} else if ( newBranch._size == 0 ) {
// Maybe we can remove all branches now
int numberOfNonNullBranches = 0;
for (int i = 0; i < _branches.length; i++) {
if (_branches[i] != null && i != branchIndex) {
numberOfNonNullBranches++;
}
}
if ( numberOfNonNullBranches == 0 ) {
return new AssociationImpl<>(_depth, _keyType, _keysArray, _valueType, _valuesArray, _keyHashes, EMPTY_BRANCHES, false);
}
newBranch = null;
}
return _withBranchAt(branchIndex, newBranch);
}
}
} else {
Object newKeysArray = _withRemoveRange(index, index+1, _keysArray, _keyType, ALLOWS_NULL);
Object newValuesArray = _withRemoveRange(index, index+1, _valuesArray, _valueType, ALLOWS_NULL);
return new AssociationImpl<>(_depth, _keyType, newKeysArray, _valueType, newValuesArray, _keyHashes, _branches, true);
}
}
@Override
public Association<K, V> putAll( Stream<Pair<? extends K, ? extends V>> entries ) {
Objects.requireNonNull(entries);
// TODO: implement branching based bulk insert
AssociationImpl<K, V> result = this;
// reduce the stream to a single association
return entries.reduce(
result,
(acc,
entry) -> (AssociationImpl<K, V>) acc.put(entry.first(), entry.second()),
(a, b) -> a);
}
@Override
public Association<K, V> removeAll( Set<? extends K> keys ) {
if ( this.isEmpty() || keys.isEmpty() )
return this;
Association<K, V> result = this;
for ( K key : keys ) {
result = result.remove(key);
}
return result;
}
@Override
public Association<K, V> retainAll( Set<? extends K> keys ) {
if ( this.isEmpty() || keys.isEmpty() )
return this;
Association<K, V> result = this;
for ( K key : this.keySet() ) {
if ( !keys.contains(key) ) {
result = result.remove(key);
}
}
return result;
}
@Override
public Association<K, V> replace( K key, V value ) {
if ( this.containsKey(key) ) {
return this.put(key, value);
} else {
return this;
}
}
@Override
public Association<K, V> replaceAll( Stream<Pair<? extends K, ? extends V>> stream ) {
Objects.requireNonNull(stream);
Association<K, V> result = this;
// reduce the stream to a single association
return stream.reduce(
result,
(acc,
entry) -> acc.replace(entry.first(), entry.second()),
(a, b) -> a);
}
@Override
public Association<K, V> clear() {
return new AssociationImpl<>(_keyType, _valueType);
}
@Override
public Map<K, V> toMap() {
Map<K, V> map = new java.util.HashMap<>();
_toMapRecursively(map);
return Collections.unmodifiableMap(map);
}
private void _toMapRecursively( Map<K, V> map ) {
int size = _length(_keysArray);
for (int i = 0; i < size; i++) {
K key = _getAt(i, _keysArray, _keyType);
V value = _getAt(i, _valuesArray, _valueType);
map.put(key, value);
}
for (AssociationImpl<K, V> branch : _branches) {
if (branch != null) {
branch._toMapRecursively(map);
}
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Association<");
sb.append(_keyType.getSimpleName()).append(",");
sb.append(_valueType.getSimpleName()).append(">[");
final int howMany = 8;
sb = _appendRecursivelyUpTo(sb, howMany);
int numberOfEntriesLeft = _size - howMany;
if ( numberOfEntriesLeft > 0 ) {
sb.append(", ...").append(numberOfEntriesLeft).append(" more entries");
}
sb.append("]");
return sb.toString();
}
private StringBuilder _appendRecursivelyUpTo( StringBuilder sb, int size ) {
int howMany = Math.min(size, _length(_keysArray));
for (int i = 0; i < howMany; i++) {
K key = _getAt(i, _keysArray, _keyType);
V value = _getAt(i, _valuesArray, _valueType);
sb.append(_toString(key, _keyType)).append(" ↦ ").append(_toString(value, _valueType));
if ( i < howMany - 1 ) {
sb.append(", ");
}
}
int deltaLeft = size - howMany;
if ( deltaLeft > 0 ) {
for (AssociationImpl<K, V> branch : _branches) {
if ( branch != null ) {
if ( deltaLeft < size - howMany || howMany > 0 )
sb.append(", ");
sb = branch._appendRecursivelyUpTo(sb, deltaLeft);
deltaLeft -= branch.size();
if ( deltaLeft <= 0 ) {
break;
}
}
}
}
return sb;
}
private static String _toString( @Nullable Object singleItem, Class<?> type ) {
if ( singleItem == null ) {
return "null";
} else if ( type == String.class ) {
return "\"" + singleItem + "\"";
} else if ( type == Character.class ) {
return "'" + singleItem + "'";
} else if ( type == Boolean.class ) {
return singleItem.toString();
} else {
return singleItem.toString();
}
}
@Override
public boolean equals(Object obj) {
if ( obj == this ) {
return true;
}
if ( obj instanceof AssociationImpl) {
AssociationImpl<K, ?> other = (AssociationImpl) obj;
if ( other.size() != this.size() ) {
return false;
}
for ( K key : this.keySet() ) {
int keyHash = key.hashCode();
Object value = this._get(key, keyHash);
if ( !Objects.equals(value, other._get(key, keyHash)) ) {
return false;
}
}
return true;
}
return false;
}
@Override
public int hashCode() {
return Long.hashCode(_recursiveHashCode());
}
private long _recursiveHashCode() {
long baseHash = 0; // -> full 64 bit improve hash distribution
int size = _length(_keysArray);
for (int i = 0; i < size; i++) {
Object key = Array.get(_keysArray, i);
Object value = Array.get(_valuesArray, i);
baseHash += _fullKeyPairHash(key, value);
}
for (AssociationImpl<K, V> branch : _branches) {
if ( branch != null ) {
baseHash += branch._recursiveHashCode();
}
}
return baseHash;
}
private static long _fullKeyPairHash( Object key, Object value ) {
return _combine(key.hashCode(), value.hashCode());
}
private static long _combine( int first32Bits, int last32Bits ) {
return (long) first32Bits << 32 | (last32Bits & 0xFFFFFFFFL);
}
@Override
public Iterator<Pair<K, V>> iterator() {
return new Iterator<Pair<K, V>>() {
// A helper class to keep track of our position in a node.
class NodeState {
final AssociationImpl<K, V> node;
int arrayIndex; // Next index in the keys/values arrays
int branchIndex; // Next branch index to check
final int arrayLength; // Total entries in the node's arrays
final int branchesLength; // Total branches in the node
NodeState(AssociationImpl<K, V> node) {
this.node = node;
this.arrayIndex = 0;
this.branchIndex = 0;
this.arrayLength = _length(node._keysArray);
this.branchesLength = node._branches.length;
}
}
// Use a stack to perform depth-first traversal.
private final Deque<NodeState> stack = new ArrayDeque<>();
{
// Initialize with this node if there is at least one element.
if (_size > 0) {
stack.push(new NodeState(AssociationImpl.this));
}
}
@Override
public boolean hasNext() {
// Loop until we find a node state with an unvisited entry or the stack is empty.
while (!stack.isEmpty()) {
NodeState current = stack.peek();
// If there is a key-value pair left in the current node, we're done.
if (current.arrayIndex < current.arrayLength) {
return true;
}
// Otherwise, check for non-null branches to traverse.
if (current.branchIndex < current.branchesLength) {
// Look for the next branch.
while (current.branchIndex < current.branchesLength) {
AssociationImpl<K, V> branch = current.node._branches[current.branchIndex];
current.branchIndex++;
if (branch != null && branch._size > 0) {
// Found a non-empty branch: push its state on the stack.
stack.push(new NodeState(branch));
break;
}
}
// Continue the while loop: now the top of the stack may have entries.
continue;
}
// If no more entries or branches are left in the current node, pop it.
stack.pop();
}
return false;
}
@Override
public Pair<K, V> next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
NodeState current = stack.peek();
// Retrieve the key and value at the current position.
K key = _getAt(current.arrayIndex, current.node._keysArray, current.node._keyType);
V value = _getAt(current.arrayIndex, current.node._valuesArray, current.node._valueType);
Objects.requireNonNull(key);
Objects.requireNonNull(value);
current.arrayIndex++;
return Pair.of(key, value);
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
}