aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Johnson <benbjohnson@yahoo.com>2019-03-01 13:44:37 -0700
committerBen Johnson <benbjohnson@yahoo.com>2019-03-01 14:07:27 -0700
commit020c5a4c470f65e2968cc0e73c972beccc50b39f (patch)
treed3e077a39fa30328c90a69fe399165bf15c80294
downloadpds-020c5a4c470f65e2968cc0e73c972beccc50b39f.tar.gz
pds-020c5a4c470f65e2968cc0e73c972beccc50b39f.tar.xz
-rw-r--r--LICENSE19
-rw-r--r--README.md257
-rw-r--r--go.mod5
-rw-r--r--go.sum2
-rw-r--r--immutable.go1864
-rw-r--r--immutable_test.go2038
6 files changed, 4185 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..428ce7d
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,19 @@
+Copyright 2019 Ben Johnson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..71c33d7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,257 @@
+Immutable ![license](https://img.shields.io/github/license/benbjohnson/immutable.svg?style=flat-square) ![release](https://img.shields.io/github/release/benbjohnson/immutable.svg?style=flat-square)
+=========
+
+This repository contains immutable collection types for Go. It includes
+`List`, `Map`, and `SortedMap` implementations. Immutable collections can
+provide efficient, lock free sharing of data by requiring that edits to the
+collections return new collections.
+
+The collection types in this library are meant to mimic Go built-in collections
+such as`slice` and `map`. The primary usage difference between Go collections
+and `immutable` collections is that `immutable` collections always return a new
+collection on mutation so you will need to save the new reference.
+
+Immutable collections are not for every situation, however, as they can incur
+additional CPU and memory overhead. Please evaluate the cost/benefit for your
+particular project.
+
+Special thanks to the [Immutable.js](https://immutable-js.github.io/immutable-js/)
+team as the `List` & `Map` implementations are loose ports from that project.
+
+
+## List
+
+The `List` type represents a sorted, indexed collection of values and operates
+similarly to a Go slice. It supports efficient append, prepend, update, and
+slice operations.
+
+
+### Adding list elements
+
+Elements can be added to the end of the list with the `Append()` method or added
+to the beginning of the list with the `Prepend()` method. Unlike Go slices,
+prepending is as efficient as appending.
+
+```go
+// Create a list with 3 elements.
+l := immutable.NewList()
+l = l.Append("foo")
+l = l.Append("bar")
+l = l.Prepend("baz")
+
+fmt.Println(l.Len()) // 3
+fmt.Println(l.Get(0)) // "baz"
+fmt.Println(l.Get(1)) // "foo"
+fmt.Println(l.Get(2)) // "bar"
+```
+
+Note that each change to the list results in a new list being created. These
+lists are all snapshots at that point in time and cannot be changed so they
+are safe to share between multiple goroutines.
+
+### Updating list elements
+
+You can also overwrite existing elements by using the `Set()` method. In the
+following example, we'll update the third element in our list and return the
+new list to a new variable. You can see that our old `l` variable retains a
+snapshot of the original value.
+
+```go
+l := immutable.NewList()
+l = l.Append("foo")
+l = l.Append("bar")
+newList := l.Set(2, "baz")
+
+fmt.Println(l.Get(1)) // "bar"
+fmt.Println(newList.Get(1)) // "baz"
+```
+
+### Deriving sublists
+
+You can create a sublist by using the `Slice()` method. This method works with
+the same rules as subslicing a Go slice:
+
+```go
+l = l.Slice(0, 2)
+
+fmt.Println(l.Len()) // 2
+fmt.Println(l.Get(0)) // "baz"
+fmt.Println(l.Get(1)) // "foo"
+```
+
+Please note that since `List` follows the same rules as slices, it will panic if
+you try to `Get()`, `Set()`, or `Slice()` with indexes that are outside of
+the range of the `List`.
+
+
+
+### Iterating lists
+
+Iterators provide a clean, simple way to iterate over the elements of the list
+in order. This is more efficient than simply calling `Get()` for each index.
+
+Below is an example of iterating over all elements of our list from above:
+
+```go
+itr := l.Iterator()
+for !itr.Done() {
+ index, value := itr.Next()
+ fmt.Printf("Index %d equals %v\n", index, value)
+}
+
+// Index 0 equals baz
+// Index 1 equals foo
+```
+
+By default iterators start from index zero, however, the `Seek()` method can be
+used to jump to a given index.
+
+
+
+## Map
+
+The `Map` represents an associative array that maps unique keys to values. It
+is implemented to act similarly to the built-in Go `map` type. It is implemented
+as a [Hash-Array Mapped Trie](https://lampwww.epfl.ch/papers/idealhashtrees.pdf).
+
+Maps require a `Hasher` to hash keys and check for equality. There are built-in
+hasher implementations for `int`, `string`, and `[]byte` keys. You may pass in
+a `nil` hasher to `NewMap()` if you are using one of these key types.
+
+
+### Setting map key/value pairs
+
+You can add a key/value pair to the map by using the `Set()` method. It will
+add the key if it does not exist or it will overwrite the value for the key if
+it does exist.
+
+Values may be fetched for a key using the `Get()` method. This method returns
+the value as well as a flag indicating if the key existed. The flag is useful
+to check if a `nil` value was set for a key versus a key did not exist.
+
+```go
+m := immutable.NewMap(nil)
+m = m.Set("jane", 100)
+m = m.Set("susy", 200)
+m = m.Set("jane", 300) // overwrite
+
+fmt.Println(m.Len()) // 2
+
+v, ok := m.Get("jane")
+fmt.Println(v, ok) // 300 true
+
+v, ok = m.Get("susy")
+fmt.Println(v, ok) // 200, true
+
+v, ok = m.Get("john")
+fmt.Println(v, ok) // nil, false
+```
+
+
+### Removing map keys
+
+Keys may be removed from the map by using the `Delete()` method. If the key does
+not exist then the original map is returned instead of a new one.
+
+```go
+m := immutable.NewMap(nil)
+m = m.Set("jane", 100)
+m = m.Delete("jane")
+
+fmt.Println(m.Len()) // 0
+
+v, ok := m.Get("jane")
+fmt.Println(v, ok) // nil false
+```
+
+
+### Iterating maps
+
+Maps are unsorted, however, iterators can be used to loop over all key/value
+pairs in the collection. Unlike Go maps, iterators are deterministic when
+iterating over key/value pairs.
+
+```go
+m := immutable.NewMap(nil)
+m = m.Set("jane", 100)
+m = m.Set("susy", 200)
+
+itr := m.Iterator()
+for !itr.Done() {
+ k, v := itr.Next()
+ fmt.Println(k, v)
+}
+
+// susy 200
+// jane 100
+```
+
+Note that you should not rely on two maps with the same key/value pairs to
+iterate in the same order. Ordering can be insertion order dependent when two
+keys generate the same hash.
+
+
+### Implementing a custom Hasher
+
+If you need to use a key type besides `int`, `string`, or `[]byte` then you'll
+need to create a custom `Hasher` implementation and pass it to `NewMap()` on
+creation.
+
+Hashers are fairly simple. They only need to generate hashes for a given key
+and check equality given two keys.
+
+```go
+type Hasher interface {
+ Hash(key interface{}) uint32
+ Equal(a, b interface{}) bool
+}
+```
+
+Please see the `IntHasher`, `StringHasher`, or `ByteSliceHasher` for examples.
+
+
+## Sorted Map
+
+The `SortedMap` represents an associative array that maps unique keys to values.
+Unlike the `Map`, however, keys can be iterated over in-order. It is implemented
+as a B+tree.
+
+Sorted maps require a `Comparer` to sort keys and check for equality. There are
+built-in comparer implementations for `int`, `string`, and `[]byte` keys. You
+may pass a `nil` comparer to `NewSortedMap()` if you are using one of these key
+types.
+
+The API is identical to the `Map` implementation.
+
+
+### Implementing a custom Comparer
+
+If you need to use a key type besides `int`, `string`, or `[]byte` then you'll
+need to create a custom `Comparer` implementation and pass it to
+`NewSortedMap()` on creation.
+
+Comparers on have one method—`Compare()`. It works the same as the
+`strings.Compare()` function. It returns `-1` if `a` is less than `b`, returns
+`1` if a is greater than `b`, and returns `0` if `a` is equal to `b`.
+
+```go
+type Comparer interface {
+ Compare(a, b interface{}) int
+}
+```
+
+Please see the `IntComparer`, `StringComparer`, or `ByteSliceComparer` for examples.
+
+
+
+## Contributing
+
+The goal of `immutable` is to provide stable, reasonably performant, immutable
+collections library for Go that has a simple, idiomatic API. As such, additional
+features and minor performance improvements will generally not be accepted. If
+you have a suggestion for a clearer API or substantial performance improvement,
+_please_ open an issue first to discuss. All pull requests without a related
+issue will be closed immediately.
+
+Please submit issues relating to bugs & documentation improvements.
+
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..9fa105c
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module github.com/benbjohnson/immutable
+
+go 1.12
+
+require github.com/google/go-cmp v0.2.0
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..5f4f636
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
diff --git a/immutable.go b/immutable.go
new file mode 100644
index 0000000..6f5e6f6
--- /dev/null
+++ b/immutable.go
@@ -0,0 +1,1864 @@
+// Package immutable provides immutable collection types.
+//
+// Introduction
+//
+// Immutable collections provide an efficient, safe way to share collections
+// of data while minimizing locks. The collections in this package provide
+// List, Map, and SortedMap implementations. These act similarly to slices
+// and maps, respectively, except that altering a collection returns a new
+// copy of the collection with that change.
+//
+// Because collections are unable to change, they are safe for multiple
+// goroutines to read from at the same time without a mutex. However, these
+// types of collections come with increased CPU & memory usage as compared
+// with Go's built-in collection types so please evaluate for your specific
+// use.
+//
+// Collection Types
+//
+// The List type provides an API similar to Go slices. They allow appending,
+// prepending, and updating of elements. Elements can also be fetched by index
+// or iterated over using a ListIterator.
+//
+// The Map & SortedMap types provide an API similar to Go maps. They allow
+// values to be assigned to unique keys and allow for the deletion of keys.
+// Values can be fetched by key and key/value pairs can be iterated over using
+// the appropriate iterator type. Both map types provide the same API. The
+// SortedMap, however, provides iteration over sorted keys while the Map
+// provides iteration over unsorted keys. Maps improved performance and memory
+// usage as compared to SortedMaps.
+//
+// Hashing and Sorting
+//
+// Map types require the use of a Hasher implementation to calculate hashes for
+// their keys and check for key equality. SortedMaps require the use of a
+// Comparer implementation to sort keys in the map.
+//
+// These collection types automatically provide built-in hasher and comparers
+// for int, string, and byte slice keys. If you are using one of these key types
+// then simply pass a nil into the constructor. Otherwise you will need to
+// implement a custom Hasher or Comparer type. Please see the provided
+// implementations for reference.
+package immutable
+
+import (
+ "bytes"
+ "fmt"
+ "math/bits"
+ "sort"
+ "strings"
+)
+
+// Lists are dense, ordered, indexed collections. They are analogous to slices
+// in Go. They can be updated by appending to the end of the list, prepending
+// values to the beginning of the list, or updating existing indexes in the
+// list.
+type List struct {
+ root listNode // root node
+ origin int // offset to zero index element
+ size int // total number of elements in use
+}
+
+// NewList returns a new empty instance of List.
+func NewList() *List {
+ return &List{
+ root: &listLeafNode{},
+ }
+}
+
+// Len returns the number of elements in the list.
+func (l *List) Len() int {
+ return l.size
+}
+
+// cap returns the total number of possible elements for the current depth.
+func (l *List) cap() int {
+ return 1 << (l.root.depth() * listNodeBits)
+}
+
+// Get returns the value at the given index. Similar to slices, this method will
+// panic if index is below zero or is greater than or equal to the list size.
+func (l *List) Get(index int) interface{} {
+ if index < 0 || index >= l.size {
+ panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index))
+ }
+ return l.root.get(l.origin + index)
+}
+
+// Set returns a new list with value set at index. Similar to slices, this
+// method will panic if index is below zero or if the index is greater than
+// or equal to the list size.
+func (l *List) Set(index int, value interface{}) *List {
+ if index < 0 || index >= l.size {
+ panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index))
+ }
+ other := *l
+ other.root = other.root.set(l.origin+index, value)
+ return &other
+}
+
+// Append returns a new list with value added to the end of the list.
+func (l *List) Append(value interface{}) *List {
+ // Expand list to the right if no slots remain.
+ other := *l
+ if other.size+other.origin >= l.cap() {
+ newRoot := &listBranchNode{d: other.root.depth() + 1}
+ newRoot.children[0] = other.root
+ other.root = newRoot
+ }
+
+ // Increase size and set the last element to the new value.
+ other.size++
+ other.root = other.root.set(other.origin+other.size-1, value)
+ return &other
+}
+
+// Prepend returns a new list with value added to the beginning of the list.
+func (l *List) Prepend(value interface{}) *List {
+ // Expand list to the left if no slots remain.
+ other := *l
+ if other.origin == 0 {
+ newRoot := &listBranchNode{d: other.root.depth() + 1}
+ newRoot.children[listNodeSize-1] = other.root
+ other.root = newRoot
+ other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits)
+ }
+
+ // Increase size and move origin back. Update first element to value.
+ other.size++
+ other.origin--
+ other.root = other.root.set(other.origin, value)
+ return &other
+}
+
+// Slice returns a new list of elements between start index and end index.
+// Similar to slices, this method will panic if start or end are below zero or
+// greater than the list size. A panic will also occur if start is greater than
+// end.
+//
+// Unlike Go slices, references to inaccessible elements will be automatically
+// removed so they can be garbage collected.
+func (l *List) Slice(start, end int) *List {
+ // Panics similar to Go slices.
+ if start < 0 || start > l.size {
+ panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start))
+ } else if end < 0 || end > l.size {
+ panic(fmt.Sprintf("immutable.List.Slice: end index %d out of bounds", end))
+ } else if start > end {
+ panic(fmt.Sprintf("immutable.List.Slice: invalid slice index: [%d:%d]", start, end))
+ }
+
+ // Return the same list if the start and end are the entire range.
+ if start == 0 && end == l.size {
+ return l
+ }
+
+ // Create copy with new origin/size.
+ other := *l
+ other.origin = l.origin + start
+ other.size = end - start
+
+ // Contract tree while the start & end are in the same child node.
+ for other.root.depth() > 1 {
+ i := (other.origin >> (other.root.depth() * listNodeBits)) & listNodeMask
+ j := ((other.origin + other.size - 1) >> (other.root.depth() * listNodeBits)) & listNodeMask
+ if i != j {
+ break // branch contains at least two nodes, exit
+ }
+
+ // Replace the current root with the single child & update origin offset.
+ other.origin -= i << (other.root.depth() * listNodeBits)
+ other.root = other.root.(*listBranchNode).children[i]
+ }
+
+ // Ensure all references are removed before start & after end.
+ other.root = other.root.deleteBefore(other.origin)
+ other.root = other.root.deleteAfter(other.origin + other.size - 1)
+
+ return &other
+}
+
+// Iterator returns a new iterator for this list positioned at the first index.
+func (l *List) Iterator() *ListIterator {
+ itr := &ListIterator{list: l}
+ itr.First()
+ return itr
+}
+
+// Constants for bit shifts used for levels in the List trie.
+const (
+ listNodeBits = 5
+ listNodeSize = 1 << listNodeBits
+ listNodeMask = listNodeSize - 1
+)
+
+// listNode represents either a branch or leaf node in a List.
+type listNode interface {
+ depth() uint
+ get(index int) interface{}
+ set(index int, v interface{}) listNode
+
+ containsBefore(index int) bool
+ containsAfter(index int) bool
+
+ deleteBefore(index int) listNode
+ deleteAfter(index int) listNode
+}
+
+// newListNode returns a leaf node for depth zero, otherwise returns a branch node.
+func newListNode(depth uint) listNode {
+ if depth == 0 {
+ return &listLeafNode{}
+ }
+ return &listBranchNode{d: depth}
+}
+
+// listBranchNode represents a branch of a List tree at a given depth.
+type listBranchNode struct {
+ d uint // depth
+ children [listNodeSize]listNode
+}
+
+// depth returns the depth of this branch node from the leaf.
+func (n *listBranchNode) depth() uint { return n.d }
+
+// get returns the child node at the segment of the index for this depth.
+func (n *listBranchNode) get(index int) interface{} {
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+ return n.children[idx].get(index)
+}
+
+// set recursively updates the value at index for each lower depth from the node.
+func (n *listBranchNode) set(index int, v interface{}) listNode {
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+
+ // Find child for the given value in the branch. Create new if it doesn't exist.
+ child := n.children[idx]
+ if child == nil {
+ child = newListNode(n.depth() - 1)
+ }
+
+ // Return a copy of this branch with the new child.
+ other := *n
+ other.children[idx] = child.set(index, v)
+ return &other
+}
+
+// containsBefore returns true if non-nil values exists between [0,index).
+func (n *listBranchNode) containsBefore(index int) bool {
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+
+ // Quickly check if any direct children exist before this segment of the index.
+ for i := 0; i < idx; i++ {
+ if n.children[i] != nil {
+ return true
+ }
+ }
+
+ // Recursively check for children directly at the given index at this segment.
+ if n.children[idx] != nil && n.children[idx].containsBefore(index) {
+ return true
+ }
+ return false
+}
+
+// containsAfter returns true if non-nil values exists between (index,listNodeSize).
+func (n *listBranchNode) containsAfter(index int) bool {
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+
+ // Quickly check if any direct children exist after this segment of the index.
+ for i := idx + 1; i < len(n.children); i++ {
+ if n.children[i] != nil {
+ return true
+ }
+ }
+
+ // Recursively check for children directly at the given index at this segment.
+ if n.children[idx] != nil && n.children[idx].containsAfter(index) {
+ return true
+ }
+ return false
+}
+
+// deleteBefore returns a new node with all elements before index removed.
+func (n *listBranchNode) deleteBefore(index int) listNode {
+ // Ignore if no nodes exist before the given index.
+ if !n.containsBefore(index) {
+ return n
+ }
+
+ // Return a copy with any nodes prior to the index removed.
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+ other := &listBranchNode{d: n.d}
+ copy(other.children[idx:][:], n.children[idx:][:])
+ if other.children[idx] != nil {
+ other.children[idx] = other.children[idx].deleteBefore(index)
+ }
+ return other
+}
+
+// deleteBefore returns a new node with all elements before index removed.
+func (n *listBranchNode) deleteAfter(index int) listNode {
+ // Ignore if no nodes exist after the given index.
+ if !n.containsAfter(index) {
+ return n
+ }
+
+ // Return a copy with any nodes after the index removed.
+ idx := (index >> (n.d * listNodeBits)) & listNodeMask
+ other := &listBranchNode{d: n.d}
+ copy(other.children[:idx+1], n.children[:idx+1])
+ if other.children[idx] != nil {
+ other.children[idx] = other.children[idx].deleteAfter(index)
+ }
+ return other
+}
+
+// listLeafNode represents a leaf node in a List.
+type listLeafNode struct {
+ children [listNodeSize]interface{}
+}
+
+// depth always returns 0 for leaf nodes.
+func (n *listLeafNode) depth() uint { return 0 }
+
+// get returns the value at the given index.
+func (n *listLeafNode) get(index int) interface{} {
+ return n.children[index&listNodeMask]
+}
+
+// set returns a copy of the node with the value at the index updated to v.
+func (n *listLeafNode) set(index int, v interface{}) listNode {
+ idx := index & listNodeMask
+ other := *n
+ other.children[idx] = v
+ return &other
+}
+
+// containsBefore returns true if non-nil values exists between [0,index).
+func (n *listLeafNode) containsBefore(index int) bool {
+ idx := index & listNodeMask
+ for i := 0; i < idx; i++ {
+ if n.children[i] != nil {
+ return true
+ }
+ }
+ return false
+}
+
+// containsAfter returns true if non-nil values exists between (index,listNodeSize).
+func (n *listLeafNode) containsAfter(index int) bool {
+ idx := index & listNodeMask
+ for i := idx + 1; i < len(n.children); i++ {
+ if n.children[i] != nil {
+ return true
+ }
+ }
+ return false
+}
+
+// deleteBefore returns a new node with all elements before index removed.
+func (n *listLeafNode) deleteBefore(index int) listNode {
+ if !n.containsBefore(index) {
+ return n
+ }
+
+ idx := index & listNodeMask
+ var other listLeafNode
+ copy(other.children[idx:][:], n.children[idx:][:])
+ return &other
+}
+
+// deleteBefore returns a new node with all elements before index removed.
+func (n *listLeafNode) deleteAfter(index int) listNode {
+ if !n.containsAfter(index) {
+ return n
+ }
+
+ idx := index & listNodeMask
+ var other listLeafNode
+ copy(other.children[:idx+1][:], n.children[:idx+1][:])
+ return &other
+}
+
+// ListIterator represents an ordered iterator over a list.
+type ListIterator struct {
+ list *List // source list
+ index int // current index position
+
+ stack [32]listIteratorElem // search stack
+ depth int // stack depth
+}
+
+// Done returns true if no more elements remain in the iterator.
+func (itr *ListIterator) Done() bool {
+ return itr.index < 0 || itr.index >= itr.list.Len()
+}
+
+// First positions the iterator on the first index.
+// If source list is empty then no change is made.
+func (itr *ListIterator) First() {
+ if itr.list.Len() != 0 {
+ itr.Seek(0)
+ }
+}
+
+// Last positions the iterator on the last index.
+// If source list is empty then no change is made.
+func (itr *ListIterator) Last() {
+ if n := itr.list.Len(); n != 0 {
+ itr.Seek(n - 1)
+ }
+}
+
+// Seek moves the iterator position to the given index in the list.
+// Similar to Go slices, this method will panic if index is below zero or if
+// the index is greater than or equal to the list size.
+func (itr *ListIterator) Seek(index int) {
+ // Panic similar to Go slices.
+ if index < 0 || index >= itr.list.Len() {
+ panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index))
+ }
+ itr.index = index
+
+ // Reset to the bottom of the stack at seek to the correct position.
+ itr.stack[0] = listIteratorElem{node: itr.list.root}
+ itr.depth = 0
+ itr.seek(index)
+}
+
+// Next returns the current index and its value & moves the iterator forward.
+// Returns an index of -1 if the there are no more elements to return.
+func (itr *ListIterator) Next() (index int, value interface{}) {
+ // Exit immediately if there are no elements remaining.
+ if itr.Done() {
+ return -1, nil
+ }
+
+ // Retrieve current index & value.
+ elem := &itr.stack[itr.depth]
+ index, value = itr.index, elem.node.(*listLeafNode).children[elem.index]
+
+ // Increase index. If index is at the end then return immediately.
+ itr.index++
+ if itr.Done() {
+ return index, value
+ }
+
+ // Move up stack until we find a node that has remaining position ahead.
+ for ; itr.depth > 0 && itr.stack[itr.depth].index >= listNodeSize-1; itr.depth-- {
+ }
+
+ // Seek to correct position from current depth.
+ itr.seek(itr.index)
+
+ return index, value
+}
+
+// Prev returns the current index and value and moves the iterator backward.
+// Returns an index of -1 if the there are no more elements to return.
+func (itr *ListIterator) Prev() (index int, value interface{}) {
+ // Exit immediately if there are no elements remaining.
+ if itr.Done() {
+ return -1, nil
+ }
+
+ // Retrieve current index & value.
+ elem := &itr.stack[itr.depth]
+ index, value = itr.index, elem.node.(*listLeafNode).children[elem.index]
+
+ // Decrease index. If index is past the beginning then return immediately.
+ itr.index--
+ if itr.Done() {
+ return index, value
+ }
+
+ // Move up stack until we find a node that has remaining position behind.
+ for ; itr.depth > 0 && itr.stack[itr.depth].index == 0; itr.depth-- {
+ }
+
+ // Seek to correct position from current depth.
+ itr.seek(itr.index)
+
+ return index, value
+}
+
+// seek positions the stack to the given index from the current depth.
+// Elements and indexes below the current depth are assumed to be correct.
+func (itr *ListIterator) seek(index int) {
+ // Iterate over each level until we reach a leaf node.
+ for {
+ elem := &itr.stack[itr.depth]
+ elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask
+
+ switch node := elem.node.(type) {
+ case *listBranchNode:
+ child := node.children[elem.index]
+ itr.stack[itr.depth+1] = listIteratorElem{node: child}
+ itr.depth++
+ case *listLeafNode:
+ return
+ }
+ }
+}
+
+// listIteratorElem represents the node and it's child index within the stack.
+type listIteratorElem struct {
+ node listNode
+ index int
+}
+
+// Size thresholds for each type of branch node.
+const (
+ maxArrayMapSize = 8
+ maxBitmapIndexedSize = 16
+)
+
+// Segment bit shifts within the map tree.
+const (
+ mapNodeBits = 5
+ mapNodeSize = 1 << mapNodeBits
+ mapNodeMask = mapNodeSize - 1
+)
+
+// Map represents an immutable hash map implementation. The map uses a Hasher
+// to generate hashes and check for equality of key values.
+//
+// It is implemented as an Hash Array Mapped Trie.
+type Map struct {
+ size int // total number of key/value pairs
+ root mapNode // root node of trie
+ hasher Hasher // hasher implementation
+}
+
+// NewMap returns a new instance of Map. If hasher is nil, a default hasher
+// implementation will automatically be chosen based on the first key added.
+// Default hasher implementations only exist for int, string, and byte slice types.
+func NewMap(hasher Hasher) *Map {
+ return &Map{
+ hasher: hasher,
+ }
+}
+
+// Len returns the number of elements in the map.
+func (m *Map) Len() int {
+ return m.size
+}
+
+// Get returns the value for a given key and a flag indicating whether the
+// key exists. This flag distinguishes a nil value set on a key versus a
+// non-existent key in the map.
+func (m *Map) Get(key interface{}) (value interface{}, ok bool) {
+ if m.root == nil {
+ return nil, false
+ }
+ keyHash := m.hasher.Hash(key)
+ return m.root.get(key, 0, keyHash, m.hasher)
+}
+
+// Set returns a map with the key set to the new value. A nil value is allowed.
+//
+// This function will return a new map even if the updated value is the same as
+// the existing value because Map does not track value equality.
+func (m *Map) Set(key, value interface{}) *Map {
+ // Set a hasher on the first value if one does not already exist.
+ hasher := m.hasher
+ if hasher == nil {
+ switch key.(type) {
+ case int:
+ hasher = &intHasher{}
+ case string:
+ hasher = &stringHasher{}
+ case []byte:
+ hasher = &byteSliceHasher{}
+ default:
+ panic(fmt.Sprintf("immutable.Map.Set: must set hasher for %T type", key))
+ }
+ }
+
+ // If the map is empty, initialize with a simple array node.
+ if m.root == nil {
+ return &Map{
+ size: 1,
+ root: &mapArrayNode{entries: []mapEntry{{key: key, value: value}}},
+ hasher: hasher,
+ }
+ }
+
+ // Otherwise copy the map and delegate insertion to the root.
+ // Resized will return true if the key does not currently exist.
+ var resized bool
+ other := &Map{
+ size: m.size,
+ root: m.root.set(key, value, 0, hasher.Hash(key), hasher, &resized),
+ hasher: hasher,
+ }
+ if resized {
+ other.size++
+ }
+ return other
+}
+
+// Delete returns a map with the given key removed.
+// Removing a non-existent key will cause this method to return the same map.
+func (m *Map) Delete(key interface{}) *Map {
+ // Return original map if no keys exist.
+ if m.root == nil {
+ return m
+ }
+
+ // If the delete did not change the node then return the original map.
+ newRoot := m.root.delete(key, 0, m.hasher.Hash(key), m.hasher)
+ if newRoot == m.root {
+ return m
+ }
+
+ // Return copy of map with new root and decreased size.
+ return &Map{
+ size: m.size - 1,
+ root: newRoot,
+ hasher: m.hasher,
+ }
+}
+
+// Iterator returns a new iterator for the map.
+func (m *Map) Iterator() *MapIterator {
+ itr := &MapIterator{m: m}
+ itr.First()
+ return itr
+}
+
+// mapNode represents any node in the map tree.
+type mapNode interface {
+ get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool)
+ set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode
+ delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode
+}
+
+var _ mapNode = (*mapArrayNode)(nil)
+var _ mapNode = (*mapBitmapIndexedNode)(nil)
+var _ mapNode = (*mapHashArrayNode)(nil)
+var _ mapNode = (*mapValueNode)(nil)
+var _ mapNode = (*mapHashCollisionNode)(nil)
+
+// mapLeafNode represents a node that stores a single key hash at the leaf of the map tree.
+type mapLeafNode interface {
+ mapNode
+ keyHashValue() uint32
+}
+
+var _ mapLeafNode = (*mapValueNode)(nil)
+var _ mapLeafNode = (*mapHashCollisionNode)(nil)
+
+// mapArrayNode is a map node that stores key/value pairs in a slice.
+// Entries are stored in insertion order. An array node expands into a bitmap
+// indexed node once a given threshold size is crossed.
+type mapArrayNode struct {
+ entries []mapEntry
+}
+
+// indexOf returns the entry index of the given key. Returns -1 if key not found.
+func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int {
+ for i := range n.entries {
+ if h.Equal(n.entries[i].key, key) {
+ return i
+ }
+ }
+ return -1
+}
+
+// get returns the value for the given key.
+func (n *mapArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+ i := n.indexOf(key, h)
+ if i == -1 {
+ return nil, false
+ }
+ return n.entries[i].value, true
+}
+
+// set inserts or updates the value for a given key. If the key is inserted and
+// the new size crosses the max size threshold, a bitmap indexed node is returned.
+func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode {
+ idx := n.indexOf(key, h)
+
+ // Mark as resized if the key doesn't exist.
+ if idx == -1 {
+ *resized = true
+ }
+
+ // If we are adding and it crosses the max size threshold, expand the node.
+ // We do this by continually setting the entries to a value node and expanding.
+ if idx == -1 && len(n.entries) >= maxArrayMapSize {
+ var node mapNode = newMapValueNode(h.Hash(key), key, value)
+ for _, entry := range n.entries {
+ node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, resized)
+ }
+ return node
+ }
+
+ // Update existing entry if a match is found.
+ // Otherwise append to the end of the element list if it doesn't exist.
+ var other mapArrayNode
+ if idx != -1 {
+ other.entries = make([]mapEntry, len(n.entries))
+ copy(other.entries, n.entries)
+ other.entries[idx] = mapEntry{key, value}
+ } else {
+ other.entries = make([]mapEntry, len(n.entries)+1)
+ copy(other.entries, n.entries)
+ other.entries[len(other.entries)-1] = mapEntry{key, value}
+ }
+ return &other
+}
+
+// delete removes the given key from the node. Returns the same node if key does
+// not exist. Returns a nil node when removing the last entry.
+func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode {
+ idx := n.indexOf(key, h)
+
+ // Return original node if key does not exist.
+ if idx == -1 {
+ return n
+ }
+
+ // Return nil if this node will contain no nodes.
+ if len(n.entries) == 1 {
+ return nil
+ }
+
+ // Otherwise create a copy with the given entry removed.
+ other := &mapArrayNode{entries: make([]mapEntry, len(n.entries)-1)}
+ copy(other.entries[:idx], n.entries[:idx])
+ copy(other.entries[idx:], n.entries[idx+1:])
+ return other
+}
+
+// mapBitmapIndexedNode represents a map branch node with a variable number of
+// node slots and indexed using a bitmap. Indexes for the node slots are
+// calculated by counting the number of set bits before the target bit using popcount.
+type mapBitmapIndexedNode struct {
+ bitmap uint32
+ nodes []mapNode
+}
+
+// get returns the value for the given key.
+func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+ bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
+ if (n.bitmap & bit) == 0 {
+ return nil, false
+ }
+ child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))]
+ return child.get(key, shift+mapNodeBits, keyHash, h)
+}
+
+// set inserts or updates the value for the given key. If a new key is inserted
+// and the size crosses the max size threshold then a hash array node is returned.
+func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode {
+ // Extract the index for the bit segment of the key hash.
+ keyHashFrag := (keyHash >> shift) & mapNodeMask
+
+ // Determine the bit based on the hash index.
+ bit := uint32(1) << keyHashFrag
+ exists := (n.bitmap & bit) != 0
+
+ // Mark as resized if the key doesn't exist.
+ if !exists {
+ *resized = true
+ }
+
+ // Find index of node based on popcount of bits before it.
+ idx := bits.OnesCount32(n.bitmap & (bit - 1))
+
+ // If the node already exists, delegate set operation to it.
+ // If the node doesn't exist then create a simple value leaf node.
+ var newNode mapNode
+ if exists {
+ newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, resized)
+ } else {
+ newNode = newMapValueNode(keyHash, key, value)
+ }
+
+ // Convert to a hash-array node once we exceed the max bitmap size.
+ // Copy each node based on their bit position within the bitmap.
+ if !exists && len(n.nodes) > maxBitmapIndexedSize {
+ var other mapHashArrayNode
+ for i := uint(0); i < uint(len(other.nodes)); i++ {
+ if n.bitmap&(uint32(1)<<i) != 0 {
+ other.nodes[i] = n.nodes[other.count]
+ other.count++
+ }
+ }
+ other.nodes[keyHashFrag] = newNode
+ other.count++
+ return &other
+ }
+
+ // If node exists at given slot then overwrite it with new node.
+ // Otherwise expand the node list and insert new node into appropriate position.
+ other := &mapBitmapIndexedNode{bitmap: n.bitmap | bit}
+ if exists {
+ other.nodes = make([]mapNode, len(n.nodes))
+ copy(other.nodes, n.nodes)
+ other.nodes[idx] = newNode
+ } else {
+ other.nodes = make([]mapNode, len(n.nodes)+1)
+ copy(other.nodes, n.nodes[:idx])
+ other.nodes[idx] = newNode
+ copy(other.nodes[idx+1:], n.nodes[idx:])
+ }
+ return other
+}
+
+// delete removes the key from the tree. If the key does not exist then the
+// original node is returned. If removing the last child node then a nil is
+// returned. Note that shrinking the node will not convert it to an array node.
+func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode {
+ bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
+
+ // Return original node if key does not exist.
+ if (n.bitmap & bit) == 0 {
+ return n
+ }
+
+ // Find index of node based on popcount of bits before it.
+ idx := bits.OnesCount32(n.bitmap & (bit - 1))
+
+ // Delegate delete to child node.
+ child := n.nodes[idx]
+ newChild := child.delete(key, shift+mapNodeBits, keyHash, h)
+
+ // Return original node if key doesn't exist in child.
+ if newChild == child {
+ return n
+ }
+
+ // Remove if returned child has been deleted.
+ if newChild == nil {
+ // If we won't have any children then return nil.
+ if len(n.nodes) == 1 {
+ return nil
+ }
+
+ // Return copy with bit removed from bitmap and node removed from node list.
+ other := &mapBitmapIndexedNode{bitmap: n.bitmap ^ bit, nodes: make([]mapNode, len(n.nodes)-1)}
+ copy(other.nodes[:idx], n.nodes[:idx])
+ copy(other.nodes[idx:], n.nodes[idx+1:])
+ return other
+ }
+
+ // Return copy with child updated.
+ other := &mapBitmapIndexedNode{bitmap: n.bitmap, nodes: make([]mapNode, len(n.nodes))}
+ copy(other.nodes, n.nodes)
+ other.nodes[idx] = newChild
+ return other
+}
+
+// mapHashArrayNode is a map branch node that stores nodes in a fixed length
+// array. Child nodes are indexed by their index bit segment for the current depth.
+type mapHashArrayNode struct {
+ count uint // number of set nodes
+ nodes [mapNodeSize]mapNode // child node slots, may contain empties
+}
+
+// get returns the value for the given key.
+func (n *mapHashArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+ node := n.nodes[(keyHash>>shift)&mapNodeMask]
+ if node == nil {
+ return nil, false
+ }
+ return node.get(key, shift+mapNodeBits, keyHash, h)
+}
+
+// set returns a node with the value set for the given key.
+func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode {
+ idx := (keyHash >> shift) & mapNodeMask
+ node := n.nodes[idx]
+
+ // If node at index doesn't exist, create a simple value leaf node.
+ // Otherwise delegate set to child node.
+ var newNode mapNode
+ if node == nil {
+ *resized = true
+ newNode = newMapValueNode(keyHash, key, value)
+ } else {
+ newNode = node.set(key, value, shift+mapNodeBits, keyHash, h, resized)
+ }
+
+ // Return a copy of node with updated child node (and updated size, if new).
+ other := *n
+ if node == nil {
+ other.count++
+ }
+ other.nodes[idx] = newNode
+ return &other
+}
+
+// delete returns a node with the given key removed. Returns the same node if
+// the key does not exist. If node shrinks to within bitmap-indexed size then
+// converts to a bitmap-indexed node.
+func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode {
+ idx := (keyHash >> shift) & mapNodeMask
+ node := n.nodes[idx]
+
+ // Return original node if child is not found.
+ if node == nil {
+ return n
+ }
+
+ // Return original node if child is unchanged.
+ newNode := node.delete(key, shift+mapNodeBits, keyHash, h)
+ if newNode == node {
+ return n
+ }
+
+ // If we remove a node and drop below a threshold, convert back to bitmap indexed node.
+ if newNode == nil && n.count <= maxBitmapIndexedSize {
+ other := &mapBitmapIndexedNode{nodes: make([]mapNode, 0, n.count-1)}
+ for i, child := range n.nodes {
+ if child != nil && uint32(i) != idx {
+ other.bitmap |= 1 << uint(i)
+ other.nodes = append(other.nodes, child)
+ }
+ }
+ return other
+ }
+
+ // Return copy of node with child updated.
+ other := *n
+ other.nodes[idx] = newNode
+ if newNode == nil {
+ other.count--
+ }
+ return &other
+}
+
+// mapValueNode represents a leaf node with a single key/value pair.
+// A value node can be converted to a hash collision leaf node if a different
+// key with the same keyHash is inserted.
+type mapValueNode struct {
+ keyHash uint32
+ key interface{}
+ value interface{}
+}
+
+// newMapValueNode returns a new instance of mapValueNode.
+func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode {
+ return &mapValueNode{
+ keyHash: keyHash,
+ key: key,
+ value: value,
+ }
+}
+
+// keyHashValue returns the key hash for this node.
+func (n *mapValueNode) keyHashValue() uint32 {
+ return n.keyHash
+}
+
+// get returns the value for the given key.
+func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+ if !h.Equal(n.key, key) {
+ return nil, false
+ }
+ return n.value, true
+}
+
+// set returns a new node with the new value set for the key. If the key equals
+// the node's key then a new value node is returned. If key is not equal to the
+// node's key but has the same hash then a hash collision node is returned.
+// Otherwise the nodes are merged into a branch node.
+func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode {
+ // If the keys match then return a new value node overwriting the value.
+ if h.Equal(n.key, key) {
+ return newMapValueNode(n.keyHash, key, value)
+ }
+
+ *resized = true
+
+ // Recursively merge nodes together if key hashes are different.
+ if n.keyHash != keyHash {
+ return mergeIntoNode(n, shift, keyHash, key, value)
+ }
+
+ // Merge into collision node if hash matches.
+ return &mapHashCollisionNode{keyHash: keyHash, entries: []mapEntry{
+ {key: n.key, value: n.value},
+ {key: key, value: value},
+ }}
+}
+
+// delete returns nil if the key matches the node's key. Otherwise returns the original node.
+func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode {
+ // Return original node if the keys do not match.
+ if !h.Equal(n.key, key) {
+ return n
+ }
+
+ // Otherwise remove the node if keys do match.
+ return nil
+}
+
+// mapHashCollisionNode represents a leaf node that contains two or more key/value
+// pairs with the same key hash. Single pairs for a hash are stored as value nodes.
+type mapHashCollisionNode struct {
+ keyHash uint32 // key hash for all entries
+ entries []mapEntry
+}
+
+// keyHashValue returns the key hash for all entries on the node.
+func (n *mapHashCollisionNode) keyHashValue() uint32 {
+ return n.keyHash
+}
+
+// indexOf returns the index of the entry for the given key.
+// Returns -1 if the key does not exist in the node.
+func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int {
+ for i := range n.entries {
+ if h.Equal(n.entries[i].key, key) {
+ return i
+ }
+ }
+ return -1
+}
+
+// get returns the value for the given key.
+func (n *mapHashCollisionNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) {
+ for i := range n.entries {
+ if h.Equal(n.entries[i].key, key) {
+ return n.entries[i].value, true
+ }
+ }
+ return nil, false
+}
+
+// set returns a copy of the node with key set to the given value.
+func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode {
+ // Merge node with key/value pair if this is not a hash collision.
+ if n.keyHash != keyHash {
+ *resized = true
+ return mergeIntoNode(n, shift, keyHash, key, value)
+ }
+
+ // Append to end of node if key doesn't exist & mark resized.
+ // Otherwise copy nodes and overwrite at matching key index.
+ other := &mapHashCollisionNode{keyHash: n.keyHash}
+ if idx := n.indexOf(key, h); idx == -1 {
+ *resized = true
+ other.entries = make([]mapEntry, len(n.entries)+1)
+ copy(other.entries, n.entries)
+ other.entries[len(other.entries)-1] = mapEntry{key, value}
+ } else {
+ other.entries = make([]mapEntry, len(n.entries))
+ copy(other.entries, n.entries)
+ other.entries[idx] = mapEntry{key, value}
+ }
+ return other
+}
+
+// delete returns a node with the given key deleted. Returns the same node if
+// the key does not exist. If removing the key would shrink the node to a single
+// entry then a value node is returned.
+func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode {
+ idx := n.indexOf(key, h)
+
+ // Return original node if key is not found.
+ if idx == -1 {
+ return n
+ }
+
+ // Convert to value node if we move to one entry.
+ if len(n.entries) == 2 {
+ return &mapValueNode{
+ keyHash: n.keyHash,
+ key: n.entries[idx^1].key,
+ value: n.entries[idx^1].value,
+ }
+ }
+
+ // Otherwise return copy with entry removed.
+ other := &mapHashCollisionNode{keyHash: n.keyHash, entries: make([]mapEntry, len(n.entries)-1)}
+ copy(other.entries[:idx], n.entries[:idx])
+ copy(other.entries[idx:], n.entries[idx+1:])
+ return other
+}
+
+// mergeIntoNode merges a key/value pair into an existing node.
+// Caller must verify that node's keyHash is not equal to keyHash.
+func mergeIntoNode(node mapLeafNode, shift uint, keyHash uint32, key, value interface{}) mapNode {
+ idx1 := (node.keyHashValue() >> shift) & mapNodeMask
+ idx2 := (keyHash >> shift) & mapNodeMask
+
+ // Recursively build branch nodes to combine the node and its key.
+ other := &mapBitmapIndexedNode{bitmap: (1 << idx1) | (1 << idx2)}
+ if idx1 == idx2 {
+ other.nodes = []mapNode{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)}
+ } else {
+ if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 {
+ other.nodes = []mapNode{node, newNode}
+ } else {
+ other.nodes = []mapNode{newNode, node}
+ }
+ }
+ return other
+}
+
+// mapEntry represents a single key/value pair.
+type mapEntry struct {
+ key interface{}
+ value interface{}
+}
+
+// MapIterator represents an iterator over a map's key/value pairs. Although
+// map keys are not sorted, the iterator's order is deterministic.
+type MapIterator struct {
+ m *Map // source map
+
+ stack [32]mapIteratorElem // search stack
+ depth int // stack depth
+}
+
+// Done returns true if no more elements remain in the iterator.
+func (itr *MapIterator) Done() bool {
+ return itr.depth == -1
+}
+
+// First resets the iterator to the first key/value pair.
+func (itr *MapIterator) First() {
+ // Exit immediately if the map is empty.
+ if itr.m.root == nil {
+ itr.depth = -1
+ return
+ }
+
+ // Initialize the stack to the left most element.
+ itr.stack[0] = mapIteratorElem{node: itr.m.root}
+ itr.depth = 0
+ itr.first()
+}
+
+// Next returns the next key/value pair. Returns a nil key when no elements remain.
+func (itr *MapIterator) Next() (key, value interface{}) {
+ // Return nil key if iteration is done.
+ if itr.Done() {
+ return nil, nil
+ }
+
+ // Retrieve current index & value. Current node is always a leaf.
+ elem := &itr.stack[itr.depth]
+ switch node := elem.node.(type) {
+ case *mapArrayNode:
+ entry := &node.entries[elem.index]
+ key, value = entry.key, entry.value
+ case *mapValueNode:
+ key, value = node.key, node.value
+ case *mapHashCollisionNode:
+ entry := &node.entries[elem.index]
+ key, value = entry.key, entry.value
+ }
+
+ // Move up stack until we find a node that has remaining position ahead
+ // and move that element forward by one.
+ for ; itr.depth >= 0; itr.depth-- {
+ elem := &itr.stack[itr.depth]
+
+ switch node := elem.node.(type) {
+ case *mapArrayNode:
+ if elem.index < len(node.entries)-1 {
+ elem.index++
+ return key, value
+ }
+
+ case *mapBitmapIndexedNode:
+ if elem.index < len(node.nodes)-1 {
+ elem.index++
+ itr.stack[itr.depth+1].node = node.nodes[elem.index]
+ itr.depth++
+ itr.first()
+ return key, value
+ }
+
+ case *mapHashArrayNode:
+ for i := elem.index + 1; i < len(node.nodes); i++ {
+ if node.nodes[i] != nil {
+ elem.index = i
+ itr.stack[itr.depth+1].node = node.nodes[elem.index]
+ itr.depth++
+ itr.first()
+ return key, value
+ }
+ }
+
+ case *mapValueNode:
+ continue // always the last value, traverse up
+
+ case *mapHashCollisionNode:
+ if elem.index < len(node.entries)-1 {
+ elem.index++
+ return key, value
+ }
+ }
+ }
+
+ // This only occurs if depth is -1.
+ return key, value
+}
+
+// first positions the stack left most index.
+// Elements and indexes at and below the current depth are assumed to be correct.
+func (itr *MapIterator) first() {
+ for ; ; itr.depth++ {
+ elem := &itr.stack[itr.depth]
+
+ switch node := elem.node.(type) {
+ case *mapBitmapIndexedNode:
+ elem.index = 0
+ itr.stack[itr.depth+1].node = node.nodes[0]
+
+ case *mapHashArrayNode:
+ for i := 0; i < len(node.nodes); i++ {
+ if node.nodes[i] != nil { // find first node
+ elem.index = i
+ itr.stack[itr.depth+1].node = node.nodes[i]
+ break
+ }
+ }
+
+ default: // *mapArrayNode, mapLeafNode
+ elem.index = 0
+ return
+ }
+ }
+}
+
+// mapIteratorElem represents a node/index pair in the MapIterator stack.
+type mapIteratorElem struct {
+ node mapNode
+ index int
+}
+
+// Sorted map child node limit size.
+const (
+ sortedMapNodeSize = 32
+)
+
+// SortedMap represents a map of key/value pairs sorted by key. The sort order
+// is determined by the Comparer used by the map.
+//
+// This map is implemented as a B+tree.
+type SortedMap struct {
+ size int // total number of key/value pairs
+ root sortedMapNode // root of b+tree
+ comparer Comparer
+}
+
+// NewSortedMap returns a new instance of SortedMap. If comparer is nil then
+// a default comparer is set after the first key is inserted. Default comparers
+// exist for int, string, and byte slice keys.
+func NewSortedMap(comparer Comparer) *SortedMap {
+ return &SortedMap{
+ comparer: comparer,
+ }
+}
+
+// Len returns the number of elements in the sorted map.
+func (m *SortedMap) Len() int {
+ return m.size
+}
+
+// Get returns the value for a given key and a flag indicating if the key is set.
+// The flag can be used to distinguish between a nil-set key versus an unset key.
+func (m *SortedMap) Get(key interface{}) (interface{}, bool) {
+ if m.root == nil {
+ return nil, false
+ }
+ return m.root.get(key, m.comparer)
+}
+
+// Set returns a copy of the map with the key set to the given value.
+func (m *SortedMap) Set(key, value interface{}) *SortedMap {
+ // Set a comparer on the first value if one does not already exist.
+ comparer := m.comparer
+ if comparer == nil {
+ switch key.(type) {
+ case int:
+ comparer = &intComparer{}
+ case string:
+ comparer = &stringComparer{}
+ case []byte:
+ comparer = &byteSliceComparer{}
+ default:
+ panic(fmt.Sprintf("immutable.SortedMap.Set: must set comparer for %T type", key))
+ }
+ }
+
+ // If no values are set then initialize with a leaf node.
+ if m.root == nil {
+ return &SortedMap{
+ size: 1,
+ root: &sortedMapLeafNode{entries: []mapEntry{{key: key, value: value}}},
+ comparer: comparer,
+ }
+ }
+
+ // Otherwise delegate to root node.
+ // If a split occurs then grow the tree from the root.
+ var resized bool
+ newRoot, splitNode := m.root.set(key, value, comparer, &resized)
+ if splitNode != nil {
+ newRoot = newSortedMapBranchNode(newRoot, splitNode)
+ }
+
+ // Return a new map with the new root.
+ other := &SortedMap{
+ size: m.size,
+ root: newRoot,
+ comparer: comparer,
+ }
+ if resized {
+ other.size++
+ }
+ return other
+}
+
+// Delete returns a copy of the map with the key removed.
+// Returns the original map if key does not exist.
+func (m *SortedMap) Delete(key interface{}) *SortedMap {
+ // Return original map if no keys exist.
+ if m.root == nil {
+ return m
+ }
+
+ // If the delete did not change the node then return the original map.
+ newRoot := m.root.delete(key, m.comparer)
+ if newRoot == m.root {
+ return m
+ }
+
+ // Return new copy with the root and size updated.
+ return &SortedMap{
+ size: m.size - 1,
+ root: newRoot,
+ comparer: m.comparer,
+ }
+}
+
+// Iterator returns a new iterator for this map positioned at the first key.
+func (m *SortedMap) Iterator() *SortedMapIterator {
+ itr := &SortedMapIterator{m: m}
+ itr.First()
+ return itr
+}
+
+// sortedMapNode represents a branch or leaf node in the sorted map.
+type sortedMapNode interface {
+ minKey() interface{}
+ indexOf(key interface{}, c Comparer) int
+ get(key interface{}, c Comparer) (value interface{}, ok bool)
+ set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode)
+ delete(key interface{}, c Comparer) sortedMapNode
+}
+
+var _ sortedMapNode = (*sortedMapBranchNode)(nil)
+var _ sortedMapNode = (*sortedMapLeafNode)(nil)
+
+// sortedMapBranchNode represents a branch in the sorted map.
+type sortedMapBranchNode struct {
+ elems []sortedMapBranchElem
+}
+
+// newSortedMapBranchNode returns a new branch node with the given child nodes.
+func newSortedMapBranchNode(children ...sortedMapNode) *sortedMapBranchNode {
+ // Fetch min keys for every child.
+ elems := make([]sortedMapBranchElem, len(children))
+ for i, child := range children {
+ elems[i] = sortedMapBranchElem{
+ key: child.minKey(),
+ node: child,
+ }
+ }
+
+ return &sortedMapBranchNode{elems: elems}
+}
+
+// minKey returns the lowest key stored in this node's tree.
+func (n *sortedMapBranchNode) minKey() interface{} {
+ return n.elems[0].node.minKey()
+}
+
+// indexOf returns the index of the key within the child nodes.
+func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int {
+ if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 {
+ return idx - 1
+ }
+ return 0
+}
+
+// get returns the value for the given key.
+func (n *sortedMapBranchNode) get(key interface{}, c Comparer) (value interface{}, ok bool) {
+ idx := n.indexOf(key, c)
+ return n.elems[idx].node.get(key, c)
+}
+
+// set returns a copy of the node with the key set to the given value.
+func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode) {
+ idx := n.indexOf(key, c)
+
+ // Delegate insert to child node.
+ newNode, splitNode := n.elems[idx].node.set(key, value, c, resized)
+
+ // If no split occurs, copy branch and update keys.
+ // If the child splits, insert new key/child into copy of branch.
+ var other sortedMapBranchNode
+ if splitNode == nil {
+ other.elems = make([]sortedMapBranchElem, len(n.elems))
+ copy(other.elems, n.elems)
+ other.elems[idx] = sortedMapBranchElem{
+ key: newNode.minKey(),
+ node: newNode,
+ }
+ } else {
+ other.elems = make([]sortedMapBranchElem, len(n.elems)+1)
+ copy(other.elems[:idx], n.elems[:idx])
+ copy(other.elems[idx+1:], n.elems[idx:])
+ other.elems[idx] = sortedMapBranchElem{
+ key: newNode.minKey(),
+ node: newNode,
+ }
+ other.elems[idx+1] = sortedMapBranchElem{
+ key: splitNode.minKey(),
+ node: splitNode,
+ }
+ }
+
+ // If the child splits and we have no more room then we split too.
+ if len(other.elems) > sortedMapNodeSize {
+ splitIdx := len(other.elems) / 2
+ newNode := &sortedMapBranchNode{elems: other.elems[:splitIdx]}
+ splitNode := &sortedMapBranchNode{elems: other.elems[splitIdx:]}
+ return newNode, splitNode
+ }
+
+ // Otherwise return the new branch node with the updated entry.
+ return &other, nil
+}
+
+// delete returns a node with the key removed. Returns the same node if the key
+// does not exist. Returns nil if all child nodes are removed.
+func (n *sortedMapBranchNode) delete(key interface{}, c Comparer) sortedMapNode {
+ idx := n.indexOf(key, c)
+
+ // Return original node if child has not changed.
+ newNode := n.elems[idx].node.delete(key, c)
+ if newNode == n.elems[idx].node {
+ return n
+ }
+
+ // Remove child if it is now nil.
+ if newNode == nil {
+ // If this node will become empty then simply return nil.
+ if len(n.elems) == 1 {
+ return nil
+ }
+
+ // Return a copy without the given node.
+ other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems)-1)}
+ copy(other.elems[:idx], n.elems[:idx])
+ copy(other.elems[idx:], n.elems[idx+1:])
+ return other
+ }
+
+ // Return a copy with the updated node.
+ other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems))}
+ copy(other.elems, n.elems)
+ other.elems[idx] = sortedMapBranchElem{
+ key: newNode.minKey(),
+ node: newNode,
+ }
+ return other
+}
+
+type sortedMapBranchElem struct {
+ key interface{}
+ node sortedMapNode
+}
+
+// sortedMapLeafNode represents a leaf node in the sorted map.
+type sortedMapLeafNode struct {
+ entries []mapEntry
+}
+
+// minKey returns the first key stored in this node.
+func (n *sortedMapLeafNode) minKey() interface{} {
+ return n.entries[0].key
+}
+
+// indexOf returns the index of the given key.
+func (n *sortedMapLeafNode) indexOf(key interface{}, c Comparer) int {
+ return sort.Search(len(n.entries), func(i int) bool {
+ return c.Compare(n.entries[i].key, key) != -1 // GTE
+ })
+}
+
+// get returns the value of the given key.
+func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{}, ok bool) {
+ idx := n.indexOf(key, c)
+
+ // If the index is beyond the entry count or the key is not equal then return 'not found'.
+ if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 {
+ return nil, false
+ }
+
+ // If the key matches then return its value.
+ return n.entries[idx].value, true
+}
+
+// set returns a copy of node with the key set to the given value. If the update
+// causes the node to grow beyond the maximum size then it is split in two.
+func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode) {
+ // Find the insertion index for the key.
+ idx := n.indexOf(key, c)
+
+ // If the key matches then simply return a copy with the entry overridden.
+ // If there is no match then insert new entry and mark as resized.
+ var newEntries []mapEntry
+ if idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0 {
+ newEntries = make([]mapEntry, len(n.entries))
+ copy(newEntries, n.entries)
+ newEntries[idx] = mapEntry{key: key, value: value}
+ } else {
+ *resized = true
+ newEntries = make([]mapEntry, len(n.entries)+1)
+ copy(newEntries[:idx], n.entries[:idx])
+ newEntries[idx] = mapEntry{key: key, value: value}
+ copy(newEntries[idx+1:], n.entries[idx:])
+ }
+
+ // If the key doesn't exist and we exceed our max allowed values then split.
+ if len(newEntries) > sortedMapNodeSize {
+ newNode := &sortedMapLeafNode{entries: newEntries[:len(newEntries)/2]}
+ splitNode := &sortedMapLeafNode{entries: newEntries[len(newEntries)/2:]}
+ return newNode, splitNode
+ }
+
+ // Otherwise return the new leaf node with the updated entry.
+ return &sortedMapLeafNode{entries: newEntries}, nil
+}
+
+// delete returns a copy of node with key removed. Returns the original node if
+// the key does not exist. Returns nil if the removed key is the last remaining key.
+func (n *sortedMapLeafNode) delete(key interface{}, c Comparer) sortedMapNode {
+ idx := n.indexOf(key, c)
+
+ // Return original node if key is not found.
+ if idx >= len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 {
+ return n
+ }
+
+ // If this is the last entry then return nil.
+ if len(n.entries) == 1 {
+ return nil
+ }
+
+ // Return copy of node with entry removed.
+ other := &sortedMapLeafNode{entries: make([]mapEntry, len(n.entries)-1)}
+ copy(other.entries[:idx], n.entries[:idx])
+ copy(other.entries[idx:], n.entries[idx+1:])
+ return other
+}
+
+// SortedMapIterator represents an iterator over a sorted map.
+// Iteration can occur in natural or reverse order based on use of Next() or Prev().
+type SortedMapIterator struct {
+ m *SortedMap // source map
+
+ stack [32]sortedMapIteratorElem // search stack
+ depth int // stack depth
+}
+
+// Done returns true if no more key/value pairs remain in the iterator.
+func (itr *SortedMapIterator) Done() bool {
+ return itr.depth == -1
+}
+
+// First moves the iterator to the first key/value pair.
+func (itr *SortedMapIterator) First() {
+ if itr.m.root != nil {
+ itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.depth = 0
+ itr.first()
+ }
+}
+
+// Last moves the iterator to the last key/value pair.
+func (itr *SortedMapIterator) Last() {
+ if itr.m.root != nil {
+ itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.depth = 0
+ itr.last()
+ }
+}
+
+// Seek moves the iterator position to the given key in the map.
+// If the key does not exist then the next key is used. If no more keys exist
+// then the iteartor is marked as done.
+func (itr *SortedMapIterator) Seek(key interface{}) {
+ if itr.m.root != nil {
+ itr.stack[0] = sortedMapIteratorElem{node: itr.m.root}
+ itr.depth = 0
+ itr.seek(key)
+ }
+}
+
+// Next returns the current key/value pair and moves the iterator forward.
+// Returns a nil key if the there are no more elements to return.
+func (itr *SortedMapIterator) Next() (key, value interface{}) {
+ // Return nil key if iteration is complete.
+ if itr.Done() {
+ return nil, nil
+ }
+
+ // Retrieve current key/value pair.
+ leafElem := &itr.stack[itr.depth]
+ leafNode := leafElem.node.(*sortedMapLeafNode)
+ leafEntry := &leafNode.entries[leafElem.index]
+ key, value = leafEntry.key, leafEntry.value
+
+ // Move to the next available key/value pair.
+ itr.next()
+
+ // Only occurs when iterator is done.
+ return key, value
+}
+
+// next moves to the next key. If no keys are after then depth is set to -1.
+func (itr *SortedMapIterator) next() {
+ for ; itr.depth >= 0; itr.depth-- {
+ elem := &itr.stack[itr.depth]
+
+ switch node := elem.node.(type) {
+ case *sortedMapLeafNode:
+ if elem.index < len(node.entries)-1 {
+ elem.index++
+ return
+ }
+ case *sortedMapBranchNode:
+ if elem.index < len(node.elems)-1 {
+ elem.index++
+ itr.stack[itr.depth+1].node = node.elems[elem.index].node
+ itr.depth++
+ itr.first()
+ return
+ }
+ }
+ }
+}
+
+// Prev returns the current key/value pair and moves the iterator backward.
+// Returns a nil key if the there are no more elements to return.
+func (itr *SortedMapIterator) Prev() (key, value interface{}) {
+ // Return nil key if iteration is complete.
+ if itr.Done() {
+ return nil, nil
+ }
+
+ // Retrieve current key/value pair.
+ leafElem := &itr.stack[itr.depth]
+ leafNode := leafElem.node.(*sortedMapLeafNode)
+ leafEntry := &leafNode.entries[leafElem.index]
+ key, value = leafEntry.key, leafEntry.value
+
+ itr.prev()
+ return key, value
+}
+
+// prev moves to the previous key. If no keys are before then depth is set to -1.
+func (itr *SortedMapIterator) prev() {
+ for ; itr.depth >= 0; itr.depth-- {
+ elem := &itr.stack[itr.depth]
+
+ switch node := elem.node.(type) {
+ case *sortedMapLeafNode:
+ if elem.index > 0 {
+ elem.index--
+ return
+ }
+ case *sortedMapBranchNode:
+ if elem.index > 0 {
+ elem.index--
+ itr.stack[itr.depth+1].node = node.elems[elem.index].node
+ itr.depth++
+ itr.last()
+ return
+ }
+ }
+ }
+}
+
+// first positions the stack to the leftmost key from the current depth.
+// Elements and indexes below the current depth are assumed to be correct.
+func (itr *SortedMapIterator) first() {
+ for {
+ elem := &itr.stack[itr.depth]
+ elem.index = 0
+
+ switch node := elem.node.(type) {
+ case *sortedMapBranchNode:
+ itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ itr.depth++
+ case *sortedMapLeafNode:
+ return
+ }
+ }
+}
+
+// last positions the stack to the rightmost key from the current depth.
+// Elements and indexes below the current depth are assumed to be correct.
+func (itr *SortedMapIterator) last() {
+ for {
+ elem := &itr.stack[itr.depth]
+
+ switch node := elem.node.(type) {
+ case *sortedMapBranchNode:
+ elem.index = len(node.elems) - 1
+ itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ itr.depth++
+ case *sortedMapLeafNode:
+ elem.index = len(node.entries) - 1
+ return
+ }
+ }
+}
+
+// seek positions the stack to the given key from the current depth.
+// Elements and indexes below the current depth are assumed to be correct.
+func (itr *SortedMapIterator) seek(key interface{}) {
+ for {
+ elem := &itr.stack[itr.depth]
+ elem.index = elem.node.indexOf(key, itr.m.comparer)
+
+ switch node := elem.node.(type) {
+ case *sortedMapBranchNode:
+ itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node}
+ itr.depth++
+ case *sortedMapLeafNode:
+ if elem.index == len(node.entries) {
+ itr.next()
+ }
+ return
+ }
+ }
+}
+
+// sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack.
+type sortedMapIteratorElem struct {
+ node sortedMapNode
+ index int
+}
+
+// Hasher hashes keys and checks them for equality.
+type Hasher interface {
+ // Computes a 32-bit hash for key.
+ Hash(key interface{}) uint32
+
+ // Returns true if a and b are equal.
+ Equal(a, b interface{}) bool
+}
+
+// intHasher implements Hasher for int keys.
+type intHasher struct{}
+
+// Hash returns a hash for key.
+func (h *intHasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(int)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not ints.
+func (h *intHasher) Equal(a, b interface{}) bool {
+ return a.(int) == b.(int)
+}
+
+// stringHasher implements Hasher for string keys.
+type stringHasher struct{}
+
+// Hash returns a hash for value.
+func (h *stringHasher) Hash(value interface{}) uint32 {
+ var hash uint32
+ for i, value := 0, value.(string); i < len(value); i++ {
+ hash = 31*hash + uint32(value[i])
+ }
+ return hash
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not strings.
+func (h *stringHasher) Equal(a, b interface{}) bool {
+ return a.(string) == b.(string)
+}
+
+// byteSliceHasher implements Hasher for string keys.
+type byteSliceHasher struct{}
+
+// Hash returns a hash for value.
+func (h *byteSliceHasher) Hash(value interface{}) uint32 {
+ var hash uint32
+ for i, value := 0, value.([]byte); i < len(value); i++ {
+ hash = 31*hash + uint32(value[i])
+ }
+ return hash
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not byte slices.
+func (h *byteSliceHasher) Equal(a, b interface{}) bool {
+ return bytes.Equal(a.([]byte), b.([]byte))
+}
+
+// hashUint64 returns a 32-bit hash for a 64-bit value.
+func hashUint64(value uint64) uint32 {
+ hash := value
+ for value > 0xffffffff {
+ value /= 0xffffffff
+ hash ^= value
+ }
+ return uint32(hash)
+}
+
+// Comparer allows the comparison of two keys for the purpose of sorting.
+type Comparer interface {
+ // Returns -1 if a is less than b, returns 1 if a is greater than b,
+ // and returns 0 if a is equal to b.
+ Compare(a, b interface{}) int
+}
+
+// intComparer compares two integers. Implements Comparer.
+type intComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int.
+func (c *intComparer) Compare(a, b interface{}) int {
+ if i, j := a.(int), b.(int); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// stringComparer compares two strings. Implements Comparer.
+type stringComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not a string.
+func (c *stringComparer) Compare(a, b interface{}) int {
+ return strings.Compare(a.(string), b.(string))
+}
+
+// byteSliceComparer compares two byte slices. Implements Comparer.
+type byteSliceComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not a byte slice.
+func (c *byteSliceComparer) Compare(a, b interface{}) int {
+ return bytes.Compare(a.([]byte), b.([]byte))
+}
diff --git a/immutable_test.go b/immutable_test.go
new file mode 100644
index 0000000..a0c7821
--- /dev/null
+++ b/immutable_test.go
@@ -0,0 +1,2038 @@
+package immutable
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "sort"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+var (
+ veryVerbose = flag.Bool("vv", false, "very verbose")
+ randomN = flag.Int("random.n", 100, "number of RunRandom() iterations")
+)
+
+func TestList(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ if size := NewList().Len(); size != 0 {
+ t.Fatalf("unexpected size: %d", size)
+ }
+ })
+
+ t.Run("Shallow", func(t *testing.T) {
+ list := NewList()
+ list = list.Append("foo")
+ if v := list.Get(0); v != "foo" {
+ t.Fatalf("unexpected value: %v", v)
+ }
+
+ other := list.Append("bar")
+ if v := other.Get(0); v != "foo" {
+ t.Fatalf("unexpected value: %v", v)
+ } else if v := other.Get(1); v != "bar" {
+ t.Fatalf("unexpected value: %v", v)
+ }
+
+ if v := list.Len(); v != 1 {
+ t.Fatalf("unexpected value: %v", v)
+ }
+ })
+
+ t.Run("Deep", func(t *testing.T) {
+ list := NewList()
+ var array []int
+ for i := 0; i < 100000; i++ {
+ list = list.Append(i)
+ array = append(array, i)
+ }
+
+ if got, exp := len(array), list.Len(); got != exp {
+ t.Fatalf("List.Len()=%d, exp %d", got, exp)
+ }
+ for j := range array {
+ if got, exp := list.Get(j).(int), array[j]; got != exp {
+ t.Fatalf("%d. List.Get(%d)=%d, exp %d", len(array), j, got, exp)
+ }
+ }
+ })
+
+ t.Run("Set", func(t *testing.T) {
+ list := NewList()
+ list = list.Append("foo")
+ list = list.Append("bar")
+
+ if v := list.Get(0); v != "foo" {
+ t.Fatalf("unexpected value: %v", v)
+ }
+
+ list = list.Set(0, "baz")
+ if v := list.Get(0); v != "baz" {
+ t.Fatalf("unexpected value: %v", v)
+ } else if v := list.Get(1); v != "bar" {
+ t.Fatalf("unexpected value: %v", v)
+ }
+ })
+
+ t.Run("GetBelowRange", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Get(-1)
+ }()
+ if r != `immutable.List.Get: index -1 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("GetAboveRange", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Get(1)
+ }()
+ if r != `immutable.List.Get: index 1 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("SetOutOfRange", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Set(1, "bar")
+ }()
+ if r != `immutable.List.Set: index 1 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("SliceStartOutOfRange", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Slice(2, 3)
+ }()
+ if r != `immutable.List.Slice: start index 2 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("SliceEndOutOfRange", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Slice(1, 3)
+ }()
+ if r != `immutable.List.Slice: end index 3 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("SliceInvalidIndex", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l.Slice(2, 1)
+ }()
+ if r != `immutable.List.Slice: invalid slice index: [2:1]` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ t.Run("SliceBeginning", func(t *testing.T) {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Slice(1, 2)
+ if got, exp := l.Len(), 1; got != exp {
+ t.Fatalf("List.Len()=%d, exp %d", got, exp)
+ } else if got, exp := l.Get(0), "bar"; got != exp {
+ t.Fatalf("List.Get(0)=%v, exp %v", got, exp)
+ }
+ })
+
+ t.Run("IteratorSeekOutOfBounds", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ l := NewList()
+ l = l.Append("foo")
+ l.Iterator().Seek(-1)
+ }()
+ if r != `immutable.ListIterator.Seek: index -1 out of bounds` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) {
+ l := NewTList()
+ for i := 0; i < 100000; i++ {
+ rnd := rand.Intn(70)
+ switch {
+ case rnd == 0: // slice
+ start, end := l.ChooseSliceIndices(rand)
+ l.Slice(start, end)
+ case rnd < 10: // set
+ if l.Len() > 0 {
+ l.Set(l.ChooseIndex(rand), rand.Intn(10000))
+ }
+ case rnd < 30: // prepend
+ l.Prepend(rand.Intn(10000))
+ default: // append
+ l.Append(rand.Intn(10000))
+ }
+
+ if err := l.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := l.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
+// TList represents a list that operates on a standard Go slice & immutable list.
+type TList struct {
+ im, prev *List
+ std []int
+}
+
+// NewTList returns a new instance of TList.
+func NewTList() *TList {
+ return &TList{
+ im: NewList(),
+ }
+}
+
+// Len returns the size of the list.
+func (l *TList) Len() int {
+ return len(l.std)
+}
+
+// ChooseIndex returns a randomly chosen, valid index from the standard slice.
+func (l *TList) ChooseIndex(rand *rand.Rand) int {
+ if len(l.std) == 0 {
+ return -1
+ }
+ return rand.Intn(len(l.std))
+}
+
+// ChooseSliceIndices returns randomly chosen, valid indices for slicing.
+func (l *TList) ChooseSliceIndices(rand *rand.Rand) (start, end int) {
+ if len(l.std) == 0 {
+ return 0, 0
+ }
+ start = rand.Intn(len(l.std))
+ end = rand.Intn(len(l.std)-start) + start
+ return start, end
+}
+
+// Append adds v to the end of slice and List.
+func (l *TList) Append(v int) {
+ l.prev = l.im
+ l.im = l.im.Append(v)
+ l.std = append(l.std, v)
+}
+
+// Prepend adds v to the beginning of the slice and List.
+func (l *TList) Prepend(v int) {
+ l.prev = l.im
+ l.im = l.im.Prepend(v)
+ l.std = append([]int{v}, l.std...)
+}
+
+// Set updates the value at index i to v in the slice and List.
+func (l *TList) Set(i, v int) {
+ l.prev = l.im
+ l.im = l.im.Set(i, v)
+ l.std[i] = v
+}
+
+// Slice contracts the slice and List to the range of start/end indices.
+func (l *TList) Slice(start, end int) {
+ l.prev = l.im
+ l.im = l.im.Slice(start, end)
+ l.std = l.std[start:end]
+}
+
+// Validate returns an error if the slice and List are different.
+func (l *TList) Validate() error {
+ if got, exp := len(l.std), l.im.Len(); got != exp {
+ return fmt.Errorf("Len()=%v, expected %d", got, exp)
+ }
+
+ for i := range l.std {
+ if got, exp := l.im.Get(i), l.std[i]; got != exp {
+ return fmt.Errorf("Get(%d)=%v, expected %v", i, got, exp)
+ }
+ }
+
+ if err := l.validateForwardIterator(); err != nil {
+ return err
+ } else if err := l.validateBackwardIterator(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (l *TList) validateForwardIterator() error {
+ itr := l.im.Iterator()
+ for i := range l.std {
+ if j, v := itr.Next(); i != j || l.std[i] != v {
+ return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v>", j, v, i, l.std[i])
+ }
+
+ done := i == len(l.std)-1
+ if v := itr.Done(); v != done {
+ return fmt.Errorf("ListIterator.Done()=%v, expected %v", v, done)
+ }
+ }
+ if i, v := itr.Next(); i != -1 || v != nil {
+ return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected DONE", i, v)
+ }
+ return nil
+}
+
+func (l *TList) validateBackwardIterator() error {
+ itr := l.im.Iterator()
+ itr.Last()
+ for i := len(l.std) - 1; i >= 0; i-- {
+ if j, v := itr.Prev(); i != j || l.std[i] != v {
+ return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected <%v,%v>", j, v, i, l.std[i])
+ }
+
+ done := i == 0
+ if v := itr.Done(); v != done {
+ return fmt.Errorf("ListIterator.Done()=%v, expected %v", v, done)
+ }
+ }
+ if i, v := itr.Prev(); i != -1 || v != nil {
+ return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected DONE", i, v)
+ }
+ return nil
+}
+
+func BenchmarkList_Append(b *testing.B) {
+ b.ReportAllocs()
+ l := NewList()
+ for i := 0; i < b.N; i++ {
+ l = l.Append(i)
+ }
+}
+
+func BenchmarkList_Prepend(b *testing.B) {
+ b.ReportAllocs()
+ l := NewList()
+ for i := 0; i < b.N; i++ {
+ l = l.Prepend(i)
+ }
+}
+
+func BenchmarkList_Set(b *testing.B) {
+ const n = 10000
+
+ l := NewList()
+ for i := 0; i < 10000; i++ {
+ l = l.Append(i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ l = l.Set(i%n, i*10)
+ }
+}
+
+func BenchmarkList_Iterator(b *testing.B) {
+ const n = 10000
+ l := NewList()
+ for i := 0; i < 10000; i++ {
+ l = l.Append(i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.Run("Forward", func(b *testing.B) {
+ itr := l.Iterator()
+ for i := 0; i < b.N; i++ {
+ if i%n == 0 {
+ itr.First()
+ }
+ itr.Next()
+ }
+ })
+
+ b.Run("Reverse", func(b *testing.B) {
+ itr := l.Iterator()
+ for i := 0; i < b.N; i++ {
+ if i%n == 0 {
+ itr.Last()
+ }
+ itr.Prev()
+ }
+ })
+}
+
+func ExampleList_Append() {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Append("baz")
+
+ fmt.Println(l.Get(0))
+ fmt.Println(l.Get(1))
+ fmt.Println(l.Get(2))
+ // Output:
+ // foo
+ // bar
+ // baz
+}
+
+func ExampleList_Prepend() {
+ l := NewList()
+ l = l.Prepend("foo")
+ l = l.Prepend("bar")
+ l = l.Prepend("baz")
+
+ fmt.Println(l.Get(0))
+ fmt.Println(l.Get(1))
+ fmt.Println(l.Get(2))
+ // Output:
+ // baz
+ // bar
+ // foo
+}
+
+func ExampleList_Set() {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Set(1, "baz")
+
+ fmt.Println(l.Get(0))
+ fmt.Println(l.Get(1))
+ // Output:
+ // foo
+ // baz
+}
+
+func ExampleList_Slice() {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Append("baz")
+ l = l.Slice(1, 3)
+
+ fmt.Println(l.Get(0))
+ fmt.Println(l.Get(1))
+ // Output:
+ // bar
+ // baz
+}
+
+func ExampleList_Iterator() {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Append("baz")
+
+ itr := l.Iterator()
+ for !itr.Done() {
+ i, v := itr.Next()
+ fmt.Println(i, v)
+ }
+ // Output:
+ // 0 foo
+ // 1 bar
+ // 2 baz
+}
+
+func ExampleList_Iterator_reverse() {
+ l := NewList()
+ l = l.Append("foo")
+ l = l.Append("bar")
+ l = l.Append("baz")
+
+ itr := l.Iterator()
+ itr.Last()
+ for !itr.Done() {
+ i, v := itr.Prev()
+ fmt.Println(i, v)
+ }
+ // Output:
+ // 2 baz
+ // 1 bar
+ // 0 foo
+}
+
+// Ensure node can support overwrites as it expands.
+func TestIngernal_mapNode_Overwrite(t *testing.T) {
+ const n = 1000
+ var h intHasher
+ var node mapNode = &mapArrayNode{}
+ for i := 0; i < n; i++ {
+ var resized bool
+ node = node.set(i, i, 0, h.Hash(i), &h, &resized)
+ if !resized {
+ t.Fatal("expected resize")
+ }
+
+ // Overwrite every node.
+ for j := 0; j <= i; j++ {
+ var resized bool
+ node = node.set(j, i*j, 0, h.Hash(j), &h, &resized)
+ if resized {
+ t.Fatalf("expected no resize: i=%d, j=%d", i, j)
+ }
+ }
+
+ // Verify not found at each branch type.
+ if _, ok := node.get(1000000, 0, h.Hash(1000000), &h); ok {
+ t.Fatal("expected no value")
+ }
+ }
+
+ // Verify all key/value pairs in map.
+ for i := 0; i < n; i++ {
+ if v, ok := node.get(i, 0, h.Hash(i), &h); !ok || v != i*(n-1) {
+ t.Fatalf("get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+}
+
+func TestIngernal_mapArrayNode(t *testing.T) {
+ // Ensure 8 or fewer elements stays in an array node.
+ t.Run("Append", func(t *testing.T) {
+ var h intHasher
+ n := &mapArrayNode{}
+ for i := 0; i < 8; i++ {
+ var resized bool
+ n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized).(*mapArrayNode)
+ if !resized {
+ t.Fatal("expected resize")
+ }
+
+ for j := 0; j < i; j++ {
+ if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j {
+ t.Fatalf("get(%d)=<%v,%v>", j, v, ok)
+ }
+ }
+ }
+ })
+
+ // Ensure 8 or fewer elements stays in an array node when inserted in reverse.
+ t.Run("Prepend", func(t *testing.T) {
+ var h intHasher
+ n := &mapArrayNode{}
+ for i := 7; i >= 0; i-- {
+ var resized bool
+ n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized).(*mapArrayNode)
+ if !resized {
+ t.Fatal("expected resize")
+ }
+
+ for j := i; j <= 7; j++ {
+ if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j {
+ t.Fatalf("get(%d)=<%v,%v>", j, v, ok)
+ }
+ }
+ }
+ })
+
+ // Ensure array can transition between node types.
+ t.Run("Expand", func(t *testing.T) {
+ var h intHasher
+ var n mapNode = &mapArrayNode{}
+ for i := 0; i < 100; i++ {
+ var resized bool
+ n = n.set(i, i, 0, h.Hash(i), &h, &resized)
+ if !resized {
+ t.Fatal("expected resize")
+ }
+
+ for j := 0; j < i; j++ {
+ if v, ok := n.get(j, 0, h.Hash(j), &h); !ok || v != j {
+ t.Fatalf("get(%d)=<%v,%v>", j, v, ok)
+ }
+ }
+ }
+ })
+
+ // Ensure deleting elements returns the correct new node.
+ RunRandom(t, "Delete", func(t *testing.T, rand *rand.Rand) {
+ var h intHasher
+ var n mapNode = &mapArrayNode{}
+ for i := 0; i < 8; i++ {
+ var resized bool
+ n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized)
+ }
+
+ for _, i := range rand.Perm(8) {
+ n = n.delete(i*10, 0, h.Hash(i*10), &h)
+ }
+ if n != nil {
+ t.Fatal("expected nil rand")
+ }
+ })
+}
+
+func TestIngernal_mapValueNode(t *testing.T) {
+ t.Run("Simple", func(t *testing.T) {
+ var h intHasher
+ n := newMapValueNode(h.Hash(2), 2, 3)
+ if v, ok := n.get(2, 0, h.Hash(2), &h); !ok {
+ t.Fatal("expected ok")
+ } else if v != 3 {
+ t.Fatalf("unexpected value: %v", v)
+ }
+ })
+
+ t.Run("KeyEqual", func(t *testing.T) {
+ var h intHasher
+ var resized bool
+ n := newMapValueNode(h.Hash(2), 2, 3)
+ other := n.set(2, 4, 0, h.Hash(2), &h, &resized).(*mapValueNode)
+ if other == n {
+ t.Fatal("expected new node")
+ } else if got, exp := other.keyHash, h.Hash(2); got != exp {
+ t.Fatalf("keyHash=%v, expected %v", got, exp)
+ } else if got, exp := other.key, 2; got != exp {
+ t.Fatalf("key=%v, expected %v", got, exp)
+ } else if got, exp := other.value, 4; got != exp {
+ t.Fatalf("value=%v, expected %v", got, exp)
+ } else if resized {
+ t.Fatal("unexpected resize")
+ }
+ })
+
+ t.Run("KeyHashEqual", func(t *testing.T) {
+ h := &mockHasher{
+ hash: func(value interface{}) uint32 { return 1 },
+ equal: func(a, b interface{}) bool { return a.(int) == b.(int) },
+ }
+ var resized bool
+ n := newMapValueNode(h.Hash(2), 2, 3)
+ other := n.set(4, 5, 0, h.Hash(4), h, &resized).(*mapHashCollisionNode)
+ if got, exp := other.keyHash, h.Hash(2); got != exp {
+ t.Fatalf("keyHash=%v, expected %v", got, exp)
+ } else if got, exp := len(other.entries), 2; got != exp {
+ t.Fatalf("entries=%v, expected %v", got, exp)
+ } else if !resized {
+ t.Fatal("expected resize")
+ }
+ if got, exp := other.entries[0].key, 2; got != exp {
+ t.Fatalf("key[0]=%v, expected %v", got, exp)
+ } else if got, exp := other.entries[0].value, 3; got != exp {
+ t.Fatalf("value[0]=%v, expected %v", got, exp)
+ }
+ if got, exp := other.entries[1].key, 4; got != exp {
+ t.Fatalf("key[1]=%v, expected %v", got, exp)
+ } else if got, exp := other.entries[1].value, 5; got != exp {
+ t.Fatalf("value[1]=%v, expected %v", got, exp)
+ }
+ })
+
+ t.Run("MergeNode", func(t *testing.T) {
+ // Inserting into a node with a different index in the mask should split into a bitmap node.
+ t.Run("NoConflict", func(t *testing.T) {
+ var h intHasher
+ var resized bool
+ n := newMapValueNode(h.Hash(2), 2, 3)
+ other := n.set(4, 5, 0, h.Hash(4), &h, &resized).(*mapBitmapIndexedNode)
+ if got, exp := other.bitmap, uint32(0x14); got != exp {
+ t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp)
+ } else if got, exp := len(other.nodes), 2; got != exp {
+ t.Fatalf("nodes=%v, expected %v", got, exp)
+ } else if !resized {
+ t.Fatal("expected resize")
+ }
+ if node, ok := other.nodes[0].(*mapValueNode); !ok {
+ t.Fatalf("node[0]=%T, unexpected type", other.nodes[0])
+ } else if got, exp := node.key, 2; got != exp {
+ t.Fatalf("key[0]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 3; got != exp {
+ t.Fatalf("value[0]=%v, expected %v", got, exp)
+ }
+ if node, ok := other.nodes[1].(*mapValueNode); !ok {
+ t.Fatalf("node[1]=%T, unexpected type", other.nodes[1])
+ } else if got, exp := node.key, 4; got != exp {
+ t.Fatalf("key[1]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 5; got != exp {
+ t.Fatalf("value[1]=%v, expected %v", got, exp)
+ }
+
+ // Ensure both values can be read.
+ if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 {
+ t.Fatalf("Get(2)=<%v,%v>", v, ok)
+ } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 {
+ t.Fatalf("Get(4)=<%v,%v>", v, ok)
+ }
+ })
+
+ // Reversing the nodes from NoConflict should yield the same result.
+ t.Run("NoConflictReverse", func(t *testing.T) {
+ var h intHasher
+ var resized bool
+ n := newMapValueNode(h.Hash(4), 4, 5)
+ other := n.set(2, 3, 0, h.Hash(2), &h, &resized).(*mapBitmapIndexedNode)
+ if got, exp := other.bitmap, uint32(0x14); got != exp {
+ t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp)
+ } else if got, exp := len(other.nodes), 2; got != exp {
+ t.Fatalf("nodes=%v, expected %v", got, exp)
+ } else if !resized {
+ t.Fatal("expected resize")
+ }
+ if node, ok := other.nodes[0].(*mapValueNode); !ok {
+ t.Fatalf("node[0]=%T, unexpected type", other.nodes[0])
+ } else if got, exp := node.key, 2; got != exp {
+ t.Fatalf("key[0]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 3; got != exp {
+ t.Fatalf("value[0]=%v, expected %v", got, exp)
+ }
+ if node, ok := other.nodes[1].(*mapValueNode); !ok {
+ t.Fatalf("node[1]=%T, unexpected type", other.nodes[1])
+ } else if got, exp := node.key, 4; got != exp {
+ t.Fatalf("key[1]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 5; got != exp {
+ t.Fatalf("value[1]=%v, expected %v", got, exp)
+ }
+
+ // Ensure both values can be read.
+ if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 {
+ t.Fatalf("Get(2)=<%v,%v>", v, ok)
+ } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 {
+ t.Fatalf("Get(4)=<%v,%v>", v, ok)
+ }
+ })
+
+ // Inserting a node with the same mask index should nest an additional level of bitmap nodes.
+ t.Run("Conflict", func(t *testing.T) {
+ h := &mockHasher{
+ hash: func(value interface{}) uint32 { return uint32(value.(int)) << 5 },
+ equal: func(a, b interface{}) bool { return a.(int) == b.(int) },
+ }
+ var resized bool
+ n := newMapValueNode(h.Hash(2), 2, 3)
+ other := n.set(4, 5, 0, h.Hash(4), h, &resized).(*mapBitmapIndexedNode)
+ if got, exp := other.bitmap, uint32(0x01); got != exp { // mask is zero, expect first slot.
+ t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp)
+ } else if got, exp := len(other.nodes), 1; got != exp {
+ t.Fatalf("nodes=%v, expected %v", got, exp)
+ } else if !resized {
+ t.Fatal("expected resize")
+ }
+ child, ok := other.nodes[0].(*mapBitmapIndexedNode)
+ if !ok {
+ t.Fatalf("node[0]=%T, unexpected type", other.nodes[0])
+ }
+
+ if node, ok := child.nodes[0].(*mapValueNode); !ok {
+ t.Fatalf("node[0]=%T, unexpected type", child.nodes[0])
+ } else if got, exp := node.key, 2; got != exp {
+ t.Fatalf("key[0]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 3; got != exp {
+ t.Fatalf("value[0]=%v, expected %v", got, exp)
+ }
+ if node, ok := child.nodes[1].(*mapValueNode); !ok {
+ t.Fatalf("node[1]=%T, unexpected type", child.nodes[1])
+ } else if got, exp := node.key, 4; got != exp {
+ t.Fatalf("key[1]=%v, expected %v", got, exp)
+ } else if got, exp := node.value, 5; got != exp {
+ t.Fatalf("value[1]=%v, expected %v", got, exp)
+ }
+
+ // Ensure both values can be read.
+ if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v.(int) != 3 {
+ t.Fatalf("Get(2)=<%v,%v>", v, ok)
+ } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v.(int) != 5 {
+ t.Fatalf("Get(4)=<%v,%v>", v, ok)
+ } else if v, ok := other.get(10, 0, h.Hash(10), h); ok {
+ t.Fatalf("Get(10)=<%v,%v>, expected no value", v, ok)
+ }
+ })
+ })
+}
+
+func TestMap_Get(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ m := NewMap(nil)
+ if v, ok := m.Get(100); ok || v != nil {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ })
+}
+
+func TestMap_Set(t *testing.T) {
+ t.Run("Simple", func(t *testing.T) {
+ m := NewMap(nil)
+ itr := m.Iterator()
+ if !itr.Done() {
+ t.Fatal("MapIterator.Done()=true, expected false")
+ } else if k, v := itr.Next(); k != nil || v != nil {
+ t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v)
+ }
+ })
+
+ t.Run("Simple", func(t *testing.T) {
+ m := NewMap(nil)
+ m = m.Set(100, "foo")
+ if v, ok := m.Get(100); !ok || v != "foo" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("VerySmall", func(t *testing.T) {
+ const n = 6
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+
+ // NOTE: Array nodes store entries in insertion order.
+ itr := m.Iterator()
+ for i := 0; i < n; i++ {
+ if k, v := itr.Next(); k != i || v != i+1 {
+ t.Fatalf("MapIterator.Next()=<%v,%v>, exp <%v,%v>", k, v, i, i+1)
+ }
+ }
+ if !itr.Done() {
+ t.Fatal("expected iterator done")
+ }
+ })
+
+ t.Run("Small", func(t *testing.T) {
+ const n = 1000
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("Large", func(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping: short")
+ }
+
+ const n = 1000000
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("StringKeys", func(t *testing.T) {
+ m := NewMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", "bat")
+ m = m.Set("", "EMPTY")
+ if v, ok := m.Get("foo"); !ok || v != "bar" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get("baz"); !ok || v != "bat" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get(""); !ok || v != "EMPTY" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ if v, ok := m.Get("no_such_key"); ok {
+ t.Fatalf("expected no value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("ByteSliceKeys", func(t *testing.T) {
+ m := NewMap(nil)
+ m = m.Set([]byte("foo"), "bar")
+ m = m.Set([]byte("baz"), "bat")
+ m = m.Set([]byte(""), "EMPTY")
+ if v, ok := m.Get([]byte("foo")); !ok || v != "bar" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ if v, ok := m.Get([]byte("no_such_key")); ok {
+ t.Fatalf("expected no value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("NoDefaultHasher", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ m := NewMap(nil)
+ m = m.Set(uint64(100), "bar")
+ }()
+ if r != `immutable.Map.Set: must set hasher for uint64 type` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) {
+ m := NewTestMap()
+ for i := 0; i < 100000; i++ {
+ switch rand.Intn(2) {
+ case 1: // overwrite
+ m.Set(m.ExistingKey(rand), rand.Intn(10000))
+ default: // set new key
+ m.Set(m.NewKey(rand), rand.Intn(10000))
+ }
+ }
+ if err := m.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
+// Ensure map can support overwrites as it expands.
+func TestMap_Overwrite(t *testing.T) {
+ const n = 10000
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ // Set original value.
+ m = m.Set(i, i)
+
+ // Overwrite every node.
+ for j := 0; j <= i; j++ {
+ m = m.Set(j, i*j)
+ }
+ }
+
+ // Verify all key/value pairs in map.
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i*(n-1) {
+ t.Fatalf("Get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+}
+
+func TestMap_Delete(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ m := NewMap(nil)
+ other := m.Delete("foo")
+ if m != other {
+ t.Fatal("expected same map")
+ }
+ })
+
+ t.Run("Simple", func(t *testing.T) {
+ m := NewMap(nil)
+ m = m.Set(100, "foo")
+ if v, ok := m.Get(100); !ok || v != "foo" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("Small", func(t *testing.T) {
+ const n = 1000
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := range rand.New(rand.NewSource(0)).Perm(n) {
+ m = m.Delete(i)
+ }
+ if m.Len() != 0 {
+ t.Fatalf("expected no elements, got %d", m.Len())
+ }
+ })
+
+ t.Run("Large", func(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping: short")
+ }
+ const n = 1000000
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := range rand.New(rand.NewSource(0)).Perm(n) {
+ m = m.Delete(i)
+ }
+ if m.Len() != 0 {
+ t.Fatalf("expected no elements, got %d", m.Len())
+ }
+ })
+
+ RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) {
+ m := NewTestMap()
+ for i := 0; i < 100000; i++ {
+ switch rand.Intn(8) {
+ case 0: // overwrite
+ m.Set(m.ExistingKey(rand), rand.Intn(10000))
+ case 1: // delete existing key
+ m.Delete(m.ExistingKey(rand))
+ case 2: // delete non-existent key.
+ m.Delete(m.NewKey(rand))
+ default: // set new key
+ m.Set(m.NewKey(rand), rand.Intn(10000))
+ }
+ }
+ if err := m.Validate(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Delete all and verify they are gone.
+ keys := make([]int, len(m.keys))
+ copy(keys, m.keys)
+
+ for _, key := range keys {
+ m.Delete(key)
+ }
+ if err := m.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
+// Ensure map works even with hash conflicts.
+func TestMap_LimitedHash(t *testing.T) {
+ h := mockHasher{
+ hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF },
+ equal: func(a, b interface{}) bool { return a.(int) == b.(int) },
+ }
+ m := NewMap(&h)
+
+ rand := rand.New(rand.NewSource(0))
+ keys := rand.Perm(100000)
+ for _, i := range keys {
+ m = m.Set(i, i) // initial set
+ }
+ for i := range keys {
+ m = m.Set(i, i*2) // overwrite
+ }
+ if m.Len() != len(keys) {
+ t.Fatalf("unexpected len: %d", m.Len())
+ }
+
+ // Verify all key/value pairs in map.
+ for i := 0; i < m.Len(); i++ {
+ if v, ok := m.Get(i); !ok || v != i*2 {
+ t.Fatalf("Get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+
+ // Verify iteration.
+ itr := m.Iterator()
+ for !itr.Done() {
+ if k, v := itr.Next(); v != k.(int)*2 {
+ t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k.(int)*2)
+ }
+ }
+
+ // Verify not found works.
+ if _, ok := m.Get(10000000); ok {
+ t.Fatal("expected no value")
+ }
+
+ // Verify delete non-existent key works.
+ if other := m.Delete(10000000 + 1); m != other {
+ t.Fatal("expected no change")
+ }
+
+ // Remove all keys.
+ for _, key := range keys {
+ m = m.Delete(key)
+ }
+ if m.Len() != 0 {
+ t.Fatalf("unexpected size: %d", m.Len())
+ }
+}
+
+// TestMap represents a combined immutable and stdlib map.
+type TestMap struct {
+ im, prev *Map
+ std map[int]int
+ keys []int
+}
+
+func NewTestMap() *TestMap {
+ return &TestMap{
+ im: NewMap(nil),
+ std: make(map[int]int),
+ }
+}
+
+func (m *TestMap) NewKey(rand *rand.Rand) int {
+ for {
+ k := rand.Int()
+ if _, ok := m.std[k]; !ok {
+ return k
+ }
+ }
+}
+
+func (m *TestMap) ExistingKey(rand *rand.Rand) int {
+ if len(m.keys) == 0 {
+ return 0
+ }
+ return m.keys[rand.Intn(len(m.keys))]
+}
+
+func (m *TestMap) Set(k, v int) {
+ m.prev = m.im
+ m.im = m.im.Set(k, v)
+
+ _, exists := m.std[k]
+ if !exists {
+ m.keys = append(m.keys, k)
+ }
+ m.std[k] = v
+}
+
+func (m *TestMap) Delete(k int) {
+ m.prev = m.im
+ m.im = m.im.Delete(k)
+ delete(m.std, k)
+
+ for i := range m.keys {
+ if m.keys[i] == k {
+ m.keys = append(m.keys[:i], m.keys[i+1:]...)
+ break
+ }
+ }
+}
+
+func (m *TestMap) Validate() error {
+ for _, k := range m.keys {
+ if v, ok := m.im.Get(k); !ok {
+ return fmt.Errorf("key not found: %d", k)
+ } else if v != m.std[k] {
+ return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k])
+ }
+ }
+ if err := m.validateIterator(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (m *TestMap) validateIterator() error {
+ other := make(map[int]int)
+ itr := m.im.Iterator()
+ for !itr.Done() {
+ k, v := itr.Next()
+ other[k.(int)] = v.(int)
+ }
+ if diff := cmp.Diff(other, m.std); diff != "" {
+ return fmt.Errorf("map iterator mismatch: %s", diff)
+ }
+ if k, v := itr.Next(); k != nil || v != nil {
+ return fmt.Errorf("map iterator returned key/value after done: <%v/%v>", k, v)
+ }
+ return nil
+}
+
+func BenchmarkMap_Set(b *testing.B) {
+ b.ReportAllocs()
+ m := NewMap(nil)
+ for i := 0; i < b.N; i++ {
+ m = m.Set(i, i)
+ }
+}
+
+func BenchmarkMap_Delete(b *testing.B) {
+ const n = 10000
+
+ m := NewMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ m.Delete(i % n) // Do not update map, always operate on original
+ }
+}
+
+func BenchmarkMap_Iterator(b *testing.B) {
+ const n = 10000
+ m := NewMap(nil)
+ for i := 0; i < 10000; i++ {
+ m = m.Set(i, i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.Run("Forward", func(b *testing.B) {
+ itr := m.Iterator()
+ for i := 0; i < b.N; i++ {
+ if i%n == 0 {
+ itr.First()
+ }
+ itr.Next()
+ }
+ })
+}
+
+func ExampleMap_Set() {
+ m := NewMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", 100)
+
+ v, ok := m.Get("foo")
+ fmt.Println("foo", v, ok)
+
+ v, ok = m.Get("baz")
+ fmt.Println("baz", v, ok)
+
+ v, ok = m.Get("bat") // does not exist
+ fmt.Println("bat", v, ok)
+ // Output:
+ // foo bar true
+ // baz 100 true
+ // bat <nil> false
+}
+
+func ExampleMap_Delete() {
+ m := NewMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", 100)
+ m = m.Delete("baz")
+
+ v, ok := m.Get("foo")
+ fmt.Println("foo", v, ok)
+
+ v, ok = m.Get("baz")
+ fmt.Println("baz", v, ok)
+ // Output:
+ // foo bar true
+ // baz <nil> false
+}
+
+func ExampleMap_Iterator() {
+ m := NewMap(nil)
+ m = m.Set("apple", 100)
+ m = m.Set("grape", 200)
+ m = m.Set("kiwi", 300)
+ m = m.Set("mango", 400)
+ m = m.Set("orange", 500)
+ m = m.Set("peach", 600)
+ m = m.Set("pear", 700)
+ m = m.Set("pineapple", 800)
+ m = m.Set("strawberry", 900)
+
+ itr := m.Iterator()
+ for !itr.Done() {
+ k, v := itr.Next()
+ fmt.Println(k, v)
+ }
+ // Output:
+ // mango 400
+ // pear 700
+ // pineapple 800
+ // grape 200
+ // orange 500
+ // strawberry 900
+ // kiwi 300
+ // peach 600
+ // apple 100
+}
+
+func TestIngernalSortedMapLeafNode(t *testing.T) {
+ RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) {
+ var cmpr intComparer
+ var node sortedMapNode = &sortedMapLeafNode{}
+ var keys []int
+ for _, i := range rand.Perm(32) {
+ var resized bool
+ var splitNode sortedMapNode
+ node, splitNode = node.set(i, i*10, &cmpr, &resized)
+ if !resized {
+ t.Fatal("expected resize")
+ } else if splitNode != nil {
+ t.Fatal("expected split")
+ }
+ keys = append(keys, i)
+
+ // Verify not found at each size.
+ if _, ok := node.get(rand.Int()+32, &cmpr); ok {
+ t.Fatal("expected no value")
+ }
+
+ // Verify min key is always the lowest.
+ sort.Ints(keys)
+ if got, exp := node.minKey(), keys[0]; got != exp {
+ t.Fatalf("minKey()=%d, expected %d", got, exp)
+ }
+ }
+
+ // Verify all key/value pairs in node.
+ for i := range keys {
+ if v, ok := node.get(i, &cmpr); !ok || v != i*10 {
+ t.Fatalf("get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ RunRandom(t, "Overwrite", func(t *testing.T, rand *rand.Rand) {
+ var cmpr intComparer
+ var node sortedMapNode = &sortedMapLeafNode{}
+ for _, i := range rand.Perm(32) {
+ var resized bool
+ node, _ = node.set(i, i*2, &cmpr, &resized)
+ }
+ for _, i := range rand.Perm(32) {
+ var resized bool
+ node, _ = node.set(i, i*3, &cmpr, &resized)
+ if resized {
+ t.Fatal("expected no resize")
+ }
+ }
+
+ // Verify all overwritten key/value pairs in node.
+ for i := 0; i < 32; i++ {
+ if v, ok := node.get(i, &cmpr); !ok || v != i*3 {
+ t.Fatalf("get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("Split", func(t *testing.T) {
+ // Fill leaf node.
+ var cmpr intComparer
+ var node sortedMapNode = &sortedMapLeafNode{}
+ for i := 0; i < 32; i++ {
+ var resized bool
+ node, _ = node.set(i, i*10, &cmpr, &resized)
+ }
+
+ // Add one more and expect split.
+ var resized bool
+ newNode, splitNode := node.set(32, 320, &cmpr, &resized)
+
+ // Verify node contents.
+ newLeafNode, ok := newNode.(*sortedMapLeafNode)
+ if !ok {
+ t.Fatalf("unexpected node type: %T", newLeafNode)
+ } else if n := len(newLeafNode.entries); n != 16 {
+ t.Fatalf("unexpected node len: %d", n)
+ }
+ for i := range newLeafNode.entries {
+ if entry := newLeafNode.entries[i]; entry.key != i || entry.value != i*10 {
+ t.Fatalf("%d. unexpected entry: %v=%v", i, entry.key, entry.value)
+ }
+ }
+
+ // Verify split node contents.
+ splitLeafNode, ok := splitNode.(*sortedMapLeafNode)
+ if !ok {
+ t.Fatalf("unexpected split node type: %T", splitLeafNode)
+ } else if n := len(splitLeafNode.entries); n != 17 {
+ t.Fatalf("unexpected split node len: %d", n)
+ }
+ for i := range splitLeafNode.entries {
+ if entry := splitLeafNode.entries[i]; entry.key != (i+16) || entry.value != (i+16)*10 {
+ t.Fatalf("%d. unexpected split node entry: %v=%v", i, entry.key, entry.value)
+ }
+ }
+ })
+}
+
+func TestIngernalSortedMapBranchNode(t *testing.T) {
+ RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) {
+ keys := make([]int, 32*16)
+ for i := range keys {
+ keys[i] = rand.Intn(10000)
+ }
+ keys = uniqueIntSlice(keys)
+ sort.Ints(keys[:2]) // ensure first two keys are sorted for initial insert.
+
+ // Initialize branch with two leafs.
+ var cmpr intComparer
+ leaf0 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[0], value: keys[0] * 10}}}
+ leaf1 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[1], value: keys[1] * 10}}}
+ var node sortedMapNode = newSortedMapBranchNode(leaf0, leaf1)
+
+ sort.Ints(keys)
+ for _, i := range rand.Perm(len(keys)) {
+ key := keys[i]
+
+ var resized bool
+ var splitNode sortedMapNode
+ node, splitNode = node.set(key, key*10, &cmpr, &resized)
+ if key == leaf0.entries[0].key || key == leaf1.entries[0].key {
+ if resized {
+ t.Fatalf("expected no resize: key=%d", key)
+ }
+ } else {
+ if !resized {
+ t.Fatalf("expected resize: key=%d", key)
+ }
+ }
+ if splitNode != nil {
+ t.Fatal("unexpected split")
+ }
+ }
+
+ // Verify all key/value pairs in node.
+ for _, key := range keys {
+ if v, ok := node.get(key, &cmpr); !ok || v != key*10 {
+ t.Fatalf("get(%d)=<%v,%v>", key, v, ok)
+ }
+ }
+
+ // Verify min key is the lowest key.
+ if got, exp := node.minKey(), keys[0]; got != exp {
+ t.Fatalf("minKey()=%d, expected %d", got, exp)
+ }
+ })
+
+ t.Run("Split", func(t *testing.T) {
+ // Generate leaf nodes.
+ var cmpr intComparer
+ children := make([]sortedMapNode, 32)
+ for i := range children {
+ leaf := &sortedMapLeafNode{entries: make([]mapEntry, 32)}
+ for j := range leaf.entries {
+ leaf.entries[j] = mapEntry{key: (i * 32) + j, value: ((i * 32) + j) * 100}
+ }
+ children[i] = leaf
+ }
+ var node sortedMapNode = newSortedMapBranchNode(children...)
+
+ // Add one more and expect split.
+ var resized bool
+ newNode, splitNode := node.set((32 * 32), (32*32)*100, &cmpr, &resized)
+
+ // Verify node contents.
+ var idx int
+ newBranchNode, ok := newNode.(*sortedMapBranchNode)
+ if !ok {
+ t.Fatalf("unexpected node type: %T", newBranchNode)
+ } else if n := len(newBranchNode.elems); n != 16 {
+ t.Fatalf("unexpected child elems len: %d", n)
+ }
+ for i, elem := range newBranchNode.elems {
+ child, ok := elem.node.(*sortedMapLeafNode)
+ if !ok {
+ t.Fatalf("unexpected child type")
+ }
+ for j, entry := range child.entries {
+ if entry.key != idx || entry.value != idx*100 {
+ t.Fatalf("%d/%d. unexpected entry: %v=%v", i, j, entry.key, entry.value)
+ }
+ idx++
+ }
+ }
+
+ // Verify split node contents.
+ splitBranchNode, ok := splitNode.(*sortedMapBranchNode)
+ if !ok {
+ t.Fatalf("unexpected split node type: %T", splitBranchNode)
+ } else if n := len(splitBranchNode.elems); n != 17 {
+ t.Fatalf("unexpected split node elem len: %d", n)
+ }
+ for i, elem := range splitBranchNode.elems {
+ child, ok := elem.node.(*sortedMapLeafNode)
+ if !ok {
+ t.Fatalf("unexpected split node child type")
+ }
+ for j, entry := range child.entries {
+ if entry.key != idx || entry.value != idx*100 {
+ t.Fatalf("%d/%d. unexpected split node entry: %v=%v", i, j, entry.key, entry.value)
+ }
+ idx++
+ }
+ }
+ })
+}
+
+func TestSortedMap_Get(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ if v, ok := m.Get(100); ok || v != nil {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ })
+}
+
+func TestSortedMap_Set(t *testing.T) {
+ t.Run("Simple", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ m = m.Set(100, "foo")
+ if v, ok := m.Get(100); !ok || v != "foo" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if got, exp := m.Len(), 1; got != exp {
+ t.Fatalf("SortedMap.Len()=%d, exp %d", got, exp)
+ }
+ })
+
+ t.Run("Small", func(t *testing.T) {
+ const n = 1000
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("Large", func(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping: short")
+ }
+
+ const n = 1000000
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("StringKeys", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", "bat")
+ m = m.Set("", "EMPTY")
+ if v, ok := m.Get("foo"); !ok || v != "bar" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get("baz"); !ok || v != "bat" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get(""); !ok || v != "EMPTY" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ if v, ok := m.Get("no_such_key"); ok {
+ t.Fatalf("expected no value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("ByteSliceKeys", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ m = m.Set([]byte("foo"), "bar")
+ m = m.Set([]byte("baz"), "bat")
+ m = m.Set([]byte(""), "EMPTY")
+ if v, ok := m.Get([]byte("foo")); !ok || v != "bar" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ if v, ok := m.Get([]byte("no_such_key")); ok {
+ t.Fatalf("expected no value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("NoDefaultComparer", func(t *testing.T) {
+ var r string
+ func() {
+ defer func() { r = recover().(string) }()
+ m := NewSortedMap(nil)
+ m = m.Set(uint64(100), "bar")
+ }()
+ if r != `immutable.SortedMap.Set: must set comparer for uint64 type` {
+ t.Fatalf("unexpected panic: %q", r)
+ }
+ })
+
+ RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) {
+ m := NewTestSortedMap()
+ for j := 0; j < 10000; j++ {
+ switch rand.Intn(2) {
+ case 1: // overwrite
+ m.Set(m.ExistingKey(rand), rand.Intn(10000))
+ default: // set new key
+ m.Set(m.NewKey(rand), rand.Intn(10000))
+ }
+ }
+ if err := m.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
+// Ensure map can support overwrites as it expands.
+func TestSortedMap_Overwrite(t *testing.T) {
+ const n = 1000
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ // Set original value.
+ m = m.Set(i, i)
+
+ // Overwrite every node.
+ for j := 0; j <= i; j++ {
+ m = m.Set(j, i*j)
+ }
+ }
+
+ // Verify all key/value pairs in map.
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i*(n-1) {
+ t.Fatalf("Get(%d)=<%v,%v>", i, v, ok)
+ }
+ }
+}
+
+func TestSortedMap_Delete(t *testing.T) {
+ t.Run("Empty", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ m = m.Delete(100)
+ if n := m.Len(); n != 0 {
+ t.Fatalf("SortedMap.Len()=%d, expected 0", n)
+ }
+ })
+
+ t.Run("Simple", func(t *testing.T) {
+ m := NewSortedMap(nil)
+ m = m.Set(100, "foo")
+ if v, ok := m.Get(100); !ok || v != "foo" {
+ t.Fatalf("unexpected value: <%v,%v>", v, ok)
+ }
+ m = m.Delete(100)
+ if v, ok := m.Get(100); ok {
+ t.Fatalf("unexpected no value: <%v,%v>", v, ok)
+ }
+ })
+
+ t.Run("Small", func(t *testing.T) {
+ const n = 1000
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+
+ for i := 0; i < n; i++ {
+ m = m.Delete(i)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); ok {
+ t.Fatalf("expected no value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ t.Run("Large", func(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping: short")
+ }
+
+ const n = 1000000
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i+1)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); !ok || v != i+1 {
+ t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+
+ for i := 0; i < n; i++ {
+ m = m.Delete(i)
+ }
+ for i := 0; i < n; i++ {
+ if v, ok := m.Get(i); ok {
+ t.Fatalf("unexpected no value for key=%v: <%v,%v>", i, v, ok)
+ }
+ }
+ })
+
+ RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) {
+ m := NewTestSortedMap()
+ for j := 0; j < 10000; j++ {
+ switch rand.Intn(8) {
+ case 0: // overwrite
+ m.Set(m.ExistingKey(rand), rand.Intn(10000))
+ case 1: // delete existing key
+ m.Delete(m.ExistingKey(rand))
+ case 2: // delete non-existent key.
+ m.Delete(m.NewKey(rand))
+ default: // set new key
+ m.Set(m.NewKey(rand), rand.Intn(10000))
+ }
+ }
+ if err := m.Validate(); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
+func TestSortedMap_Iterator(t *testing.T) {
+ t.Run("Seek", func(t *testing.T) {
+ const n = 100
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i += 2 {
+ m = m.Set(fmt.Sprintf("%04d", i), i)
+ }
+
+ t.Run("Exact", func(t *testing.T) {
+ itr := m.Iterator()
+ for i := 0; i < n; i += 2 {
+ itr.Seek(fmt.Sprintf("%04d", i))
+ for j := i; j < n; j += 2 {
+ if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) {
+ t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j)
+ }
+ }
+ if !itr.Done() {
+ t.Fatalf("SortedMapIterator.Done()=true, expected false")
+ }
+ }
+ })
+
+ t.Run("Miss", func(t *testing.T) {
+ itr := m.Iterator()
+ for i := 1; i < n-2; i += 2 {
+ itr.Seek(fmt.Sprintf("%04d", i))
+ for j := i + 1; j < n; j += 2 {
+ if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) {
+ t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j)
+ }
+ }
+ if !itr.Done() {
+ t.Fatalf("SortedMapIterator.Done()=true, expected false")
+ }
+ }
+ })
+
+ t.Run("BeforeFirst", func(t *testing.T) {
+ itr := m.Iterator()
+ itr.Seek("")
+ for i := 0; i < n; i += 2 {
+ if k, _ := itr.Next(); k != fmt.Sprintf("%04d", i) {
+ t.Fatalf("%d. SortedMapIterator.Next()=%v, expected key %04d", i, k, i)
+ }
+ }
+ if !itr.Done() {
+ t.Fatalf("SortedMapIterator.Done()=true, expected false")
+ }
+ })
+ t.Run("AfterLast", func(t *testing.T) {
+ itr := m.Iterator()
+ itr.Seek("1000")
+ if k, _ := itr.Next(); k != nil {
+ t.Fatalf("0. SortedMapIterator.Next()=%v, expected nil key", k)
+ } else if !itr.Done() {
+ t.Fatalf("SortedMapIterator.Done()=true, expected false")
+ }
+ })
+ })
+}
+
+// TestSortedMap represents a combined immutable and stdlib sorted map.
+type TestSortedMap struct {
+ im, prev *SortedMap
+ std map[int]int
+ keys []int
+}
+
+func NewTestSortedMap() *TestSortedMap {
+ return &TestSortedMap{
+ im: NewSortedMap(nil),
+ std: make(map[int]int),
+ }
+}
+
+func (m *TestSortedMap) NewKey(rand *rand.Rand) int {
+ for {
+ k := rand.Int()
+ if _, ok := m.std[k]; !ok {
+ return k
+ }
+ }
+}
+
+func (m *TestSortedMap) ExistingKey(rand *rand.Rand) int {
+ if len(m.keys) == 0 {
+ return 0
+ }
+ return m.keys[rand.Intn(len(m.keys))]
+}
+
+func (m *TestSortedMap) Set(k, v int) {
+ m.prev = m.im
+ m.im = m.im.Set(k, v)
+
+ if _, ok := m.std[k]; !ok {
+ m.keys = append(m.keys, k)
+ sort.Ints(m.keys)
+ }
+ m.std[k] = v
+}
+
+func (m *TestSortedMap) Delete(k int) {
+ m.prev = m.im
+ m.im = m.im.Delete(k)
+ delete(m.std, k)
+
+ for i := range m.keys {
+ if m.keys[i] == k {
+ m.keys = append(m.keys[:i], m.keys[i+1:]...)
+ break
+ }
+ }
+}
+
+func (m *TestSortedMap) Validate() error {
+ for _, k := range m.keys {
+ if v, ok := m.im.Get(k); !ok {
+ return fmt.Errorf("key not found: %d", k)
+ } else if v != m.std[k] {
+ return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k])
+ }
+ }
+
+ sort.Ints(m.keys)
+ if err := m.validateForwardIterator(); err != nil {
+ return err
+ } else if err := m.validateBackwardIterator(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (m *TestSortedMap) validateForwardIterator() error {
+ itr := m.im.Iterator()
+ for i, k0 := range m.keys {
+ v0 := m.std[k0]
+ if k1, v1 := itr.Next(); k0 != k1 || v0 != v1 {
+ return fmt.Errorf("%d. SortedMapIterator.Next()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0)
+ }
+
+ done := i == len(m.keys)-1
+ if v := itr.Done(); v != done {
+ return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done)
+ }
+ }
+ if k, v := itr.Next(); k != nil || v != nil {
+ return fmt.Errorf("SortedMapIterator.Next()=<%v,%v>, expected nil after done", k, v)
+ }
+ return nil
+}
+
+func (m *TestSortedMap) validateBackwardIterator() error {
+ itr := m.im.Iterator()
+ itr.Last()
+ for i := len(m.keys) - 1; i >= 0; i-- {
+ k0 := m.keys[i]
+ v0 := m.std[k0]
+ if k1, v1 := itr.Prev(); k0 != k1 || v0 != v1 {
+ return fmt.Errorf("%d. SortedMapIterator.Prev()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0)
+ }
+
+ done := i == 0
+ if v := itr.Done(); v != done {
+ return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done)
+ }
+ }
+ if k, v := itr.Prev(); k != nil || v != nil {
+ return fmt.Errorf("SortedMapIterator.Prev()=<%v,%v>, expected nil after done", k, v)
+ }
+ return nil
+}
+
+func BenchmarkSortedMap_Set(b *testing.B) {
+ b.ReportAllocs()
+ m := NewSortedMap(nil)
+ for i := 0; i < b.N; i++ {
+ m = m.Set(i, i)
+ }
+}
+
+func BenchmarkSortedMap_Delete(b *testing.B) {
+ const n = 10000
+
+ m := NewSortedMap(nil)
+ for i := 0; i < n; i++ {
+ m = m.Set(i, i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ m.Delete(i % n) // Do not update map, always operate on original
+ }
+}
+
+func BenchmarkSortedMap_Iterator(b *testing.B) {
+ const n = 10000
+ m := NewSortedMap(nil)
+ for i := 0; i < 10000; i++ {
+ m = m.Set(i, i)
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.Run("Forward", func(b *testing.B) {
+ itr := m.Iterator()
+ for i := 0; i < b.N; i++ {
+ if i%n == 0 {
+ itr.First()
+ }
+ itr.Next()
+ }
+ })
+
+ b.Run("Reverse", func(b *testing.B) {
+ itr := m.Iterator()
+ for i := 0; i < b.N; i++ {
+ if i%n == 0 {
+ itr.Last()
+ }
+ itr.Prev()
+ }
+ })
+}
+
+func ExampleSortedMap_Set() {
+ m := NewSortedMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", 100)
+
+ v, ok := m.Get("foo")
+ fmt.Println("foo", v, ok)
+
+ v, ok = m.Get("baz")
+ fmt.Println("baz", v, ok)
+
+ v, ok = m.Get("bat") // does not exist
+ fmt.Println("bat", v, ok)
+ // Output:
+ // foo bar true
+ // baz 100 true
+ // bat <nil> false
+}
+
+func ExampleSortedMap_Delete() {
+ m := NewSortedMap(nil)
+ m = m.Set("foo", "bar")
+ m = m.Set("baz", 100)
+ m = m.Delete("baz")
+
+ v, ok := m.Get("foo")
+ fmt.Println("foo", v, ok)
+
+ v, ok = m.Get("baz")
+ fmt.Println("baz", v, ok)
+ // Output:
+ // foo bar true
+ // baz <nil> false
+}
+
+func ExampleSortedMap_Iterator() {
+ m := NewSortedMap(nil)
+ m = m.Set("strawberry", 900)
+ m = m.Set("kiwi", 300)
+ m = m.Set("apple", 100)
+ m = m.Set("pear", 700)
+ m = m.Set("pineapple", 800)
+ m = m.Set("peach", 600)
+ m = m.Set("orange", 500)
+ m = m.Set("grape", 200)
+ m = m.Set("mango", 400)
+
+ itr := m.Iterator()
+ for !itr.Done() {
+ k, v := itr.Next()
+ fmt.Println(k, v)
+ }
+ // Output:
+ // apple 100
+ // grape 200
+ // kiwi 300
+ // mango 400
+ // orange 500
+ // peach 600
+ // pear 700
+ // pineapple 800
+ // strawberry 900
+}
+
+// RunRandom executes fn multiple times with a different rand.
+func RunRandom(t *testing.T, name string, fn func(t *testing.T, rand *rand.Rand)) {
+ if testing.Short() {
+ t.Skip("short mode")
+ }
+ t.Run(name, func(t *testing.T) {
+ for i := 0; i < *randomN; i++ {
+ t.Run(fmt.Sprintf("%08d", i), func(t *testing.T) {
+ t.Parallel()
+ fn(t, rand.New(rand.NewSource(int64(i))))
+ })
+ }
+ })
+}
+
+func uniqueIntSlice(a []int) []int {
+ m := make(map[int]struct{})
+ other := make([]int, 0, len(a))
+ for _, v := range a {
+ if _, ok := m[v]; ok {
+ continue
+ }
+ m[v] = struct{}{}
+ other = append(other, v)
+ }
+ return other
+}
+
+// mockHasher represents a mock implementation of immutable.Hasher.
+type mockHasher struct {
+ hash func(value interface{}) uint32
+ equal func(a, b interface{}) bool
+}
+
+// Hash executes the mocked HashFn function.
+func (h *mockHasher) Hash(value interface{}) uint32 {
+ return h.hash(value)
+}
+
+// Equal executes the mocked EqualFn function.
+func (h *mockHasher) Equal(a, b interface{}) bool {
+ return h.equal(a, b)
+}
+
+// mockComparer represents a mock implementation of immutable.Comparer.
+type mockComparer struct {
+ compare func(a, b interface{}) int
+}
+
+// Compare executes the mocked CompreFn function.
+func (h *mockComparer) Compare(a, b interface{}) int {
+ return h.compare(a, b)
+}