diff options
-rw-r--r-- | .github/workflows/test.yml | 4 | ||||
-rw-r--r-- | README.md | 34 | ||||
-rw-r--r-- | go.mod | 4 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | immutable.go | 1312 | ||||
-rw-r--r-- | immutable_test.go | 458 |
6 files changed, 720 insertions, 1094 deletions
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5d83c6e..400912e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: - name: Install Go uses: actions/setup-go@v2 with: - go-version: 1.15.x + go-version: 1.18.x - name: Checkout code uses: actions/checkout@v2 - name: Short test @@ -22,7 +22,7 @@ jobs: - name: Install Go uses: actions/setup-go@v2 with: - go-version: 1.15.x + go-version: 1.18.x - name: Checkout code uses: actions/checkout@v2 - name: Test @@ -1,7 +1,7 @@ Immutable     ========= -This repository contains immutable collection types for Go. It includes +This repository contains *generic* immutable collection types for Go. It includes `List`, `Map`, and `SortedMap` implementations. Immutable collections can provide efficient, lock free sharing of data by requiring that edits to the collections return new collections. @@ -34,7 +34,7 @@ prepending is as efficient as appending. ```go // Create a list with 3 elements. -l := immutable.NewList() +l := immutable.NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Prepend("baz") @@ -46,7 +46,7 @@ fmt.Println(l.Get(2)) // "bar" ``` Note that each change to the list results in a new list being created. These -lists are all snapshots at that point in time and cannot be changed so they +lists are all snapshots at that point in time and cannot be changed so they are safe to share between multiple goroutines. ### Updating list elements @@ -57,7 +57,7 @@ new list to a new variable. You can see that our old `l` variable retains a snapshot of the original value. ```go -l := immutable.NewList() +l := immutable.NewList[string]() l = l.Append("foo") l = l.Append("bar") newList := l.Set(2, "baz") @@ -95,7 +95,7 @@ Below is an example of iterating over all elements of our list from above: ```go itr := l.Iterator() for !itr.Done() { - index, value := itr.Next() + index, value, _ := itr.Next() fmt.Printf("Index %d equals %v\n", index, value) } @@ -115,7 +115,7 @@ a list in-place until you are ready to use it. This can improve bulk list building by 10x or more. ```go -b := immutable.NewListBuilder() +b := immutable.NewListBuilder[string]() b.Append("foo") b.Append("bar") b.Set(2, "baz") @@ -151,7 +151,7 @@ the value as well as a flag indicating if the key existed. The flag is useful to check if a `nil` value was set for a key versus a key did not exist. ```go -m := immutable.NewMap(nil) +m := immutable.NewMap[string,int](nil) m = m.Set("jane", 100) m = m.Set("susy", 200) m = m.Set("jane", 300) // overwrite @@ -175,7 +175,7 @@ Keys may be removed from the map by using the `Delete()` method. If the key does not exist then the original map is returned instead of a new one. ```go -m := immutable.NewMap(nil) +m := immutable.NewMap[string,int](nil) m = m.Set("jane", 100) m = m.Delete("jane") @@ -193,7 +193,7 @@ pairs in the collection. Unlike Go maps, iterators are deterministic when iterating over key/value pairs. ```go -m := immutable.NewMap(nil) +m := immutable.NewMap[string,int](nil) m = m.Set("jane", 100) m = m.Set("susy", 200) @@ -215,11 +215,11 @@ keys generate the same hash. ### Efficiently building maps If you are executing multiple mutations on a map, it can be much more efficient -to use the `MapBuilder`. It uses nearly the same API as `Map` except that it -updates a map in-place until you are ready to use it. +to use the `MapBuilder`. It uses nearly the same API as `Map` except that it +updates a map in-place until you are ready to use it. ```go -b := immutable.NewMapBuilder(nil) +b := immutable.NewMapBuilder[string,int](nil) b.Set("foo", 100) b.Set("bar", 200) b.Set("foo", 300) @@ -242,9 +242,9 @@ Hashers are fairly simple. They only need to generate hashes for a given key and check equality given two keys. ```go -type Hasher interface { - Hash(key interface{}) uint32 - Equal(a, b interface{}) bool +type Hasher[K constraints.Ordered] interface { + Hash(key K) uint32 + Equal(a, b K) bool } ``` @@ -278,8 +278,8 @@ Comparers on have one method—`Compare()`. It works the same as the `1` if a is greater than `b`, and returns `0` if `a` is equal to `b`. ```go -type Comparer interface { - Compare(a, b interface{}) int +type Comparer[K constraints.Ordered] interface { + Compare(a, b K) int } ``` @@ -1,3 +1,5 @@ module github.com/benbjohnson/immutable -go 1.12 +go 1.18 + +require golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect @@ -0,0 +1,2 @@ +golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf h1:oXVg4h2qJDd9htKxb5SCpFBHLipW6hXmL3qpUixS2jw= +golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys= diff --git a/immutable.go b/immutable.go index c8da568..b189612 100644 --- a/immutable.go +++ b/immutable.go @@ -42,50 +42,51 @@ package immutable import ( - "bytes" "fmt" "math/bits" "reflect" "sort" "strings" + + "golang.org/x/exp/constraints" ) // List is a dense, ordered, indexed collections. They are analogous to slices // in Go. They can be updated by appending to the end of the list, prepending // values to the beginning of the list, or updating existing indexes in the // list. -type List struct { - root listNode // root node - origin int // offset to zero index element - size int // total number of elements in use +type List[T comparable] struct { + root listNode[T] // root node + origin int // offset to zero index element + size int // total number of elements in use } // NewList returns a new empty instance of List. -func NewList() *List { - return &List{ - root: &listLeafNode{}, +func NewList[T comparable]() *List[T] { + return &List[T]{ + root: &listLeafNode[T]{}, } } // clone returns a copy of the list. -func (l *List) clone() *List { +func (l *List[T]) clone() *List[T] { other := *l return &other } // Len returns the number of elements in the list. -func (l *List) Len() int { +func (l *List[T]) Len() int { return l.size } // cap returns the total number of possible elements for the current depth. -func (l *List) cap() int { +func (l *List[T]) cap() int { return 1 << (l.root.depth() * listNodeBits) } // Get returns the value at the given index. Similar to slices, this method will // panic if index is below zero or is greater than or equal to the list size. -func (l *List) Get(index int) interface{} { +func (l *List[T]) Get(index int) T { if index < 0 || index >= l.size { panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index)) } @@ -95,11 +96,11 @@ func (l *List) Get(index int) interface{} { // Set returns a new list with value set at index. Similar to slices, this // method will panic if index is below zero or if the index is greater than // or equal to the list size. -func (l *List) Set(index int, value interface{}) *List { +func (l *List[T]) Set(index int, value T) *List[T] { return l.set(index, value, false) } -func (l *List) set(index int, value interface{}, mutable bool) *List { +func (l *List[T]) set(index int, value T, mutable bool) *List[T] { if index < 0 || index >= l.size { panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index)) } @@ -112,11 +113,11 @@ func (l *List) set(index int, value interface{}, mutable bool) *List { } // Append returns a new list with value added to the end of the list. -func (l *List) Append(value interface{}) *List { +func (l *List[T]) Append(value T) *List[T] { return l.append(value, false) } -func (l *List) append(value interface{}, mutable bool) *List { +func (l *List[T]) append(value T, mutable bool) *List[T] { other := l if !mutable { other = l.clone() @@ -124,7 +125,7 @@ func (l *List) append(value interface{}, mutable bool) *List { // Expand list to the right if no slots remain. if other.size+other.origin >= l.cap() { - newRoot := &listBranchNode{d: other.root.depth() + 1} + newRoot := &listBranchNode[T]{d: other.root.depth() + 1} newRoot.children[0] = other.root other.root = newRoot } @@ -136,11 +137,11 @@ func (l *List) append(value interface{}, mutable bool) *List { } // Prepend returns a new list with value added to the beginning of the list. -func (l *List) Prepend(value interface{}) *List { +func (l *List[T]) Prepend(value T) *List[T] { return l.prepend(value, false) } -func (l *List) prepend(value interface{}, mutable bool) *List { +func (l *List[T]) prepend(value T, mutable bool) *List[T] { other := l if !mutable { other = l.clone() @@ -148,7 +149,7 @@ func (l *List) prepend(value interface{}, mutable bool) *List { // Expand list to the left if no slots remain. if other.origin == 0 { - newRoot := &listBranchNode{d: other.root.depth() + 1} + newRoot := &listBranchNode[T]{d: other.root.depth() + 1} newRoot.children[listNodeSize-1] = other.root other.root = newRoot other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits) @@ -168,11 +169,11 @@ func (l *List) prepend(value interface{}, mutable bool) *List { // // Unlike Go slices, references to inaccessible elements will be automatically // removed so they can be garbage collected. -func (l *List) Slice(start, end int) *List { +func (l *List[T]) Slice(start, end int) *List[T] { return l.slice(start, end, false) } -func (l *List) slice(start, end int, mutable bool) *List { +func (l *List[T]) slice(start, end int, mutable bool) *List[T] { // Panics similar to Go slices. if start < 0 || start > l.size { panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start)) @@ -207,7 +208,7 @@ func (l *List) slice(start, end int, mutable bool) *List { // Replace the current root with the single child & update origin offset. other.origin -= i << (other.root.depth() * listNodeBits) - other.root = other.root.(*listBranchNode).children[i] + other.root = other.root.(*listBranchNode[T]).children[i] } // Ensure all references are removed before start & after end. @@ -218,25 +219,25 @@ func (l *List) slice(start, end int, mutable bool) *List { } // Iterator returns a new iterator for this list positioned at the first index. -func (l *List) Iterator() *ListIterator { - itr := &ListIterator{list: l} +func (l *List[T]) Iterator() *ListIterator[T] { + itr := &ListIterator[T]{list: l} itr.First() return itr } // ListBuilder represents an efficient builder for creating new Lists. -type ListBuilder struct { - list *List // current state +type ListBuilder[T comparable] struct { + list *List[T] // current state } // NewListBuilder returns a new instance of ListBuilder. -func NewListBuilder() *ListBuilder { - return &ListBuilder{list: NewList()} +func NewListBuilder[T comparable]() *ListBuilder[T] { + return &ListBuilder[T]{list: NewList[T]()} } // List returns the current copy of the list. // The builder should not be used again after the list after this call. -func (b *ListBuilder) List() *List { +func (b *ListBuilder[T]) List() *List[T] { assert(b.list != nil, "immutable.ListBuilder.List(): duplicate call to fetch list") list := b.list b.list = nil @@ -244,14 +245,14 @@ func (b *ListBuilder) List() *List { } // Len returns the number of elements in the underlying list. -func (b *ListBuilder) Len() int { +func (b *ListBuilder[T]) Len() int { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") return b.list.Len() } // Get returns the value at the given index. Similar to slices, this method will // panic if index is below zero or is greater than or equal to the list size. -func (b *ListBuilder) Get(index int) interface{} { +func (b *ListBuilder[T]) Get(index int) T { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") return b.list.Get(index) } @@ -259,32 +260,32 @@ func (b *ListBuilder) Get(index int) interface{} { // Set updates the value at the given index. Similar to slices, this method will // panic if index is below zero or if the index is greater than or equal to the // list size. -func (b *ListBuilder) Set(index int, value interface{}) { +func (b *ListBuilder[T]) Set(index int, value T) { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") b.list = b.list.set(index, value, true) } // Append adds value to the end of the list. -func (b *ListBuilder) Append(value interface{}) { +func (b *ListBuilder[T]) Append(value T) { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") b.list = b.list.append(value, true) } // Prepend adds value to the beginning of the list. -func (b *ListBuilder) Prepend(value interface{}) { +func (b *ListBuilder[T]) Prepend(value T) { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") b.list = b.list.prepend(value, true) } // Slice updates the list with a sublist of elements between start and end index. // See List.Slice() for more details. -func (b *ListBuilder) Slice(start, end int) { +func (b *ListBuilder[T]) Slice(start, end int) { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") b.list = b.list.slice(start, end, true) } // Iterator returns a new iterator for the underlying list. -func (b *ListBuilder) Iterator() *ListIterator { +func (b *ListBuilder[T]) Iterator() *ListIterator[T] { assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") return b.list.Iterator() } @@ -297,53 +298,53 @@ const ( ) // listNode represents either a branch or leaf node in a List. -type listNode interface { +type listNode[T comparable] interface { depth() uint - get(index int) interface{} - set(index int, v interface{}, mutable bool) listNode + get(index int) T + set(index int, v T, mutable bool) listNode[T] containsBefore(index int) bool containsAfter(index int) bool - deleteBefore(index int, mutable bool) listNode - deleteAfter(index int, mutable bool) listNode + deleteBefore(index int, mutable bool) listNode[T] + deleteAfter(index int, mutable bool) listNode[T] } // newListNode returns a leaf node for depth zero, otherwise returns a branch node. -func newListNode(depth uint) listNode { +func newListNode[T comparable](depth uint) listNode[T] { if depth == 0 { - return &listLeafNode{} + return &listLeafNode[T]{} } - return &listBranchNode{d: depth} + return &listBranchNode[T]{d: depth} } // listBranchNode represents a branch of a List tree at a given depth. -type listBranchNode struct { +type listBranchNode[T comparable] struct { d uint // depth - children [listNodeSize]listNode + children [listNodeSize]listNode[T] } // depth returns the depth of this branch node from the leaf. -func (n *listBranchNode) depth() uint { return n.d } +func (n *listBranchNode[T]) depth() uint { return n.d } // get returns the child node at the segment of the index for this depth. -func (n *listBranchNode) get(index int) interface{} { +func (n *listBranchNode[T]) get(index int) T { idx := (index >> (n.d * listNodeBits)) & listNodeMask return n.children[idx].get(index) } // set recursively updates the value at index for each lower depth from the node. -func (n *listBranchNode) set(index int, v interface{}, mutable bool) listNode { +func (n *listBranchNode[T]) set(index int, v T, mutable bool) listNode[T] { idx := (index >> (n.d * listNodeBits)) & listNodeMask // Find child for the given value in the branch. Create new if it doesn't exist. child := n.children[idx] if child == nil { - child = newListNode(n.depth() - 1) + child = newListNode[T](n.depth() - 1) } // Return a copy of this branch with the new child. - var other *listBranchNode + var other *listBranchNode[T] if mutable { other = n } else { @@ -355,7 +356,7 @@ func (n *listBranchNode) set(index int, v interface{}, mutable bool) listNode { } // containsBefore returns true if non-nil values exists between [0,index). -func (n *listBranchNode) containsBefore(index int) bool { +func (n *listBranchNode[T]) containsBefore(index int) bool { idx := (index >> (n.d * listNodeBits)) & listNodeMask // Quickly check if any direct children exist before this segment of the index. @@ -373,7 +374,7 @@ func (n *listBranchNode) containsBefore(index int) bool { } // containsAfter returns true if non-nil values exists between (index,listNodeSize). -func (n *listBranchNode) containsAfter(index int) bool { +func (n *listBranchNode[T]) containsAfter(index int) bool { idx := (index >> (n.d * listNodeBits)) & listNodeMask // Quickly check if any direct children exist after this segment of the index. @@ -391,7 +392,7 @@ func (n *listBranchNode) containsAfter(index int) bool { } // deleteBefore returns a new node with all elements before index removed. -func (n *listBranchNode) deleteBefore(index int, mutable bool) listNode { +func (n *listBranchNode[T]) deleteBefore(index int, mutable bool) listNode[T] { // Ignore if no nodes exist before the given index. if !n.containsBefore(index) { return n @@ -400,14 +401,14 @@ func (n *listBranchNode) deleteBefore(index int, mutable bool) listNode { // Return a copy with any nodes prior to the index removed. idx := (index >> (n.d * listNodeBits)) & listNodeMask - var other *listBranchNode + var other *listBranchNode[T] if mutable { other = n for i := 0; i < idx; i++ { n.children[i] = nil } } else { - other = &listBranchNode{d: n.d} + other = &listBranchNode[T]{d: n.d} copy(other.children[idx:][:], n.children[idx:][:]) } @@ -418,7 +419,7 @@ func (n *listBranchNode) deleteBefore(index int, mutable bool) listNode { } // deleteBefore returns a new node with all elements before index removed. -func (n *listBranchNode) deleteAfter(index int, mutable bool) listNode { +func (n *listBranchNode[T]) deleteAfter(index int, mutable bool) listNode[T] { // Ignore if no nodes exist after the given index. if !n.containsAfter(index) { return n @@ -427,14 +428,14 @@ func (n *listBranchNode) deleteAfter(index int, mutable bool) listNode { // Return a copy with any nodes after the index removed. idx := (index >> (n.d * listNodeBits)) & listNodeMask - var other *listBranchNode + var other *listBranchNode[T] if mutable { other = n for i := idx + 1; i < len(n.children); i++ { n.children[i] = nil } } else { - other = &listBranchNode{d: n.d} + other = &listBranchNode[T]{d: n.d} copy(other.children[:idx+1], n.children[:idx+1]) } @@ -445,22 +446,22 @@ func (n *listBranchNode) deleteAfter(index int, mutable bool) listNode { } // listLeafNode represents a leaf node in a List. -type listLeafNode struct { - children [listNodeSize]interface{} +type listLeafNode[T comparable] struct { + children [listNodeSize]T } // depth always returns 0 for leaf nodes. -func (n *listLeafNode) depth() uint { return 0 } +func (n *listLeafNode[T]) depth() uint { return 0 } // get returns the value at the given index. -func (n *listLeafNode) get(index int) interface{} { +func (n *listLeafNode[T]) get(index int) T { return n.children[index&listNodeMask] } // set returns a copy of the node with the value at the index updated to v. -func (n *listLeafNode) set(index int, v interface{}, mutable bool) listNode { +func (n *listLeafNode[T]) set(index int, v T, mutable bool) listNode[T] { idx := index & listNodeMask - var other *listLeafNode + var other *listLeafNode[T] if mutable { other = n } else { @@ -468,14 +469,17 @@ func (n *listLeafNode) set(index int, v interface{}, mutable bool) listNode { other = &tmp } other.children[idx] = v - return other + var otherLN listNode[T] + otherLN = other + return otherLN } // containsBefore returns true if non-nil values exists between [0,index). -func (n *listLeafNode) containsBefore(index int) bool { +func (n *listLeafNode[T]) containsBefore(index int) bool { idx := index & listNodeMask + var empty T for i := 0; i < idx; i++ { - if n.children[i] != nil { + if n.children[i] != empty { return true } } @@ -483,10 +487,11 @@ func (n *listLeafNode) containsBefore(index int) bool { } // containsAfter returns true if non-nil values exists between (index,listNodeSize). -func (n *listLeafNode) containsAfter(index int) bool { +func (n *listLeafNode[T]) containsAfter(index int) bool { idx := index & listNodeMask + var empty T for i := idx + 1; i < len(n.children); i++ { - if n.children[i] != nil { + if n.children[i] != empty { return true } } @@ -494,62 +499,64 @@ func (n *listLeafNode) containsAfter(index int) bool { } // deleteBefore returns a new node with all elements before index removed. -func (n *listLeafNode) deleteBefore(index int, mutable bool) listNode { +func (n *listLeafNode[T]) deleteBefore(index int, mutable bool) listNode[T] { if !n.containsBefore(index) { return n } idx := index & listNodeMask - var other *listLeafNode + var other *listLeafNode[T] if mutable { other = n + var empty T for i := 0; i < idx; i++ { - other.children[i] = nil + other.children[i] = empty } } else { - other = &listLeafNode{} + other = &listLeafNode[T]{} copy(other.children[idx:][:], n.children[idx:][:]) } return other } // deleteBefore returns a new node with all elements before index removed. -func (n *listLeafNode) deleteAfter(index int, mutable bool) listNode { +func (n *listLeafNode[T]) deleteAfter(index int, mutable bool) listNode[T] { if !n.containsAfter(index) { return n } idx := index & listNodeMask - var other *listLeafNode + var other *listLeafNode[T] if mutable { other = n + var empty T for i := idx + 1; i < len(n.children); i++ { - other.children[i] = nil + other.children[i] = empty } } else { - other = &listLeafNode{} + other = &listLeafNode[T]{} copy(other.children[:idx+1][:], n.children[:idx+1][:]) } return other } // ListIterator represents an ordered iterator over a list. -type ListIterator struct { - list *List // source list - index int // current index position +type ListIterator[T comparable] struct { + list *List[T] // source list + index int // current index position - stack [32]listIteratorElem // search stack - depth int // stack depth + stack [32]listIteratorElem[T] // search stack + depth int // stack depth } // Done returns true if no more elements remain in the iterator. -func (itr *ListIterator) Done() bool { +func (itr *ListIterator[T]) Done() bool { return itr.index < 0 || itr.index >= itr.list.Len() } // First positions the iterator on the first index. // If source list is empty then no change is made. -func (itr *ListIterator) First() { +func (itr *ListIterator[T]) First() { if itr.list.Len() != 0 { itr.Seek(0) } @@ -557,7 +564,7 @@ func (itr *ListIterator) First() { // Last positions the iterator on the last index. // If source list is empty then no change is made. -func (itr *ListIterator) Last() { +func (itr *ListIterator[T]) Last() { if n := itr.list.Len(); n != 0 { itr.Seek(n - 1) } @@ -566,7 +573,7 @@ func (itr *ListIterator) Last() { // Seek moves the iterator position to the given index in the list. // Similar to Go slices, this method will panic if index is below zero or if // the index is greater than or equal to the list size. -func (itr *ListIterator) Seek(index int) { +func (itr *ListIterator[T]) Seek(index int) { // Panic similar to Go slices. if index < 0 || index >= itr.list.Len() { panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index)) @@ -574,22 +581,23 @@ func (itr *ListIterator) Seek(index int) { itr.index = index // Reset to the bottom of the stack at seek to the correct position. - itr.stack[0] = listIteratorElem{node: itr.list.root} + itr.stack[0] = listIteratorElem[T]{node: itr.list.root} itr.depth = 0 itr.seek(index) } // Next returns the current index and its value & moves the iterator forward. // Returns an index of -1 if the there are no more elements to return. -func (itr *ListIterator) Next() (index int, value interface{}) { +func (itr *ListIterator[T]) Next() (index int, value T) { // Exit immediately if there are no elements remaining. + var empty T if itr.Done() { - return -1, nil + return -1, empty } // Retrieve current index & value. elem := &itr.stack[itr.depth] - index, value = itr.index, elem.node.(*listLeafNode).children[elem.index] + index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] // Increase index. If index is at the end then return immediately. itr.index++ @@ -609,15 +617,16 @@ func (itr *ListIterator) Next() (index int, value interface{}) { // Prev returns the current index and value and moves the iterator backward. // Returns an index of -1 if the there are no more elements to return. -func (itr *ListIterator) Prev() (index int, value interface{}) { +func (itr *ListIterator[T]) Prev() (index int, value T) { // Exit immediately if there are no elements remaining. + var empty T if itr.Done() { - return -1, nil + return -1, empty } // Retrieve current index & value. elem := &itr.stack[itr.depth] - index, value = itr.index, elem.node.(*listLeafNode).children[elem.index] + index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] // Decrease index. If index is past the beginning then return immediately. itr.index-- @@ -637,26 +646,26 @@ func (itr *ListIterator) Prev() (index int, value interface{}) { // seek positions the stack to the given index from the current depth. // Elements and indexes below the current depth are assumed to be correct. -func (itr *ListIterator) seek(index int) { +func (itr *ListIterator[T]) seek(index int) { // Iterate over each level until we reach a leaf node. for { elem := &itr.stack[itr.depth] elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask switch node := elem.node.(type) { - case *listBranchNode: + case *listBranchNode[T]: child := node.children[elem.index] - itr.stack[itr.depth+1] = listIteratorElem{node: child} + itr.stack[itr.depth+1] = listIteratorElem[T]{node: child} itr.depth++ - case *listLeafNode: + case *listLeafNode[T]: return } } } // listIteratorElem represents the node and it's child index within the stack. -type listIteratorElem struct { - node listNode +type listIteratorElem[T comparable] struct { + node listNode[T] index int } @@ -677,28 +686,28 @@ const ( // to generate hashes and check for equality of key values. // // It is implemented as an Hash Array Mapped Trie. -type Map struct { - size int // total number of key/value pairs - root mapNode // root node of trie - hasher Hasher // hasher implementation +type Map[K constraints.Ordered, V any] struct { + size int // total number of key/value pairs + root mapNode[K, V] // root node of trie + hasher Hasher[K] // hasher implementation } // NewMap returns a new instance of Map. If hasher is nil, a default hasher // implementation will automatically be chosen based on the first key added. // Default hasher implementations only exist for int, string, and byte slice types. -func NewMap(hasher Hasher) *Map { - return &Map{ +func NewMap[K constraints.Ordered, V any](hasher Hasher[K]) *Map[K, V] { + return &Map[K, V]{ hasher: hasher, } } // Len returns the number of elements in the map. -func (m *Map) Len() int { +func (m *Map[K, V]) Len() int { return m.size } // clone returns a shallow copy of m. -func (m *Map) clone() *Map { +func (m *Map[K, V]) clone() *Map[K, V] { other := *m return &other } @@ -706,9 +715,10 @@ func (m *Map) clone() *Map { // Get returns the value for a given key and a flag indicating whether the // key exists. This flag distinguishes a nil value set on a key versus a // non-existent key in the map. -func (m *Map) Get(key interface{}) (value interface{}, ok bool) { +func (m *Map[K, V]) Get(key K) (value V, ok bool) { + var empty V if m.root == nil { - return nil, false + return empty, false } keyHash := m.hasher.Hash(key) return m.root.get(key, 0, keyHash, m.hasher) @@ -718,11 +728,11 @@ func (m *Map) Get(key interface{}) (value interface{}, ok bool) { // // This function will return a new map even if the updated value is the same as // the existing value because Map does not track value equality. -func (m *Map) Set(key, value interface{}) *Map { +func (m *Map[K, V]) Set(key K, value V) *Map[K, V] { return m.set(key, value, false) } -func (m *Map) set(key, value interface{}, mutable bool) *Map { +func (m *Map[K, V]) set(key K, value V, mutable bool) *Map[K, V] { // Set a hasher on the first value if one does not already exist. hasher := m.hasher if hasher == nil { @@ -739,7 +749,7 @@ func (m *Map) set(key, value interface{}, mutable bool) *Map { // If the map is empty, initialize with a simple array node. if m.root == nil { other.size = 1 - other.root = &mapArrayNode{entries: []mapEntry{{key: key, value: value}}} + other.root = &mapArrayNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} return other } @@ -755,11 +765,11 @@ func (m *Map) set(key, value interface{}, mutable bool) *Map { // Delete returns a map with the given key removed. // Removing a non-existent key will cause this method to return the same map. -func (m *Map) Delete(key interface{}) *Map { +func (m *Map[K, V]) Delete(key K) *Map[K, V] { return m.delete(key, false) } -func (m *Map) delete(key interface{}, mutable bool) *Map { +func (m *Map[K, V]) delete(key K, mutable bool) *Map[K, V] { // Return original map if no keys exist. if m.root == nil { return m @@ -785,25 +795,25 @@ func (m *Map) delete(key interface{}, mutable bool) *Map { } // Iterator returns a new iterator for the map. -func (m *Map) Iterator() *MapIterator { - itr := &MapIterator{m: m} +func (m *Map[K, V]) Iterator() *MapIterator[K, V] { + itr := &MapIterator[K, V]{m: m} itr.First() return itr } // MapBuilder represents an efficient builder for creating Maps. -type MapBuilder struct { - m *Map // current state +type MapBuilder[K constraints.Ordered, V any] struct { + m *Map[K, V] // current state } // NewMapBuilder returns a new instance of MapBuilder. -func NewMapBuilder(hasher Hasher) *MapBuilder { - return &MapBuilder{m: NewMap(hasher)} +func NewMapBuilder[K constraints.Ordered, V any](hasher Hasher[K]) *MapBuilder[K, V] { + return &MapBuilder[K, V]{m: NewMap[K, V](hasher)} } // Map returns the underlying map. Only call once. // Builder is invalid after call. Will panic on second invocation. -func (b *MapBuilder) Map() *Map { +func (b *MapBuilder[K, V]) Map() *Map[K, V] { assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") m := b.m b.m = nil @@ -811,66 +821,66 @@ func (b *MapBuilder) Map() *Map { } // Len returns the number of elements in the underlying map. -func (b *MapBuilder) Len() int { +func (b *MapBuilder[K, V]) Len() int { assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") return b.m.Len() } // Get returns the value for the given key. -func (b *MapBuilder) Get(key interface{}) (value interface{}, ok bool) { +func (b *MapBuilder[K, V]) Get(key K) (value V, ok bool) { assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") return b.m.Get(key) } // Set sets the value of the given key. See Map.Set() for additional details. -func (b *MapBuilder) Set(key, value interface{}) { +func (b *MapBuilder[K, V]) Set(key K, value V) { assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") b.m = b.m.set(key, value, true) } // Delete removes the given key. See Map.Delete() for additional details. -func (b *MapBuilder) Delete(key interface{}) { +func (b *MapBuilder[K, V]) Delete(key K) { assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") b.m = b.m.delete(key, true) } // Iterator returns a new iterator for the underlying map. -func (b *MapBuilder) Iterator() *MapIterator { +func (b *MapBuilder[K, V]) Iterator() *MapIterator[K, V] { assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") return b.m.Iterator() } // mapNode represents any node in the map tree. -type mapNode interface { - get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) - set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode - delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode +type mapNode[K constraints.Ordered, V any] interface { + get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) + set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] + delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] } -var _ mapNode = (*mapArrayNode)(nil) -var _ mapNode = (*mapBitmapIndexedNode)(nil) -var _ mapNode = (*mapHashArrayNode)(nil) -var _ mapNode = (*mapValueNode)(nil) -var _ mapNode = (*mapHashCollisionNode)(nil) +var _ mapNode[string, any] = (*mapArrayNode[string, any])(nil) +var _ mapNode[string, any] = (*mapBitmapIndexedNode[string, any])(nil) +var _ mapNode[string, any] = (*mapHashArrayNode[string, any])(nil) +var _ mapNode[string, any] = (*mapValueNode[string, any])(nil) +var _ mapNode[string, any] = (*mapHashCollisionNode[string, any])(nil) // mapLeafNode represents a node that stores a single key hash at the leaf of the map tree. -type mapLeafNode interface { - mapNode +type mapLeafNode[K constraints.Ordered, V any] interface { + mapNode[K, V] keyHashValue() uint32 } -var _ mapLeafNode = (*mapValueNode)(nil) -var _ mapLeafNode = (*mapHashCollisionNode)(nil) +var _ mapLeafNode[string, any] = (*mapValueNode[string, any])(nil) +var _ mapLeafNode[string, any] = (*mapHashCollisionNode[string, any])(nil) // mapArrayNode is a map node that stores key/value pairs in a slice. // Entries are stored in insertion order. An array node expands into a bitmap // indexed node once a given threshold size is crossed. -type mapArrayNode struct { - entries []mapEntry +type mapArrayNode[K constraints.Ordered, V any] struct { + entries []mapEntry[K, V] } // indexOf returns the entry index of the given key. Returns -1 if key not found. -func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int { +func (n *mapArrayNode[K, V]) indexOf(key K, h Hasher[K]) int { for i := range n.entries { if h.Equal(n.entries[i].key, key) { return i @@ -880,17 +890,17 @@ func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int { } // get returns the value for the given key. -func (n *mapArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { +func (n *mapArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { i := n.indexOf(key, h) if i == -1 { - return nil, false + return value, false } return n.entries[i].value, true } // set inserts or updates the value for a given key. If the key is inserted and // the new size crosses the max size threshold, a bitmap indexed node is returned. -func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { idx := n.indexOf(key, h) // Mark as resized if the key doesn't exist. @@ -901,7 +911,7 @@ func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h // If we are adding and it crosses the max size threshold, expand the node. // We do this by continually setting the entries to a value node and expanding. if idx == -1 && len(n.entries) >= maxArrayMapSize { - var node mapNode = newMapValueNode(h.Hash(key), key, value) + var node mapNode[K, V] = newMapValueNode(h.Hash(key), key, value) for _, entry := range n.entries { node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, false, resized) } @@ -911,31 +921,31 @@ func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h // Update in-place if mutable. if mutable { if idx != -1 { - n.entries[idx] = mapEntry{key, value} + n.entries[idx] = mapEntry[K, V]{key, value} } else { - n.entries = append(n.entries, mapEntry{key, value}) + n.entries = append(n.entries, mapEntry[K, V]{key, value}) } return n } // Update existing entry if a match is found. // Otherwise append to the end of the element list if it doesn't exist. - var other mapArrayNode + var other mapArrayNode[K, V] if idx != -1 { - other.entries = make([]mapEntry, len(n.entries)) + other.entries = make([]mapEntry[K, V], len(n.entries)) copy(other.entries, n.entries) - other.entries[idx] = mapEntry{key, value} + other.entries[idx] = mapEntry[K, V]{key, value} } else { - other.entries = make([]mapEntry, len(n.entries)+1) + other.entries = make([]mapEntry[K, V], len(n.entries)+1) copy(other.entries, n.entries) - other.entries[len(other.entries)-1] = mapEntry{key, value} + other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} } return &other } // delete removes the given key from the node. Returns the same node if key does // not exist. Returns a nil node when removing the last entry. -func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { idx := n.indexOf(key, h) // Return original node if key does not exist. @@ -952,13 +962,13 @@ func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Has // Update in-place, if mutable. if mutable { copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry{} + n.entries[len(n.entries)-1] = mapEntry[K, V]{} n.entries = n.entries[:len(n.entries)-1] return n } // Otherwise create a copy with the given entry removed. - other := &mapArrayNode{entries: make([]mapEntry, len(n.entries)-1)} + other := &mapArrayNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} copy(other.entries[:idx], n.entries[:idx]) copy(other.entries[idx:], n.entries[idx+1:]) return other @@ -967,16 +977,16 @@ func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Has // mapBitmapIndexedNode represents a map branch node with a variable number of // node slots and indexed using a bitmap. Indexes for the node slots are // calculated by counting the number of set bits before the target bit using popcount. -type mapBitmapIndexedNode struct { +type mapBitmapIndexedNode[K constraints.Ordered, V any] struct { bitmap uint32 - nodes []mapNode + nodes []mapNode[K, V] } // get returns the value for the given key. -func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { +func (n *mapBitmapIndexedNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) if (n.bitmap & bit) == 0 { - return nil, false + return value, false } child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))] return child.get(key, shift+mapNodeBits, keyHash, h) @@ -984,7 +994,7 @@ func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32, // set inserts or updates the value for the given key. If a new key is inserted // and the size crosses the max size threshold then a hash array node is returned. -func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapBitmapIndexedNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { // Extract the index for the bit segment of the key hash. keyHashFrag := (keyHash >> shift) & mapNodeMask @@ -1002,17 +1012,17 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u // If the node already exists, delegate set operation to it. // If the node doesn't exist then create a simple value leaf node. - var newNode mapNode + var newNode mapNode[K, V] if exists { newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized) } else { - newNode = newMapValueNode(keyHash, key, value) + newNode = newMapValueNode[K, V](keyHash, key, value) } // Convert to a hash-array node once we exceed the max bitmap size. // Copy each node based on their bit position within the bitmap. if !exists && len(n.nodes) > maxBitmapIndexedSize { - var other mapHashArrayNode + var other mapHashArrayNode[K, V] for i := uint(0); i < uint(len(other.nodes)); i++ { if n.bitmap&(uint32(1)<<i) != 0 { other.nodes[i] = n.nodes[other.count] @@ -1039,13 +1049,13 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u // If node exists at given slot then overwrite it with new node. // Otherwise expand the node list and insert new node into appropriate position. - other := &mapBitmapIndexedNode{bitmap: n.bitmap | bit} + other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap | bit} if exists { - other.nodes = make([]mapNode, len(n.nodes)) + other.nodes = make([]mapNode[K, V], len(n.nodes)) copy(other.nodes, n.nodes) other.nodes[idx] = newNode } else { - other.nodes = make([]mapNode, len(n.nodes)+1) + other.nodes = make([]mapNode[K, V], len(n.nodes)+1) copy(other.nodes, n.nodes[:idx]) other.nodes[idx] = newNode copy(other.nodes[idx+1:], n.nodes[idx:]) @@ -1056,7 +1066,7 @@ func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash u // delete removes the key from the tree. If the key does not exist then the // original node is returned. If removing the last child node then a nil is // returned. Note that shrinking the node will not convert it to an array node. -func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapBitmapIndexedNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) // Return original node if key does not exist. @@ -1093,7 +1103,7 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3 } // Return copy with bit removed from bitmap and node removed from node list. - other := &mapBitmapIndexedNode{bitmap: n.bitmap ^ bit, nodes: make([]mapNode, len(n.nodes)-1)} + other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap ^ bit, nodes: make([]mapNode[K, V], len(n.nodes)-1)} copy(other.nodes[:idx], n.nodes[:idx]) copy(other.nodes[idx:], n.nodes[idx+1:]) return other @@ -1102,7 +1112,7 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3 // Generate copy, if necessary. other := n if !mutable { - other = &mapBitmapIndexedNode{bitmap: n.bitmap, nodes: make([]mapNode, len(n.nodes))} + other = &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap, nodes: make([]mapNode[K, V], len(n.nodes))} copy(other.nodes, n.nodes) } @@ -1113,34 +1123,34 @@ func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint3 // mapHashArrayNode is a map branch node that stores nodes in a fixed length // array. Child nodes are indexed by their index bit segment for the current depth. -type mapHashArrayNode struct { - count uint // number of set nodes - nodes [mapNodeSize]mapNode // child node slots, may contain empties +type mapHashArrayNode[K constraints.Ordered, V any] struct { + count uint // number of set nodes + nodes [mapNodeSize]mapNode[K, V] // child node slots, may contain empties } // clone returns a shallow copy of n. -func (n *mapHashArrayNode) clone() *mapHashArrayNode { +func (n *mapHashArrayNode[K, V]) clone() *mapHashArrayNode[K, V] { other := *n return &other } // get returns the value for the given key. -func (n *mapHashArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { +func (n *mapHashArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { node := n.nodes[(keyHash>>shift)&mapNodeMask] if node == nil { - return nil, false + return value, false } return node.get(key, shift+mapNodeBits, keyHash, h) } // set returns a node with the value set for the given key. -func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapHashArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { idx := (keyHash >> shift) & mapNodeMask node := n.nodes[idx] // If node at index doesn't exist, create a simple value leaf node. // Otherwise delegate set to child node. - var newNode mapNode + var newNode mapNode[K, V] if node == nil { *resized = true newNode = newMapValueNode(keyHash, key, value) @@ -1165,7 +1175,7 @@ func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint3 // delete returns a node with the given key removed. Returns the same node if // the key does not exist. If node shrinks to within bitmap-indexed size then // converts to a bitmap-indexed node. -func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapHashArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { idx := (keyHash >> shift) & mapNodeMask node := n.nodes[idx] @@ -1182,7 +1192,7 @@ func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h // If we remove a node and drop below a threshold, convert back to bitmap indexed node. if newNode == nil && n.count <= maxBitmapIndexedSize { - other := &mapBitmapIndexedNode{nodes: make([]mapNode, 0, n.count-1)} + other := &mapBitmapIndexedNode[K, V]{nodes: make([]mapNode[K, V], 0, n.count-1)} for i, child := range n.nodes { if child != nil && uint32(i) != idx { other.bitmap |= 1 << uint(i) @@ -1209,15 +1219,15 @@ func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h // mapValueNode represents a leaf node with a single key/value pair. // A value node can be converted to a hash collision leaf node if a different // key with the same keyHash is inserted. -type mapValueNode struct { +type mapValueNode[K constraints.Ordered, V any] struct { keyHash uint32 - key interface{} - value interface{} + key K + value V } // newMapValueNode returns a new instance of mapValueNode. -func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode { - return &mapValueNode{ +func newMapValueNode[K constraints.Ordered, V any](keyHash uint32, key K, value V) *mapValueNode[K, V] { + return &mapValueNode[K, V]{ keyHash: keyHash, key: key, value: value, @@ -1225,14 +1235,14 @@ func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode { } // keyHashValue returns the key hash for this node. -func (n *mapValueNode) keyHashValue() uint32 { +func (n *mapValueNode[K, V]) keyHashValue() uint32 { return n.keyHash } // get returns the value for the given key. -func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { +func (n *mapValueNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { if !h.Equal(n.key, key) { - return nil, false + return value, false } return n.value, true } @@ -1241,7 +1251,7 @@ func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher // the node's key then a new value node is returned. If key is not equal to the // node's key but has the same hash then a hash collision node is returned. // Otherwise the nodes are merged into a branch node. -func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapValueNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { // If the keys match then return a new value node overwriting the value. if h.Equal(n.key, key) { // Update in-place if mutable. @@ -1257,18 +1267,18 @@ func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h // Recursively merge nodes together if key hashes are different. if n.keyHash != keyHash { - return mergeIntoNode(n, shift, keyHash, key, value) + return mergeIntoNode[K, V](n, shift, keyHash, key, value) } // Merge into collision node if hash matches. - return &mapHashCollisionNode{keyHash: keyHash, entries: []mapEntry{ + return &mapHashCollisionNode[K, V]{keyHash: keyHash, entries: []mapEntry[K, V]{ {key: n.key, value: n.value}, {key: key, value: value}, }} } // delete returns nil if the key matches the node's key. Otherwise returns the original node. -func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapValueNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { // Return original node if the keys do not match. if !h.Equal(n.key, key) { return n @@ -1281,19 +1291,19 @@ func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Has // mapHashCollisionNode represents a leaf node that contains two or more key/value // pairs with the same key hash. Single pairs for a hash are stored as value nodes. -type mapHashCollisionNode struct { +type mapHashCollisionNode[K constraints.Ordered, V any] struct { keyHash uint32 // key hash for all entries - entries []mapEntry + entries []mapEntry[K, V] } // keyHashValue returns the key hash for all entries on the node. -func (n *mapHashCollisionNode) keyHashValue() uint32 { +func (n *mapHashCollisionNode[K, V]) keyHashValue() uint32 { return n.keyHash } // indexOf returns the index of the entry for the given key. // Returns -1 if the key does not exist in the node. -func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int { +func (n *mapHashCollisionNode[K, V]) indexOf(key K, h Hasher[K]) int { for i := range n.entries { if h.Equal(n.entries[i].key, key) { return i @@ -1303,46 +1313,46 @@ func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int { } // get returns the value for the given key. -func (n *mapHashCollisionNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { +func (n *mapHashCollisionNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { for i := range n.entries { if h.Equal(n.entries[i].key, key) { return n.entries[i].value, true } } - return nil, false + return value, false } // set returns a copy of the node with key set to the given value. -func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapHashCollisionNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { // Merge node with key/value pair if this is not a hash collision. if n.keyHash != keyHash { *resized = true - return mergeIntoNode(n, shift, keyHash, key, value) + return mergeIntoNode[K, V](n, shift, keyHash, key, value) } // Update in-place if mutable. if mutable { if idx := n.indexOf(key, h); idx == -1 { *resized = true - n.entries = append(n.entries, mapEntry{key, value}) + n.entries = append(n.entries, mapEntry[K, V]{key, value}) } else { - n.entries[idx] = mapEntry{key, value} + n.entries[idx] = mapEntry[K, V]{key, value} } return n } // Append to end of node if key doesn't exist & mark resized. // Otherwise copy nodes and overwrite at matching key index. - other := &mapHashCollisionNode{keyHash: n.keyHash} + other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash} if idx := n.indexOf(key, h); idx == -1 { *resized = true - other.entries = make([]mapEntry, len(n.entries)+1) + other.entries = make([]mapEntry[K, V], len(n.entries)+1) copy(other.entries, n.entries) - other.entries[len(other.entries)-1] = mapEntry{key, value} + other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} } else { - other.entries = make([]mapEntry, len(n.entries)) + other.entries = make([]mapEntry[K, V], len(n.entries)) copy(other.entries, n.entries) - other.entries[idx] = mapEntry{key, value} + other.entries[idx] = mapEntry[K, V]{key, value} } return other } @@ -1350,7 +1360,7 @@ func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash u // delete returns a node with the given key deleted. Returns the same node if // the key does not exist. If removing the key would shrink the node to a single // entry then a value node is returned. -func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher, mutable bool, resized *bool) mapNode { +func (n *mapHashCollisionNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { idx := n.indexOf(key, h) // Return original node if key is not found. @@ -1363,7 +1373,7 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3 // Convert to value node if we move to one entry. if len(n.entries) == 2 { - return &mapValueNode{ + return &mapValueNode[K, V]{ keyHash: n.keyHash, key: n.entries[idx^1].key, value: n.entries[idx^1].value, @@ -1373,13 +1383,13 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3 // Remove entry in-place if mutable. if mutable { copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry{} + n.entries[len(n.entries)-1] = mapEntry[K, V]{} n.entries = n.entries[:len(n.entries)-1] return n } // Return copy without entry if immutable. - other := &mapHashCollisionNode{keyHash: n.keyHash, entries: make([]mapEntry, len(n.entries)-1)} + other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash, entries: make([]mapEntry[K, V], len(n.entries)-1)} copy(other.entries[:idx], n.entries[:idx]) copy(other.entries[idx:], n.entries[idx+1:]) return other @@ -1387,46 +1397,46 @@ func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint3 // mergeIntoNode merges a key/value pair into an existing node. // Caller must verify that node's keyHash is not equal to keyHash. -func mergeIntoNode(node mapLeafNode, shift uint, keyHash uint32, key, value interface{}) mapNode { +func mergeIntoNode[K constraints.Ordered, V any](node mapLeafNode[K, V], shift uint, keyHash uint32, key K, value V) mapNode[K, V] { idx1 := (node.keyHashValue() >> shift) & mapNodeMask idx2 := (keyHash >> shift) & mapNodeMask // Recursively build branch nodes to combine the node and its key. - other := &mapBitmapIndexedNode{bitmap: (1 << idx1) | (1 << idx2)} + other := &mapBitmapIndexedNode[K, V]{bitmap: (1 << idx1) | (1 << idx2)} if idx1 == idx2 { - other.nodes = []mapNode{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)} + other.nodes = []mapNode[K, V]{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)} } else { if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 { - other.nodes = []mapNode{node, newNode} + other.nodes = []mapNode[K, V]{node, newNode} } else { - other.nodes = []mapNode{newNode, node} + other.nodes = []mapNode[K, V]{newNode, node} } } return other } // mapEntry represents a single key/value pair. -type mapEntry struct { - key interface{} - value interface{} +type mapEntry[K constraints.Ordered, V any] struct { + key K + value V } // MapIterator represents an iterator over a map's key/value pairs. Although // map keys are not sorted, the iterator's order is deterministic. -type MapIterator struct { - m *Map // source map +type MapIterator[K constraints.Ordered, V any] struct { + m *Map[K, V] // source map - stack [32]mapIteratorElem // search stack - depth int // stack depth + stack [32]mapIteratorElem[K, V] // search stack + depth int // stack depth } // Done returns true if no more elements remain in the iterator. -func (itr *MapIterator) Done() bool { +func (itr *MapIterator[K, V]) Done() bool { return itr.depth == -1 } // First resets the iterator to the first key/value pair. -func (itr *MapIterator) First() { +func (itr *MapIterator[K, V]) First() { // Exit immediately if the map is empty. if itr.m.root == nil { itr.depth = -1 @@ -1434,27 +1444,27 @@ func (itr *MapIterator) First() { } // Initialize the stack to the left most element. - itr.stack[0] = mapIteratorElem{node: itr.m.root} + itr.stack[0] = mapIteratorElem[K, V]{node: itr.m.root} itr.depth = 0 itr.first() } // Next returns the next key/value pair. Returns a nil key when no elements remain. -func (itr *MapIterator) Next() (key, value interface{}) { +func (itr *MapIterator[K, V]) Next() (key K, value V, ok bool) { // Return nil key if iteration is done. if itr.Done() { - return nil, nil + return key, value, false } // Retrieve current index & value. Current node is always a leaf. elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *mapArrayNode: + case *mapArrayNode[K, V]: entry := &node.entries[elem.index] key, value = entry.key, entry.value - case *mapValueNode: + case *mapValueNode[K, V]: key, value = node.key, node.value - case *mapHashCollisionNode: + case *mapHashCollisionNode[K, V]: entry := &node.entries[elem.index] key, value = entry.key, entry.value } @@ -1462,22 +1472,22 @@ func (itr *MapIterator) Next() (key, value interface{}) { // Move up stack until we find a node that has remaining position ahead // and move that element forward by one. itr.next() - return key, value + return key, value, true } // next moves to the next available key. -func (itr *MapIterator) next() { +func (itr *MapIterator[K, V]) next() { for ; itr.depth >= 0; itr.depth-- { elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *mapArrayNode: + case *mapArrayNode[K, V]: if elem.index < len(node.entries)-1 { elem.index++ return } - case *mapBitmapIndexedNode: + case *mapBitmapIndexedNode[K, V]: if elem.index < len(node.nodes)-1 { elem.index++ itr.stack[itr.depth+1].node = node.nodes[elem.index] @@ -1486,7 +1496,7 @@ func (itr *MapIterator) next() { return } - case *mapHashArrayNode: + case *mapHashArrayNode[K, V]: for i := elem.index + 1; i < len(node.nodes); i++ { if node.nodes[i] != nil { elem.index = i @@ -1497,10 +1507,10 @@ func (itr *MapIterator) next() { } } - case *mapValueNode: + case *mapValueNode[K, V]: continue // always the last value, traverse up - case *mapHashCollisionNode: + case *mapHashCollisionNode[K, V]: if elem.index < len(node.entries)-1 { elem.index++ return @@ -1511,16 +1521,16 @@ func (itr *MapIterator) next() { // first positions the stack left most index. // Elements and indexes at and below the current depth are assumed to be correct. -func (itr *MapIterator) first() { +func (itr *MapIterator[K, V]) first() { for ; ; itr.depth++ { elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *mapBitmapIndexedNode: + case *mapBitmapIndexedNode[K, V]: elem.index = 0 itr.stack[itr.depth+1].node = node.nodes[0] - case *mapHashArrayNode: + case *mapHashArrayNode[K, V]: for i := 0; i < len(node.nodes); i++ { if node.nodes[i] != nil { // find first node elem.index = i @@ -1537,8 +1547,8 @@ func (itr *MapIterator) first() { } // mapIteratorElem represents a node/index pair in the MapIterator stack. -type mapIteratorElem struct { - node mapNode +type mapIteratorElem[K constraints.Ordered, V any] struct { + node mapNode[K, V] index int } @@ -1551,45 +1561,46 @@ const ( // is determined by the Comparer used by the map. // // This map is implemented as a B+tree. -type SortedMap struct { - size int // total number of key/value pairs - root sortedMapNode // root of b+tree - comparer Comparer +type SortedMap[K constraints.Ordered, V any] struct { + size int // total number of key/value pairs + root sortedMapNode[K, V] // root of b+tree + comparer Comparer[K] } // NewSortedMap returns a new instance of SortedMap. If comparer is nil then // a default comparer is set after the first key is inserted. Default comparers // exist for int, string, and byte slice keys. -func NewSortedMap(comparer Comparer) *SortedMap { - return &SortedMap{ +func NewSortedMap[K constraints.Ordered, V any](comparer Comparer[K]) *SortedMap[K, V] { + return &SortedMap[K, V]{ comparer: comparer, } } // Len returns the number of elements in the sorted map. -func (m *SortedMap) Len() int { +func (m *SortedMap[K, V]) Len() int { return m.size } // Get returns the value for a given key and a flag indicating if the key is set. // The flag can be used to distinguish between a nil-set key versus an unset key. -func (m *SortedMap) Get(key interface{}) (interface{}, bool) { +func (m *SortedMap[K, V]) Get(key K) (V, bool) { if m.root == nil { - return nil, false + var v V + return v, false } return m.root.get(key, m.comparer) } // Set returns a copy of the map with the key set to the given value. -func (m *SortedMap) Set(key, value interface{}) *SortedMap { +func (m *SortedMap[K, V]) Set(key K, value V) *SortedMap[K, V] { return m.set(key, value, false) } -func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap { +func (m *SortedMap[K, V]) set(key K, value V, mutable bool) *SortedMap[K, V] { // Set a comparer on the first value if one does not already exist. comparer := m.comparer if comparer == nil { - comparer = NewComparer(key) + comparer = NewComparer[K](key) } // Create copy, if necessary. @@ -1602,7 +1613,7 @@ func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap { // If no values are set then initialize with a leaf node. if m.root == nil { other.size = 1 - other.root = &sortedMapLeafNode{entries: []mapEntry{{key: key, value: value}}} + other.root = &sortedMapLeafNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} return other } @@ -1625,11 +1636,11 @@ func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap { // Delete returns a copy of the map with the key removed. // Returns the original map if key does not exist. -func (m *SortedMap) Delete(key interface{}) *SortedMap { +func (m *SortedMap[K, V]) Delete(key K) *SortedMap[K, V] { return m.delete(key, false) } -func (m *SortedMap) delete(key interface{}, mutable bool) *SortedMap { +func (m *SortedMap[K, V]) delete(key K, mutable bool) *SortedMap[K, V] { // Return original map if no keys exist. if m.root == nil { return m @@ -1655,31 +1666,31 @@ func (m *SortedMap) delete(key interface{}, mutable bool) *SortedMap { } // clone returns a shallow copy of m. -func (m *SortedMap) clone() *SortedMap { +func (m *SortedMap[K, V]) clone() *SortedMap[K, V] { other := *m return &other } // Iterator returns a new iterator for this map positioned at the first key. -func (m *SortedMap) Iterator() *SortedMapIterator { - itr := &SortedMapIterator{m: m} +func (m *SortedMap[K, V]) Iterator() *SortedMapIterator[K, V] { + itr := &SortedMapIterator[K, V]{m: m} itr.First() return itr } // SortedMapBuilder represents an efficient builder for creating sorted maps. -type SortedMapBuilder struct { - m *SortedMap // current state +type SortedMapBuilder[K constraints.Ordered, V any] struct { + m *SortedMap[K, V] // current state } // NewSortedMapBuilder returns a new instance of SortedMapBuilder. -func NewSortedMapBuilder(comparer Comparer) *SortedMapBuilder { - return &SortedMapBuilder{m: NewSortedMap(comparer)} +func NewSortedMapBuilder[K constraints.Ordered, V any](comparer Comparer[K]) *SortedMapBuilder[K, V] { + return &SortedMapBuilder[K, V]{m: NewSortedMap[K, V](comparer)} } // SortedMap returns the current copy of the map. // The returned map is safe to use even if after the builder continues to be used. -func (b *SortedMapBuilder) Map() *SortedMap { +func (b *SortedMapBuilder[K, V]) Map() *SortedMap[K, V] { assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") m := b.m b.m = nil @@ -1687,73 +1698,73 @@ func (b *SortedMapBuilder) Map() *SortedMap { } // Len returns the number of elements in the underlying map. -func (b *SortedMapBuilder) Len() int { +func (b *SortedMapBuilder[K, V]) Len() int { assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") return b.m.Len() } // Get returns the value for the given key. -func (b *SortedMapBuilder) Get(key interface{}) (value interface{}, ok bool) { +func (b *SortedMapBuilder[K, V]) Get(key K) (value V, ok bool) { assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") return b.m.Get(key) } // Set sets the value of the given key. See SortedMap.Set() for additional details. -func (b *SortedMapBuilder) Set(key, value interface{}) { +func (b *SortedMapBuilder[K, V]) Set(key K, value V) { assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") b.m = b.m.set(key, value, true) } // Delete removes the given key. See SortedMap.Delete() for additional details. -func (b *SortedMapBuilder) Delete(key interface{}) { +func (b *SortedMapBuilder[K, V]) Delete(key K) { assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") b.m = b.m.delete(key, true) } // Iterator returns a new iterator for the underlying map positioned at the first key. -func (b *SortedMapBuilder) Iterator() *SortedMapIterator { +func (b *SortedMapBuilder[K, V]) Iterator() *SortedMapIterator[K, V] { assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") return b.m.Iterator() } // sortedMapNode represents a branch or leaf node in the sorted map. -type sortedMapNode interface { - minKey() interface{} - indexOf(key interface{}, c Comparer) int - get(key interface{}, c Comparer) (value interface{}, ok bool) - set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode) - delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode +type sortedMapNode[K constraints.Ordered, V any] interface { + minKey() K + indexOf(key K, c Comparer[K]) int + get(key K, c Comparer[K]) (value V, ok bool) + set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) + delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] } -var _ sortedMapNode = (*sortedMapBranchNode)(nil) -var _ sortedMapNode = (*sortedMapLeafNode)(nil) +var _ sortedMapNode[string, any] = (*sortedMapBranchNode[string, any])(nil) +var _ sortedMapNode[string, any] = (*sortedMapLeafNode[string, any])(nil) // sortedMapBranchNode represents a branch in the sorted map. -type sortedMapBranchNode struct { - elems []sortedMapBranchElem +type sortedMapBranchNode[K constraints.Ordered, V any] struct { + elems []sortedMapBranchElem[K, V] } // newSortedMapBranchNode returns a new branch node with the given child nodes. -func newSortedMapBranchNode(children ...sortedMapNode) *sortedMapBranchNode { +func newSortedMapBranchNode[K constraints.Ordered, V any](children ...sortedMapNode[K, V]) *sortedMapBranchNode[K, V] { // Fetch min keys for every child. - elems := make([]sortedMapBranchElem, len(children)) + elems := make([]sortedMapBranchElem[K, V], len(children)) for i, child := range children { - elems[i] = sortedMapBranchElem{ + elems[i] = sortedMapBranchElem[K, V]{ key: child.minKey(), node: child, } } - return &sortedMapBranchNode{elems: elems} + return &sortedMapBranchNode[K, V]{elems: elems} } // minKey returns the lowest key stored in this node's tree. -func (n *sortedMapBranchNode) minKey() interface{} { +func (n *sortedMapBranchNode[K, V]) minKey() K { return n.elems[0].node.minKey() } // indexOf returns the index of the key within the child nodes. -func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int { +func (n *sortedMapBranchNode[K, V]) indexOf(key K, c Comparer[K]) int { if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 { return idx - 1 } @@ -1761,13 +1772,13 @@ func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int { } // get returns the value for the given key. -func (n *sortedMapBranchNode) get(key interface{}, c Comparer) (value interface{}, ok bool) { +func (n *sortedMapBranchNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { idx := n.indexOf(key, c) return n.elems[idx].node.get(key, c) } // set returns a copy of the node with the key set to the given value. -func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode) { +func (n *sortedMapBranchNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { idx := n.indexOf(key, c) // Delegate insert to child node. @@ -1775,18 +1786,18 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo // Update in-place, if mutable. if mutable { - n.elems[idx] = sortedMapBranchElem{key: newNode.minKey(), node: newNode} + n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} if splitNode != nil { - n.elems = append(n.elems, sortedMapBranchElem{}) + n.elems = append(n.elems, sortedMapBranchElem[K, V]{}) copy(n.elems[idx+1:], n.elems[idx:]) - n.elems[idx+1] = sortedMapBranchElem{key: splitNode.minKey(), node: splitNode} + n.elems[idx+1] = sortedMapBranchElem[K, V]{key: splitNode.minKey(), node: splitNode} } // If the child splits and we have no more room then we split too. if len(n.elems) > sortedMapNodeSize { splitIdx := len(n.elems) / 2 - newNode := &sortedMapBranchNode{elems: n.elems[:splitIdx:splitIdx]} - splitNode := &sortedMapBranchNode{elems: n.elems[splitIdx:]} + newNode := &sortedMapBranchNode[K, V]{elems: n.elems[:splitIdx:splitIdx]} + splitNode := &sortedMapBranchNode[K, V]{elems: n.elems[splitIdx:]} return newNode, splitNode } return n, nil @@ -1794,23 +1805,23 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo // If no split occurs, copy branch and update keys. // If the child splits, insert new key/child into copy of branch. - var other sortedMapBranchNode + var other sortedMapBranchNode[K, V] if splitNode == nil { - other.elems = make([]sortedMapBranchElem, len(n.elems)) + other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)) copy(other.elems, n.elems) - other.elems[idx] = sortedMapBranchElem{ + other.elems[idx] = sortedMapBranchElem[K, V]{ key: newNode.minKey(), node: newNode, } } else { - other.elems = make([]sortedMapBranchElem, len(n.elems)+1) + other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)+1) copy(other.elems[:idx], n.elems[:idx]) copy(other.elems[idx+1:], n.elems[idx:]) - other.elems[idx] = sortedMapBranchElem{ + other.elems[idx] = sortedMapBranchElem[K, V]{ key: newNode.minKey(), node: newNode, } - other.elems[idx+1] = sortedMapBranchElem{ + other.elems[idx+1] = sortedMapBranchElem[K, V]{ key: splitNode.minKey(), node: splitNode, } @@ -1819,8 +1830,8 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo // If the child splits and we have no more room then we split too. if len(other.elems) > sortedMapNodeSize { splitIdx := len(other.elems) / 2 - newNode := &sortedMapBranchNode{elems: other.elems[:splitIdx:splitIdx]} - splitNode := &sortedMapBranchNode{elems: other.elems[splitIdx:]} + newNode := &sortedMapBranchNode[K, V]{elems: other.elems[:splitIdx:splitIdx]} + splitNode := &sortedMapBranchNode[K, V]{elems: other.elems[splitIdx:]} return newNode, splitNode } @@ -1830,7 +1841,7 @@ func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, mutable bo // delete returns a node with the key removed. Returns the same node if the key // does not exist. Returns nil if all child nodes are removed. -func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode { +func (n *sortedMapBranchNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { idx := n.indexOf(key, c) // Return original node if child has not changed. @@ -1849,13 +1860,13 @@ func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool, // If mutable, update in-place. if mutable { copy(n.elems[idx:], n.elems[idx+1:]) - n.elems[len(n.elems)-1] = sortedMapBranchElem{} + n.elems[len(n.elems)-1] = sortedMapBranchElem[K, V]{} n.elems = n.elems[:len(n.elems)-1] return n } // Return a copy without the given node. - other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems)-1)} + other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems)-1)} copy(other.elems[:idx], n.elems[:idx]) copy(other.elems[idx:], n.elems[idx+1:]) return other @@ -1863,49 +1874,49 @@ func (n *sortedMapBranchNode) delete(key interface{}, c Comparer, mutable bool, // If mutable, update in-place. if mutable { - n.elems[idx] = sortedMapBranchElem{key: newNode.minKey(), node: newNode} + n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} return n } // Return a copy with the updated node. - other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems))} + other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems))} copy(other.elems, n.elems) - other.elems[idx] = sortedMapBranchElem{ + other.elems[idx] = sortedMapBranchElem[K, V]{ key: newNode.minKey(), node: newNode, } return other } -type sortedMapBranchElem struct { - key interface{} - node sortedMapNode +type sortedMapBranchElem[K constraints.Ordered, V any] struct { + key K + node sortedMapNode[K, V] } // sortedMapLeafNode represents a leaf node in the sorted map. -type sortedMapLeafNode struct { - entries []mapEntry +type sortedMapLeafNode[K constraints.Ordered, V any] struct { + entries []mapEntry[K, V] } // minKey returns the first key stored in this node. -func (n *sortedMapLeafNode) minKey() interface{} { +func (n *sortedMapLeafNode[K, V]) minKey() K { return n.entries[0].key } // indexOf returns the index of the given key. -func (n *sortedMapLeafNode) indexOf(key interface{}, c Comparer) int { +func (n *sortedMapLeafNode[K, V]) indexOf(key K, c Comparer[K]) int { return sort.Search(len(n.entries), func(i int) bool { return c.Compare(n.entries[i].key, key) != -1 // GTE }) } // get returns the value of the given key. -func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{}, ok bool) { +func (n *sortedMapLeafNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { idx := n.indexOf(key, c) // If the index is beyond the entry count or the key is not equal then return 'not found'. if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { - return nil, false + return value, false } // If the key matches then return its value. @@ -1914,7 +1925,7 @@ func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{}, // set returns a copy of node with the key set to the given value. If the update // causes the node to grow beyond the maximum size then it is split in two. -func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool, resized *bool) (sortedMapNode, sortedMapNode) { +func (n *sortedMapLeafNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { // Find the insertion index for the key. idx := n.indexOf(key, c) exists := idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0 @@ -1923,16 +1934,16 @@ func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool if mutable { if !exists { *resized = true - n.entries = append(n.entries, mapEntry{}) + n.entries = append(n.entries, mapEntry[K, V]{}) copy(n.entries[idx+1:], n.entries[idx:]) } - n.entries[idx] = mapEntry{key: key, value: value} + n.entries[idx] = mapEntry[K, V]{key: key, value: value} // If the key doesn't exist and we exceed our max allowed values then split. if len(n.entries) > sortedMapNodeSize { splitIdx := len(n.entries) / 2 - newNode := &sortedMapLeafNode{entries: n.entries[:splitIdx:splitIdx]} - splitNode := &sortedMapLeafNode{entries: n.entries[splitIdx:]} + newNode := &sortedMapLeafNode[K, V]{entries: n.entries[:splitIdx:splitIdx]} + splitNode := &sortedMapLeafNode[K, V]{entries: n.entries[splitIdx:]} return newNode, splitNode } return n, nil @@ -1940,34 +1951,34 @@ func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, mutable bool // If the key matches then simply return a copy with the entry overridden. // If there is no match then insert new entry and mark as resized. - var newEntries []mapEntry + var newEntries []mapEntry[K, V] if exists { - newEntries = make([]mapEntry, len(n.entries)) + newEntries = make([]mapEntry[K, V], len(n.entries)) copy(newEntries, n.entries) - newEntries[idx] = mapEntry{key: key, value: value} + newEntries[idx] = mapEntry[K, V]{key: key, value: value} } else { *resized = true - newEntries = make([]mapEntry, len(n.entries)+1) + newEntries = make([]mapEntry[K, V], len(n.entries)+1) copy(newEntries[:idx], n.entries[:idx]) - newEntries[idx] = mapEntry{key: key, value: value} + newEntries[idx] = mapEntry[K, V]{key: key, value: value} copy(newEntries[idx+1:], n.entries[idx:]) } // If the key doesn't exist and we exceed our max allowed values then split. if len(newEntries) > sortedMapNodeSize { splitIdx := len(newEntries) / 2 - newNode := &sortedMapLeafNode{entries: newEntries[:splitIdx:splitIdx]} - splitNode := &sortedMapLeafNode{entries: newEntries[splitIdx:]} + newNode := &sortedMapLeafNode[K, V]{entries: newEntries[:splitIdx:splitIdx]} + splitNode := &sortedMapLeafNode[K, V]{entries: newEntries[splitIdx:]} return newNode, splitNode } // Otherwise return the new leaf node with the updated entry. - return &sortedMapLeafNode{entries: newEntries}, nil + return &sortedMapLeafNode[K, V]{entries: newEntries}, nil } // delete returns a copy of node with key removed. Returns the original node if // the key does not exist. Returns nil if the removed key is the last remaining key. -func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, resized *bool) sortedMapNode { +func (n *sortedMapLeafNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { idx := n.indexOf(key, c) // Return original node if key is not found. @@ -1984,13 +1995,13 @@ func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, re // Update in-place, if mutable. if mutable { copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry{} + n.entries[len(n.entries)-1] = mapEntry[K, V]{} n.entries = n.entries[:len(n.entries)-1] return n } // Return copy of node with entry removed. - other := &sortedMapLeafNode{entries: make([]mapEntry, len(n.entries)-1)} + other := &sortedMapLeafNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} copy(other.entries[:idx], n.entries[:idx]) copy(other.entries[idx:], n.entries[idx+1:]) return other @@ -1998,36 +2009,36 @@ func (n *sortedMapLeafNode) delete(key interface{}, c Comparer, mutable bool, re // SortedMapIterator represents an iterator over a sorted map. // Iteration can occur in natural or reverse order based on use of Next() or Prev(). -type SortedMapIterator struct { - m *SortedMap // source map +type SortedMapIterator[K constraints.Ordered, V any] struct { + m *SortedMap[K, V] // source map - stack [32]sortedMapIteratorElem // search stack - depth int // stack depth + stack [32]sortedMapIteratorElem[K, V] // search stack + depth int // stack depth } // Done returns true if no more key/value pairs remain in the iterator. -func (itr *SortedMapIterator) Done() bool { +func (itr *SortedMapIterator[K, V]) Done() bool { return itr.depth == -1 } // First moves the iterator to the first key/value pair. -func (itr *SortedMapIterator) First() { +func (itr *SortedMapIterator[K, V]) First() { if itr.m.root == nil { itr.depth = -1 return } - itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} itr.depth = 0 itr.first() } // Last moves the iterator to the last key/value pair. -func (itr *SortedMapIterator) Last() { +func (itr *SortedMapIterator[K, V]) Last() { if itr.m.root == nil { itr.depth = -1 return } - itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} itr.depth = 0 itr.last() } @@ -2035,27 +2046,27 @@ func (itr *SortedMapIterator) Last() { // Seek moves the iterator position to the given key in the map. // If the key does not exist then the next key is used. If no more keys exist // then the iteartor is marked as done. -func (itr *SortedMapIterator) Seek(key interface{}) { +func (itr *SortedMapIterator[K, V]) Seek(key K) { if itr.m.root == nil { itr.depth = -1 return } - itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} itr.depth = 0 itr.seek(key) } // Next returns the current key/value pair and moves the iterator forward. // Returns a nil key if the there are no more elements to return. -func (itr *SortedMapIterator) Next() (key, value interface{}) { +func (itr *SortedMapIterator[K, V]) Next() (key K, value V, ok bool) { // Return nil key if iteration is complete. if itr.Done() { - return nil, nil + return key, value, false } // Retrieve current key/value pair. leafElem := &itr.stack[itr.depth] - leafNode := leafElem.node.(*sortedMapLeafNode) + leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) leafEntry := &leafNode.entries[leafElem.index] key, value = leafEntry.key, leafEntry.value @@ -2063,21 +2074,21 @@ func (itr *SortedMapIterator) Next() (key, value interface{}) { itr.next() // Only occurs when iterator is done. - return key, value + return key, value, true } // next moves to the next key. If no keys are after then depth is set to -1. -func (itr *SortedMapIterator) next() { +func (itr *SortedMapIterator[K, V]) next() { for ; itr.depth >= 0; itr.depth-- { elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *sortedMapLeafNode: + case *sortedMapLeafNode[K, V]: if elem.index < len(node.entries)-1 { elem.index++ return } - case *sortedMapBranchNode: + case *sortedMapBranchNode[K, V]: if elem.index < len(node.elems)-1 { elem.index++ itr.stack[itr.depth+1].node = node.elems[elem.index].node @@ -2091,34 +2102,34 @@ func (itr *SortedMapIterator) next() { // Prev returns the current key/value pair and moves the iterator backward. // Returns a nil key if the there are no more elements to return. -func (itr *SortedMapIterator) Prev() (key, value interface{}) { +func (itr *SortedMapIterator[K, V]) Prev() (key K, value V, ok bool) { // Return nil key if iteration is complete. if itr.Done() { - return nil, nil + return key, value, false } // Retrieve current key/value pair. leafElem := &itr.stack[itr.depth] - leafNode := leafElem.node.(*sortedMapLeafNode) + leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) leafEntry := &leafNode.entries[leafElem.index] key, value = leafEntry.key, leafEntry.value itr.prev() - return key, value + return key, value, true } // prev moves to the previous key. If no keys are before then depth is set to -1. -func (itr *SortedMapIterator) prev() { +func (itr *SortedMapIterator[K, V]) prev() { for ; itr.depth >= 0; itr.depth-- { elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *sortedMapLeafNode: + case *sortedMapLeafNode[K, V]: if elem.index > 0 { elem.index-- return } - case *sortedMapBranchNode: + case *sortedMapBranchNode[K, V]: if elem.index > 0 { elem.index-- itr.stack[itr.depth+1].node = node.elems[elem.index].node @@ -2132,16 +2143,16 @@ func (itr *SortedMapIterator) prev() { // first positions the stack to the leftmost key from the current depth. // Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator) first() { +func (itr *SortedMapIterator[K, V]) first() { for { elem := &itr.stack[itr.depth] elem.index = 0 switch node := elem.node.(type) { - case *sortedMapBranchNode: - itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + case *sortedMapBranchNode[K, V]: + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} itr.depth++ - case *sortedMapLeafNode: + case *sortedMapLeafNode[K, V]: return } } @@ -2149,16 +2160,16 @@ func (itr *SortedMapIterator) first() { // last positions the stack to the rightmost key from the current depth. // Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator) last() { +func (itr *SortedMapIterator[K, V]) last() { for { elem := &itr.stack[itr.depth] switch node := elem.node.(type) { - case *sortedMapBranchNode: + case *sortedMapBranchNode[K, V]: elem.index = len(node.elems) - 1 - itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} itr.depth++ - case *sortedMapLeafNode: + case *sortedMapLeafNode[K, V]: elem.index = len(node.entries) - 1 return } @@ -2167,16 +2178,16 @@ func (itr *SortedMapIterator) last() { // seek positions the stack to the given key from the current depth. // Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator) seek(key interface{}) { +func (itr *SortedMapIterator[K, V]) seek(key K) { for { elem := &itr.stack[itr.depth] elem.index = elem.node.indexOf(key, itr.m.comparer) switch node := elem.node.(type) { - case *sortedMapBranchNode: - itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + case *sortedMapBranchNode[K, V]: + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} itr.depth++ - case *sortedMapLeafNode: + case *sortedMapLeafNode[K, V]: if elem.index == len(node.entries) { itr.next() } @@ -2186,59 +2197,33 @@ func (itr *SortedMapIterator) seek(key interface{}) { } // sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack. -type sortedMapIteratorElem struct { - node sortedMapNode +type sortedMapIteratorElem[K constraints.Ordered, V any] struct { + node sortedMapNode[K, V] index int } // Hasher hashes keys and checks them for equality. -type Hasher interface { - // Computes a 32-bit hash for key. - Hash(key interface{}) uint32 +type Hasher[K constraints.Ordered] interface { + // Computes a hash for key. + Hash(key K) uint32 // Returns true if a and b are equal. - Equal(a, b interface{}) bool + Equal(a, b K) bool } // NewHasher returns the built-in hasher for a given key type. -func NewHasher(key interface{}) Hasher { +func NewHasher[K constraints.Ordered](key K) Hasher[K] { // Attempt to use non-reflection based hasher first. - switch key.(type) { - case int: - return &intHasher{} - case int8: - return &int8Hasher{} - case int16: - return &int16Hasher{} - case int32: - return &int32Hasher{} - case int64: - return &int64Hasher{} - case uint: - return &uintHasher{} - case uint8: - return &uint8Hasher{} - case uint16: - return &uint16Hasher{} - case uint32: - return &uint32Hasher{} - case uint64: - return &uint64Hasher{} - case string: - return &stringHasher{} - case []byte: - return &byteSliceHasher{} + switch (any(key)).(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: + return &defaultHasher[K]{} } // Fallback to reflection-based hasher otherwise. // This is used when caller wraps a type around a primitive type. switch reflect.TypeOf(key).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return &reflectIntHasher{} - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &reflectUintHasher{} - case reflect.String: - return &reflectStringHasher{} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + return &reflectHasher[K]{} } // If no hashers match then panic. @@ -2246,227 +2231,49 @@ func NewHasher(key interface{}) Hasher { panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key)) } -// intHasher implements Hasher for int keys. -type intHasher struct{} - -// Hash returns a hash for key. -func (h *intHasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(int))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not ints. -func (h *intHasher) Equal(a, b interface{}) bool { - return a.(int) == b.(int) -} - -// int8Hasher implements Hasher for int8 keys. -type int8Hasher struct{} - -// Hash returns a hash for key. -func (h *int8Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(int8))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not int8s. -func (h *int8Hasher) Equal(a, b interface{}) bool { - return a.(int8) == b.(int8) -} - -// int16Hasher implements Hasher for int16 keys. -type int16Hasher struct{} - -// Hash returns a hash for key. -func (h *int16Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(int16))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not int16s. -func (h *int16Hasher) Equal(a, b interface{}) bool { - return a.(int16) == b.(int16) -} - -// int32Hasher implements Hasher for int32 keys. -type int32Hasher struct{} - -// Hash returns a hash for key. -func (h *int32Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(int32))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not int32s. -func (h *int32Hasher) Equal(a, b interface{}) bool { - return a.(int32) == b.(int32) -} - -// int64Hasher implements Hasher for int64 keys. -type int64Hasher struct{} - -// Hash returns a hash for key. -func (h *int64Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(int64))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not int64s. -func (h *int64Hasher) Equal(a, b interface{}) bool { - return a.(int64) == b.(int64) -} - -// uintHasher implements Hasher for uint keys. -type uintHasher struct{} - -// Hash returns a hash for key. -func (h *uintHasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(uint))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not uints. -func (h *uintHasher) Equal(a, b interface{}) bool { - return a.(uint) == b.(uint) -} - -// uint8Hasher implements Hasher for uint8 keys. -type uint8Hasher struct{} - -// Hash returns a hash for key. -func (h *uint8Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(uint8))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not uint8s. -func (h *uint8Hasher) Equal(a, b interface{}) bool { - return a.(uint8) == b.(uint8) -} - -// uint16Hasher implements Hasher for uint16 keys. -type uint16Hasher struct{} - -// Hash returns a hash for key. -func (h *uint16Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(uint16))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not uint16s. -func (h *uint16Hasher) Equal(a, b interface{}) bool { - return a.(uint16) == b.(uint16) -} - -// uint32Hasher implements Hasher for uint32 keys. -type uint32Hasher struct{} - -// Hash returns a hash for key. -func (h *uint32Hasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(key.(uint32))) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not uint32s. -func (h *uint32Hasher) Equal(a, b interface{}) bool { - return a.(uint32) == b.(uint32) -} - -// uint64Hasher implements Hasher for uint64 keys. -type uint64Hasher struct{} - -// Hash returns a hash for key. -func (h *uint64Hasher) Hash(key interface{}) uint32 { - return hashUint64(key.(uint64)) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not uint64s. -func (h *uint64Hasher) Equal(a, b interface{}) bool { - return a.(uint64) == b.(uint64) -} - -// stringHasher implements Hasher for string keys. -type stringHasher struct{} - -// Hash returns a hash for value. -func (h *stringHasher) Hash(value interface{}) uint32 { - var hash uint32 - for i, value := 0, value.(string); i < len(value); i++ { - hash = 31*hash + uint32(value[i]) - } - return hash -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not strings. -func (h *stringHasher) Equal(a, b interface{}) bool { - return a.(string) == b.(string) -} - -// byteSliceHasher implements Hasher for byte slice keys. -type byteSliceHasher struct{} - // Hash returns a hash for value. -func (h *byteSliceHasher) Hash(value interface{}) uint32 { +func hashString(value string) uint32 { var hash uint32 - for i, value := 0, value.([]byte); i < len(value); i++ { + for i, value := 0, value; i < len(value); i++ { hash = 31*hash + uint32(value[i]) } return hash } -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not byte slices. -func (h *byteSliceHasher) Equal(a, b interface{}) bool { - return bytes.Equal(a.([]byte), b.([]byte)) -} - // reflectIntHasher implements a reflection-based Hasher for int keys. -type reflectIntHasher struct{} +type reflectHasher[K constraints.Ordered] struct{} // Hash returns a hash for key. -func (h *reflectIntHasher) Hash(key interface{}) uint32 { - return hashUint64(uint64(reflect.ValueOf(key).Int())) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not ints. -func (h *reflectIntHasher) Equal(a, b interface{}) bool { - return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int() -} - -// reflectUintHasher implements a reflection-based Hasher for uint keys. -type reflectUintHasher struct{} - -// Hash returns a hash for key. -func (h *reflectUintHasher) Hash(key interface{}) uint32 { - return hashUint64(reflect.ValueOf(key).Uint()) +func (h *reflectHasher[K]) Hash(key K) uint32 { + switch reflect.TypeOf(key).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return hashUint64(uint64(reflect.ValueOf(key).Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return hashUint64(reflect.ValueOf(key).Uint()) + case reflect.String: + var hash uint32 + s := reflect.ValueOf(key).String() + for i := 0; i < len(s); i++ { + hash = 31*hash + uint32(s[i]) + } + return hash + } + panic(fmt.Sprintf("immutable.reflectHasher.Hash: reflectHasher does not support %T type", key)) } // Equal returns true if a is equal to b. Otherwise returns false. // Panics if a and b are not ints. -func (h *reflectUintHasher) Equal(a, b interface{}) bool { - return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint() -} - -// reflectStringHasher implements a refletion-based Hasher for string keys. -type reflectStringHasher struct{} - -// Hash returns a hash for value. -func (h *reflectStringHasher) Hash(value interface{}) uint32 { - var hash uint32 - s := reflect.ValueOf(value).String() - for i := 0; i < len(s); i++ { - hash = 31*hash + uint32(s[i]) +func (h *reflectHasher[K]) Equal(a, b K) bool { + switch reflect.TypeOf(a).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint() + case reflect.String: + return reflect.ValueOf(a).String() == reflect.ValueOf(b).String() } - return hash -} + panic(fmt.Sprintf("immutable.reflectHasher.Equal: reflectHasher does not support %T type", a)) -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not strings. -func (h *reflectStringHasher) Equal(a, b interface{}) bool { - return reflect.ValueOf(a).String() == reflect.ValueOf(b).String() } // hashUint64 returns a 32-bit hash for a 64-bit value. @@ -2479,192 +2286,79 @@ func hashUint64(value uint64) uint32 { return uint32(hash) } -// Comparer allows the comparison of two keys for the purpose of sorting. -type Comparer interface { - // Returns -1 if a is less than b, returns 1 if a is greater than b, - // and returns 0 if a is equal to b. - Compare(a, b interface{}) int -} +// defaultHasher implements Hasher. +type defaultHasher[K constraints.Ordered] struct{} -// NewComparer returns the built-in comparer for a given key type. -func NewComparer(key interface{}) Comparer { - // Attempt to use non-reflection based comparer first. - switch key.(type) { +// Hash returns a hash for key. +func (h *defaultHasher[K]) Hash(key K) uint32 { + // Attempt to use non-reflection based hasher first. + switch x := (any(key)).(type) { case int: - return &intComparer{} + return hashUint64(uint64(x)) case int8: - return &int8Comparer{} + return hashUint64(uint64(x)) case int16: - return &int16Comparer{} + return hashUint64(uint64(x)) case int32: - return &int32Comparer{} + return hashUint64(uint64(x)) case int64: - return &int64Comparer{} + return hashUint64(uint64(x)) case uint: - return &uintComparer{} + return hashUint64(uint64(x)) case uint8: - return &uint8Comparer{} + return hashUint64(uint64(x)) case uint16: - return &uint16Comparer{} + return hashUint64(uint64(x)) case uint32: - return &uint32Comparer{} + return hashUint64(uint64(x)) case uint64: - return &uint64Comparer{} + return hashUint64(uint64(x)) + case uintptr: + return hashUint64(uint64(x)) case string: - return &stringComparer{} - case []byte: - return &byteSliceComparer{} + return hashString(x) } + panic(fmt.Sprintf("immutable.defaultHasher.Hash: must set comparer for %T type", key)) +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not ints. +func (h *defaultHasher[K]) Equal(a, b K) bool { + return a == b +} +// Comparer allows the comparison of two keys for the purpose of sorting. +type Comparer[K constraints.Ordered] interface { + // Returns -1 if a is less than b, returns 1 if a is greater than b, + // and returns 0 if a is equal to b. + Compare(a, b K) int +} + +// NewComparer returns the built-in comparer for a given key type. +func NewComparer[K constraints.Ordered](key K) Comparer[K] { + // Attempt to use non-reflection based comparer first. + switch (any(key)).(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: + return &defaultComparer[K]{} + } // Fallback to reflection-based comparer otherwise. // This is used when caller wraps a type around a primitive type. switch reflect.TypeOf(key).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return &reflectIntComparer{} - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &reflectUintComparer{} - case reflect.String: - return &reflectStringComparer{} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + return &reflectComparer[K]{} } - // If no comparers match then panic. // This is a compile time issue so it should not return an error. panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key)) } -// intComparer compares two integers. Implements Comparer. -type intComparer struct{} +// defaultComparer compares two integers. Implements Comparer. +type defaultComparer[K constraints.Ordered] struct{} // Compare returns -1 if a is less than b, returns 1 if a is greater than b, and // returns 0 if a is equal to b. Panic if a or b is not an int. -func (c *intComparer) Compare(a, b interface{}) int { - if i, j := a.(int), b.(int); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// int8Comparer compares two int8 values. Implements Comparer. -type int8Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int8. -func (c *int8Comparer) Compare(a, b interface{}) int { - if i, j := a.(int8), b.(int8); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// int16Comparer compares two int16 values. Implements Comparer. -type int16Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int16. -func (c *int16Comparer) Compare(a, b interface{}) int { - if i, j := a.(int16), b.(int16); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// int32Comparer compares two int32 values. Implements Comparer. -type int32Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int32. -func (c *int32Comparer) Compare(a, b interface{}) int { - if i, j := a.(int32), b.(int32); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// int64Comparer compares two int64 values. Implements Comparer. -type int64Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int64. -func (c *int64Comparer) Compare(a, b interface{}) int { - if i, j := a.(int64), b.(int64); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// uintComparer compares two uint values. Implements Comparer. -type uintComparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an uint. -func (c *uintComparer) Compare(a, b interface{}) int { - if i, j := a.(uint), b.(uint); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// uint8Comparer compares two uint8 values. Implements Comparer. -type uint8Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an uint8. -func (c *uint8Comparer) Compare(a, b interface{}) int { - if i, j := a.(uint8), b.(uint8); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// uint16Comparer compares two uint16 values. Implements Comparer. -type uint16Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an uint16. -func (c *uint16Comparer) Compare(a, b interface{}) int { - if i, j := a.(uint16), b.(uint16); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// uint32Comparer compares two uint32 values. Implements Comparer. -type uint32Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an uint32. -func (c *uint32Comparer) Compare(a, b interface{}) int { - if i, j := a.(uint32), b.(uint32); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// uint64Comparer compares two uint64 values. Implements Comparer. -type uint64Comparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an uint64. -func (c *uint64Comparer) Compare(a, b interface{}) int { - if i, j := a.(uint64), b.(uint64); i < j { +func (c *defaultComparer[K]) Compare(i K, j K) int { + if i < j { return -1 } else if i > j { return 1 @@ -2672,59 +2366,31 @@ func (c *uint64Comparer) Compare(a, b interface{}) int { return 0 } -// stringComparer compares two strings. Implements Comparer. -type stringComparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not a string. -func (c *stringComparer) Compare(a, b interface{}) int { - return strings.Compare(a.(string), b.(string)) -} - -// byteSliceComparer compares two byte slices. Implements Comparer. -type byteSliceComparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not a byte slice. -func (c *byteSliceComparer) Compare(a, b interface{}) int { - return bytes.Compare(a.([]byte), b.([]byte)) -} - // reflectIntComparer compares two int values using reflection. Implements Comparer. -type reflectIntComparer struct{} +type reflectComparer[K constraints.Ordered] struct{} // Compare returns -1 if a is less than b, returns 1 if a is greater than b, and // returns 0 if a is equal to b. Panic if a or b is not an int. -func (c *reflectIntComparer) Compare(a, b interface{}) int { - if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// reflectUintComparer compares two uint values using reflection. Implements Comparer. -type reflectUintComparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int. -func (c *reflectUintComparer) Compare(a, b interface{}) int { - if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j { - return -1 - } else if i > j { - return 1 +func (c *reflectComparer[K]) Compare(a, b K) int { + switch reflect.TypeOf(a).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j { + return -1 + } else if i > j { + return 1 + } + return 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j { + return -1 + } else if i > j { + return 1 + } + return 0 + case reflect.String: + return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String()) } - return 0 -} - -// reflectStringComparer compares two string values using reflection. Implements Comparer. -type reflectStringComparer struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int. -func (c *reflectStringComparer) Compare(a, b interface{}) int { - return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String()) + panic(fmt.Sprintf("immutable.reflectComparer.Compare: must set comparer for %T type", a)) } func assert(condition bool, message string) { diff --git a/immutable_test.go b/immutable_test.go index 130904f..c4b0297 100644 --- a/immutable_test.go +++ b/immutable_test.go @@ -6,6 +6,8 @@ import ( "math/rand" "sort" "testing" + + "golang.org/x/exp/constraints" ) var ( @@ -15,13 +17,13 @@ var ( func TestList(t *testing.T) { t.Run("Empty", func(t *testing.T) { - if size := NewList().Len(); size != 0 { + if size := NewList[string]().Len(); size != 0 { t.Fatalf("unexpected size: %d", size) } }) t.Run("Shallow", func(t *testing.T) { - list := NewList() + list := NewList[string]() list = list.Append("foo") if v := list.Get(0); v != "foo" { t.Fatalf("unexpected value: %v", v) @@ -40,7 +42,7 @@ func TestList(t *testing.T) { }) t.Run("Deep", func(t *testing.T) { - list := NewList() + list := NewList[int]() var array []int for i := 0; i < 100000; i++ { list = list.Append(i) @@ -51,14 +53,14 @@ func TestList(t *testing.T) { t.Fatalf("List.Len()=%d, exp %d", got, exp) } for j := range array { - if got, exp := list.Get(j).(int), array[j]; got != exp { + if got, exp := list.Get(j), array[j]; got != exp { t.Fatalf("%d. List.Get(%d)=%d, exp %d", len(array), j, got, exp) } } }) t.Run("Set", func(t *testing.T) { - list := NewList() + list := NewList[string]() list = list.Append("foo") list = list.Append("bar") @@ -78,7 +80,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Get(-1) }() @@ -91,7 +93,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Get(1) }() @@ -104,7 +106,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Set(1, "bar") }() @@ -117,7 +119,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Slice(2, 3) }() @@ -130,7 +132,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Slice(1, 3) }() @@ -143,7 +145,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l.Slice(2, 1) @@ -154,7 +156,7 @@ func TestList(t *testing.T) { }) t.Run("SliceBeginning", func(t *testing.T) { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Slice(1, 2) @@ -169,7 +171,7 @@ func TestList(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - l := NewList() + l := NewList[string]() l = l.Append("foo") l.Iterator().Seek(-1) }() @@ -204,16 +206,16 @@ func TestList(t *testing.T) { // TList represents a list that operates on a standard Go slice & immutable list. type TList struct { - im, prev *List - builder *ListBuilder + im, prev *List[int] + builder *ListBuilder[int] std []int } // NewTList returns a new instance of TList. func NewTList() *TList { return &TList{ - im: NewList(), - builder: NewListBuilder(), + im: NewList[int](), + builder: NewListBuilder[int](), } } @@ -302,7 +304,7 @@ func (l *TList) Validate() error { return nil } -func (l *TList) validateForwardIterator(typ string, itr *ListIterator) error { +func (l *TList) validateForwardIterator(typ string, itr *ListIterator[int]) error { for i := range l.std { if j, v := itr.Next(); i != j || l.std[i] != v { return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) @@ -313,13 +315,13 @@ func (l *TList) validateForwardIterator(typ string, itr *ListIterator) error { return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) } } - if i, v := itr.Next(); i != -1 || v != nil { + if i, v := itr.Next(); i != -1 || v != 0 { return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected DONE [%s]", i, v, typ) } return nil } -func (l *TList) validateBackwardIterator(typ string, itr *ListIterator) error { +func (l *TList) validateBackwardIterator(typ string, itr *ListIterator[int]) error { itr.Last() for i := len(l.std) - 1; i >= 0; i-- { if j, v := itr.Prev(); i != j || l.std[i] != v { @@ -331,7 +333,7 @@ func (l *TList) validateBackwardIterator(typ string, itr *ListIterator) error { return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) } } - if i, v := itr.Prev(); i != -1 || v != nil { + if i, v := itr.Prev(); i != -1 || v != 0 { return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected DONE [%s]", i, v, typ) } return nil @@ -339,7 +341,7 @@ func (l *TList) validateBackwardIterator(typ string, itr *ListIterator) error { func BenchmarkList_Append(b *testing.B) { b.ReportAllocs() - l := NewList() + l := NewList[int]() for i := 0; i < b.N; i++ { l = l.Append(i) } @@ -347,7 +349,7 @@ func BenchmarkList_Append(b *testing.B) { func BenchmarkList_Prepend(b *testing.B) { b.ReportAllocs() - l := NewList() + l := NewList[int]() for i := 0; i < b.N; i++ { l = l.Prepend(i) } @@ -356,7 +358,7 @@ func BenchmarkList_Prepend(b *testing.B) { func BenchmarkList_Set(b *testing.B) { const n = 10000 - l := NewList() + l := NewList[int]() for i := 0; i < 10000; i++ { l = l.Append(i) } @@ -370,7 +372,7 @@ func BenchmarkList_Set(b *testing.B) { func BenchmarkList_Iterator(b *testing.B) { const n = 10000 - l := NewList() + l := NewList[int]() for i := 0; i < 10000; i++ { l = l.Append(i) } @@ -417,7 +419,7 @@ func BenchmarkBuiltinSlice_Append(b *testing.B) { func BenchmarkListBuilder_Append(b *testing.B) { b.ReportAllocs() - builder := NewListBuilder() + builder := NewListBuilder[int]() for i := 0; i < b.N; i++ { builder.Append(i) } @@ -425,7 +427,7 @@ func BenchmarkListBuilder_Append(b *testing.B) { func BenchmarkListBuilder_Prepend(b *testing.B) { b.ReportAllocs() - builder := NewListBuilder() + builder := NewListBuilder[int]() for i := 0; i < b.N; i++ { builder.Prepend(i) } @@ -434,7 +436,7 @@ func BenchmarkListBuilder_Prepend(b *testing.B) { func BenchmarkListBuilder_Set(b *testing.B) { const n = 10000 - builder := NewListBuilder() + builder := NewListBuilder[int]() for i := 0; i < 10000; i++ { builder.Append(i) } @@ -447,7 +449,7 @@ func BenchmarkListBuilder_Set(b *testing.B) { } func ExampleList_Append() { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Append("baz") @@ -462,7 +464,7 @@ func ExampleList_Append() { } func ExampleList_Prepend() { - l := NewList() + l := NewList[string]() l = l.Prepend("foo") l = l.Prepend("bar") l = l.Prepend("baz") @@ -477,7 +479,7 @@ func ExampleList_Prepend() { } func ExampleList_Set() { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Set(1, "baz") @@ -490,7 +492,7 @@ func ExampleList_Set() { } func ExampleList_Slice() { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Append("baz") @@ -504,7 +506,7 @@ func ExampleList_Slice() { } func ExampleList_Iterator() { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Append("baz") @@ -521,7 +523,7 @@ func ExampleList_Iterator() { } func ExampleList_Iterator_reverse() { - l := NewList() + l := NewList[string]() l = l.Append("foo") l = l.Append("bar") l = l.Append("baz") @@ -539,7 +541,7 @@ func ExampleList_Iterator_reverse() { } func ExampleListBuilder_Append() { - b := NewListBuilder() + b := NewListBuilder[string]() b.Append("foo") b.Append("bar") b.Append("baz") @@ -555,7 +557,7 @@ func ExampleListBuilder_Append() { } func ExampleListBuilder_Prepend() { - b := NewListBuilder() + b := NewListBuilder[string]() b.Prepend("foo") b.Prepend("bar") b.Prepend("baz") @@ -571,7 +573,7 @@ func ExampleListBuilder_Prepend() { } func ExampleListBuilder_Set() { - b := NewListBuilder() + b := NewListBuilder[string]() b.Append("foo") b.Append("bar") b.Set(1, "baz") @@ -585,7 +587,7 @@ func ExampleListBuilder_Set() { } func ExampleListBuilder_Slice() { - b := NewListBuilder() + b := NewListBuilder[string]() b.Append("foo") b.Append("bar") b.Append("baz") @@ -602,8 +604,8 @@ func ExampleListBuilder_Slice() { // Ensure node can support overwrites as it expands. func TestInternal_mapNode_Overwrite(t *testing.T) { const n = 1000 - var h intHasher - var node mapNode = &mapArrayNode{} + var h defaultHasher[int] + var node mapNode[int, int] = &mapArrayNode[int, int]{} for i := 0; i < n; i++ { var resized bool node = node.set(i, i, 0, h.Hash(i), &h, false, &resized) @@ -637,11 +639,11 @@ func TestInternal_mapNode_Overwrite(t *testing.T) { func TestInternal_mapArrayNode(t *testing.T) { // Ensure 8 or fewer elements stays in an array node. t.Run("Append", func(t *testing.T) { - var h intHasher - n := &mapArrayNode{} + var h defaultHasher[int] + n := &mapArrayNode[int, int]{} for i := 0; i < 8; i++ { var resized bool - n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode) + n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) if !resized { t.Fatal("expected resize") } @@ -656,11 +658,11 @@ func TestInternal_mapArrayNode(t *testing.T) { // Ensure 8 or fewer elements stays in an array node when inserted in reverse. t.Run("Prepend", func(t *testing.T) { - var h intHasher - n := &mapArrayNode{} + var h defaultHasher[int] + n := &mapArrayNode[int, int]{} for i := 7; i >= 0; i-- { var resized bool - n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode) + n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) if !resized { t.Fatal("expected resize") } @@ -675,8 +677,8 @@ func TestInternal_mapArrayNode(t *testing.T) { // Ensure array can transition between node types. t.Run("Expand", func(t *testing.T) { - var h intHasher - var n mapNode = &mapArrayNode{} + var h defaultHasher[int] + var n mapNode[int, int] = &mapArrayNode[int, int]{} for i := 0; i < 100; i++ { var resized bool n = n.set(i, i, 0, h.Hash(i), &h, false, &resized) @@ -694,8 +696,8 @@ func TestInternal_mapArrayNode(t *testing.T) { // Ensure deleting elements returns the correct new node. RunRandom(t, "Delete", func(t *testing.T, rand *rand.Rand) { - var h intHasher - var n mapNode = &mapArrayNode{} + var h defaultHasher[int] + var n mapNode[int, int] = &mapArrayNode[int, int]{} for i := 0; i < 8; i++ { var resized bool n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized) @@ -713,7 +715,7 @@ func TestInternal_mapArrayNode(t *testing.T) { func TestInternal_mapValueNode(t *testing.T) { t.Run("Simple", func(t *testing.T) { - var h intHasher + var h defaultHasher[int] n := newMapValueNode(h.Hash(2), 2, 3) if v, ok := n.get(2, 0, h.Hash(2), &h); !ok { t.Fatal("expected ok") @@ -723,10 +725,10 @@ func TestInternal_mapValueNode(t *testing.T) { }) t.Run("KeyEqual", func(t *testing.T) { - var h intHasher + var h defaultHasher[int] var resized bool n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(2, 4, 0, h.Hash(2), &h, false, &resized).(*mapValueNode) + other := n.set(2, 4, 0, h.Hash(2), &h, false, &resized).(*mapValueNode[int, int]) if other == n { t.Fatal("expected new node") } else if got, exp := other.keyHash, h.Hash(2); got != exp { @@ -741,13 +743,13 @@ func TestInternal_mapValueNode(t *testing.T) { }) t.Run("KeyHashEqual", func(t *testing.T) { - h := &mockHasher{ - hash: func(value interface{}) uint32 { return 1 }, - equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + h := &mockHasher[int]{ + hash: func(value int) uint32 { return 1 }, + equal: func(a, b int) bool { return a == b }, } var resized bool n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapHashCollisionNode) + other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapHashCollisionNode[int, int]) if got, exp := other.keyHash, h.Hash(2); got != exp { t.Fatalf("keyHash=%v, expected %v", got, exp) } else if got, exp := len(other.entries), 2; got != exp { @@ -770,10 +772,10 @@ func TestInternal_mapValueNode(t *testing.T) { t.Run("MergeNode", func(t *testing.T) { // Inserting into a node with a different index in the mask should split into a bitmap node. t.Run("NoConflict", func(t *testing.T) { - var h intHasher + var h defaultHasher[int] var resized bool n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), &h, false, &resized).(*mapBitmapIndexedNode) + other := n.set(4, 5, 0, h.Hash(4), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) if got, exp := other.bitmap, uint32(0x14); got != exp { t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) } else if got, exp := len(other.nodes), 2; got != exp { @@ -781,14 +783,14 @@ func TestInternal_mapValueNode(t *testing.T) { } else if !resized { t.Fatal("expected resize") } - if node, ok := other.nodes[0].(*mapValueNode); !ok { + if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) } else if got, exp := node.key, 2; got != exp { t.Fatalf("key[0]=%v, expected %v", got, exp) } else if got, exp := node.value, 3; got != exp { t.Fatalf("value[0]=%v, expected %v", got, exp) } - if node, ok := other.nodes[1].(*mapValueNode); !ok { + if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) } else if got, exp := node.key, 4; got != exp { t.Fatalf("key[1]=%v, expected %v", got, exp) @@ -797,19 +799,19 @@ func TestInternal_mapValueNode(t *testing.T) { } // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 { + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 { + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { t.Fatalf("Get(4)=<%v,%v>", v, ok) } }) // Reversing the nodes from NoConflict should yield the same result. t.Run("NoConflictReverse", func(t *testing.T) { - var h intHasher + var h defaultHasher[int] var resized bool n := newMapValueNode(h.Hash(4), 4, 5) - other := n.set(2, 3, 0, h.Hash(2), &h, false, &resized).(*mapBitmapIndexedNode) + other := n.set(2, 3, 0, h.Hash(2), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) if got, exp := other.bitmap, uint32(0x14); got != exp { t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) } else if got, exp := len(other.nodes), 2; got != exp { @@ -817,14 +819,14 @@ func TestInternal_mapValueNode(t *testing.T) { } else if !resized { t.Fatal("expected resize") } - if node, ok := other.nodes[0].(*mapValueNode); !ok { + if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) } else if got, exp := node.key, 2; got != exp { t.Fatalf("key[0]=%v, expected %v", got, exp) } else if got, exp := node.value, 3; got != exp { t.Fatalf("value[0]=%v, expected %v", got, exp) } - if node, ok := other.nodes[1].(*mapValueNode); !ok { + if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) } else if got, exp := node.key, 4; got != exp { t.Fatalf("key[1]=%v, expected %v", got, exp) @@ -833,22 +835,22 @@ func TestInternal_mapValueNode(t *testing.T) { } // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 { + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 { + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { t.Fatalf("Get(4)=<%v,%v>", v, ok) } }) // Inserting a node with the same mask index should nest an additional level of bitmap nodes. t.Run("Conflict", func(t *testing.T) { - h := &mockHasher{ - hash: func(value interface{}) uint32 { return uint32(value.(int)) << 5 }, - equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + h := &mockHasher[int]{ + hash: func(value int) uint32 { return uint32(value << 5) }, + equal: func(a, b int) bool { return a == b }, } var resized bool n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapBitmapIndexedNode) + other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapBitmapIndexedNode[int, int]) if got, exp := other.bitmap, uint32(0x01); got != exp { // mask is zero, expect first slot. t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) } else if got, exp := len(other.nodes), 1; got != exp { @@ -856,19 +858,19 @@ func TestInternal_mapValueNode(t *testing.T) { } else if !resized { t.Fatal("expected resize") } - child, ok := other.nodes[0].(*mapBitmapIndexedNode) + child, ok := other.nodes[0].(*mapBitmapIndexedNode[int, int]) if !ok { t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) } - if node, ok := child.nodes[0].(*mapValueNode); !ok { + if node, ok := child.nodes[0].(*mapValueNode[int, int]); !ok { t.Fatalf("node[0]=%T, unexpected type", child.nodes[0]) } else if got, exp := node.key, 2; got != exp { t.Fatalf("key[0]=%v, expected %v", got, exp) } else if got, exp := node.value, 3; got != exp { t.Fatalf("value[0]=%v, expected %v", got, exp) } - if node, ok := child.nodes[1].(*mapValueNode); !ok { + if node, ok := child.nodes[1].(*mapValueNode[int, int]); !ok { t.Fatalf("node[1]=%T, unexpected type", child.nodes[1]) } else if got, exp := node.key, 4; got != exp { t.Fatalf("key[1]=%v, expected %v", got, exp) @@ -877,9 +879,9 @@ func TestInternal_mapValueNode(t *testing.T) { } // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v.(int) != 3 { + if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v != 3 { t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v.(int) != 5 { + } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v != 5 { t.Fatalf("Get(4)=<%v,%v>", v, ok) } else if v, ok := other.get(10, 0, h.Hash(10), h); ok { t.Fatalf("Get(10)=<%v,%v>, expected no value", v, ok) @@ -890,8 +892,8 @@ func TestInternal_mapValueNode(t *testing.T) { func TestMap_Get(t *testing.T) { t.Run("Empty", func(t *testing.T) { - m := NewMap(nil) - if v, ok := m.Get(100); ok || v != nil { + m := NewMap[int, string](nil) + if v, ok := m.Get(100); ok { t.Fatalf("unexpected value: <%v,%v>", v, ok) } }) @@ -899,17 +901,17 @@ func TestMap_Get(t *testing.T) { func TestMap_Set(t *testing.T) { t.Run("Simple", func(t *testing.T) { - m := NewMap(nil) + m := NewMap[int, string](nil) itr := m.Iterator() if !itr.Done() { t.Fatal("MapIterator.Done()=true, expected false") - } else if k, v := itr.Next(); k != nil || v != nil { + } else if k, v, ok := itr.Next(); ok { t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) } }) t.Run("Simple", func(t *testing.T) { - m := NewMap(nil) + m := NewMap[int, string](nil) m = m.Set(100, "foo") if v, ok := m.Get(100); !ok || v != "foo" { t.Fatalf("unexpected value: <%v,%v>", v, ok) @@ -918,7 +920,7 @@ func TestMap_Set(t *testing.T) { t.Run("VerySmall", func(t *testing.T) { const n = 6 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -931,7 +933,7 @@ func TestMap_Set(t *testing.T) { // NOTE: Array nodes store entries in insertion order. itr := m.Iterator() for i := 0; i < n; i++ { - if k, v := itr.Next(); k != i || v != i+1 { + if k, v, ok := itr.Next(); !ok || k != i || v != i+1 { t.Fatalf("MapIterator.Next()=<%v,%v>, exp <%v,%v>", k, v, i, i+1) } } @@ -942,7 +944,7 @@ func TestMap_Set(t *testing.T) { t.Run("Small", func(t *testing.T) { const n = 1000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -959,7 +961,7 @@ func TestMap_Set(t *testing.T) { } const n = 1000000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -971,7 +973,7 @@ func TestMap_Set(t *testing.T) { }) t.Run("StringKeys", func(t *testing.T) { - m := NewMap(nil) + m := NewMap[string, string](nil) m = m.Set("foo", "bar") m = m.Set("baz", "bat") m = m.Set("", "EMPTY") @@ -987,36 +989,6 @@ func TestMap_Set(t *testing.T) { } }) - t.Run("ByteSliceKeys", func(t *testing.T) { - m := NewMap(nil) - m = m.Set([]byte("foo"), "bar") - m = m.Set([]byte("baz"), "bat") - m = m.Set([]byte(""), "EMPTY") - if v, ok := m.Get([]byte("foo")); !ok || v != "bar" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - if v, ok := m.Get([]byte("no_such_key")); ok { - t.Fatalf("expected no value: <%v,%v>", v, ok) - } - }) - - t.Run("NoDefaultHasher", func(t *testing.T) { - type T struct{} - var r string - func() { - defer func() { r = recover().(string) }() - m := NewMap(nil) - m = m.Set(T{}, "bar") - }() - if r != `immutable.NewHasher: must set hasher for immutable.T type` { - t.Fatalf("unexpected panic: %q", r) - } - }) - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { m := NewTestMap() for i := 0; i < 10000; i++ { @@ -1040,7 +1012,7 @@ func TestMap_Overwrite(t *testing.T) { } const n = 10000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { // Set original value. m = m.Set(i, i) @@ -1061,7 +1033,7 @@ func TestMap_Overwrite(t *testing.T) { func TestMap_Delete(t *testing.T) { t.Run("Empty", func(t *testing.T) { - m := NewMap(nil) + m := NewMap[string, int](nil) other := m.Delete("foo") if m != other { t.Fatal("expected same map") @@ -1069,7 +1041,7 @@ func TestMap_Delete(t *testing.T) { }) t.Run("Simple", func(t *testing.T) { - m := NewMap(nil) + m := NewMap[int, string](nil) m = m.Set(100, "foo") if v, ok := m.Get(100); !ok || v != "foo" { t.Fatalf("unexpected value: <%v,%v>", v, ok) @@ -1078,7 +1050,7 @@ func TestMap_Delete(t *testing.T) { t.Run("Small", func(t *testing.T) { const n = 1000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1095,7 +1067,7 @@ func TestMap_Delete(t *testing.T) { t.Skip("skipping: short") } const n = 1000000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1142,11 +1114,11 @@ func TestMap_LimitedHash(t *testing.T) { } t.Run("Immutable", func(t *testing.T) { - h := mockHasher{ - hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF }, - equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + h := mockHasher[int]{ + hash: func(value int) uint32 { return hashUint64(uint64(value)) % 0xFF }, + equal: func(a, b int) bool { return a == b }, } - m := NewMap(&h) + m := NewMap[int, int](&h) rand := rand.New(rand.NewSource(0)) keys := rand.Perm(100000) @@ -1170,8 +1142,8 @@ func TestMap_LimitedHash(t *testing.T) { // Verify iteration. itr := m.Iterator() for !itr.Done() { - if k, v := itr.Next(); v != k.(int)*2 { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k.(int)*2) + if k, v, ok := itr.Next(); !ok || v != k*2 { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) } } @@ -1195,11 +1167,11 @@ func TestMap_LimitedHash(t *testing.T) { }) t.Run("Builder", func(t *testing.T) { - h := mockHasher{ - hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF }, - equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + h := mockHasher[int]{ + hash: func(value int) uint32 { return hashUint64(uint64(value)) }, + equal: func(a, b int) bool { return a == b }, } - b := NewMapBuilder(&h) + b := NewMapBuilder[int, int](&h) rand := rand.New(rand.NewSource(0)) keys := rand.Perm(100000) @@ -1223,8 +1195,8 @@ func TestMap_LimitedHash(t *testing.T) { // Verify iteration. itr := b.Iterator() for !itr.Done() { - if k, v := itr.Next(); v != k.(int)*2 { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k.(int)*2) + if k, v, ok := itr.Next(); !ok || v != k*2 { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) } } @@ -1245,16 +1217,16 @@ func TestMap_LimitedHash(t *testing.T) { // TMap represents a combined immutable and stdlib map. type TMap struct { - im, prev *Map - builder *MapBuilder + im, prev *Map[int, int] + builder *MapBuilder[int, int] std map[int]int keys []int } func NewTestMap() *TMap { return &TMap{ - im: NewMap(nil), - builder: NewMapBuilder(nil), + im: NewMap[int, int](nil), + builder: NewMapBuilder[int, int](nil), std: make(map[int]int), } } @@ -1322,11 +1294,11 @@ func (m *TMap) Validate() error { return nil } -func (m *TMap) validateIterator(itr *MapIterator) error { +func (m *TMap) validateIterator(itr *MapIterator[int, int]) error { other := make(map[int]int) for !itr.Done() { - k, v := itr.Next() - other[k.(int)] = v.(int) + k, v, _ := itr.Next() + other[k] = v } if len(other) != len(m.std) { return fmt.Errorf("map iterator size mismatch: %v!=%v", len(m.std), len(other)) @@ -1336,7 +1308,7 @@ func (m *TMap) validateIterator(itr *MapIterator) error { return fmt.Errorf("map iterator mismatch: key=%v, %v!=%v", k, v, other[k]) } } - if k, v := itr.Next(); k != nil || v != nil { + if k, v, ok := itr.Next(); ok { return fmt.Errorf("map iterator returned key/value after done: <%v/%v>", k, v) } return nil @@ -1367,7 +1339,7 @@ func BenchmarkBuiltinMap_Delete(b *testing.B) { func BenchmarkMap_Set(b *testing.B) { b.ReportAllocs() - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < b.N; i++ { m = m.Set(i, i) } @@ -1376,7 +1348,7 @@ func BenchmarkMap_Set(b *testing.B) { func BenchmarkMap_Delete(b *testing.B) { const n = 10000000 - builder := NewMapBuilder(nil) + builder := NewMapBuilder[int, int](nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -1391,7 +1363,7 @@ func BenchmarkMap_Delete(b *testing.B) { func BenchmarkMap_Iterator(b *testing.B) { const n = 10000 - m := NewMap(nil) + m := NewMap[int, int](nil) for i := 0; i < 10000; i++ { m = m.Set(i, i) } @@ -1411,7 +1383,7 @@ func BenchmarkMap_Iterator(b *testing.B) { func BenchmarkMapBuilder_Set(b *testing.B) { b.ReportAllocs() - builder := NewMapBuilder(nil) + builder := NewMapBuilder[int, int](nil) for i := 0; i < b.N; i++ { builder.Set(i, i) } @@ -1420,7 +1392,7 @@ func BenchmarkMapBuilder_Set(b *testing.B) { func BenchmarkMapBuilder_Delete(b *testing.B) { const n = 10000000 - builder := NewMapBuilder(nil) + builder := NewMapBuilder[int, int](nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -1433,7 +1405,7 @@ func BenchmarkMapBuilder_Delete(b *testing.B) { } func ExampleMap_Set() { - m := NewMap(nil) + m := NewMap[string, any](nil) m = m.Set("foo", "bar") m = m.Set("baz", 100) @@ -1452,7 +1424,7 @@ func ExampleMap_Set() { } func ExampleMap_Delete() { - m := NewMap(nil) + m := NewMap[string, any](nil) m = m.Set("foo", "bar") m = m.Set("baz", 100) m = m.Delete("baz") @@ -1468,7 +1440,7 @@ func ExampleMap_Delete() { } func ExampleMap_Iterator() { - m := NewMap(nil) + m := NewMap[string, int](nil) m = m.Set("apple", 100) m = m.Set("grape", 200) m = m.Set("kiwi", 300) @@ -1481,7 +1453,7 @@ func ExampleMap_Iterator() { itr := m.Iterator() for !itr.Done() { - k, v := itr.Next() + k, v, _ := itr.Next() fmt.Println(k, v) } // Output: @@ -1497,7 +1469,7 @@ func ExampleMap_Iterator() { } func ExampleMapBuilder_Set() { - b := NewMapBuilder(nil) + b := NewMapBuilder[string, any](nil) b.Set("foo", "bar") b.Set("baz", 100) @@ -1517,7 +1489,7 @@ func ExampleMapBuilder_Set() { } func ExampleMapBuilder_Delete() { - b := NewMapBuilder(nil) + b := NewMapBuilder[string, any](nil) b.Set("foo", "bar") b.Set("baz", 100) b.Delete("baz") @@ -1535,12 +1507,12 @@ func ExampleMapBuilder_Delete() { func TestInternalSortedMapLeafNode(t *testing.T) { RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { - var cmpr intComparer - var node sortedMapNode = &sortedMapLeafNode{} + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} var keys []int for _, i := range rand.Perm(32) { var resized bool - var splitNode sortedMapNode + var splitNode sortedMapNode[int, int] node, splitNode = node.set(i, i*10, &cmpr, false, &resized) if !resized { t.Fatal("expected resize") @@ -1570,8 +1542,9 @@ func TestInternalSortedMapLeafNode(t *testing.T) { }) RunRandom(t, "Overwrite", func(t *testing.T, rand *rand.Rand) { - var cmpr intComparer - var node sortedMapNode = &sortedMapLeafNode{} + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} + for _, i := range rand.Perm(32) { var resized bool node, _ = node.set(i, i*2, &cmpr, false, &resized) @@ -1593,9 +1566,9 @@ func TestInternalSortedMapLeafNode(t *testing.T) { }) t.Run("Split", func(t *testing.T) { - // Fill leaf node. - var cmpr intComparer - var node sortedMapNode = &sortedMapLeafNode{} + // Fill leaf node. var cmpr defaultComparer[int] + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} for i := 0; i < 32; i++ { var resized bool node, _ = node.set(i, i*10, &cmpr, false, &resized) @@ -1606,7 +1579,7 @@ func TestInternalSortedMapLeafNode(t *testing.T) { newNode, splitNode := node.set(32, 320, &cmpr, false, &resized) // Verify node contents. - newLeafNode, ok := newNode.(*sortedMapLeafNode) + newLeafNode, ok := newNode.(*sortedMapLeafNode[int, int]) if !ok { t.Fatalf("unexpected node type: %T", newLeafNode) } else if n := len(newLeafNode.entries); n != 16 { @@ -1619,7 +1592,7 @@ func TestInternalSortedMapLeafNode(t *testing.T) { } // Verify split node contents. - splitLeafNode, ok := splitNode.(*sortedMapLeafNode) + splitLeafNode, ok := splitNode.(*sortedMapLeafNode[int, int]) if !ok { t.Fatalf("unexpected split node type: %T", splitLeafNode) } else if n := len(splitLeafNode.entries); n != 17 { @@ -1643,17 +1616,17 @@ func TestInternalSortedMapBranchNode(t *testing.T) { sort.Ints(keys[:2]) // ensure first two keys are sorted for initial insert. // Initialize branch with two leafs. - var cmpr intComparer - leaf0 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[0], value: keys[0] * 10}}} - leaf1 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[1], value: keys[1] * 10}}} - var node sortedMapNode = newSortedMapBranchNode(leaf0, leaf1) + var cmpr defaultComparer[int] + leaf0 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[0], value: keys[0] * 10}}} + leaf1 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[1], value: keys[1] * 10}}} + var node sortedMapNode[int, int] = newSortedMapBranchNode[int, int](leaf0, leaf1) sort.Ints(keys) for _, i := range rand.Perm(len(keys)) { key := keys[i] var resized bool - var splitNode sortedMapNode + var splitNode sortedMapNode[int, int] node, splitNode = node.set(key, key*10, &cmpr, false, &resized) if key == leaf0.entries[0].key || key == leaf1.entries[0].key { if resized { @@ -1684,16 +1657,16 @@ func TestInternalSortedMapBranchNode(t *testing.T) { t.Run("Split", func(t *testing.T) { // Generate leaf nodes. - var cmpr intComparer - children := make([]sortedMapNode, 32) + var cmpr defaultComparer[int] + children := make([]sortedMapNode[int, int], 32) for i := range children { - leaf := &sortedMapLeafNode{entries: make([]mapEntry, 32)} + leaf := &sortedMapLeafNode[int, int]{entries: make([]mapEntry[int, int], 32)} for j := range leaf.entries { - leaf.entries[j] = mapEntry{key: (i * 32) + j, value: ((i * 32) + j) * 100} + leaf.entries[j] = mapEntry[int, int]{key: (i * 32) + j, value: ((i * 32) + j) * 100} } children[i] = leaf } - var node sortedMapNode = newSortedMapBranchNode(children...) + var node sortedMapNode[int, int] = newSortedMapBranchNode(children...) // Add one more and expect split. var resized bool @@ -1701,14 +1674,14 @@ func TestInternalSortedMapBranchNode(t *testing.T) { // Verify node contents. var idx int - newBranchNode, ok := newNode.(*sortedMapBranchNode) + newBranchNode, ok := newNode.(*sortedMapBranchNode[int, int]) if !ok { t.Fatalf("unexpected node type: %T", newBranchNode) } else if n := len(newBranchNode.elems); n != 16 { t.Fatalf("unexpected child elems len: %d", n) } for i, elem := range newBranchNode.elems { - child, ok := elem.node.(*sortedMapLeafNode) + child, ok := elem.node.(*sortedMapLeafNode[int, int]) if !ok { t.Fatalf("unexpected child type") } @@ -1721,14 +1694,14 @@ func TestInternalSortedMapBranchNode(t *testing.T) { } // Verify split node contents. - splitBranchNode, ok := splitNode.(*sortedMapBranchNode) + splitBranchNode, ok := splitNode.(*sortedMapBranchNode[int, int]) if !ok { t.Fatalf("unexpected split node type: %T", splitBranchNode) } else if n := len(splitBranchNode.elems); n != 17 { t.Fatalf("unexpected split node elem len: %d", n) } for i, elem := range splitBranchNode.elems { - child, ok := elem.node.(*sortedMapLeafNode) + child, ok := elem.node.(*sortedMapLeafNode[int, int]) if !ok { t.Fatalf("unexpected split node child type") } @@ -1744,8 +1717,8 @@ func TestInternalSortedMapBranchNode(t *testing.T) { func TestSortedMap_Get(t *testing.T) { t.Run("Empty", func(t *testing.T) { - m := NewSortedMap(nil) - if v, ok := m.Get(100); ok || v != nil { + m := NewSortedMap[int, int](nil) + if v, ok := m.Get(100); ok { t.Fatalf("unexpected value: <%v,%v>", v, ok) } }) @@ -1753,7 +1726,7 @@ func TestSortedMap_Get(t *testing.T) { func TestSortedMap_Set(t *testing.T) { t.Run("Simple", func(t *testing.T) { - m := NewSortedMap(nil) + m := NewSortedMap[int, string](nil) m = m.Set(100, "foo") if v, ok := m.Get(100); !ok || v != "foo" { t.Fatalf("unexpected value: <%v,%v>", v, ok) @@ -1764,7 +1737,7 @@ func TestSortedMap_Set(t *testing.T) { t.Run("Small", func(t *testing.T) { const n = 1000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1781,7 +1754,7 @@ func TestSortedMap_Set(t *testing.T) { } const n = 1000000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1793,7 +1766,7 @@ func TestSortedMap_Set(t *testing.T) { }) t.Run("StringKeys", func(t *testing.T) { - m := NewSortedMap(nil) + m := NewSortedMap[string, string](nil) m = m.Set("foo", "bar") m = m.Set("baz", "bat") m = m.Set("", "EMPTY") @@ -1809,28 +1782,11 @@ func TestSortedMap_Set(t *testing.T) { } }) - t.Run("ByteSliceKeys", func(t *testing.T) { - m := NewSortedMap(nil) - m = m.Set([]byte("foo"), "bar") - m = m.Set([]byte("baz"), "bat") - m = m.Set([]byte(""), "EMPTY") - if v, ok := m.Get([]byte("foo")); !ok || v != "bar" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - if v, ok := m.Get([]byte("no_such_key")); ok { - t.Fatalf("expected no value: <%v,%v>", v, ok) - } - }) - t.Run("NoDefaultComparer", func(t *testing.T) { var r string func() { defer func() { r = recover().(string) }() - m := NewSortedMap(nil) + m := NewSortedMap[float64, string](nil) m = m.Set(float64(100), "bar") }() if r != `immutable.NewComparer: must set comparer for float64 type` { @@ -1857,7 +1813,7 @@ func TestSortedMap_Set(t *testing.T) { // Ensure map can support overwrites as it expands. func TestSortedMap_Overwrite(t *testing.T) { const n = 1000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { // Set original value. m = m.Set(i, i) @@ -1878,7 +1834,7 @@ func TestSortedMap_Overwrite(t *testing.T) { func TestSortedMap_Delete(t *testing.T) { t.Run("Empty", func(t *testing.T) { - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) m = m.Delete(100) if n := m.Len(); n != 0 { t.Fatalf("SortedMap.Len()=%d, expected 0", n) @@ -1886,7 +1842,7 @@ func TestSortedMap_Delete(t *testing.T) { }) t.Run("Simple", func(t *testing.T) { - m := NewSortedMap(nil) + m := NewSortedMap[int, string](nil) m = m.Set(100, "foo") if v, ok := m.Get(100); !ok || v != "foo" { t.Fatalf("unexpected value: <%v,%v>", v, ok) @@ -1899,7 +1855,7 @@ func TestSortedMap_Delete(t *testing.T) { t.Run("Small", func(t *testing.T) { const n = 1000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1925,7 +1881,7 @@ func TestSortedMap_Delete(t *testing.T) { } const n = 1000000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i+1) } @@ -1978,25 +1934,25 @@ func TestSortedMap_Delete(t *testing.T) { func TestSortedMap_Iterator(t *testing.T) { t.Run("Empty", func(t *testing.T) { t.Run("First", func(t *testing.T) { - itr := NewSortedMap(nil).Iterator() + itr := NewSortedMap[int, int](nil).Iterator() itr.First() - if k, v := itr.Next(); k != nil || v != nil { + if k, v, ok := itr.Next(); ok { t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) } }) t.Run("Last", func(t *testing.T) { - itr := NewSortedMap(nil).Iterator() + itr := NewSortedMap[int, int](nil).Iterator() itr.Last() - if k, v := itr.Prev(); k != nil || v != nil { + if k, v, ok := itr.Prev(); ok { t.Fatalf("SortedMapIterator.Prev()=<%v,%v>, expected nil", k, v) } }) t.Run("Seek", func(t *testing.T) { - itr := NewSortedMap(nil).Iterator() + itr := NewSortedMap[string, int](nil).Iterator() itr.Seek("foo") - if k, v := itr.Next(); k != nil || v != nil { + if k, v, ok := itr.Next(); ok { t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) } }) @@ -2004,7 +1960,7 @@ func TestSortedMap_Iterator(t *testing.T) { t.Run("Seek", func(t *testing.T) { const n = 100 - m := NewSortedMap(nil) + m := NewSortedMap[string, int](nil) for i := 0; i < n; i += 2 { m = m.Set(fmt.Sprintf("%04d", i), i) } @@ -2014,7 +1970,7 @@ func TestSortedMap_Iterator(t *testing.T) { for i := 0; i < n; i += 2 { itr.Seek(fmt.Sprintf("%04d", i)) for j := i; j < n; j += 2 { - if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) } } @@ -2029,7 +1985,7 @@ func TestSortedMap_Iterator(t *testing.T) { for i := 1; i < n-2; i += 2 { itr.Seek(fmt.Sprintf("%04d", i)) for j := i + 1; j < n; j += 2 { - if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) } } @@ -2043,7 +1999,7 @@ func TestSortedMap_Iterator(t *testing.T) { itr := m.Iterator() itr.Seek("") for i := 0; i < n; i += 2 { - if k, _ := itr.Next(); k != fmt.Sprintf("%04d", i) { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", i) { t.Fatalf("%d. SortedMapIterator.Next()=%v, expected key %04d", i, k, i) } } @@ -2054,7 +2010,7 @@ func TestSortedMap_Iterator(t *testing.T) { t.Run("AfterLast", func(t *testing.T) { itr := m.Iterator() itr.Seek("1000") - if k, _ := itr.Next(); k != nil { + if k, _, ok := itr.Next(); ok { t.Fatalf("0. SortedMapIterator.Next()=%v, expected nil key", k) } else if !itr.Done() { t.Fatalf("SortedMapIterator.Done()=true, expected false") @@ -2078,7 +2034,7 @@ func TestNewHasher(t *testing.T) { t.Run("uint64", func(t *testing.T) { testNewHasher(t, uint64(100)) }) t.Run("string", func(t *testing.T) { testNewHasher(t, "foo") }) - t.Run("byteSlice", func(t *testing.T) { testNewHasher(t, []byte("foo")) }) + //t.Run("byteSlice", func(t *testing.T) { testNewHasher(t, []byte("foo")) }) }) t.Run("reflection", func(t *testing.T) { @@ -2093,7 +2049,7 @@ func TestNewHasher(t *testing.T) { }) } -func testNewHasher(t *testing.T, v interface{}) { +func testNewHasher[V constraints.Ordered](t *testing.T, v V) { t.Helper() h := NewHasher(v) h.Hash(v) @@ -2117,7 +2073,7 @@ func TestNewComparer(t *testing.T) { t.Run("uint64", func(t *testing.T) { testNewComparer(t, uint64(100), uint64(101)) }) t.Run("string", func(t *testing.T) { testNewComparer(t, "bar", "foo") }) - t.Run("byteSlice", func(t *testing.T) { testNewComparer(t, []byte("bar"), []byte("foo")) }) + //t.Run("byteSlice", func(t *testing.T) { testNewComparer(t, []byte("bar"), []byte("foo")) }) }) t.Run("reflection", func(t *testing.T) { @@ -2132,7 +2088,7 @@ func TestNewComparer(t *testing.T) { }) } -func testNewComparer(t *testing.T, x, y interface{}) { +func testNewComparer[T constraints.Ordered](t *testing.T, x, y T) { t.Helper() c := NewComparer(x) if c.Compare(x, y) != -1 { @@ -2146,16 +2102,16 @@ func testNewComparer(t *testing.T, x, y interface{}) { // TSortedMap represents a combined immutable and stdlib sorted map. type TSortedMap struct { - im, prev *SortedMap - builder *SortedMapBuilder + im, prev *SortedMap[int, int] + builder *SortedMapBuilder[int, int] std map[int]int keys []int } func NewTSortedMap() *TSortedMap { return &TSortedMap{ - im: NewSortedMap(nil), - builder: NewSortedMapBuilder(nil), + im: NewSortedMap[int, int](nil), + builder: NewSortedMapBuilder[int, int](nil), std: make(map[int]int), } } @@ -2235,10 +2191,10 @@ func (m *TSortedMap) Validate() error { return nil } -func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator) error { +func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator[int, int]) error { for i, k0 := range m.keys { v0 := m.std[k0] - if k1, v1 := itr.Next(); k0 != k1 || v0 != v1 { + if k1, v1, ok := itr.Next(); !ok || k0 != k1 || v0 != v1 { return fmt.Errorf("%d. SortedMapIterator.Next()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) } @@ -2247,18 +2203,18 @@ func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator) error { return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) } } - if k, v := itr.Next(); k != nil || v != nil { + if k, v, ok := itr.Next(); ok { return fmt.Errorf("SortedMapIterator.Next()=<%v,%v>, expected nil after done", k, v) } return nil } -func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator) error { +func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator[int, int]) error { itr.Last() for i := len(m.keys) - 1; i >= 0; i-- { k0 := m.keys[i] v0 := m.std[k0] - if k1, v1 := itr.Prev(); k0 != k1 || v0 != v1 { + if k1, v1, ok := itr.Prev(); !ok || k0 != k1 || v0 != v1 { return fmt.Errorf("%d. SortedMapIterator.Prev()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) } @@ -2267,7 +2223,7 @@ func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator) error { return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) } } - if k, v := itr.Prev(); k != nil || v != nil { + if k, v, ok := itr.Prev(); ok { return fmt.Errorf("SortedMapIterator.Prev()=<%v,%v>, expected nil after done", k, v) } return nil @@ -2275,7 +2231,7 @@ func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator) error { func BenchmarkSortedMap_Set(b *testing.B) { b.ReportAllocs() - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < b.N; i++ { m = m.Set(i, i) } @@ -2284,7 +2240,7 @@ func BenchmarkSortedMap_Set(b *testing.B) { func BenchmarkSortedMap_Delete(b *testing.B) { const n = 10000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < n; i++ { m = m.Set(i, i) } @@ -2298,7 +2254,7 @@ func BenchmarkSortedMap_Delete(b *testing.B) { func BenchmarkSortedMap_Iterator(b *testing.B) { const n = 10000 - m := NewSortedMap(nil) + m := NewSortedMap[int, int](nil) for i := 0; i < 10000; i++ { m = m.Set(i, i) } @@ -2328,7 +2284,7 @@ func BenchmarkSortedMap_Iterator(b *testing.B) { func BenchmarkSortedMapBuilder_Set(b *testing.B) { b.ReportAllocs() - builder := NewSortedMapBuilder(nil) + builder := NewSortedMapBuilder[int, int](nil) for i := 0; i < b.N; i++ { builder.Set(i, i) } @@ -2337,7 +2293,7 @@ func BenchmarkSortedMapBuilder_Set(b *testing.B) { func BenchmarkSortedMapBuilder_Delete(b *testing.B) { const n = 1000000 - builder := NewSortedMapBuilder(nil) + builder := NewSortedMapBuilder[int, int](nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -2350,7 +2306,7 @@ func BenchmarkSortedMapBuilder_Delete(b *testing.B) { } func ExampleSortedMap_Set() { - m := NewSortedMap(nil) + m := NewSortedMap[string, any](nil) m = m.Set("foo", "bar") m = m.Set("baz", 100) @@ -2369,7 +2325,7 @@ func ExampleSortedMap_Set() { } func ExampleSortedMap_Delete() { - m := NewSortedMap(nil) + m := NewSortedMap[string, any](nil) m = m.Set("foo", "bar") m = m.Set("baz", 100) m = m.Delete("baz") @@ -2385,7 +2341,7 @@ func ExampleSortedMap_Delete() { } func ExampleSortedMap_Iterator() { - m := NewSortedMap(nil) + m := NewSortedMap[string, any](nil) m = m.Set("strawberry", 900) m = m.Set("kiwi", 300) m = m.Set("apple", 100) @@ -2398,7 +2354,7 @@ func ExampleSortedMap_Iterator() { itr := m.Iterator() for !itr.Done() { - k, v := itr.Next() + k, v, _ := itr.Next() fmt.Println(k, v) } // Output: @@ -2414,7 +2370,7 @@ func ExampleSortedMap_Iterator() { } func ExampleSortedMapBuilder_Set() { - b := NewSortedMapBuilder(nil) + b := NewSortedMapBuilder[string, any](nil) b.Set("foo", "bar") b.Set("baz", 100) @@ -2434,7 +2390,7 @@ func ExampleSortedMapBuilder_Set() { } func ExampleSortedMapBuilder_Delete() { - b := NewSortedMapBuilder(nil) + b := NewSortedMapBuilder[string, any](nil) b.Set("foo", "bar") b.Set("baz", 100) b.Delete("baz") @@ -2479,27 +2435,27 @@ func uniqueIntSlice(a []int) []int { } // mockHasher represents a mock implementation of immutable.Hasher. -type mockHasher struct { - hash func(value interface{}) uint32 - equal func(a, b interface{}) bool +type mockHasher[K constraints.Ordered] struct { + hash func(value K) uint32 + equal func(a, b K) bool } // Hash executes the mocked HashFn function. -func (h *mockHasher) Hash(value interface{}) uint32 { +func (h *mockHasher[K]) Hash(value K) uint32 { return h.hash(value) } // Equal executes the mocked EqualFn function. -func (h *mockHasher) Equal(a, b interface{}) bool { +func (h *mockHasher[K]) Equal(a, b K) bool { return h.equal(a, b) } // mockComparer represents a mock implementation of immutable.Comparer. -type mockComparer struct { - compare func(a, b interface{}) int +type mockComparer[K constraints.Ordered] struct { + compare func(a, b K) int } // Compare executes the mocked CompreFn function. -func (h *mockComparer) Compare(a, b interface{}) int { +func (h *mockComparer[K]) Compare(a, b K) int { return h.compare(a, b) } |