aboutsummaryrefslogtreecommitdiff
path: root/immutable.go
diff options
context:
space:
mode:
Diffstat (limited to 'immutable.go')
-rw-r--r--immutable.go1105
1 files changed, 382 insertions, 723 deletions
diff --git a/immutable.go b/immutable.go
index 1ec851e..9666fa6 100644
--- a/immutable.go
+++ b/immutable.go
@@ -42,12 +42,13 @@
package immutable
import (
- "bytes"
"fmt"
"math/bits"
"reflect"
"sort"
"strings"
+
+ "golang.org/x/exp/constraints"
)
// List is a dense, ordered, indexed collections. They are analogous to slices
@@ -685,28 +686,28 @@ const (
// to generate hashes and check for equality of key values.
//
// It is implemented as an Hash Array Mapped Trie.
-type Map struct {
- size int // total number of key/value pairs
- root mapNode // root node of trie
- hasher Hasher // hasher implementation
+type Map[K constraints.Ordered, V any] struct {
+ size int // total number of key/value pairs
+ root mapNode[K, V] // root node of trie
+ hasher Hasher[K] // hasher implementation
}
// NewMap returns a new instance of Map. If hasher is nil, a default hasher
// implementation will automatically be chosen based on the first key added.
// Default hasher implementations only exist for int, string, and byte slice types.
-func NewMap(hasher Hasher) *Map {
- return &Map{
+func NewMap[K constraints.Ordered, V any](hasher Hasher[K]) *Map[K, V] {
+ return &Map[K, V]{
hasher: hasher,
}
}
// Len returns the number of elements in the map.
-func (m *Map) Len() int {
+func (m *Map[K, V]) Len() int {
return m.size
}
// clone returns a shallow copy of m.
-func (m *Map) clone() *Map {
+func (m *Map[K, V]) clone() *Map[K, V] {
other := *m
return &other
}
@@ -714,9 +715,10 @@ func (m *Map) clone() *Map {
// Get returns the value for a given key and a flag indicating whether the
// key exists. This flag distinguishes a nil value set on a key versus a
// non-existent key in the map.
-func (m *Map) Get(key interface{}) (value interface{}, ok bool) {
+func (m *Map[K, V]) Get(key K) (value V, ok bool) {
+ var empty V
if m.root == nil {
- return nil, false
+ return empty, false
}
keyHash := m.hasher.Hash(key)
return m.root.get(key, 0, keyHash, m.hasher)
@@ -726,11 +728,11 @@ func (m *Map) Get(key interface{}) (value interface{}, ok bool) {
//
// This function will return a new map even if the updated value is the same as
// the existing value because Map does not track value equality.
-func (m *Map) Set(key, value interface{}) *Map {
+func (m *Map[K, V]) Set(key K, value V) *Map[K, V] {
return m.set(key, value, false)
}
-func (m *Map) set(key, value interface{}, mutable bool) *Map {
+func (m *Map[K, V]) set(key K, value V, mutable bool) *Map[K, V] {
// Set a hasher on the first value if one does not already exist.
hasher := m.hasher
if hasher == nil {
@@ -747,7 +749,7 @@ func (m *Map) set(key, value interface{}, mutable bool) *Map {
// If the map is empty, initialize with a simple array node.
if m.root == nil {
other.size = 1
- other.root = &mapArrayNode{entries: []mapEntry{{key: key, value: value}}}
+ other.root = &mapArrayNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}}
return other
}
@@ -763,11 +765,11 @@ func (m *Map) set(key, value interface{}, mutable bool) *Map {
// Delete returns a map with the given key removed.
// Removing a non-existent key will cause this method to return the same map.
-func (m *Map) Delete(key interface{}) *Map {
+func (m *Map[K, V]) Delete(key K) *Map[K, V] {
return m.delete(key, false)
}
-func (m *Map) delete(key interface{}, mutable bool) *Map {
+func (m *Map[K, V]) delete(key K, mutable bool) *Map[K, V] {
// Return original map if no keys exist.
if m.root == nil {
return m
@@ -793,25 +795,25 @@ func (m *Map) delete(key interface{}, mutable bool) *Map {
}
// Iterator returns a new iterator for the map.
-func (m *Map) Iterator() *MapIterator {
- itr := &MapIterator{m: m}
+func (m *Map[K, V]) Iterator() *MapIterator[K, V] {
+ itr := &MapIterator[K, V]{m: m}
itr.First()
return itr
}
// MapBuilder represents an efficient builder for creating Maps.
-type MapBuilder struct {
- m *Map // current state
+type MapBuilder[K constraints.Ordered, V any] struct {
+ m *Map[K, V] // current state
}
// NewMapBuilder returns a new instance of MapBuilder.
-func NewMapBuilder(hasher Hasher) *MapBuilder {
- return &MapBuilder{m: NewMap(hasher)}
+func NewMapBuilder[K constraints.Ordered, V any](hasher Hasher[K]) *MapBuilder[K, V] {
+ return &MapBuilder[K, V]{m: NewMap[K, V](hasher)}
}
// Map returns the underlying map. Only call once.
// Builder is invalid after call. Will panic on second invocation.
-func (b *MapBuilder) Map() *Map {
+func (b *MapBuilder[K, V]) Map() *Map[K, V] {
assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map")
m := b.m
b.m = nil
@@ -819,66 +821,66 @@ func (b *MapBuilder) Map() *Map {
}
// Len returns the number of elements in the underlying map.
-func (b *MapBuilder) Len() int {
+func (b *MapBuilder[K, V]) Len() int {
assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
return b.m.Len()
}
// Get returns the value for the given key.
-func (b *MapBuilder) Get(key interface{}) (value interface{}, ok bool) {
+func (b *MapBuilder[K, V]) Get(key K) (value V, ok bool) {
assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
return b.m.Get(key)
}
// Set sets the value of the given key. See Map.Set() for additional details.
-func (b *MapBuilder) Set(key, value interface{}) {
+func (b *MapBuilder[K, V]) Set(key K, value V) {
assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
b.m = b.m.set(key, value, true)
}
// Delete removes the given key. See Map.Delete() for additional details.
-func (b *MapBuilder) Delete(key interface{}) {
+func (b *MapBuilder[K, V]) Delete(key K) {
assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
b.m = b.m.delete(key, true)
}
// Iterator returns a new iterator for the underlying map.
-func (b *MapBuilder) Iterator() *MapIterator {
+func (b *MapBuilder[K, V]) Iterator() *MapIterator[K, V] {
assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
return b.m.Iterator()
}
// mapNode represents any node in the map tree.
-type mapNode interface {
- get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool)
- set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode
- delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode
+type mapNode[K constraints.Ordered, V any] interface {
+ get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool)
+ set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V]
+ delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V]
}
-var _ mapNode = (*mapArrayNode)(nil)
-var _ mapNode = (*mapBitmapIndexedNode)(nil)
-var _ mapNode = (*mapHashArrayNode)(nil)
-var _ mapNode = (*mapValueNode)(nil)
-var _ mapNode = (*mapHashCollisionNode)(nil)
+var _ mapNode[string, any] = (*mapArrayNode[string, any])(nil)
+var _ mapNode[string, any] = (*mapBitmapIndexedNode[string, any])(nil)
+var _ mapNode[string, any] = (*mapHashArrayNode[string, any])(nil)
+var _ mapNode[string, any] = (*mapValueNode[string, any])(nil)
+var _ mapNode[string, any] = (*mapHashCollisionNode[string, any])(nil)
// mapLeafNode represents a node that stores a single key hash at the leaf of the map tree.
-type mapLeafNode interface {
- mapNode
+type mapLeafNode[K constraints.Ordered, V any] interface {
+ mapNode[K, V]
keyHashValue() uint32
}
-var _ mapLeafNode = (*mapValueNode)(nil)
-var _ mapLeafNode = (*mapHashCollisionNode)(nil)
+var _ mapLeafNode[string, any] = (*mapValueNode[string, any])(nil)
+var _ mapLeafNode[string, any] = (*mapHashCollisionNode[string, any])(nil)
// mapArrayNode is a map node that stores key/value pairs in a slice.
// Entries are stored in insertion order. An array node expands into a bitmap
// indexed node once a given threshold size is crossed.
-type mapArrayNode struct {
- entries []mapEntry
+type mapArrayNode[K constraints.Ordered, V any] struct {
+ entries []mapEntry[K, V]
}
// indexOf returns the entry index of the given key. Returns -1 if key not found.
-func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int {
+func (n *mapArrayNode[K, V]) indexOf(key K, h Hasher[K]) int {
for i := range n.entries {
if h.Equal(n.entries[i].key, key) {
return i
@@ -888,17 +890,17 @@ func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int {
}
// get returns the value for the given key.
-func (n *mapArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+func (n *mapArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
i := n.indexOf(key, h)
if i == -1 {
- return nil, false
+ return value, false
}
return n.entries[i].value, true
}
// set inserts or updates the value for a given key. If the key is inserted and
// the new size crosses the max size threshold, a bitmap indexed node is returned.
-func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
idx := n.indexOf(key, h)
// Mark as resized if the key doesn't exist.
@@ -909,7 +911,7 @@ func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h
// If we are adding and it crosses the max size threshold, expand the node.
// We do this by continually setting the entries to a value node and expanding.
if idx == -1 && len(n.entries) >= maxArrayMapSize {
- var node mapNode = newMapValueNode(h.Hash(key), key, value)
+ var node mapNode[K, V] = newMapValueNode(h.Hash(key), key, value)
for _, entry := range n.entries {
node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, false, resized)
}
@@ -919,31 +921,31 @@ func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h
// Update in-place if mutable.
if mutable {
if idx != -1 {
- n.entries[idx] = mapEntry{key, value}
+ n.entries[idx] = mapEntry[K, V]{key, value}
} else {
- n.entries = append(n.entries, mapEntry{key, value})
+ n.entries = append(n.entries, mapEntry[K, V]{key, value})
}
return n
}
// Update existing entry if a match is found.
// Otherwise append to the end of the element list if it doesn't exist.
- var other mapArrayNode
+ var other mapArrayNode[K, V]
if idx != -1 {
- other.entries = make([]mapEntry, len(n.entries))
+ other.entries = make([]mapEntry[K, V], len(n.entries))
copy(other.entries, n.entries)
- other.entries[idx] = mapEntry{key, value}
+ other.entries[idx] = mapEntry[K, V]{key, value}
} else {
- other.entries = make([]mapEntry, len(n.entries)+1)
+ other.entries = make([]mapEntry[K, V], len(n.entries)+1)
copy(other.entries, n.entries)
- other.entries[len(other.entries)-1] = mapEntry{key, value}
+ other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value}
}
return &other
}
// delete removes the given key from the node. Returns the same node if key does
// not exist. Returns a nil node when removing the last entry.
-func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
idx := n.indexOf(key, h)
// Return original node if key does not exist.
@@ -960,13 +962,13 @@ func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Has
// Update in-place, if mutable.
if mutable {
copy(n.entries[idx:], n.entries[idx+1:])
- n.entries[len(n.entries)-1] = mapEntry{}
+ n.entries[len(n.entries)-1] = mapEntry[K, V]{}
n.entries = n.entries[:len(n.entries)-1]
return n
}
// Otherwise create a copy with the given entry removed.
- other := &mapArrayNode{entries: make([]mapEntry, len(n.entries)-1)}
+ other := &mapArrayNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)}
copy(other.entries[:idx], n.entries[:idx])
copy(other.entries[idx:], n.entries[idx+1:])
return other
@@ -975,16 +977,16 @@ func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Has
// mapBitmapIndexedNode represents a map branch node with a variable number of
// node slots and indexed using a bitmap. Indexes for the node slots are
// calculated by counting the number of set bits before the target bit using popcount.
-type mapBitmapIndexedNode struct {
+type mapBitmapIndexedNode[K constraints.Ordered, V any] struct {
bitmap uint32
- nodes []mapNode
+ nodes []mapNode[K, V]
}
// get returns the value for the given key.
-func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+func (n *mapBitmapIndexedNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
if (n.bitmap & bit) == 0 {
- return nil, false
+ return value, false
}
child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))]
return child.get(key, shift+mapNodeBits, keyHash, h)
@@ -992,7 +994,7 @@ func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32,
// set inserts or updates the value for the given key. If a new key is inserted
// and the size crosses the max size threshold then a hash array node is returned.
-func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapBitmapIndexedNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
// Extract the index for the bit segment of the key hash.
keyHashFrag := (keyHash >> shift) & mapNodeMask
@@ -1010,17 +1012,17 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u
// If the node already exists, delegate set operation to it.
// If the node doesn't exist then create a simple value leaf node.
- var newNode mapNode
+ var newNode mapNode[K, V]
if exists {
newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized)
} else {
- newNode = newMapValueNode(keyHash, key, value)
+ newNode = newMapValueNode[K, V](keyHash, key, value)
}
// Convert to a hash-array node once we exceed the max bitmap size.
// Copy each node based on their bit position within the bitmap.
if !exists && len(n.nodes) > maxBitmapIndexedSize {
- var other mapHashArrayNode
+ var other mapHashArrayNode[K, V]
for i := uint(0); i < uint(len(other.nodes)); i++ {
if n.bitmap&(uint32(1)<<i) != 0 {
other.nodes[i] = n.nodes[other.count]
@@ -1047,13 +1049,13 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u
// If node exists at given slot then overwrite it with new node.
// Otherwise expand the node list and insert new node into appropriate position.
- other := &mapBitmapIndexedNode{bitmap: n.bitmap | bit}
+ other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap | bit}
if exists {
- other.nodes = make([]mapNode, len(n.nodes))
+ other.nodes = make([]mapNode[K, V], len(n.nodes))
copy(other.nodes, n.nodes)
other.nodes[idx] = newNode
} else {
- other.nodes = make([]mapNode, len(n.nodes)+1)
+ other.nodes = make([]mapNode[K, V], len(n.nodes)+1)
copy(other.nodes, n.nodes[:idx])
other.nodes[idx] = newNode
copy(other.nodes[idx+1:], n.nodes[idx:])
@@ -1064,7 +1066,7 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u
// delete removes the key from the tree. If the key does not exist then the
// original node is returned. If removing the last child node then a nil is
// returned. Note that shrinking the node will not convert it to an array node.
-func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapBitmapIndexedNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
// Return original node if key does not exist.
@@ -1101,7 +1103,7 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3
}
// Return copy with bit removed from bitmap and node removed from node list.
- other := &mapBitmapIndexedNode{bitmap: n.bitmap ^ bit, nodes: make([]mapNode, len(n.nodes)-1)}
+ other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap ^ bit, nodes: make([]mapNode[K, V], len(n.nodes)-1)}
copy(other.nodes[:idx], n.nodes[:idx])
copy(other.nodes[idx:], n.nodes[idx+1:])
return other
@@ -1110,7 +1112,7 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3
// Generate copy, if necessary.
other := n
if !mutable {
- other = &mapBitmapIndexedNode{bitmap: n.bitmap, nodes: make([]mapNode, len(n.nodes))}
+ other = &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap, nodes: make([]mapNode[K, V], len(n.nodes))}
copy(other.nodes, n.nodes)
}
@@ -1121,34 +1123,34 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3
// mapHashArrayNode is a map branch node that stores nodes in a fixed length
// array. Child nodes are indexed by their index bit segment for the current depth.
-type mapHashArrayNode struct {
- count uint // number of set nodes
- nodes [mapNodeSize]mapNode // child node slots, may contain empties
+type mapHashArrayNode[K constraints.Ordered, V any] struct {
+ count uint // number of set nodes
+ nodes [mapNodeSize]mapNode[K, V] // child node slots, may contain empties
}
// clone returns a shallow copy of n.
-func (n *mapHashArrayNode) clone() *mapHashArrayNode {
+func (n *mapHashArrayNode[K, V]) clone() *mapHashArrayNode[K, V] {
other := *n
return &other
}
// get returns the value for the given key.
-func (n *mapHashArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+func (n *mapHashArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
node := n.nodes[(keyHash>>shift)&mapNodeMask]
if node == nil {
- return nil, false
+ return value, false
}
return node.get(key, shift+mapNodeBits, keyHash, h)
}
// set returns a node with the value set for the given key.
-func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapHashArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
idx := (keyHash >> shift) & mapNodeMask
node := n.nodes[idx]
// If node at index doesn't exist, create a simple value leaf node.
// Otherwise delegate set to child node.
- var newNode mapNode
+ var newNode mapNode[K, V]
if node == nil {
*resized = true
newNode = newMapValueNode(keyHash, key, value)
@@ -1173,7 +1175,7 @@ func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint3
// delete returns a node with the given key removed. Returns the same node if
// the key does not exist. If node shrinks to within bitmap-indexed size then
// converts to a bitmap-indexed node.
-func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapHashArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
idx := (keyHash >> shift) & mapNodeMask
node := n.nodes[idx]
@@ -1190,7 +1192,7 @@ func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h
// If we remove a node and drop below a threshold, convert back to bitmap indexed node.
if newNode == nil && n.count <= maxBitmapIndexedSize {
- other := &mapBitmapIndexedNode{nodes: make([]mapNode, 0, n.count-1)}
+ other := &mapBitmapIndexedNode[K, V]{nodes: make([]mapNode[K, V], 0, n.count-1)}
for i, child := range n.nodes {
if child != nil && uint32(i) != idx {
other.bitmap |= 1 << uint(i)
@@ -1217,15 +1219,15 @@ func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h
// mapValueNode represents a leaf node with a single key/value pair.
// A value node can be converted to a hash collision leaf node if a different
// key with the same keyHash is inserted.
-type mapValueNode struct {
+type mapValueNode[K constraints.Ordered, V any] struct {
keyHash uint32
- key interface{}
- value interface{}
+ key K
+ value V
}
// newMapValueNode returns a new instance of mapValueNode.
-func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode {
- return &mapValueNode{
+func newMapValueNode[K constraints.Ordered, V any](keyHash uint32, key K, value V) *mapValueNode[K, V] {
+ return &mapValueNode[K, V]{
keyHash: keyHash,
key: key,
value: value,
@@ -1233,14 +1235,14 @@ func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode {
}
// keyHashValue returns the key hash for this node.
-func (n *mapValueNode) keyHashValue() uint32 {
+func (n *mapValueNode[K, V]) keyHashValue() uint32 {
return n.keyHash
}
// get returns the value for the given key.
-func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+func (n *mapValueNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
if !h.Equal(n.key, key) {
- return nil, false
+ return value, false
}
return n.value, true
}
@@ -1249,7 +1251,7 @@ func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher
// the node's key then a new value node is returned. If key is not equal to the
// node's key but has the same hash then a hash collision node is returned.
// Otherwise the nodes are merged into a branch node.
-func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapValueNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
// If the keys match then return a new value node overwriting the value.
if h.Equal(n.key, key) {
// Update in-place if mutable.
@@ -1265,18 +1267,18 @@ func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h
// Recursively merge nodes together if key hashes are different.
if n.keyHash != keyHash {
- return mergeIntoNode(n, shift, keyHash, key, value)
+ return mergeIntoNode[K, V](n, shift, keyHash, key, value)
}
// Merge into collision node if hash matches.
- return &mapHashCollisionNode{keyHash: keyHash, entries: []mapEntry{
+ return &mapHashCollisionNode[K, V]{keyHash: keyHash, entries: []mapEntry[K, V]{
{key: n.key, value: n.value},
{key: key, value: value},
}}
}
// delete returns nil if the key matches the node's key. Otherwise returns the original node.
-func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapValueNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
// Return original node if the keys do not match.
if !h.Equal(n.key, key) {
return n
@@ -1289,19 +1291,19 @@ func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Has
// mapHashCollisionNode represents a leaf node that contains two or more key/value
// pairs with the same key hash. Single pairs for a hash are stored as value nodes.
-type mapHashCollisionNode struct {
+type mapHashCollisionNode[K constraints.Ordered, V any] struct {
keyHash uint32 // key hash for all entries
- entries []mapEntry
+ entries []mapEntry[K, V]
}
// keyHashValue returns the key hash for all entries on the node.
-func (n *mapHashCollisionNode) keyHashValue() uint32 {
+func (n *mapHashCollisionNode[K, V]) keyHashValue() uint32 {
return n.keyHash
}
// indexOf returns the index of the entry for the given key.
// Returns -1 if the key does not exist in the node.
-func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int {
+func (n *mapHashCollisionNode[K, V]) indexOf(key K, h Hasher[K]) int {
for i := range n.entries {
if h.Equal(n.entries[i].key, key) {
return i
@@ -1311,46 +1313,46 @@ func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int {
}
// get returns the value for the given key.
-func (n *mapHashCollisionNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+func (n *mapHashCollisionNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
for i := range n.entries {
if h.Equal(n.entries[i].key, key) {
return n.entries[i].value, true
}
}
- return nil, false
+ return value, false
}
// set returns a copy of the node with key set to the given value.
-func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapHashCollisionNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
// Merge node with key/value pair if this is not a hash collision.
if n.keyHash != keyHash {
*resized = true
- return mergeIntoNode(n, shift, keyHash, key, value)
+ return mergeIntoNode[K, V](n, shift, keyHash, key, value)
}
// Update in-place if mutable.
if mutable {
if idx := n.indexOf(key, h); idx == -1 {
*resized = true
- n.entries = append(n.entries, mapEntry{key, value})
+ n.entries = append(n.entries, mapEntry[K, V]{key, value})
} else {
- n.entries[idx] = mapEntry{key, value}
+ n.entries[idx] = mapEntry[K, V]{key, value}
}
return n
}
// Append to end of node if key doesn't exist & mark resized.
// Otherwise copy nodes and overwrite at matching key index.
- other := &mapHashCollisionNode{keyHash: n.keyHash}
+ other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash}
if idx := n.indexOf(key, h); idx == -1 {
*resized = true
- other.entries = make([]mapEntry, len(n.entries)+1)
+ other.entries = make([]mapEntry[K, V], len(n.entries)+1)
copy(other.entries, n.entries)
- other.entries[len(other.entries)-1] = mapEntry{key, value}
+ other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value}
} else {
- other.entries = make([]mapEntry, len(n.entries))
+ other.entries = make([]mapEntry[K, V], len(n.entries))
copy(other.entries, n.entries)
- other.entries[idx] = mapEntry{key, value}
+ other.entries[idx] = mapEntry[K, V]{key, value}
}
return other
}
@@ -1358,7 +1360,7 @@ func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash u
// delete returns a node with the given key deleted. Returns the same node if
// the key does not exist. If removing the key would shrink the node to a single
// entry then a value node is returned.
-func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode {
+func (n *mapHashCollisionNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
idx := n.indexOf(key, h)
// Return original node if key is not found.
@@ -1371,7 +1373,7 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3
// Convert to value node if we move to one entry.
if len(n.entries) == 2 {
- return &mapValueNode{
+ return &mapValueNode[K, V]{
keyHash: n.keyHash,
key: n.entries[idx^1].key,
value: n.entries[idx^1].value,
@@ -1381,13 +1383,13 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3
// Remove entry in-place if mutable.
if mutable {
copy(n.entries[idx:], n.entries[idx+1:])
- n.entries[len(n.entries)-1] = mapEntry{}
+ n.entries[len(n.entries)-1] = mapEntry[K, V]{}
n.entries = n.entries[:len(n.entries)-1]
return n
}
// Return copy without entry if immutable.
- other := &mapHashCollisionNode{keyHash: n.keyHash, entries: make([]mapEntry, len(n.entries)-1)}
+ other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash, entries: make([]mapEntry[K, V], len(n.entries)-1)}
copy(other.entries[:idx], n.entries[:idx])
copy(other.entries[idx:], n.entries[idx+1:])
return other
@@ -1395,46 +1397,46 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3
// mergeIntoNode merges a key/value pair into an existing node.
// Caller must verify that node's keyHash is not equal to keyHash.
-func mergeIntoNode(node mapLeafNode, shift uint, keyHash uint32, key, value interface{}) mapNode {
+func mergeIntoNode[K constraints.Ordered, V any](node mapLeafNode[K, V], shift uint, keyHash uint32, key K, value V) mapNode[K, V] {
idx1 := (node.keyHashValue() >> shift) & mapNodeMask
idx2 := (keyHash >> shift) & mapNodeMask
// Recursively build branch nodes to combine the node and its key.
- other := &mapBitmapIndexedNode{bitmap: (1 << idx1) | (1 << idx2)}
+ other := &mapBitmapIndexedNode[K, V]{bitmap: (1 << idx1) | (1 << idx2)}
if idx1 == idx2 {
- other.nodes = []mapNode{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)}
+ other.nodes = []mapNode[K, V]{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)}
} else {
if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 {
- other.nodes = []mapNode{node, newNode}
+ other.nodes = []mapNode[K, V]{node, newNode}
} else {
- other.nodes = []mapNode{newNode, node}
+ other.nodes = []mapNode[K, V]{newNode, node}
}
}
return other
}
// mapEntry represents a single key/value pair.
-type mapEntry struct {
- key interface{}
- value interface{}
+type mapEntry[K constraints.Ordered, V any] struct {
+ key K
+ value V
}
// MapIterator represents an iterator over a map's key/value pairs. Although
// map keys are not sorted, the iterator's order is deterministic.
-type MapIterator struct {
- m *Map // source map
+type MapIterator[K constraints.Ordered, V any] struct {
+ m *Map[K, V] // source map
- stack [32]mapIteratorElem // search stack
- depth int // stack depth
+ stack [32]mapIteratorElem[K, V] // search stack
+ depth int // stack depth
}
// Done returns true if no more elements remain in the iterator.
-func (itr *MapIterator) Done() bool {
+func (itr *MapIterator[K, V]) Done() bool {
return itr.depth == -1
}
// First resets the iterator to the first key/value pair.
-func (itr *MapIterator) First() {
+func (itr *MapIterator[K, V]) First() {
// Exit immediately if the map is empty.
if itr.m.root == nil {
itr.depth = -1
@@ -1442,27 +1444,27 @@ func (itr *MapIterator) First() {
}
// Initialize the stack to the left most element.
- itr.stack[0] = mapIteratorElem{node: itr.m.root}
+ itr.stack[0] = mapIteratorElem[K, V]{node: itr.m.root}
itr.depth = 0
itr.first()
}
// Next returns the next key/value pair. Returns a nil key when no elements remain.
-func (itr *MapIterator) Next() (key, value interface{}) {
+func (itr *MapIterator[K, V]) Next() (key K, value V, ok bool) {
// Return nil key if iteration is done.
if itr.Done() {
- return nil, nil
+ return key, value, false
}
// Retrieve current index & value. Current node is always a leaf.
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *mapArrayNode:
+ case *mapArrayNode[K, V]:
entry := &node.entries[elem.index]
key, value = entry.key, entry.value
- case *mapValueNode:
+ case *mapValueNode[K, V]:
key, value = node.key, node.value
- case *mapHashCollisionNode:
+ case *mapHashCollisionNode[K, V]:
entry := &node.entries[elem.index]
key, value = entry.key, entry.value
}
@@ -1470,22 +1472,22 @@ func (itr *MapIterator) Next() (key, value interface{}) {
// Move up stack until we find a node that has remaining position ahead
// and move that element forward by one.
itr.next()
- return key, value
+ return key, value, true
}
// next moves to the next available key.
-func (itr *MapIterator) next() {
+func (itr *MapIterator[K, V]) next() {
for ; itr.depth >= 0; itr.depth-- {
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *mapArrayNode:
+ case *mapArrayNode[K, V]:
if elem.index < len(node.entries)-1 {
elem.index++
return
}
- case *mapBitmapIndexedNode:
+ case *mapBitmapIndexedNode[K, V]:
if elem.index < len(node.nodes)-1 {
elem.index++
itr.stack[itr.depth+1].node = node.nodes[elem.index]
@@ -1494,7 +1496,7 @@ func (itr *MapIterator) next() {
return
}
- case *mapHashArrayNode:
+ case *mapHashArrayNode[K, V]:
for i := elem.index + 1; i < len(node.nodes); i++ {
if node.nodes[i] != nil {
elem.index = i
@@ -1505,10 +1507,10 @@ func (itr *MapIterator) next() {
}
}
- case *mapValueNode:
+ case *mapValueNode[K, V]:
continue // always the last value, traverse up
- case *mapHashCollisionNode:
+ case *mapHashCollisionNode[K, V]:
if elem.index < len(node.entries)-1 {
elem.index++
return
@@ -1519,16 +1521,16 @@ func (itr *MapIterator) next() {
// first positions the stack left most index.
// Elements and indexes at and below the current depth are assumed to be correct.
-func (itr *MapIterator) first() {
+func (itr *MapIterator[K, V]) first() {
for ; ; itr.depth++ {
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *mapBitmapIndexedNode:
+ case *mapBitmapIndexedNode[K, V]:
elem.index = 0
itr.stack[itr.depth+1].node = node.nodes[0]
- case *mapHashArrayNode:
+ case *mapHashArrayNode[K, V]:
for i := 0; i < len(node.nodes); i++ {
if node.nodes[i] != nil { // find first node
elem.index = i
@@ -1545,8 +1547,8 @@ func (itr *MapIterator) first() {
}
// mapIteratorElem represents a node/index pair in the MapIterator stack.
-type mapIteratorElem struct {
- node mapNode
+type mapIteratorElem[K constraints.Ordered, V any] struct {
+ node mapNode[K, V]
index int
}
@@ -1559,45 +1561,46 @@ const (
// is determined by the Comparer used by the map.
//
// This map is implemented as a B+tree.
-type SortedMap struct {
- size int // total number of key/value pairs
- root sortedMapNode // root of b+tree
- comparer Comparer
+type SortedMap[K constraints.Ordered, V any] struct {
+ size int // total number of key/value pairs
+ root sortedMapNode[K, V] // root of b+tree
+ comparer Comparer[K]
}
// NewSortedMap returns a new instance of SortedMap. If comparer is nil then
// a default comparer is set after the first key is inserted. Default comparers
// exist for int, string, and byte slice keys.
-func NewSortedMap(comparer Comparer) *SortedMap {
- return &SortedMap{
+func NewSortedMap[K constraints.Ordered, V any](comparer Comparer[K]) *SortedMap[K, V] {
+ return &SortedMap[K, V]{
comparer: comparer,
}
}
// Len returns the number of elements in the sorted map.
-func (m *SortedMap) Len() int {
+func (m *SortedMap[K, V]) Len() int {
return m.size
}
// Get returns the value for a given key and a flag indicating if the key is set.
// The flag can be used to distinguish between a nil-set key versus an unset key.
-func (m *SortedMap) Get(key interface{}) (interface{}, bool) {
+func (m *SortedMap[K, V]) Get(key K) (V, bool) {
if m.root == nil {
- return nil, false
+ var v V
+ return v, false
}
return m.root.get(key, m.comparer)
}
// Set returns a copy of the map with the key set to the given value.
-func (m *SortedMap) Set(key, value interface{}) *SortedMap {
+func (m *SortedMap[K, V]) Set(key K, value V) *SortedMap[K, V] {
return m.set(key, value, false)
}
-func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap {
+func (m *SortedMap[K, V]) set(key K, value V, mutable bool) *SortedMap[K, V] {
// Set a comparer on the first value if one does not already exist.
comparer := m.comparer
if comparer == nil {
- comparer = NewComparer(key)
+ comparer = NewComparer[K](key)
}
// Create copy, if necessary.
@@ -1610,7 +1613,7 @@ func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap {
// If no values are set then initialize with a leaf node.
if m.root == nil {
other.size = 1
- other.root = &sortedMapLeafNode{entries: []mapEntry{{key: key, value: value}}}
+ other.root = &sortedMapLeafNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}}
return other
}
@@ -1633,11 +1636,11 @@ func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap {
// Delete returns a copy of the map with the key removed.
// Returns the original map if key does not exist.
-func (m *SortedMap) Delete(key interface{}) *SortedMap {
+func (m *SortedMap[K, V]) Delete(key K) *SortedMap[K, V] {
return m.delete(key, false)
}
-func (m *SortedMap) delete(key interface{}, mutable bool) *SortedMap {
+func (m *SortedMap[K, V]) delete(key K, mutable bool) *SortedMap[K, V] {
// Return original map if no keys exist.
if m.root == nil {
return m
@@ -1663,31 +1666,31 @@ func (m *SortedMap) delete(key interface{}, mutable bool) *SortedMap {
}
// clone returns a shallow copy of m.
-func (m *SortedMap) clone() *SortedMap {
+func (m *SortedMap[K, V]) clone() *SortedMap[K, V] {
other := *m
return &other
}
// Iterator returns a new iterator for this map positioned at the first key.
-func (m *SortedMap) Iterator() *SortedMapIterator {
- itr := &SortedMapIterator{m: m}
+func (m *SortedMap[K, V]) Iterator() *SortedMapIterator[K, V] {
+ itr := &SortedMapIterator[K, V]{m: m}
itr.First()
return itr
}
// SortedMapBuilder represents an efficient builder for creating sorted maps.
-type SortedMapBuilder struct {
- m *SortedMap // current state
+type SortedMapBuilder[K constraints.Ordered, V any] struct {
+ m *SortedMap[K, V] // current state
}
// NewSortedMapBuilder returns a new instance of SortedMapBuilder.
-func NewSortedMapBuilder(comparer Comparer) *SortedMapBuilder {
- return &SortedMapBuilder{m: NewSortedMap(comparer)}
+func NewSortedMapBuilder[K constraints.Ordered, V any](comparer Comparer[K]) *SortedMapBuilder[K, V] {
+ return &SortedMapBuilder[K, V]{m: NewSortedMap[K, V](comparer)}
}
// SortedMap returns the current copy of the map.
// The returned map is safe to use even if after the builder continues to be used.
-func (b *SortedMapBuilder) Map() *SortedMap {
+func (b *SortedMapBuilder[K, V]) Map() *SortedMap[K, V] {
assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map")
m := b.m
b.m = nil
@@ -1695,73 +1698,73 @@ func (b *SortedMapBuilder) Map() *SortedMap {
}
// Len returns the number of elements in the underlying map.
-func (b *SortedMapBuilder) Len() int {
+func (b *SortedMapBuilder[K, V]) Len() int {
assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
return b.m.Len()
}
// Get returns the value for the given key.
-func (b *SortedMapBuilder) Get(key interface{}) (value interface{}, ok bool) {
+func (b *SortedMapBuilder[K, V]) Get(key K) (value V, ok bool) {
assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
return b.m.Get(key)
}
// Set sets the value of the given key. See SortedMap.Set() for additional details.
-func (b *SortedMapBuilder) Set(key, value interface{}) {
+func (b *SortedMapBuilder[K, V]) Set(key K, value V) {
assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
b.m = b.m.set(key, value, true)
}
// Delete removes the given key. See SortedMap.Delete() for additional details.
-func (b *SortedMapBuilder) Delete(key interface{}) {
+func (b *SortedMapBuilder[K, V]) Delete(key K) {
assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
b.m = b.m.delete(key, true)
}
// Iterator returns a new iterator for the underlying map positioned at the first key.
-func (b *SortedMapBuilder) Iterator() *SortedMapIterator {
+func (b *SortedMapBuilder[K, V]) Iterator() *SortedMapIterator[K, V] {
assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
return b.m.Iterator()
}
// sortedMapNode represents a branch or leaf node in the sorted map.
-type sortedMapNode interface {
- minKey() interface{}
- indexOf(key interface{}, c Comparer) int
- get(key interface{}, c Comparer) (value interface{}, ok bool)
- set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode)
- delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode
+type sortedMapNode[K constraints.Ordered, V any] interface {
+ minKey() K
+ indexOf(key K, c Comparer[K]) int
+ get(key K, c Comparer[K]) (value V, ok bool)
+ set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V])
+ delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V]
}
-var _ sortedMapNode = (*sortedMapBranchNode)(nil)
-var _ sortedMapNode = (*sortedMapLeafNode)(nil)
+var _ sortedMapNode[string, any] = (*sortedMapBranchNode[string, any])(nil)
+var _ sortedMapNode[string, any] = (*sortedMapLeafNode[string, any])(nil)
// sortedMapBranchNode represents a branch in the sorted map.
-type sortedMapBranchNode struct {
- elems []sortedMapBranchElem
+type sortedMapBranchNode[K constraints.Ordered, V any] struct {
+ elems []sortedMapBranchElem[K, V]
}
// newSortedMapBranchNode returns a new branch node with the given child nodes.
-func newSortedMapBranchNode(children ...sortedMapNode) *sortedMapBranchNode {
+func newSortedMapBranchNode[K constraints.Ordered, V any](children ...sortedMapNode[K, V]) *sortedMapBranchNode[K, V] {
// Fetch min keys for every child.
- elems := make([]sortedMapBranchElem, len(children))
+ elems := make([]sortedMapBranchElem[K, V], len(children))
for i, child := range children {
- elems[i] = sortedMapBranchElem{
+ elems[i] = sortedMapBranchElem[K, V]{
key: child.minKey(),
node: child,
}
}
- return &sortedMapBranchNode{elems: elems}
+ return &sortedMapBranchNode[K, V]{elems: elems}
}
// minKey returns the lowest key stored in this node's tree.
-func (n *sortedMapBranchNode) minKey() interface{} {
+func (n *sortedMapBranchNode[K, V]) minKey() K {
return n.elems[0].node.minKey()
}
// indexOf returns the index of the key within the child nodes.
-func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int {
+func (n *sortedMapBranchNode[K, V]) indexOf(key K, c Comparer[K]) int {
if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 {
return idx - 1
}
@@ -1769,13 +1772,13 @@ func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int {
}
// get returns the value for the given key.
-func (n *sortedMapBranchNode) get(key interface{}, c Comparer) (value interface{}, ok bool) {
+func (n *sortedMapBranchNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) {
idx := n.indexOf(key, c)
return n.elems[idx].node.get(key, c)
}
// set returns a copy of the node with the key set to the given value.
-func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode) {
+func (n *sortedMapBranchNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) {
idx := n.indexOf(key, c)
// Delegate insert to child node.
@@ -1783,18 +1786,18 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo
// Update in-place, if mutable.
if mutable {
- n.elems[idx] = sortedMapBranchElem{key: newNode.minKey(), node: newNode}
+ n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode}
if splitNode != nil {
- n.elems = append(n.elems, sortedMapBranchElem{})
+ n.elems = append(n.elems, sortedMapBranchElem[K, V]{})
copy(n.elems[idx+1:], n.elems[idx:])
- n.elems[idx+1] = sortedMapBranchElem{key: splitNode.minKey(), node: splitNode}
+ n.elems[idx+1] = sortedMapBranchElem[K, V]{key: splitNode.minKey(), node: splitNode}
}
// If the child splits and we have no more room then we split too.
if len(n.elems) > sortedMapNodeSize {
splitIdx := len(n.elems) / 2
- newNode := &sortedMapBranchNode{elems: n.elems[:splitIdx:splitIdx]}
- splitNode := &sortedMapBranchNode{elems: n.elems[splitIdx:]}
+ newNode := &sortedMapBranchNode[K, V]{elems: n.elems[:splitIdx:splitIdx]}
+ splitNode := &sortedMapBranchNode[K, V]{elems: n.elems[splitIdx:]}
return newNode, splitNode
}
return n, nil
@@ -1802,23 +1805,23 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo
// If no split occurs, copy branch and update keys.
// If the child splits, insert new key/child into copy of branch.
- var other sortedMapBranchNode
+ var other sortedMapBranchNode[K, V]
if splitNode == nil {
- other.elems = make([]sortedMapBranchElem, len(n.elems))
+ other.elems = make([]sortedMapBranchElem[K, V], len(n.elems))
copy(other.elems, n.elems)
- other.elems[idx] = sortedMapBranchElem{
+ other.elems[idx] = sortedMapBranchElem[K, V]{
key: newNode.minKey(),
node: newNode,
}
} else {
- other.elems = make([]sortedMapBranchElem, len(n.elems)+1)
+ other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)+1)
copy(other.elems[:idx], n.elems[:idx])
copy(other.elems[idx+1:], n.elems[idx:])
- other.elems[idx] = sortedMapBranchElem{
+ other.elems[idx] = sortedMapBranchElem[K, V]{
key: newNode.minKey(),
node: newNode,
}
- other.elems[idx+1] = sortedMapBranchElem{
+ other.elems[idx+1] = sortedMapBranchElem[K, V]{
key: splitNode.minKey(),
node: splitNode,
}
@@ -1827,8 +1830,8 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo
// If the child splits and we have no more room then we split too.
if len(other.elems) > sortedMapNodeSize {
splitIdx := len(other.elems) / 2
- newNode := &sortedMapBranchNode{elems: other.elems[:splitIdx:splitIdx]}
- splitNode := &sortedMapBranchNode{elems: other.elems[splitIdx:]}
+ newNode := &sortedMapBranchNode[K, V]{elems: other.elems[:splitIdx:splitIdx]}
+ splitNode := &sortedMapBranchNode[K, V]{elems: other.elems[splitIdx:]}
return newNode, splitNode
}
@@ -1838,7 +1841,7 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo
// delete returns a node with the key removed. Returns the same node if the key
// does not exist. Returns nil if all child nodes are removed.
-func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode {
+func (n *sortedMapBranchNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] {
idx := n.indexOf(key, c)
// Return original node if child has not changed.
@@ -1857,13 +1860,13 @@ func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool,
// If mutable, update in-place.
if mutable {
copy(n.elems[idx:], n.elems[idx+1:])
- n.elems[len(n.elems)-1] = sortedMapBranchElem{}
+ n.elems[len(n.elems)-1] = sortedMapBranchElem[K, V]{}
n.elems = n.elems[:len(n.elems)-1]
return n
}
// Return a copy without the given node.
- other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems)-1)}
+ other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems)-1)}
copy(other.elems[:idx], n.elems[:idx])
copy(other.elems[idx:], n.elems[idx+1:])
return other
@@ -1871,49 +1874,49 @@ func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool,
// If mutable, update in-place.
if mutable {
- n.elems[idx] = sortedMapBranchElem{key: newNode.minKey(), node: newNode}
+ n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode}
return n
}
// Return a copy with the updated node.
- other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems))}
+ other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems))}
copy(other.elems, n.elems)
- other.elems[idx] = sortedMapBranchElem{
+ other.elems[idx] = sortedMapBranchElem[K, V]{
key: newNode.minKey(),
node: newNode,
}
return other
}
-type sortedMapBranchElem struct {
- key interface{}
- node sortedMapNode
+type sortedMapBranchElem[K constraints.Ordered, V any] struct {
+ key K
+ node sortedMapNode[K, V]
}
// sortedMapLeafNode represents a leaf node in the sorted map.
-type sortedMapLeafNode struct {
- entries []mapEntry
+type sortedMapLeafNode[K constraints.Ordered, V any] struct {
+ entries []mapEntry[K, V]
}
// minKey returns the first key stored in this node.
-func (n *sortedMapLeafNode) minKey() interface{} {
+func (n *sortedMapLeafNode[K, V]) minKey() K {
return n.entries[0].key
}
// indexOf returns the index of the given key.
-func (n *sortedMapLeafNode) indexOf(key interface{}, c Comparer) int {
+func (n *sortedMapLeafNode[K, V]) indexOf(key K, c Comparer[K]) int {
return sort.Search(len(n.entries), func(i int) bool {
return c.Compare(n.entries[i].key, key) != -1 // GTE
})
}
// get returns the value of the given key.
-func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{}, ok bool) {
+func (n *sortedMapLeafNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) {
idx := n.indexOf(key, c)
// If the index is beyond the entry count or the key is not equal then return 'not found'.
if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 {
- return nil, false
+ return value, false
}
// If the key matches then return its value.
@@ -1922,7 +1925,7 @@ func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{},
// set returns a copy of node with the key set to the given value. If the update
// causes the node to grow beyond the maximum size then it is split in two.
-func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode) {
+func (n *sortedMapLeafNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) {
// Find the insertion index for the key.
idx := n.indexOf(key, c)
exists := idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0
@@ -1931,16 +1934,16 @@ func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool
if mutable {
if !exists {
*resized = true
- n.entries = append(n.entries, mapEntry{})
+ n.entries = append(n.entries, mapEntry[K, V]{})
copy(n.entries[idx+1:], n.entries[idx:])
}
- n.entries[idx] = mapEntry{key: key, value: value}
+ n.entries[idx] = mapEntry[K, V]{key: key, value: value}
// If the key doesn't exist and we exceed our max allowed values then split.
if len(n.entries) > sortedMapNodeSize {
splitIdx := len(n.entries) / 2
- newNode := &sortedMapLeafNode{entries: n.entries[:splitIdx:splitIdx]}
- splitNode := &sortedMapLeafNode{entries: n.entries[splitIdx:]}
+ newNode := &sortedMapLeafNode[K, V]{entries: n.entries[:splitIdx:splitIdx]}
+ splitNode := &sortedMapLeafNode[K, V]{entries: n.entries[splitIdx:]}
return newNode, splitNode
}
return n, nil
@@ -1948,34 +1951,34 @@ func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool
// If the key matches then simply return a copy with the entry overridden.
// If there is no match then insert new entry and mark as resized.
- var newEntries []mapEntry
+ var newEntries []mapEntry[K, V]
if exists {
- newEntries = make([]mapEntry, len(n.entries))
+ newEntries = make([]mapEntry[K, V], len(n.entries))
copy(newEntries, n.entries)
- newEntries[idx] = mapEntry{key: key, value: value}
+ newEntries[idx] = mapEntry[K, V]{key: key, value: value}
} else {
*resized = true
- newEntries = make([]mapEntry, len(n.entries)+1)
+ newEntries = make([]mapEntry[K, V], len(n.entries)+1)
copy(newEntries[:idx], n.entries[:idx])
- newEntries[idx] = mapEntry{key: key, value: value}
+ newEntries[idx] = mapEntry[K, V]{key: key, value: value}
copy(newEntries[idx+1:], n.entries[idx:])
}
// If the key doesn't exist and we exceed our max allowed values then split.
if len(newEntries) > sortedMapNodeSize {
splitIdx := len(newEntries) / 2
- newNode := &sortedMapLeafNode{entries: newEntries[:splitIdx:splitIdx]}
- splitNode := &sortedMapLeafNode{entries: newEntries[splitIdx:]}
+ newNode := &sortedMapLeafNode[K, V]{entries: newEntries[:splitIdx:splitIdx]}
+ splitNode := &sortedMapLeafNode[K, V]{entries: newEntries[splitIdx:]}
return newNode, splitNode
}
// Otherwise return the new leaf node with the updated entry.
- return &sortedMapLeafNode{entries: newEntries}, nil
+ return &sortedMapLeafNode[K, V]{entries: newEntries}, nil
}
// delete returns a copy of node with key removed. Returns the original node if
// the key does not exist. Returns nil if the removed key is the last remaining key.
-func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode {
+func (n *sortedMapLeafNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] {
idx := n.indexOf(key, c)
// Return original node if key is not found.
@@ -1992,13 +1995,13 @@ func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, re
// Update in-place, if mutable.
if mutable {
copy(n.entries[idx:], n.entries[idx+1:])
- n.entries[len(n.entries)-1] = mapEntry{}
+ n.entries[len(n.entries)-1] = mapEntry[K, V]{}
n.entries = n.entries[:len(n.entries)-1]
return n
}
// Return copy of node with entry removed.
- other := &sortedMapLeafNode{entries: make([]mapEntry, len(n.entries)-1)}
+ other := &sortedMapLeafNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)}
copy(other.entries[:idx], n.entries[:idx])
copy(other.entries[idx:], n.entries[idx+1:])
return other
@@ -2006,36 +2009,36 @@ func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, re
// SortedMapIterator represents an iterator over a sorted map.
// Iteration can occur in natural or reverse order based on use of Next() or Prev().
-type SortedMapIterator struct {
- m *SortedMap // source map
+type SortedMapIterator[K constraints.Ordered, V any] struct {
+ m *SortedMap[K, V] // source map
- stack [32]sortedMapIteratorElem // search stack
- depth int // stack depth
+ stack [32]sortedMapIteratorElem[K, V] // search stack
+ depth int // stack depth
}
// Done returns true if no more key/value pairs remain in the iterator.
-func (itr *SortedMapIterator) Done() bool {
+func (itr *SortedMapIterator[K, V]) Done() bool {
return itr.depth == -1
}
// First moves the iterator to the first key/value pair.
-func (itr *SortedMapIterator) First() {
+func (itr *SortedMapIterator[K, V]) First() {
if itr.m.root == nil {
itr.depth = -1
return
}
- itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
itr.depth = 0
itr.first()
}
// Last moves the iterator to the last key/value pair.
-func (itr *SortedMapIterator) Last() {
+func (itr *SortedMapIterator[K, V]) Last() {
if itr.m.root == nil {
itr.depth = -1
return
}
- itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
itr.depth = 0
itr.last()
}
@@ -2043,27 +2046,27 @@ func (itr *SortedMapIterator) Last() {
// Seek moves the iterator position to the given key in the map.
// If the key does not exist then the next key is used. If no more keys exist
// then the iteartor is marked as done.
-func (itr *SortedMapIterator) Seek(key interface{}) {
+func (itr *SortedMapIterator[K, V]) Seek(key K) {
if itr.m.root == nil {
itr.depth = -1
return
}
- itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
itr.depth = 0
itr.seek(key)
}
// Next returns the current key/value pair and moves the iterator forward.
// Returns a nil key if the there are no more elements to return.
-func (itr *SortedMapIterator) Next() (key, value interface{}) {
+func (itr *SortedMapIterator[K, V]) Next() (key K, value V, ok bool) {
// Return nil key if iteration is complete.
if itr.Done() {
- return nil, nil
+ return key, value, false
}
// Retrieve current key/value pair.
leafElem := &itr.stack[itr.depth]
- leafNode := leafElem.node.(*sortedMapLeafNode)
+ leafNode := leafElem.node.(*sortedMapLeafNode[K, V])
leafEntry := &leafNode.entries[leafElem.index]
key, value = leafEntry.key, leafEntry.value
@@ -2071,21 +2074,21 @@ func (itr *SortedMapIterator) Next() (key, value interface{}) {
itr.next()
// Only occurs when iterator is done.
- return key, value
+ return key, value, true
}
// next moves to the next key. If no keys are after then depth is set to -1.
-func (itr *SortedMapIterator) next() {
+func (itr *SortedMapIterator[K, V]) next() {
for ; itr.depth >= 0; itr.depth-- {
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *sortedMapLeafNode:
+ case *sortedMapLeafNode[K, V]:
if elem.index < len(node.entries)-1 {
elem.index++
return
}
- case *sortedMapBranchNode:
+ case *sortedMapBranchNode[K, V]:
if elem.index < len(node.elems)-1 {
elem.index++
itr.stack[itr.depth+1].node = node.elems[elem.index].node
@@ -2099,34 +2102,34 @@ func (itr *SortedMapIterator) next() {
// Prev returns the current key/value pair and moves the iterator backward.
// Returns a nil key if the there are no more elements to return.
-func (itr *SortedMapIterator) Prev() (key, value interface{}) {
+func (itr *SortedMapIterator[K, V]) Prev() (key K, value V, ok bool) {
// Return nil key if iteration is complete.
if itr.Done() {
- return nil, nil
+ return key, value, false
}
// Retrieve current key/value pair.
leafElem := &itr.stack[itr.depth]
- leafNode := leafElem.node.(*sortedMapLeafNode)
+ leafNode := leafElem.node.(*sortedMapLeafNode[K, V])
leafEntry := &leafNode.entries[leafElem.index]
key, value = leafEntry.key, leafEntry.value
itr.prev()
- return key, value
+ return key, value, true
}
// prev moves to the previous key. If no keys are before then depth is set to -1.
-func (itr *SortedMapIterator) prev() {
+func (itr *SortedMapIterator[K, V]) prev() {
for ; itr.depth >= 0; itr.depth-- {
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *sortedMapLeafNode:
+ case *sortedMapLeafNode[K, V]:
if elem.index > 0 {
elem.index--
return
}
- case *sortedMapBranchNode:
+ case *sortedMapBranchNode[K, V]:
if elem.index > 0 {
elem.index--
itr.stack[itr.depth+1].node = node.elems[elem.index].node
@@ -2140,16 +2143,16 @@ func (itr *SortedMapIterator) prev() {
// first positions the stack to the leftmost key from the current depth.
// Elements and indexes below the current depth are assumed to be correct.
-func (itr *SortedMapIterator) first() {
+func (itr *SortedMapIterator[K, V]) first() {
for {
elem := &itr.stack[itr.depth]
elem.index = 0
switch node := elem.node.(type) {
- case *sortedMapBranchNode:
- itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ case *sortedMapBranchNode[K, V]:
+ itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
itr.depth++
- case *sortedMapLeafNode:
+ case *sortedMapLeafNode[K, V]:
return
}
}
@@ -2157,16 +2160,16 @@ func (itr *SortedMapIterator) first() {
// last positions the stack to the rightmost key from the current depth.
// Elements and indexes below the current depth are assumed to be correct.
-func (itr *SortedMapIterator) last() {
+func (itr *SortedMapIterator[K, V]) last() {
for {
elem := &itr.stack[itr.depth]
switch node := elem.node.(type) {
- case *sortedMapBranchNode:
+ case *sortedMapBranchNode[K, V]:
elem.index = len(node.elems) - 1
- itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
itr.depth++
- case *sortedMapLeafNode:
+ case *sortedMapLeafNode[K, V]:
elem.index = len(node.entries) - 1
return
}
@@ -2175,16 +2178,16 @@ func (itr *SortedMapIterator) last() {
// seek positions the stack to the given key from the current depth.
// Elements and indexes below the current depth are assumed to be correct.
-func (itr *SortedMapIterator) seek(key interface{}) {
+func (itr *SortedMapIterator[K, V]) seek(key K) {
for {
elem := &itr.stack[itr.depth]
elem.index = elem.node.indexOf(key, itr.m.comparer)
switch node := elem.node.(type) {
- case *sortedMapBranchNode:
- itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ case *sortedMapBranchNode[K, V]:
+ itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
itr.depth++
- case *sortedMapLeafNode:
+ case *sortedMapLeafNode[K, V]:
if elem.index == len(node.entries) {
itr.next()
}
@@ -2194,59 +2197,33 @@ func (itr *SortedMapIterator) seek(key interface{}) {
}
// sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack.
-type sortedMapIteratorElem struct {
- node sortedMapNode
+type sortedMapIteratorElem[K constraints.Ordered, V any] struct {
+ node sortedMapNode[K, V]
index int
}
// Hasher hashes keys and checks them for equality.
-type Hasher interface {
- // Computes a 32-bit hash for key.
- Hash(key interface{}) uint32
+type Hasher[K constraints.Ordered] interface {
+ // Computes a hash for key.
+ Hash(key K) uint32
// Returns true if a and b are equal.
- Equal(a, b interface{}) bool
+ Equal(a, b K) bool
}
// NewHasher returns the built-in hasher for a given key type.
-func NewHasher(key interface{}) Hasher {
+func NewHasher[K constraints.Ordered](key K) Hasher[K] {
// Attempt to use non-reflection based hasher first.
- switch key.(type) {
- case int:
- return &intHasher{}
- case int8:
- return &int8Hasher{}
- case int16:
- return &int16Hasher{}
- case int32:
- return &int32Hasher{}
- case int64:
- return &int64Hasher{}
- case uint:
- return &uintHasher{}
- case uint8:
- return &uint8Hasher{}
- case uint16:
- return &uint16Hasher{}
- case uint32:
- return &uint32Hasher{}
- case uint64:
- return &uint64Hasher{}
- case string:
- return &stringHasher{}
- case []byte:
- return &byteSliceHasher{}
+ switch (any(key)).(type) {
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, string:
+ return &defaultHasher[K]{}
}
// Fallback to reflection-based hasher otherwise.
// This is used when caller wraps a type around a primitive type.
switch reflect.TypeOf(key).Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return &reflectIntHasher{}
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return &reflectUintHasher{}
- case reflect.String:
- return &reflectStringHasher{}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.String:
+ return &reflectHasher[K]{}
}
// If no hashers match then panic.
@@ -2254,227 +2231,49 @@ func NewHasher(key interface{}) Hasher {
panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key))
}
-// intHasher implements Hasher for int keys.
-type intHasher struct{}
-
-// Hash returns a hash for key.
-func (h *intHasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(int)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not ints.
-func (h *intHasher) Equal(a, b interface{}) bool {
- return a.(int) == b.(int)
-}
-
-// int8Hasher implements Hasher for int8 keys.
-type int8Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *int8Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(int8)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not int8s.
-func (h *int8Hasher) Equal(a, b interface{}) bool {
- return a.(int8) == b.(int8)
-}
-
-// int16Hasher implements Hasher for int16 keys.
-type int16Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *int16Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(int16)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not int16s.
-func (h *int16Hasher) Equal(a, b interface{}) bool {
- return a.(int16) == b.(int16)
-}
-
-// int32Hasher implements Hasher for int32 keys.
-type int32Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *int32Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(int32)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not int32s.
-func (h *int32Hasher) Equal(a, b interface{}) bool {
- return a.(int32) == b.(int32)
-}
-
-// int64Hasher implements Hasher for int64 keys.
-type int64Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *int64Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(int64)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not int64s.
-func (h *int64Hasher) Equal(a, b interface{}) bool {
- return a.(int64) == b.(int64)
-}
-
-// uintHasher implements Hasher for uint keys.
-type uintHasher struct{}
-
-// Hash returns a hash for key.
-func (h *uintHasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(uint)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not uints.
-func (h *uintHasher) Equal(a, b interface{}) bool {
- return a.(uint) == b.(uint)
-}
-
-// uint8Hasher implements Hasher for uint8 keys.
-type uint8Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *uint8Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(uint8)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not uint8s.
-func (h *uint8Hasher) Equal(a, b interface{}) bool {
- return a.(uint8) == b.(uint8)
-}
-
-// uint16Hasher implements Hasher for uint16 keys.
-type uint16Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *uint16Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(uint16)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not uint16s.
-func (h *uint16Hasher) Equal(a, b interface{}) bool {
- return a.(uint16) == b.(uint16)
-}
-
-// uint32Hasher implements Hasher for uint32 keys.
-type uint32Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *uint32Hasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(key.(uint32)))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not uint32s.
-func (h *uint32Hasher) Equal(a, b interface{}) bool {
- return a.(uint32) == b.(uint32)
-}
-
-// uint64Hasher implements Hasher for uint64 keys.
-type uint64Hasher struct{}
-
-// Hash returns a hash for key.
-func (h *uint64Hasher) Hash(key interface{}) uint32 {
- return hashUint64(key.(uint64))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not uint64s.
-func (h *uint64Hasher) Equal(a, b interface{}) bool {
- return a.(uint64) == b.(uint64)
-}
-
-// stringHasher implements Hasher for string keys.
-type stringHasher struct{}
-
-// Hash returns a hash for value.
-func (h *stringHasher) Hash(value interface{}) uint32 {
- var hash uint32
- for i, value := 0, value.(string); i < len(value); i++ {
- hash = 31*hash + uint32(value[i])
- }
- return hash
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not strings.
-func (h *stringHasher) Equal(a, b interface{}) bool {
- return a.(string) == b.(string)
-}
-
-// byteSliceHasher implements Hasher for byte slice keys.
-type byteSliceHasher struct{}
-
// Hash returns a hash for value.
-func (h *byteSliceHasher) Hash(value interface{}) uint32 {
+func hashString(value string) uint32 {
var hash uint32
- for i, value := 0, value.([]byte); i < len(value); i++ {
+ for i, value := 0, value; i < len(value); i++ {
hash = 31*hash + uint32(value[i])
}
return hash
}
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not byte slices.
-func (h *byteSliceHasher) Equal(a, b interface{}) bool {
- return bytes.Equal(a.([]byte), b.([]byte))
-}
-
// reflectIntHasher implements a reflection-based Hasher for int keys.
-type reflectIntHasher struct{}
+type reflectHasher[K constraints.Ordered] struct{}
// Hash returns a hash for key.
-func (h *reflectIntHasher) Hash(key interface{}) uint32 {
- return hashUint64(uint64(reflect.ValueOf(key).Int()))
-}
-
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not ints.
-func (h *reflectIntHasher) Equal(a, b interface{}) bool {
- return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int()
-}
-
-// reflectUintHasher implements a reflection-based Hasher for uint keys.
-type reflectUintHasher struct{}
-
-// Hash returns a hash for key.
-func (h *reflectUintHasher) Hash(key interface{}) uint32 {
- return hashUint64(reflect.ValueOf(key).Uint())
+func (h *reflectHasher[K]) Hash(key K) uint32 {
+ switch reflect.TypeOf(key).Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return hashUint64(uint64(reflect.ValueOf(key).Int()))
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return hashUint64(reflect.ValueOf(key).Uint())
+ case reflect.String:
+ var hash uint32
+ s := reflect.ValueOf(key).String()
+ for i := 0; i < len(s); i++ {
+ hash = 31*hash + uint32(s[i])
+ }
+ return hash
+ }
+ panic(fmt.Sprintf("immutable.reflectHasher.Hash: reflectHasher does not support %T type", key))
}
// Equal returns true if a is equal to b. Otherwise returns false.
// Panics if a and b are not ints.
-func (h *reflectUintHasher) Equal(a, b interface{}) bool {
- return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint()
-}
-
-// reflectStringHasher implements a refletion-based Hasher for string keys.
-type reflectStringHasher struct{}
-
-// Hash returns a hash for value.
-func (h *reflectStringHasher) Hash(value interface{}) uint32 {
- var hash uint32
- s := reflect.ValueOf(value).String()
- for i := 0; i < len(s); i++ {
- hash = 31*hash + uint32(s[i])
+func (h *reflectHasher[K]) Equal(a, b K) bool {
+ switch reflect.TypeOf(a).Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int()
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint()
+ case reflect.String:
+ return reflect.ValueOf(a).String() == reflect.ValueOf(b).String()
}
- return hash
-}
+ panic(fmt.Sprintf("immutable.reflectHasher.Equal: reflectHasher does not support %T type", a))
-// Equal returns true if a is equal to b. Otherwise returns false.
-// Panics if a and b are not strings.
-func (h *reflectStringHasher) Equal(a, b interface{}) bool {
- return reflect.ValueOf(a).String() == reflect.ValueOf(b).String()
}
// hashUint64 returns a 32-bit hash for a 64-bit value.
@@ -2487,66 +2286,80 @@ func hashUint64(value uint64) uint32 {
return uint32(hash)
}
-// Comparer allows the comparison of two keys for the purpose of sorting.
-type Comparer interface {
- // Returns -1 if a is less than b, returns 1 if a is greater than b,
- // and returns 0 if a is equal to b.
- Compare(a, b interface{}) int
-}
+// defaultHasher implements Hasher.
+type defaultHasher[K constraints.Ordered] struct{}
-// NewComparer returns the built-in comparer for a given key type.
-func NewComparer(key interface{}) Comparer {
- // Attempt to use non-reflection based comparer first.
- switch key.(type) {
+// Hash returns a hash for key.
+func (h *defaultHasher[K]) Hash(key K) uint32 {
+ // Attempt to use non-reflection based hasher first.
+ switch x := (any(key)).(type) {
+ // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
case int:
- return &intComparer{}
+ return hashUint64(uint64(x))
case int8:
- return &int8Comparer{}
+ return hashUint64(uint64(x))
case int16:
- return &int16Comparer{}
+ return hashUint64(uint64(x))
case int32:
- return &int32Comparer{}
+ return hashUint64(uint64(x))
case int64:
- return &int64Comparer{}
+ return hashUint64(uint64(x))
case uint:
- return &uintComparer{}
+ return hashUint64(uint64(x))
case uint8:
- return &uint8Comparer{}
+ return hashUint64(uint64(x))
case uint16:
- return &uint16Comparer{}
+ return hashUint64(uint64(x))
case uint32:
- return &uint32Comparer{}
+ return hashUint64(uint64(x))
case uint64:
- return &uint64Comparer{}
+ return hashUint64(uint64(x))
+ case uintptr:
+ return hashUint64(uint64(x))
case string:
- return &stringComparer{}
- case []byte:
- return &byteSliceComparer{}
+ return hashString(x)
}
+ panic(fmt.Sprintf("immutable.defaultHasher.Hash: must set comparer for %T type", key))
+}
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not ints.
+func (h *defaultHasher[K]) Equal(a, b K) bool {
+ return a == b
+}
+
+// Comparer allows the comparison of two keys for the purpose of sorting.
+type Comparer[K constraints.Ordered] interface {
+ // Returns -1 if a is less than b, returns 1 if a is greater than b,
+ // and returns 0 if a is equal to b.
+ Compare(a, b K) int
+}
+
+// NewComparer returns the built-in comparer for a given key type.
+func NewComparer[K constraints.Ordered](key K) Comparer[K] {
+ // Attempt to use non-reflection based comparer first.
+ switch (any(key)).(type) {
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, string:
+ return &defaultComparer[K]{}
+ }
// Fallback to reflection-based comparer otherwise.
// This is used when caller wraps a type around a primitive type.
switch reflect.TypeOf(key).Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return &reflectIntComparer{}
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return &reflectUintComparer{}
- case reflect.String:
- return &reflectStringComparer{}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.String:
+ return &reflectComparer[K]{}
}
-
// If no comparers match then panic.
// This is a compile time issue so it should not return an error.
panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key))
}
-// intComparer compares two integers. Implements Comparer.
-type intComparer struct{}
+// defaultComparer compares two integers. Implements Comparer.
+type defaultComparer[K constraints.Ordered] struct{}
// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
// returns 0 if a is equal to b. Panic if a or b is not an int.
-func (c *intComparer) Compare(a, b interface{}) int {
- if i, j := a.(int), b.(int); i < j {
+func (c *defaultComparer[K]) Compare(i K, j K) int {
+ if i < j {
return -1
} else if i > j {
return 1
@@ -2554,185 +2367,31 @@ func (c *intComparer) Compare(a, b interface{}) int {
return 0
}
-// int8Comparer compares two int8 values. Implements Comparer.
-type int8Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int8.
-func (c *int8Comparer) Compare(a, b interface{}) int {
- if i, j := a.(int8), b.(int8); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// int16Comparer compares two int16 values. Implements Comparer.
-type int16Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int16.
-func (c *int16Comparer) Compare(a, b interface{}) int {
- if i, j := a.(int16), b.(int16); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// int32Comparer compares two int32 values. Implements Comparer.
-type int32Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int32.
-func (c *int32Comparer) Compare(a, b interface{}) int {
- if i, j := a.(int32), b.(int32); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// int64Comparer compares two int64 values. Implements Comparer.
-type int64Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int64.
-func (c *int64Comparer) Compare(a, b interface{}) int {
- if i, j := a.(int64), b.(int64); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// uintComparer compares two uint values. Implements Comparer.
-type uintComparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an uint.
-func (c *uintComparer) Compare(a, b interface{}) int {
- if i, j := a.(uint), b.(uint); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// uint8Comparer compares two uint8 values. Implements Comparer.
-type uint8Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an uint8.
-func (c *uint8Comparer) Compare(a, b interface{}) int {
- if i, j := a.(uint8), b.(uint8); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// uint16Comparer compares two uint16 values. Implements Comparer.
-type uint16Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an uint16.
-func (c *uint16Comparer) Compare(a, b interface{}) int {
- if i, j := a.(uint16), b.(uint16); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// uint32Comparer compares two uint32 values. Implements Comparer.
-type uint32Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an uint32.
-func (c *uint32Comparer) Compare(a, b interface{}) int {
- if i, j := a.(uint32), b.(uint32); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// uint64Comparer compares two uint64 values. Implements Comparer.
-type uint64Comparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an uint64.
-func (c *uint64Comparer) Compare(a, b interface{}) int {
- if i, j := a.(uint64), b.(uint64); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// stringComparer compares two strings. Implements Comparer.
-type stringComparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not a string.
-func (c *stringComparer) Compare(a, b interface{}) int {
- return strings.Compare(a.(string), b.(string))
-}
-
-// byteSliceComparer compares two byte slices. Implements Comparer.
-type byteSliceComparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not a byte slice.
-func (c *byteSliceComparer) Compare(a, b interface{}) int {
- return bytes.Compare(a.([]byte), b.([]byte))
-}
-
// reflectIntComparer compares two int values using reflection. Implements Comparer.
-type reflectIntComparer struct{}
+type reflectComparer[K constraints.Ordered] struct{}
// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
// returns 0 if a is equal to b. Panic if a or b is not an int.
-func (c *reflectIntComparer) Compare(a, b interface{}) int {
- if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j {
- return -1
- } else if i > j {
- return 1
- }
- return 0
-}
-
-// reflectUintComparer compares two uint values using reflection. Implements Comparer.
-type reflectUintComparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int.
-func (c *reflectUintComparer) Compare(a, b interface{}) int {
- if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j {
- return -1
- } else if i > j {
- return 1
+func (c *reflectComparer[K]) Compare(a, b K) int {
+ switch reflect.TypeOf(a).Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+ case reflect.String:
+ return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String())
}
- return 0
-}
-
-// reflectStringComparer compares two string values using reflection. Implements Comparer.
-type reflectStringComparer struct{}
-
-// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
-// returns 0 if a is equal to b. Panic if a or b is not an int.
-func (c *reflectStringComparer) Compare(a, b interface{}) int {
- return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String())
+ panic(fmt.Sprintf("immutable.reflectComparer.Compare: must set comparer for %T type", a))
}
func assert(condition bool, message string) {