diff options
-rw-r--r-- | LICENSE | 19 | ||||
-rw-r--r-- | README.md | 257 | ||||
-rw-r--r-- | go.mod | 5 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | immutable.go | 1864 | ||||
-rw-r--r-- | immutable_test.go | 2038 |
6 files changed, 4185 insertions, 0 deletions
@@ -0,0 +1,19 @@ +Copyright 2019 Ben Johnson + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..71c33d7 --- /dev/null +++ b/README.md @@ -0,0 +1,257 @@ +Immutable   +========= + +This repository contains immutable collection types for Go. It includes +`List`, `Map`, and `SortedMap` implementations. Immutable collections can +provide efficient, lock free sharing of data by requiring that edits to the +collections return new collections. + +The collection types in this library are meant to mimic Go built-in collections +such as`slice` and `map`. The primary usage difference between Go collections +and `immutable` collections is that `immutable` collections always return a new +collection on mutation so you will need to save the new reference. + +Immutable collections are not for every situation, however, as they can incur +additional CPU and memory overhead. Please evaluate the cost/benefit for your +particular project. + +Special thanks to the [Immutable.js](https://immutable-js.github.io/immutable-js/) +team as the `List` & `Map` implementations are loose ports from that project. + + +## List + +The `List` type represents a sorted, indexed collection of values and operates +similarly to a Go slice. It supports efficient append, prepend, update, and +slice operations. + + +### Adding list elements + +Elements can be added to the end of the list with the `Append()` method or added +to the beginning of the list with the `Prepend()` method. Unlike Go slices, +prepending is as efficient as appending. + +```go +// Create a list with 3 elements. +l := immutable.NewList() +l = l.Append("foo") +l = l.Append("bar") +l = l.Prepend("baz") + +fmt.Println(l.Len()) // 3 +fmt.Println(l.Get(0)) // "baz" +fmt.Println(l.Get(1)) // "foo" +fmt.Println(l.Get(2)) // "bar" +``` + +Note that each change to the list results in a new list being created. These +lists are all snapshots at that point in time and cannot be changed so they +are safe to share between multiple goroutines. + +### Updating list elements + +You can also overwrite existing elements by using the `Set()` method. In the +following example, we'll update the third element in our list and return the +new list to a new variable. You can see that our old `l` variable retains a +snapshot of the original value. + +```go +l := immutable.NewList() +l = l.Append("foo") +l = l.Append("bar") +newList := l.Set(2, "baz") + +fmt.Println(l.Get(1)) // "bar" +fmt.Println(newList.Get(1)) // "baz" +``` + +### Deriving sublists + +You can create a sublist by using the `Slice()` method. This method works with +the same rules as subslicing a Go slice: + +```go +l = l.Slice(0, 2) + +fmt.Println(l.Len()) // 2 +fmt.Println(l.Get(0)) // "baz" +fmt.Println(l.Get(1)) // "foo" +``` + +Please note that since `List` follows the same rules as slices, it will panic if +you try to `Get()`, `Set()`, or `Slice()` with indexes that are outside of +the range of the `List`. + + + +### Iterating lists + +Iterators provide a clean, simple way to iterate over the elements of the list +in order. This is more efficient than simply calling `Get()` for each index. + +Below is an example of iterating over all elements of our list from above: + +```go +itr := l.Iterator() +for !itr.Done() { + index, value := itr.Next() + fmt.Printf("Index %d equals %v\n", index, value) +} + +// Index 0 equals baz +// Index 1 equals foo +``` + +By default iterators start from index zero, however, the `Seek()` method can be +used to jump to a given index. + + + +## Map + +The `Map` represents an associative array that maps unique keys to values. It +is implemented to act similarly to the built-in Go `map` type. It is implemented +as a [Hash-Array Mapped Trie](https://lampwww.epfl.ch/papers/idealhashtrees.pdf). + +Maps require a `Hasher` to hash keys and check for equality. There are built-in +hasher implementations for `int`, `string`, and `[]byte` keys. You may pass in +a `nil` hasher to `NewMap()` if you are using one of these key types. + + +### Setting map key/value pairs + +You can add a key/value pair to the map by using the `Set()` method. It will +add the key if it does not exist or it will overwrite the value for the key if +it does exist. + +Values may be fetched for a key using the `Get()` method. This method returns +the value as well as a flag indicating if the key existed. The flag is useful +to check if a `nil` value was set for a key versus a key did not exist. + +```go +m := immutable.NewMap(nil) +m = m.Set("jane", 100) +m = m.Set("susy", 200) +m = m.Set("jane", 300) // overwrite + +fmt.Println(m.Len()) // 2 + +v, ok := m.Get("jane") +fmt.Println(v, ok) // 300 true + +v, ok = m.Get("susy") +fmt.Println(v, ok) // 200, true + +v, ok = m.Get("john") +fmt.Println(v, ok) // nil, false +``` + + +### Removing map keys + +Keys may be removed from the map by using the `Delete()` method. If the key does +not exist then the original map is returned instead of a new one. + +```go +m := immutable.NewMap(nil) +m = m.Set("jane", 100) +m = m.Delete("jane") + +fmt.Println(m.Len()) // 0 + +v, ok := m.Get("jane") +fmt.Println(v, ok) // nil false +``` + + +### Iterating maps + +Maps are unsorted, however, iterators can be used to loop over all key/value +pairs in the collection. Unlike Go maps, iterators are deterministic when +iterating over key/value pairs. + +```go +m := immutable.NewMap(nil) +m = m.Set("jane", 100) +m = m.Set("susy", 200) + +itr := m.Iterator() +for !itr.Done() { + k, v := itr.Next() + fmt.Println(k, v) +} + +// susy 200 +// jane 100 +``` + +Note that you should not rely on two maps with the same key/value pairs to +iterate in the same order. Ordering can be insertion order dependent when two +keys generate the same hash. + + +### Implementing a custom Hasher + +If you need to use a key type besides `int`, `string`, or `[]byte` then you'll +need to create a custom `Hasher` implementation and pass it to `NewMap()` on +creation. + +Hashers are fairly simple. They only need to generate hashes for a given key +and check equality given two keys. + +```go +type Hasher interface { + Hash(key interface{}) uint32 + Equal(a, b interface{}) bool +} +``` + +Please see the `IntHasher`, `StringHasher`, or `ByteSliceHasher` for examples. + + +## Sorted Map + +The `SortedMap` represents an associative array that maps unique keys to values. +Unlike the `Map`, however, keys can be iterated over in-order. It is implemented +as a B+tree. + +Sorted maps require a `Comparer` to sort keys and check for equality. There are +built-in comparer implementations for `int`, `string`, and `[]byte` keys. You +may pass a `nil` comparer to `NewSortedMap()` if you are using one of these key +types. + +The API is identical to the `Map` implementation. + + +### Implementing a custom Comparer + +If you need to use a key type besides `int`, `string`, or `[]byte` then you'll +need to create a custom `Comparer` implementation and pass it to +`NewSortedMap()` on creation. + +Comparers on have one method—`Compare()`. It works the same as the +`strings.Compare()` function. It returns `-1` if `a` is less than `b`, returns +`1` if a is greater than `b`, and returns `0` if `a` is equal to `b`. + +```go +type Comparer interface { + Compare(a, b interface{}) int +} +``` + +Please see the `IntComparer`, `StringComparer`, or `ByteSliceComparer` for examples. + + + +## Contributing + +The goal of `immutable` is to provide stable, reasonably performant, immutable +collections library for Go that has a simple, idiomatic API. As such, additional +features and minor performance improvements will generally not be accepted. If +you have a suggestion for a clearer API or substantial performance improvement, +_please_ open an issue first to discuss. All pull requests without a related +issue will be closed immediately. + +Please submit issues relating to bugs & documentation improvements. + @@ -0,0 +1,5 @@ +module github.com/benbjohnson/immutable + +go 1.12 + +require github.com/google/go-cmp v0.2.0 @@ -0,0 +1,2 @@ +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= diff --git a/immutable.go b/immutable.go new file mode 100644 index 0000000..6f5e6f6 --- /dev/null +++ b/immutable.go @@ -0,0 +1,1864 @@ +// Package immutable provides immutable collection types. +// +// Introduction +// +// Immutable collections provide an efficient, safe way to share collections +// of data while minimizing locks. The collections in this package provide +// List, Map, and SortedMap implementations. These act similarly to slices +// and maps, respectively, except that altering a collection returns a new +// copy of the collection with that change. +// +// Because collections are unable to change, they are safe for multiple +// goroutines to read from at the same time without a mutex. However, these +// types of collections come with increased CPU & memory usage as compared +// with Go's built-in collection types so please evaluate for your specific +// use. +// +// Collection Types +// +// The List type provides an API similar to Go slices. They allow appending, +// prepending, and updating of elements. Elements can also be fetched by index +// or iterated over using a ListIterator. +// +// The Map & SortedMap types provide an API similar to Go maps. They allow +// values to be assigned to unique keys and allow for the deletion of keys. +// Values can be fetched by key and key/value pairs can be iterated over using +// the appropriate iterator type. Both map types provide the same API. The +// SortedMap, however, provides iteration over sorted keys while the Map +// provides iteration over unsorted keys. Maps improved performance and memory +// usage as compared to SortedMaps. +// +// Hashing and Sorting +// +// Map types require the use of a Hasher implementation to calculate hashes for +// their keys and check for key equality. SortedMaps require the use of a +// Comparer implementation to sort keys in the map. +// +// These collection types automatically provide built-in hasher and comparers +// for int, string, and byte slice keys. If you are using one of these key types +// then simply pass a nil into the constructor. Otherwise you will need to +// implement a custom Hasher or Comparer type. Please see the provided +// implementations for reference. +package immutable + +import ( + "bytes" + "fmt" + "math/bits" + "sort" + "strings" +) + +// Lists are dense, ordered, indexed collections. They are analogous to slices +// in Go. They can be updated by appending to the end of the list, prepending +// values to the beginning of the list, or updating existing indexes in the +// list. +type List struct { + root listNode // root node + origin int // offset to zero index element + size int // total number of elements in use +} + +// NewList returns a new empty instance of List. +func NewList() *List { + return &List{ + root: &listLeafNode{}, + } +} + +// Len returns the number of elements in the list. +func (l *List) Len() int { + return l.size +} + +// cap returns the total number of possible elements for the current depth. +func (l *List) cap() int { + return 1 << (l.root.depth() * listNodeBits) +} + +// Get returns the value at the given index. Similar to slices, this method will +// panic if index is below zero or is greater than or equal to the list size. +func (l *List) Get(index int) interface{} { + if index < 0 || index >= l.size { + panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index)) + } + return l.root.get(l.origin + index) +} + +// Set returns a new list with value set at index. Similar to slices, this +// method will panic if index is below zero or if the index is greater than +// or equal to the list size. +func (l *List) Set(index int, value interface{}) *List { + if index < 0 || index >= l.size { + panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index)) + } + other := *l + other.root = other.root.set(l.origin+index, value) + return &other +} + +// Append returns a new list with value added to the end of the list. +func (l *List) Append(value interface{}) *List { + // Expand list to the right if no slots remain. + other := *l + if other.size+other.origin >= l.cap() { + newRoot := &listBranchNode{d: other.root.depth() + 1} + newRoot.children[0] = other.root + other.root = newRoot + } + + // Increase size and set the last element to the new value. + other.size++ + other.root = other.root.set(other.origin+other.size-1, value) + return &other +} + +// Prepend returns a new list with value added to the beginning of the list. +func (l *List) Prepend(value interface{}) *List { + // Expand list to the left if no slots remain. + other := *l + if other.origin == 0 { + newRoot := &listBranchNode{d: other.root.depth() + 1} + newRoot.children[listNodeSize-1] = other.root + other.root = newRoot + other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits) + } + + // Increase size and move origin back. Update first element to value. + other.size++ + other.origin-- + other.root = other.root.set(other.origin, value) + return &other +} + +// Slice returns a new list of elements between start index and end index. +// Similar to slices, this method will panic if start or end are below zero or +// greater than the list size. A panic will also occur if start is greater than +// end. +// +// Unlike Go slices, references to inaccessible elements will be automatically +// removed so they can be garbage collected. +func (l *List) Slice(start, end int) *List { + // Panics similar to Go slices. + if start < 0 || start > l.size { + panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start)) + } else if end < 0 || end > l.size { + panic(fmt.Sprintf("immutable.List.Slice: end index %d out of bounds", end)) + } else if start > end { + panic(fmt.Sprintf("immutable.List.Slice: invalid slice index: [%d:%d]", start, end)) + } + + // Return the same list if the start and end are the entire range. + if start == 0 && end == l.size { + return l + } + + // Create copy with new origin/size. + other := *l + other.origin = l.origin + start + other.size = end - start + + // Contract tree while the start & end are in the same child node. + for other.root.depth() > 1 { + i := (other.origin >> (other.root.depth() * listNodeBits)) & listNodeMask + j := ((other.origin + other.size - 1) >> (other.root.depth() * listNodeBits)) & listNodeMask + if i != j { + break // branch contains at least two nodes, exit + } + + // Replace the current root with the single child & update origin offset. + other.origin -= i << (other.root.depth() * listNodeBits) + other.root = other.root.(*listBranchNode).children[i] + } + + // Ensure all references are removed before start & after end. + other.root = other.root.deleteBefore(other.origin) + other.root = other.root.deleteAfter(other.origin + other.size - 1) + + return &other +} + +// Iterator returns a new iterator for this list positioned at the first index. +func (l *List) Iterator() *ListIterator { + itr := &ListIterator{list: l} + itr.First() + return itr +} + +// Constants for bit shifts used for levels in the List trie. +const ( + listNodeBits = 5 + listNodeSize = 1 << listNodeBits + listNodeMask = listNodeSize - 1 +) + +// listNode represents either a branch or leaf node in a List. +type listNode interface { + depth() uint + get(index int) interface{} + set(index int, v interface{}) listNode + + containsBefore(index int) bool + containsAfter(index int) bool + + deleteBefore(index int) listNode + deleteAfter(index int) listNode +} + +// newListNode returns a leaf node for depth zero, otherwise returns a branch node. +func newListNode(depth uint) listNode { + if depth == 0 { + return &listLeafNode{} + } + return &listBranchNode{d: depth} +} + +// listBranchNode represents a branch of a List tree at a given depth. +type listBranchNode struct { + d uint // depth + children [listNodeSize]listNode +} + +// depth returns the depth of this branch node from the leaf. +func (n *listBranchNode) depth() uint { return n.d } + +// get returns the child node at the segment of the index for this depth. +func (n *listBranchNode) get(index int) interface{} { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + return n.children[idx].get(index) +} + +// set recursively updates the value at index for each lower depth from the node. +func (n *listBranchNode) set(index int, v interface{}) listNode { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Find child for the given value in the branch. Create new if it doesn't exist. + child := n.children[idx] + if child == nil { + child = newListNode(n.depth() - 1) + } + + // Return a copy of this branch with the new child. + other := *n + other.children[idx] = child.set(index, v) + return &other +} + +// containsBefore returns true if non-nil values exists between [0,index). +func (n *listBranchNode) containsBefore(index int) bool { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Quickly check if any direct children exist before this segment of the index. + for i := 0; i < idx; i++ { + if n.children[i] != nil { + return true + } + } + + // Recursively check for children directly at the given index at this segment. + if n.children[idx] != nil && n.children[idx].containsBefore(index) { + return true + } + return false +} + +// containsAfter returns true if non-nil values exists between (index,listNodeSize). +func (n *listBranchNode) containsAfter(index int) bool { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Quickly check if any direct children exist after this segment of the index. + for i := idx + 1; i < len(n.children); i++ { + if n.children[i] != nil { + return true + } + } + + // Recursively check for children directly at the given index at this segment. + if n.children[idx] != nil && n.children[idx].containsAfter(index) { + return true + } + return false +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listBranchNode) deleteBefore(index int) listNode { + // Ignore if no nodes exist before the given index. + if !n.containsBefore(index) { + return n + } + + // Return a copy with any nodes prior to the index removed. + idx := (index >> (n.d * listNodeBits)) & listNodeMask + other := &listBranchNode{d: n.d} + copy(other.children[idx:][:], n.children[idx:][:]) + if other.children[idx] != nil { + other.children[idx] = other.children[idx].deleteBefore(index) + } + return other +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listBranchNode) deleteAfter(index int) listNode { + // Ignore if no nodes exist after the given index. + if !n.containsAfter(index) { + return n + } + + // Return a copy with any nodes after the index removed. + idx := (index >> (n.d * listNodeBits)) & listNodeMask + other := &listBranchNode{d: n.d} + copy(other.children[:idx+1], n.children[:idx+1]) + if other.children[idx] != nil { + other.children[idx] = other.children[idx].deleteAfter(index) + } + return other +} + +// listLeafNode represents a leaf node in a List. +type listLeafNode struct { + children [listNodeSize]interface{} +} + +// depth always returns 0 for leaf nodes. +func (n *listLeafNode) depth() uint { return 0 } + +// get returns the value at the given index. +func (n *listLeafNode) get(index int) interface{} { + return n.children[index&listNodeMask] +} + +// set returns a copy of the node with the value at the index updated to v. +func (n *listLeafNode) set(index int, v interface{}) listNode { + idx := index & listNodeMask + other := *n + other.children[idx] = v + return &other +} + +// containsBefore returns true if non-nil values exists between [0,index). +func (n *listLeafNode) containsBefore(index int) bool { + idx := index & listNodeMask + for i := 0; i < idx; i++ { + if n.children[i] != nil { + return true + } + } + return false +} + +// containsAfter returns true if non-nil values exists between (index,listNodeSize). +func (n *listLeafNode) containsAfter(index int) bool { + idx := index & listNodeMask + for i := idx + 1; i < len(n.children); i++ { + if n.children[i] != nil { + return true + } + } + return false +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listLeafNode) deleteBefore(index int) listNode { + if !n.containsBefore(index) { + return n + } + + idx := index & listNodeMask + var other listLeafNode + copy(other.children[idx:][:], n.children[idx:][:]) + return &other +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listLeafNode) deleteAfter(index int) listNode { + if !n.containsAfter(index) { + return n + } + + idx := index & listNodeMask + var other listLeafNode + copy(other.children[:idx+1][:], n.children[:idx+1][:]) + return &other +} + +// ListIterator represents an ordered iterator over a list. +type ListIterator struct { + list *List // source list + index int // current index position + + stack [32]listIteratorElem // search stack + depth int // stack depth +} + +// Done returns true if no more elements remain in the iterator. +func (itr *ListIterator) Done() bool { + return itr.index < 0 || itr.index >= itr.list.Len() +} + +// First positions the iterator on the first index. +// If source list is empty then no change is made. +func (itr *ListIterator) First() { + if itr.list.Len() != 0 { + itr.Seek(0) + } +} + +// Last positions the iterator on the last index. +// If source list is empty then no change is made. +func (itr *ListIterator) Last() { + if n := itr.list.Len(); n != 0 { + itr.Seek(n - 1) + } +} + +// Seek moves the iterator position to the given index in the list. +// Similar to Go slices, this method will panic if index is below zero or if +// the index is greater than or equal to the list size. +func (itr *ListIterator) Seek(index int) { + // Panic similar to Go slices. + if index < 0 || index >= itr.list.Len() { + panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index)) + } + itr.index = index + + // Reset to the bottom of the stack at seek to the correct position. + itr.stack[0] = listIteratorElem{node: itr.list.root} + itr.depth = 0 + itr.seek(index) +} + +// Next returns the current index and its value & moves the iterator forward. +// Returns an index of -1 if the there are no more elements to return. +func (itr *ListIterator) Next() (index int, value interface{}) { + // Exit immediately if there are no elements remaining. + if itr.Done() { + return -1, nil + } + + // Retrieve current index & value. + elem := &itr.stack[itr.depth] + index, value = itr.index, elem.node.(*listLeafNode).children[elem.index] + + // Increase index. If index is at the end then return immediately. + itr.index++ + if itr.Done() { + return index, value + } + + // Move up stack until we find a node that has remaining position ahead. + for ; itr.depth > 0 && itr.stack[itr.depth].index >= listNodeSize-1; itr.depth-- { + } + + // Seek to correct position from current depth. + itr.seek(itr.index) + + return index, value +} + +// Prev returns the current index and value and moves the iterator backward. +// Returns an index of -1 if the there are no more elements to return. +func (itr *ListIterator) Prev() (index int, value interface{}) { + // Exit immediately if there are no elements remaining. + if itr.Done() { + return -1, nil + } + + // Retrieve current index & value. + elem := &itr.stack[itr.depth] + index, value = itr.index, elem.node.(*listLeafNode).children[elem.index] + + // Decrease index. If index is past the beginning then return immediately. + itr.index-- + if itr.Done() { + return index, value + } + + // Move up stack until we find a node that has remaining position behind. + for ; itr.depth > 0 && itr.stack[itr.depth].index == 0; itr.depth-- { + } + + // Seek to correct position from current depth. + itr.seek(itr.index) + + return index, value +} + +// seek positions the stack to the given index from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *ListIterator) seek(index int) { + // Iterate over each level until we reach a leaf node. + for { + elem := &itr.stack[itr.depth] + elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask + + switch node := elem.node.(type) { + case *listBranchNode: + child := node.children[elem.index] + itr.stack[itr.depth+1] = listIteratorElem{node: child} + itr.depth++ + case *listLeafNode: + return + } + } +} + +// listIteratorElem represents the node and it's child index within the stack. +type listIteratorElem struct { + node listNode + index int +} + +// Size thresholds for each type of branch node. +const ( + maxArrayMapSize = 8 + maxBitmapIndexedSize = 16 +) + +// Segment bit shifts within the map tree. +const ( + mapNodeBits = 5 + mapNodeSize = 1 << mapNodeBits + mapNodeMask = mapNodeSize - 1 +) + +// Map represents an immutable hash map implementation. The map uses a Hasher +// to generate hashes and check for equality of key values. +// +// It is implemented as an Hash Array Mapped Trie. +type Map struct { + size int // total number of key/value pairs + root mapNode // root node of trie + hasher Hasher // hasher implementation +} + +// NewMap returns a new instance of Map. If hasher is nil, a default hasher +// implementation will automatically be chosen based on the first key added. +// Default hasher implementations only exist for int, string, and byte slice types. +func NewMap(hasher Hasher) *Map { + return &Map{ + hasher: hasher, + } +} + +// Len returns the number of elements in the map. +func (m *Map) Len() int { + return m.size +} + +// Get returns the value for a given key and a flag indicating whether the +// key exists. This flag distinguishes a nil value set on a key versus a +// non-existent key in the map. +func (m *Map) Get(key interface{}) (value interface{}, ok bool) { + if m.root == nil { + return nil, false + } + keyHash := m.hasher.Hash(key) + return m.root.get(key, 0, keyHash, m.hasher) +} + +// Set returns a map with the key set to the new value. A nil value is allowed. +// +// This function will return a new map even if the updated value is the same as +// the existing value because Map does not track value equality. +func (m *Map) Set(key, value interface{}) *Map { + // Set a hasher on the first value if one does not already exist. + hasher := m.hasher + if hasher == nil { + switch key.(type) { + case int: + hasher = &intHasher{} + case string: + hasher = &stringHasher{} + case []byte: + hasher = &byteSliceHasher{} + default: + panic(fmt.Sprintf("immutable.Map.Set: must set hasher for %T type", key)) + } + } + + // If the map is empty, initialize with a simple array node. + if m.root == nil { + return &Map{ + size: 1, + root: &mapArrayNode{entries: []mapEntry{{key: key, value: value}}}, + hasher: hasher, + } + } + + // Otherwise copy the map and delegate insertion to the root. + // Resized will return true if the key does not currently exist. + var resized bool + other := &Map{ + size: m.size, + root: m.root.set(key, value, 0, hasher.Hash(key), hasher, &resized), + hasher: hasher, + } + if resized { + other.size++ + } + return other +} + +// Delete returns a map with the given key removed. +// Removing a non-existent key will cause this method to return the same map. +func (m *Map) Delete(key interface{}) *Map { + // Return original map if no keys exist. + if m.root == nil { + return m + } + + // If the delete did not change the node then return the original map. + newRoot := m.root.delete(key, 0, m.hasher.Hash(key), m.hasher) + if newRoot == m.root { + return m + } + + // Return copy of map with new root and decreased size. + return &Map{ + size: m.size - 1, + root: newRoot, + hasher: m.hasher, + } +} + +// Iterator returns a new iterator for the map. +func (m *Map) Iterator() *MapIterator { + itr := &MapIterator{m: m} + itr.First() + return itr +} + +// mapNode represents any node in the map tree. +type mapNode interface { + get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) + set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode + delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode +} + +var _ mapNode = (*mapArrayNode)(nil) +var _ mapNode = (*mapBitmapIndexedNode)(nil) +var _ mapNode = (*mapHashArrayNode)(nil) +var _ mapNode = (*mapValueNode)(nil) +var _ mapNode = (*mapHashCollisionNode)(nil) + +// mapLeafNode represents a node that stores a single key hash at the leaf of the map tree. +type mapLeafNode interface { + mapNode + keyHashValue() uint32 +} + +var _ mapLeafNode = (*mapValueNode)(nil) +var _ mapLeafNode = (*mapHashCollisionNode)(nil) + +// mapArrayNode is a map node that stores key/value pairs in a slice. +// Entries are stored in insertion order. An array node expands into a bitmap +// indexed node once a given threshold size is crossed. +type mapArrayNode struct { + entries []mapEntry +} + +// indexOf returns the entry index of the given key. Returns -1 if key not found. +func (n *mapArrayNode) indexOf(key interface{}, h Hasher) int { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return i + } + } + return -1 +} + +// get returns the value for the given key. +func (n *mapArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { + i := n.indexOf(key, h) + if i == -1 { + return nil, false + } + return n.entries[i].value, true +} + +// set inserts or updates the value for a given key. If the key is inserted and +// the new size crosses the max size threshold, a bitmap indexed node is returned. +func (n *mapArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode { + idx := n.indexOf(key, h) + + // Mark as resized if the key doesn't exist. + if idx == -1 { + *resized = true + } + + // If we are adding and it crosses the max size threshold, expand the node. + // We do this by continually setting the entries to a value node and expanding. + if idx == -1 && len(n.entries) >= maxArrayMapSize { + var node mapNode = newMapValueNode(h.Hash(key), key, value) + for _, entry := range n.entries { + node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, resized) + } + return node + } + + // Update existing entry if a match is found. + // Otherwise append to the end of the element list if it doesn't exist. + var other mapArrayNode + if idx != -1 { + other.entries = make([]mapEntry, len(n.entries)) + copy(other.entries, n.entries) + other.entries[idx] = mapEntry{key, value} + } else { + other.entries = make([]mapEntry, len(n.entries)+1) + copy(other.entries, n.entries) + other.entries[len(other.entries)-1] = mapEntry{key, value} + } + return &other +} + +// delete removes the given key from the node. Returns the same node if key does +// not exist. Returns a nil node when removing the last entry. +func (n *mapArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode { + idx := n.indexOf(key, h) + + // Return original node if key does not exist. + if idx == -1 { + return n + } + + // Return nil if this node will contain no nodes. + if len(n.entries) == 1 { + return nil + } + + // Otherwise create a copy with the given entry removed. + other := &mapArrayNode{entries: make([]mapEntry, len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// mapBitmapIndexedNode represents a map branch node with a variable number of +// node slots and indexed using a bitmap. Indexes for the node slots are +// calculated by counting the number of set bits before the target bit using popcount. +type mapBitmapIndexedNode struct { + bitmap uint32 + nodes []mapNode +} + +// get returns the value for the given key. +func (n *mapBitmapIndexedNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { + bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) + if (n.bitmap & bit) == 0 { + return nil, false + } + child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))] + return child.get(key, shift+mapNodeBits, keyHash, h) +} + +// set inserts or updates the value for the given key. If a new key is inserted +// and the size crosses the max size threshold then a hash array node is returned. +func (n *mapBitmapIndexedNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode { + // Extract the index for the bit segment of the key hash. + keyHashFrag := (keyHash >> shift) & mapNodeMask + + // Determine the bit based on the hash index. + bit := uint32(1) << keyHashFrag + exists := (n.bitmap & bit) != 0 + + // Mark as resized if the key doesn't exist. + if !exists { + *resized = true + } + + // Find index of node based on popcount of bits before it. + idx := bits.OnesCount32(n.bitmap & (bit - 1)) + + // If the node already exists, delegate set operation to it. + // If the node doesn't exist then create a simple value leaf node. + var newNode mapNode + if exists { + newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, resized) + } else { + newNode = newMapValueNode(keyHash, key, value) + } + + // Convert to a hash-array node once we exceed the max bitmap size. + // Copy each node based on their bit position within the bitmap. + if !exists && len(n.nodes) > maxBitmapIndexedSize { + var other mapHashArrayNode + for i := uint(0); i < uint(len(other.nodes)); i++ { + if n.bitmap&(uint32(1)<<i) != 0 { + other.nodes[i] = n.nodes[other.count] + other.count++ + } + } + other.nodes[keyHashFrag] = newNode + other.count++ + return &other + } + + // If node exists at given slot then overwrite it with new node. + // Otherwise expand the node list and insert new node into appropriate position. + other := &mapBitmapIndexedNode{bitmap: n.bitmap | bit} + if exists { + other.nodes = make([]mapNode, len(n.nodes)) + copy(other.nodes, n.nodes) + other.nodes[idx] = newNode + } else { + other.nodes = make([]mapNode, len(n.nodes)+1) + copy(other.nodes, n.nodes[:idx]) + other.nodes[idx] = newNode + copy(other.nodes[idx+1:], n.nodes[idx:]) + } + return other +} + +// delete removes the key from the tree. If the key does not exist then the +// original node is returned. If removing the last child node then a nil is +// returned. Note that shrinking the node will not convert it to an array node. +func (n *mapBitmapIndexedNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode { + bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) + + // Return original node if key does not exist. + if (n.bitmap & bit) == 0 { + return n + } + + // Find index of node based on popcount of bits before it. + idx := bits.OnesCount32(n.bitmap & (bit - 1)) + + // Delegate delete to child node. + child := n.nodes[idx] + newChild := child.delete(key, shift+mapNodeBits, keyHash, h) + + // Return original node if key doesn't exist in child. + if newChild == child { + return n + } + + // Remove if returned child has been deleted. + if newChild == nil { + // If we won't have any children then return nil. + if len(n.nodes) == 1 { + return nil + } + + // Return copy with bit removed from bitmap and node removed from node list. + other := &mapBitmapIndexedNode{bitmap: n.bitmap ^ bit, nodes: make([]mapNode, len(n.nodes)-1)} + copy(other.nodes[:idx], n.nodes[:idx]) + copy(other.nodes[idx:], n.nodes[idx+1:]) + return other + } + + // Return copy with child updated. + other := &mapBitmapIndexedNode{bitmap: n.bitmap, nodes: make([]mapNode, len(n.nodes))} + copy(other.nodes, n.nodes) + other.nodes[idx] = newChild + return other +} + +// mapHashArrayNode is a map branch node that stores nodes in a fixed length +// array. Child nodes are indexed by their index bit segment for the current depth. +type mapHashArrayNode struct { + count uint // number of set nodes + nodes [mapNodeSize]mapNode // child node slots, may contain empties +} + +// get returns the value for the given key. +func (n *mapHashArrayNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { + node := n.nodes[(keyHash>>shift)&mapNodeMask] + if node == nil { + return nil, false + } + return node.get(key, shift+mapNodeBits, keyHash, h) +} + +// set returns a node with the value set for the given key. +func (n *mapHashArrayNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode { + idx := (keyHash >> shift) & mapNodeMask + node := n.nodes[idx] + + // If node at index doesn't exist, create a simple value leaf node. + // Otherwise delegate set to child node. + var newNode mapNode + if node == nil { + *resized = true + newNode = newMapValueNode(keyHash, key, value) + } else { + newNode = node.set(key, value, shift+mapNodeBits, keyHash, h, resized) + } + + // Return a copy of node with updated child node (and updated size, if new). + other := *n + if node == nil { + other.count++ + } + other.nodes[idx] = newNode + return &other +} + +// delete returns a node with the given key removed. Returns the same node if +// the key does not exist. If node shrinks to within bitmap-indexed size then +// converts to a bitmap-indexed node. +func (n *mapHashArrayNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode { + idx := (keyHash >> shift) & mapNodeMask + node := n.nodes[idx] + + // Return original node if child is not found. + if node == nil { + return n + } + + // Return original node if child is unchanged. + newNode := node.delete(key, shift+mapNodeBits, keyHash, h) + if newNode == node { + return n + } + + // If we remove a node and drop below a threshold, convert back to bitmap indexed node. + if newNode == nil && n.count <= maxBitmapIndexedSize { + other := &mapBitmapIndexedNode{nodes: make([]mapNode, 0, n.count-1)} + for i, child := range n.nodes { + if child != nil && uint32(i) != idx { + other.bitmap |= 1 << uint(i) + other.nodes = append(other.nodes, child) + } + } + return other + } + + // Return copy of node with child updated. + other := *n + other.nodes[idx] = newNode + if newNode == nil { + other.count-- + } + return &other +} + +// mapValueNode represents a leaf node with a single key/value pair. +// A value node can be converted to a hash collision leaf node if a different +// key with the same keyHash is inserted. +type mapValueNode struct { + keyHash uint32 + key interface{} + value interface{} +} + +// newMapValueNode returns a new instance of mapValueNode. +func newMapValueNode(keyHash uint32, key, value interface{}) *mapValueNode { + return &mapValueNode{ + keyHash: keyHash, + key: key, + value: value, + } +} + +// keyHashValue returns the key hash for this node. +func (n *mapValueNode) keyHashValue() uint32 { + return n.keyHash +} + +// get returns the value for the given key. +func (n *mapValueNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { + if !h.Equal(n.key, key) { + return nil, false + } + return n.value, true +} + +// set returns a new node with the new value set for the key. If the key equals +// the node's key then a new value node is returned. If key is not equal to the +// node's key but has the same hash then a hash collision node is returned. +// Otherwise the nodes are merged into a branch node. +func (n *mapValueNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode { + // If the keys match then return a new value node overwriting the value. + if h.Equal(n.key, key) { + return newMapValueNode(n.keyHash, key, value) + } + + *resized = true + + // Recursively merge nodes together if key hashes are different. + if n.keyHash != keyHash { + return mergeIntoNode(n, shift, keyHash, key, value) + } + + // Merge into collision node if hash matches. + return &mapHashCollisionNode{keyHash: keyHash, entries: []mapEntry{ + {key: n.key, value: n.value}, + {key: key, value: value}, + }} +} + +// delete returns nil if the key matches the node's key. Otherwise returns the original node. +func (n *mapValueNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode { + // Return original node if the keys do not match. + if !h.Equal(n.key, key) { + return n + } + + // Otherwise remove the node if keys do match. + return nil +} + +// mapHashCollisionNode represents a leaf node that contains two or more key/value +// pairs with the same key hash. Single pairs for a hash are stored as value nodes. +type mapHashCollisionNode struct { + keyHash uint32 // key hash for all entries + entries []mapEntry +} + +// keyHashValue returns the key hash for all entries on the node. +func (n *mapHashCollisionNode) keyHashValue() uint32 { + return n.keyHash +} + +// indexOf returns the index of the entry for the given key. +// Returns -1 if the key does not exist in the node. +func (n *mapHashCollisionNode) indexOf(key interface{}, h Hasher) int { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return i + } + } + return -1 +} + +// get returns the value for the given key. +func (n *mapHashCollisionNode) get(key interface{}, shift uint, keyHash uint32, h Hasher) (value interface{}, ok bool) { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return n.entries[i].value, true + } + } + return nil, false +} + +// set returns a copy of the node with key set to the given value. +func (n *mapHashCollisionNode) set(key, value interface{}, shift uint, keyHash uint32, h Hasher, resized *bool) mapNode { + // Merge node with key/value pair if this is not a hash collision. + if n.keyHash != keyHash { + *resized = true + return mergeIntoNode(n, shift, keyHash, key, value) + } + + // Append to end of node if key doesn't exist & mark resized. + // Otherwise copy nodes and overwrite at matching key index. + other := &mapHashCollisionNode{keyHash: n.keyHash} + if idx := n.indexOf(key, h); idx == -1 { + *resized = true + other.entries = make([]mapEntry, len(n.entries)+1) + copy(other.entries, n.entries) + other.entries[len(other.entries)-1] = mapEntry{key, value} + } else { + other.entries = make([]mapEntry, len(n.entries)) + copy(other.entries, n.entries) + other.entries[idx] = mapEntry{key, value} + } + return other +} + +// delete returns a node with the given key deleted. Returns the same node if +// the key does not exist. If removing the key would shrink the node to a single +// entry then a value node is returned. +func (n *mapHashCollisionNode) delete(key interface{}, shift uint, keyHash uint32, h Hasher) mapNode { + idx := n.indexOf(key, h) + + // Return original node if key is not found. + if idx == -1 { + return n + } + + // Convert to value node if we move to one entry. + if len(n.entries) == 2 { + return &mapValueNode{ + keyHash: n.keyHash, + key: n.entries[idx^1].key, + value: n.entries[idx^1].value, + } + } + + // Otherwise return copy with entry removed. + other := &mapHashCollisionNode{keyHash: n.keyHash, entries: make([]mapEntry, len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// mergeIntoNode merges a key/value pair into an existing node. +// Caller must verify that node's keyHash is not equal to keyHash. +func mergeIntoNode(node mapLeafNode, shift uint, keyHash uint32, key, value interface{}) mapNode { + idx1 := (node.keyHashValue() >> shift) & mapNodeMask + idx2 := (keyHash >> shift) & mapNodeMask + + // Recursively build branch nodes to combine the node and its key. + other := &mapBitmapIndexedNode{bitmap: (1 << idx1) | (1 << idx2)} + if idx1 == idx2 { + other.nodes = []mapNode{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)} + } else { + if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 { + other.nodes = []mapNode{node, newNode} + } else { + other.nodes = []mapNode{newNode, node} + } + } + return other +} + +// mapEntry represents a single key/value pair. +type mapEntry struct { + key interface{} + value interface{} +} + +// MapIterator represents an iterator over a map's key/value pairs. Although +// map keys are not sorted, the iterator's order is deterministic. +type MapIterator struct { + m *Map // source map + + stack [32]mapIteratorElem // search stack + depth int // stack depth +} + +// Done returns true if no more elements remain in the iterator. +func (itr *MapIterator) Done() bool { + return itr.depth == -1 +} + +// First resets the iterator to the first key/value pair. +func (itr *MapIterator) First() { + // Exit immediately if the map is empty. + if itr.m.root == nil { + itr.depth = -1 + return + } + + // Initialize the stack to the left most element. + itr.stack[0] = mapIteratorElem{node: itr.m.root} + itr.depth = 0 + itr.first() +} + +// Next returns the next key/value pair. Returns a nil key when no elements remain. +func (itr *MapIterator) Next() (key, value interface{}) { + // Return nil key if iteration is done. + if itr.Done() { + return nil, nil + } + + // Retrieve current index & value. Current node is always a leaf. + elem := &itr.stack[itr.depth] + switch node := elem.node.(type) { + case *mapArrayNode: + entry := &node.entries[elem.index] + key, value = entry.key, entry.value + case *mapValueNode: + key, value = node.key, node.value + case *mapHashCollisionNode: + entry := &node.entries[elem.index] + key, value = entry.key, entry.value + } + + // Move up stack until we find a node that has remaining position ahead + // and move that element forward by one. + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *mapArrayNode: + if elem.index < len(node.entries)-1 { + elem.index++ + return key, value + } + + case *mapBitmapIndexedNode: + if elem.index < len(node.nodes)-1 { + elem.index++ + itr.stack[itr.depth+1].node = node.nodes[elem.index] + itr.depth++ + itr.first() + return key, value + } + + case *mapHashArrayNode: + for i := elem.index + 1; i < len(node.nodes); i++ { + if node.nodes[i] != nil { + elem.index = i + itr.stack[itr.depth+1].node = node.nodes[elem.index] + itr.depth++ + itr.first() + return key, value + } + } + + case *mapValueNode: + continue // always the last value, traverse up + + case *mapHashCollisionNode: + if elem.index < len(node.entries)-1 { + elem.index++ + return key, value + } + } + } + + // This only occurs if depth is -1. + return key, value +} + +// first positions the stack left most index. +// Elements and indexes at and below the current depth are assumed to be correct. +func (itr *MapIterator) first() { + for ; ; itr.depth++ { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *mapBitmapIndexedNode: + elem.index = 0 + itr.stack[itr.depth+1].node = node.nodes[0] + + case *mapHashArrayNode: + for i := 0; i < len(node.nodes); i++ { + if node.nodes[i] != nil { // find first node + elem.index = i + itr.stack[itr.depth+1].node = node.nodes[i] + break + } + } + + default: // *mapArrayNode, mapLeafNode + elem.index = 0 + return + } + } +} + +// mapIteratorElem represents a node/index pair in the MapIterator stack. +type mapIteratorElem struct { + node mapNode + index int +} + +// Sorted map child node limit size. +const ( + sortedMapNodeSize = 32 +) + +// SortedMap represents a map of key/value pairs sorted by key. The sort order +// is determined by the Comparer used by the map. +// +// This map is implemented as a B+tree. +type SortedMap struct { + size int // total number of key/value pairs + root sortedMapNode // root of b+tree + comparer Comparer +} + +// NewSortedMap returns a new instance of SortedMap. If comparer is nil then +// a default comparer is set after the first key is inserted. Default comparers +// exist for int, string, and byte slice keys. +func NewSortedMap(comparer Comparer) *SortedMap { + return &SortedMap{ + comparer: comparer, + } +} + +// Len returns the number of elements in the sorted map. +func (m *SortedMap) Len() int { + return m.size +} + +// Get returns the value for a given key and a flag indicating if the key is set. +// The flag can be used to distinguish between a nil-set key versus an unset key. +func (m *SortedMap) Get(key interface{}) (interface{}, bool) { + if m.root == nil { + return nil, false + } + return m.root.get(key, m.comparer) +} + +// Set returns a copy of the map with the key set to the given value. +func (m *SortedMap) Set(key, value interface{}) *SortedMap { + // Set a comparer on the first value if one does not already exist. + comparer := m.comparer + if comparer == nil { + switch key.(type) { + case int: + comparer = &intComparer{} + case string: + comparer = &stringComparer{} + case []byte: + comparer = &byteSliceComparer{} + default: + panic(fmt.Sprintf("immutable.SortedMap.Set: must set comparer for %T type", key)) + } + } + + // If no values are set then initialize with a leaf node. + if m.root == nil { + return &SortedMap{ + size: 1, + root: &sortedMapLeafNode{entries: []mapEntry{{key: key, value: value}}}, + comparer: comparer, + } + } + + // Otherwise delegate to root node. + // If a split occurs then grow the tree from the root. + var resized bool + newRoot, splitNode := m.root.set(key, value, comparer, &resized) + if splitNode != nil { + newRoot = newSortedMapBranchNode(newRoot, splitNode) + } + + // Return a new map with the new root. + other := &SortedMap{ + size: m.size, + root: newRoot, + comparer: comparer, + } + if resized { + other.size++ + } + return other +} + +// Delete returns a copy of the map with the key removed. +// Returns the original map if key does not exist. +func (m *SortedMap) Delete(key interface{}) *SortedMap { + // Return original map if no keys exist. + if m.root == nil { + return m + } + + // If the delete did not change the node then return the original map. + newRoot := m.root.delete(key, m.comparer) + if newRoot == m.root { + return m + } + + // Return new copy with the root and size updated. + return &SortedMap{ + size: m.size - 1, + root: newRoot, + comparer: m.comparer, + } +} + +// Iterator returns a new iterator for this map positioned at the first key. +func (m *SortedMap) Iterator() *SortedMapIterator { + itr := &SortedMapIterator{m: m} + itr.First() + return itr +} + +// sortedMapNode represents a branch or leaf node in the sorted map. +type sortedMapNode interface { + minKey() interface{} + indexOf(key interface{}, c Comparer) int + get(key interface{}, c Comparer) (value interface{}, ok bool) + set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode) + delete(key interface{}, c Comparer) sortedMapNode +} + +var _ sortedMapNode = (*sortedMapBranchNode)(nil) +var _ sortedMapNode = (*sortedMapLeafNode)(nil) + +// sortedMapBranchNode represents a branch in the sorted map. +type sortedMapBranchNode struct { + elems []sortedMapBranchElem +} + +// newSortedMapBranchNode returns a new branch node with the given child nodes. +func newSortedMapBranchNode(children ...sortedMapNode) *sortedMapBranchNode { + // Fetch min keys for every child. + elems := make([]sortedMapBranchElem, len(children)) + for i, child := range children { + elems[i] = sortedMapBranchElem{ + key: child.minKey(), + node: child, + } + } + + return &sortedMapBranchNode{elems: elems} +} + +// minKey returns the lowest key stored in this node's tree. +func (n *sortedMapBranchNode) minKey() interface{} { + return n.elems[0].node.minKey() +} + +// indexOf returns the index of the key within the child nodes. +func (n *sortedMapBranchNode) indexOf(key interface{}, c Comparer) int { + if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 { + return idx - 1 + } + return 0 +} + +// get returns the value for the given key. +func (n *sortedMapBranchNode) get(key interface{}, c Comparer) (value interface{}, ok bool) { + idx := n.indexOf(key, c) + return n.elems[idx].node.get(key, c) +} + +// set returns a copy of the node with the key set to the given value. +func (n *sortedMapBranchNode) set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode) { + idx := n.indexOf(key, c) + + // Delegate insert to child node. + newNode, splitNode := n.elems[idx].node.set(key, value, c, resized) + + // If no split occurs, copy branch and update keys. + // If the child splits, insert new key/child into copy of branch. + var other sortedMapBranchNode + if splitNode == nil { + other.elems = make([]sortedMapBranchElem, len(n.elems)) + copy(other.elems, n.elems) + other.elems[idx] = sortedMapBranchElem{ + key: newNode.minKey(), + node: newNode, + } + } else { + other.elems = make([]sortedMapBranchElem, len(n.elems)+1) + copy(other.elems[:idx], n.elems[:idx]) + copy(other.elems[idx+1:], n.elems[idx:]) + other.elems[idx] = sortedMapBranchElem{ + key: newNode.minKey(), + node: newNode, + } + other.elems[idx+1] = sortedMapBranchElem{ + key: splitNode.minKey(), + node: splitNode, + } + } + + // If the child splits and we have no more room then we split too. + if len(other.elems) > sortedMapNodeSize { + splitIdx := len(other.elems) / 2 + newNode := &sortedMapBranchNode{elems: other.elems[:splitIdx]} + splitNode := &sortedMapBranchNode{elems: other.elems[splitIdx:]} + return newNode, splitNode + } + + // Otherwise return the new branch node with the updated entry. + return &other, nil +} + +// delete returns a node with the key removed. Returns the same node if the key +// does not exist. Returns nil if all child nodes are removed. +func (n *sortedMapBranchNode) delete(key interface{}, c Comparer) sortedMapNode { + idx := n.indexOf(key, c) + + // Return original node if child has not changed. + newNode := n.elems[idx].node.delete(key, c) + if newNode == n.elems[idx].node { + return n + } + + // Remove child if it is now nil. + if newNode == nil { + // If this node will become empty then simply return nil. + if len(n.elems) == 1 { + return nil + } + + // Return a copy without the given node. + other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems)-1)} + copy(other.elems[:idx], n.elems[:idx]) + copy(other.elems[idx:], n.elems[idx+1:]) + return other + } + + // Return a copy with the updated node. + other := &sortedMapBranchNode{elems: make([]sortedMapBranchElem, len(n.elems))} + copy(other.elems, n.elems) + other.elems[idx] = sortedMapBranchElem{ + key: newNode.minKey(), + node: newNode, + } + return other +} + +type sortedMapBranchElem struct { + key interface{} + node sortedMapNode +} + +// sortedMapLeafNode represents a leaf node in the sorted map. +type sortedMapLeafNode struct { + entries []mapEntry +} + +// minKey returns the first key stored in this node. +func (n *sortedMapLeafNode) minKey() interface{} { + return n.entries[0].key +} + +// indexOf returns the index of the given key. +func (n *sortedMapLeafNode) indexOf(key interface{}, c Comparer) int { + return sort.Search(len(n.entries), func(i int) bool { + return c.Compare(n.entries[i].key, key) != -1 // GTE + }) +} + +// get returns the value of the given key. +func (n *sortedMapLeafNode) get(key interface{}, c Comparer) (value interface{}, ok bool) { + idx := n.indexOf(key, c) + + // If the index is beyond the entry count or the key is not equal then return 'not found'. + if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { + return nil, false + } + + // If the key matches then return its value. + return n.entries[idx].value, true +} + +// set returns a copy of node with the key set to the given value. If the update +// causes the node to grow beyond the maximum size then it is split in two. +func (n *sortedMapLeafNode) set(key, value interface{}, c Comparer, resized *bool) (sortedMapNode, sortedMapNode) { + // Find the insertion index for the key. + idx := n.indexOf(key, c) + + // If the key matches then simply return a copy with the entry overridden. + // If there is no match then insert new entry and mark as resized. + var newEntries []mapEntry + if idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0 { + newEntries = make([]mapEntry, len(n.entries)) + copy(newEntries, n.entries) + newEntries[idx] = mapEntry{key: key, value: value} + } else { + *resized = true + newEntries = make([]mapEntry, len(n.entries)+1) + copy(newEntries[:idx], n.entries[:idx]) + newEntries[idx] = mapEntry{key: key, value: value} + copy(newEntries[idx+1:], n.entries[idx:]) + } + + // If the key doesn't exist and we exceed our max allowed values then split. + if len(newEntries) > sortedMapNodeSize { + newNode := &sortedMapLeafNode{entries: newEntries[:len(newEntries)/2]} + splitNode := &sortedMapLeafNode{entries: newEntries[len(newEntries)/2:]} + return newNode, splitNode + } + + // Otherwise return the new leaf node with the updated entry. + return &sortedMapLeafNode{entries: newEntries}, nil +} + +// delete returns a copy of node with key removed. Returns the original node if +// the key does not exist. Returns nil if the removed key is the last remaining key. +func (n *sortedMapLeafNode) delete(key interface{}, c Comparer) sortedMapNode { + idx := n.indexOf(key, c) + + // Return original node if key is not found. + if idx >= len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { + return n + } + + // If this is the last entry then return nil. + if len(n.entries) == 1 { + return nil + } + + // Return copy of node with entry removed. + other := &sortedMapLeafNode{entries: make([]mapEntry, len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// SortedMapIterator represents an iterator over a sorted map. +// Iteration can occur in natural or reverse order based on use of Next() or Prev(). +type SortedMapIterator struct { + m *SortedMap // source map + + stack [32]sortedMapIteratorElem // search stack + depth int // stack depth +} + +// Done returns true if no more key/value pairs remain in the iterator. +func (itr *SortedMapIterator) Done() bool { + return itr.depth == -1 +} + +// First moves the iterator to the first key/value pair. +func (itr *SortedMapIterator) First() { + if itr.m.root != nil { + itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.depth = 0 + itr.first() + } +} + +// Last moves the iterator to the last key/value pair. +func (itr *SortedMapIterator) Last() { + if itr.m.root != nil { + itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.depth = 0 + itr.last() + } +} + +// Seek moves the iterator position to the given key in the map. +// If the key does not exist then the next key is used. If no more keys exist +// then the iteartor is marked as done. +func (itr *SortedMapIterator) Seek(key interface{}) { + if itr.m.root != nil { + itr.stack[0] = sortedMapIteratorElem{node: itr.m.root} + itr.depth = 0 + itr.seek(key) + } +} + +// Next returns the current key/value pair and moves the iterator forward. +// Returns a nil key if the there are no more elements to return. +func (itr *SortedMapIterator) Next() (key, value interface{}) { + // Return nil key if iteration is complete. + if itr.Done() { + return nil, nil + } + + // Retrieve current key/value pair. + leafElem := &itr.stack[itr.depth] + leafNode := leafElem.node.(*sortedMapLeafNode) + leafEntry := &leafNode.entries[leafElem.index] + key, value = leafEntry.key, leafEntry.value + + // Move to the next available key/value pair. + itr.next() + + // Only occurs when iterator is done. + return key, value +} + +// next moves to the next key. If no keys are after then depth is set to -1. +func (itr *SortedMapIterator) next() { + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapLeafNode: + if elem.index < len(node.entries)-1 { + elem.index++ + return + } + case *sortedMapBranchNode: + if elem.index < len(node.elems)-1 { + elem.index++ + itr.stack[itr.depth+1].node = node.elems[elem.index].node + itr.depth++ + itr.first() + return + } + } + } +} + +// Prev returns the current key/value pair and moves the iterator backward. +// Returns a nil key if the there are no more elements to return. +func (itr *SortedMapIterator) Prev() (key, value interface{}) { + // Return nil key if iteration is complete. + if itr.Done() { + return nil, nil + } + + // Retrieve current key/value pair. + leafElem := &itr.stack[itr.depth] + leafNode := leafElem.node.(*sortedMapLeafNode) + leafEntry := &leafNode.entries[leafElem.index] + key, value = leafEntry.key, leafEntry.value + + itr.prev() + return key, value +} + +// prev moves to the previous key. If no keys are before then depth is set to -1. +func (itr *SortedMapIterator) prev() { + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapLeafNode: + if elem.index > 0 { + elem.index-- + return + } + case *sortedMapBranchNode: + if elem.index > 0 { + elem.index-- + itr.stack[itr.depth+1].node = node.elems[elem.index].node + itr.depth++ + itr.last() + return + } + } + } +} + +// first positions the stack to the leftmost key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator) first() { + for { + elem := &itr.stack[itr.depth] + elem.index = 0 + + switch node := elem.node.(type) { + case *sortedMapBranchNode: + itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode: + return + } + } +} + +// last positions the stack to the rightmost key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator) last() { + for { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapBranchNode: + elem.index = len(node.elems) - 1 + itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode: + elem.index = len(node.entries) - 1 + return + } + } +} + +// seek positions the stack to the given key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator) seek(key interface{}) { + for { + elem := &itr.stack[itr.depth] + elem.index = elem.node.indexOf(key, itr.m.comparer) + + switch node := elem.node.(type) { + case *sortedMapBranchNode: + itr.stack[itr.depth+1] = sortedMapIteratorElem{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode: + if elem.index == len(node.entries) { + itr.next() + } + return + } + } +} + +// sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack. +type sortedMapIteratorElem struct { + node sortedMapNode + index int +} + +// Hasher hashes keys and checks them for equality. +type Hasher interface { + // Computes a 32-bit hash for key. + Hash(key interface{}) uint32 + + // Returns true if a and b are equal. + Equal(a, b interface{}) bool +} + +// intHasher implements Hasher for int keys. +type intHasher struct{} + +// Hash returns a hash for key. +func (h *intHasher) Hash(key interface{}) uint32 { + return hashUint64(uint64(key.(int))) +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not ints. +func (h *intHasher) Equal(a, b interface{}) bool { + return a.(int) == b.(int) +} + +// stringHasher implements Hasher for string keys. +type stringHasher struct{} + +// Hash returns a hash for value. +func (h *stringHasher) Hash(value interface{}) uint32 { + var hash uint32 + for i, value := 0, value.(string); i < len(value); i++ { + hash = 31*hash + uint32(value[i]) + } + return hash +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not strings. +func (h *stringHasher) Equal(a, b interface{}) bool { + return a.(string) == b.(string) +} + +// byteSliceHasher implements Hasher for string keys. +type byteSliceHasher struct{} + +// Hash returns a hash for value. +func (h *byteSliceHasher) Hash(value interface{}) uint32 { + var hash uint32 + for i, value := 0, value.([]byte); i < len(value); i++ { + hash = 31*hash + uint32(value[i]) + } + return hash +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not byte slices. +func (h *byteSliceHasher) Equal(a, b interface{}) bool { + return bytes.Equal(a.([]byte), b.([]byte)) +} + +// hashUint64 returns a 32-bit hash for a 64-bit value. +func hashUint64(value uint64) uint32 { + hash := value + for value > 0xffffffff { + value /= 0xffffffff + hash ^= value + } + return uint32(hash) +} + +// Comparer allows the comparison of two keys for the purpose of sorting. +type Comparer interface { + // Returns -1 if a is less than b, returns 1 if a is greater than b, + // and returns 0 if a is equal to b. + Compare(a, b interface{}) int +} + +// intComparer compares two integers. Implements Comparer. +type intComparer struct{} + +// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and +// returns 0 if a is equal to b. Panic if a or b is not an int. +func (c *intComparer) Compare(a, b interface{}) int { + if i, j := a.(int), b.(int); i < j { + return -1 + } else if i > j { + return 1 + } + return 0 +} + +// stringComparer compares two strings. Implements Comparer. +type stringComparer struct{} + +// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and +// returns 0 if a is equal to b. Panic if a or b is not a string. +func (c *stringComparer) Compare(a, b interface{}) int { + return strings.Compare(a.(string), b.(string)) +} + +// byteSliceComparer compares two byte slices. Implements Comparer. +type byteSliceComparer struct{} + +// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and +// returns 0 if a is equal to b. Panic if a or b is not a byte slice. +func (c *byteSliceComparer) Compare(a, b interface{}) int { + return bytes.Compare(a.([]byte), b.([]byte)) +} diff --git a/immutable_test.go b/immutable_test.go new file mode 100644 index 0000000..a0c7821 --- /dev/null +++ b/immutable_test.go @@ -0,0 +1,2038 @@ +package immutable + +import ( + "flag" + "fmt" + "math/rand" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var ( + veryVerbose = flag.Bool("vv", false, "very verbose") + randomN = flag.Int("random.n", 100, "number of RunRandom() iterations") +) + +func TestList(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + if size := NewList().Len(); size != 0 { + t.Fatalf("unexpected size: %d", size) + } + }) + + t.Run("Shallow", func(t *testing.T) { + list := NewList() + list = list.Append("foo") + if v := list.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } + + other := list.Append("bar") + if v := other.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } else if v := other.Get(1); v != "bar" { + t.Fatalf("unexpected value: %v", v) + } + + if v := list.Len(); v != 1 { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("Deep", func(t *testing.T) { + list := NewList() + var array []int + for i := 0; i < 100000; i++ { + list = list.Append(i) + array = append(array, i) + } + + if got, exp := len(array), list.Len(); got != exp { + t.Fatalf("List.Len()=%d, exp %d", got, exp) + } + for j := range array { + if got, exp := list.Get(j).(int), array[j]; got != exp { + t.Fatalf("%d. List.Get(%d)=%d, exp %d", len(array), j, got, exp) + } + } + }) + + t.Run("Set", func(t *testing.T) { + list := NewList() + list = list.Append("foo") + list = list.Append("bar") + + if v := list.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } + + list = list.Set(0, "baz") + if v := list.Get(0); v != "baz" { + t.Fatalf("unexpected value: %v", v) + } else if v := list.Get(1); v != "bar" { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("GetBelowRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Get(-1) + }() + if r != `immutable.List.Get: index -1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("GetAboveRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Get(1) + }() + if r != `immutable.List.Get: index 1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SetOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Set(1, "bar") + }() + if r != `immutable.List.Set: index 1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceStartOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Slice(2, 3) + }() + if r != `immutable.List.Slice: start index 2 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceEndOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Slice(1, 3) + }() + if r != `immutable.List.Slice: end index 3 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceInvalidIndex", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l.Slice(2, 1) + }() + if r != `immutable.List.Slice: invalid slice index: [2:1]` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceBeginning", func(t *testing.T) { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Slice(1, 2) + if got, exp := l.Len(), 1; got != exp { + t.Fatalf("List.Len()=%d, exp %d", got, exp) + } else if got, exp := l.Get(0), "bar"; got != exp { + t.Fatalf("List.Get(0)=%v, exp %v", got, exp) + } + }) + + t.Run("IteratorSeekOutOfBounds", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList() + l = l.Append("foo") + l.Iterator().Seek(-1) + }() + if r != `immutable.ListIterator.Seek: index -1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + l := NewTList() + for i := 0; i < 100000; i++ { + rnd := rand.Intn(70) + switch { + case rnd == 0: // slice + start, end := l.ChooseSliceIndices(rand) + l.Slice(start, end) + case rnd < 10: // set + if l.Len() > 0 { + l.Set(l.ChooseIndex(rand), rand.Intn(10000)) + } + case rnd < 30: // prepend + l.Prepend(rand.Intn(10000)) + default: // append + l.Append(rand.Intn(10000)) + } + + if err := l.Validate(); err != nil { + t.Fatal(err) + } + } + if err := l.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// TList represents a list that operates on a standard Go slice & immutable list. +type TList struct { + im, prev *List + std []int +} + +// NewTList returns a new instance of TList. +func NewTList() *TList { + return &TList{ + im: NewList(), + } +} + +// Len returns the size of the list. +func (l *TList) Len() int { + return len(l.std) +} + +// ChooseIndex returns a randomly chosen, valid index from the standard slice. +func (l *TList) ChooseIndex(rand *rand.Rand) int { + if len(l.std) == 0 { + return -1 + } + return rand.Intn(len(l.std)) +} + +// ChooseSliceIndices returns randomly chosen, valid indices for slicing. +func (l *TList) ChooseSliceIndices(rand *rand.Rand) (start, end int) { + if len(l.std) == 0 { + return 0, 0 + } + start = rand.Intn(len(l.std)) + end = rand.Intn(len(l.std)-start) + start + return start, end +} + +// Append adds v to the end of slice and List. +func (l *TList) Append(v int) { + l.prev = l.im + l.im = l.im.Append(v) + l.std = append(l.std, v) +} + +// Prepend adds v to the beginning of the slice and List. +func (l *TList) Prepend(v int) { + l.prev = l.im + l.im = l.im.Prepend(v) + l.std = append([]int{v}, l.std...) +} + +// Set updates the value at index i to v in the slice and List. +func (l *TList) Set(i, v int) { + l.prev = l.im + l.im = l.im.Set(i, v) + l.std[i] = v +} + +// Slice contracts the slice and List to the range of start/end indices. +func (l *TList) Slice(start, end int) { + l.prev = l.im + l.im = l.im.Slice(start, end) + l.std = l.std[start:end] +} + +// Validate returns an error if the slice and List are different. +func (l *TList) Validate() error { + if got, exp := len(l.std), l.im.Len(); got != exp { + return fmt.Errorf("Len()=%v, expected %d", got, exp) + } + + for i := range l.std { + if got, exp := l.im.Get(i), l.std[i]; got != exp { + return fmt.Errorf("Get(%d)=%v, expected %v", i, got, exp) + } + } + + if err := l.validateForwardIterator(); err != nil { + return err + } else if err := l.validateBackwardIterator(); err != nil { + return err + } + return nil +} + +func (l *TList) validateForwardIterator() error { + itr := l.im.Iterator() + for i := range l.std { + if j, v := itr.Next(); i != j || l.std[i] != v { + return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v>", j, v, i, l.std[i]) + } + + done := i == len(l.std)-1 + if v := itr.Done(); v != done { + return fmt.Errorf("ListIterator.Done()=%v, expected %v", v, done) + } + } + if i, v := itr.Next(); i != -1 || v != nil { + return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected DONE", i, v) + } + return nil +} + +func (l *TList) validateBackwardIterator() error { + itr := l.im.Iterator() + itr.Last() + for i := len(l.std) - 1; i >= 0; i-- { + if j, v := itr.Prev(); i != j || l.std[i] != v { + return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected <%v,%v>", j, v, i, l.std[i]) + } + + done := i == 0 + if v := itr.Done(); v != done { + return fmt.Errorf("ListIterator.Done()=%v, expected %v", v, done) + } + } + if i, v := itr.Prev(); i != -1 || v != nil { + return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected DONE", i, v) + } + return nil +} + +func BenchmarkList_Append(b *testing.B) { + b.ReportAllocs() + l := NewList() + for i := 0; i < b.N; i++ { + l = l.Append(i) + } +} + +func BenchmarkList_Prepend(b *testing.B) { + b.ReportAllocs() + l := NewList() + for i := 0; i < b.N; i++ { + l = l.Prepend(i) + } +} + +func BenchmarkList_Set(b *testing.B) { + const n = 10000 + + l := NewList() + for i := 0; i < 10000; i++ { + l = l.Append(i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + l = l.Set(i%n, i*10) + } +} + +func BenchmarkList_Iterator(b *testing.B) { + const n = 10000 + l := NewList() + for i := 0; i < 10000; i++ { + l = l.Append(i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := l.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) + + b.Run("Reverse", func(b *testing.B) { + itr := l.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.Last() + } + itr.Prev() + } + }) +} + +func ExampleList_Append() { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // foo + // bar + // baz +} + +func ExampleList_Prepend() { + l := NewList() + l = l.Prepend("foo") + l = l.Prepend("bar") + l = l.Prepend("baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // baz + // bar + // foo +} + +func ExampleList_Set() { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Set(1, "baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // foo + // baz +} + +func ExampleList_Slice() { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + l = l.Slice(1, 3) + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // bar + // baz +} + +func ExampleList_Iterator() { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + itr := l.Iterator() + for !itr.Done() { + i, v := itr.Next() + fmt.Println(i, v) + } + // Output: + // 0 foo + // 1 bar + // 2 baz +} + +func ExampleList_Iterator_reverse() { + l := NewList() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + itr := l.Iterator() + itr.Last() + for !itr.Done() { + i, v := itr.Prev() + fmt.Println(i, v) + } + // Output: + // 2 baz + // 1 bar + // 0 foo +} + +// Ensure node can support overwrites as it expands. +func TestIngernal_mapNode_Overwrite(t *testing.T) { + const n = 1000 + var h intHasher + var node mapNode = &mapArrayNode{} + for i := 0; i < n; i++ { + var resized bool + node = node.set(i, i, 0, h.Hash(i), &h, &resized) + if !resized { + t.Fatal("expected resize") + } + + // Overwrite every node. + for j := 0; j <= i; j++ { + var resized bool + node = node.set(j, i*j, 0, h.Hash(j), &h, &resized) + if resized { + t.Fatalf("expected no resize: i=%d, j=%d", i, j) + } + } + + // Verify not found at each branch type. + if _, ok := node.get(1000000, 0, h.Hash(1000000), &h); ok { + t.Fatal("expected no value") + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := node.get(i, 0, h.Hash(i), &h); !ok || v != i*(n-1) { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } +} + +func TestIngernal_mapArrayNode(t *testing.T) { + // Ensure 8 or fewer elements stays in an array node. + t.Run("Append", func(t *testing.T) { + var h intHasher + n := &mapArrayNode{} + for i := 0; i < 8; i++ { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized).(*mapArrayNode) + if !resized { + t.Fatal("expected resize") + } + + for j := 0; j < i; j++ { + if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure 8 or fewer elements stays in an array node when inserted in reverse. + t.Run("Prepend", func(t *testing.T) { + var h intHasher + n := &mapArrayNode{} + for i := 7; i >= 0; i-- { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized).(*mapArrayNode) + if !resized { + t.Fatal("expected resize") + } + + for j := i; j <= 7; j++ { + if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure array can transition between node types. + t.Run("Expand", func(t *testing.T) { + var h intHasher + var n mapNode = &mapArrayNode{} + for i := 0; i < 100; i++ { + var resized bool + n = n.set(i, i, 0, h.Hash(i), &h, &resized) + if !resized { + t.Fatal("expected resize") + } + + for j := 0; j < i; j++ { + if v, ok := n.get(j, 0, h.Hash(j), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure deleting elements returns the correct new node. + RunRandom(t, "Delete", func(t *testing.T, rand *rand.Rand) { + var h intHasher + var n mapNode = &mapArrayNode{} + for i := 0; i < 8; i++ { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, &resized) + } + + for _, i := range rand.Perm(8) { + n = n.delete(i*10, 0, h.Hash(i*10), &h) + } + if n != nil { + t.Fatal("expected nil rand") + } + }) +} + +func TestIngernal_mapValueNode(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + var h intHasher + n := newMapValueNode(h.Hash(2), 2, 3) + if v, ok := n.get(2, 0, h.Hash(2), &h); !ok { + t.Fatal("expected ok") + } else if v != 3 { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("KeyEqual", func(t *testing.T) { + var h intHasher + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(2, 4, 0, h.Hash(2), &h, &resized).(*mapValueNode) + if other == n { + t.Fatal("expected new node") + } else if got, exp := other.keyHash, h.Hash(2); got != exp { + t.Fatalf("keyHash=%v, expected %v", got, exp) + } else if got, exp := other.key, 2; got != exp { + t.Fatalf("key=%v, expected %v", got, exp) + } else if got, exp := other.value, 4; got != exp { + t.Fatalf("value=%v, expected %v", got, exp) + } else if resized { + t.Fatal("unexpected resize") + } + }) + + t.Run("KeyHashEqual", func(t *testing.T) { + h := &mockHasher{ + hash: func(value interface{}) uint32 { return 1 }, + equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + } + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), h, &resized).(*mapHashCollisionNode) + if got, exp := other.keyHash, h.Hash(2); got != exp { + t.Fatalf("keyHash=%v, expected %v", got, exp) + } else if got, exp := len(other.entries), 2; got != exp { + t.Fatalf("entries=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if got, exp := other.entries[0].key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := other.entries[0].value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if got, exp := other.entries[1].key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := other.entries[1].value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + }) + + t.Run("MergeNode", func(t *testing.T) { + // Inserting into a node with a different index in the mask should split into a bitmap node. + t.Run("NoConflict", func(t *testing.T) { + var h intHasher + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), &h, &resized).(*mapBitmapIndexedNode) + if got, exp := other.bitmap, uint32(0x14); got != exp { + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 2; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if node, ok := other.nodes[0].(*mapValueNode); !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := other.nodes[1].(*mapValueNode); !ok { + t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } + }) + + // Reversing the nodes from NoConflict should yield the same result. + t.Run("NoConflictReverse", func(t *testing.T) { + var h intHasher + var resized bool + n := newMapValueNode(h.Hash(4), 4, 5) + other := n.set(2, 3, 0, h.Hash(2), &h, &resized).(*mapBitmapIndexedNode) + if got, exp := other.bitmap, uint32(0x14); got != exp { + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 2; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if node, ok := other.nodes[0].(*mapValueNode); !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := other.nodes[1].(*mapValueNode); !ok { + t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v.(int) != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v.(int) != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } + }) + + // Inserting a node with the same mask index should nest an additional level of bitmap nodes. + t.Run("Conflict", func(t *testing.T) { + h := &mockHasher{ + hash: func(value interface{}) uint32 { return uint32(value.(int)) << 5 }, + equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + } + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), h, &resized).(*mapBitmapIndexedNode) + if got, exp := other.bitmap, uint32(0x01); got != exp { // mask is zero, expect first slot. + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 1; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + child, ok := other.nodes[0].(*mapBitmapIndexedNode) + if !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } + + if node, ok := child.nodes[0].(*mapValueNode); !ok { + t.Fatalf("node[0]=%T, unexpected type", child.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := child.nodes[1].(*mapValueNode); !ok { + t.Fatalf("node[1]=%T, unexpected type", child.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v.(int) != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v.(int) != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } else if v, ok := other.get(10, 0, h.Hash(10), h); ok { + t.Fatalf("Get(10)=<%v,%v>, expected no value", v, ok) + } + }) + }) +} + +func TestMap_Get(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewMap(nil) + if v, ok := m.Get(100); ok || v != nil { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) +} + +func TestMap_Set(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + m := NewMap(nil) + itr := m.Iterator() + if !itr.Done() { + t.Fatal("MapIterator.Done()=true, expected false") + } else if k, v := itr.Next(); k != nil || v != nil { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewMap(nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) + + t.Run("VerySmall", func(t *testing.T) { + const n = 6 + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + // NOTE: Array nodes store entries in insertion order. + itr := m.Iterator() + for i := 0; i < n; i++ { + if k, v := itr.Next(); k != i || v != i+1 { + t.Fatalf("MapIterator.Next()=<%v,%v>, exp <%v,%v>", k, v, i, i+1) + } + } + if !itr.Done() { + t.Fatal("expected iterator done") + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("StringKeys", func(t *testing.T) { + m := NewMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", "bat") + m = m.Set("", "EMPTY") + if v, ok := m.Get("foo"); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get("baz"); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get(""); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get("no_such_key"); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + t.Run("ByteSliceKeys", func(t *testing.T) { + m := NewMap(nil) + m = m.Set([]byte("foo"), "bar") + m = m.Set([]byte("baz"), "bat") + m = m.Set([]byte(""), "EMPTY") + if v, ok := m.Get([]byte("foo")); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get([]byte("no_such_key")); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + t.Run("NoDefaultHasher", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + m := NewMap(nil) + m = m.Set(uint64(100), "bar") + }() + if r != `immutable.Map.Set: must set hasher for uint64 type` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestMap() + for i := 0; i < 100000; i++ { + switch rand.Intn(2) { + case 1: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map can support overwrites as it expands. +func TestMap_Overwrite(t *testing.T) { + const n = 10000 + m := NewMap(nil) + for i := 0; i < n; i++ { + // Set original value. + m = m.Set(i, i) + + // Overwrite every node. + for j := 0; j <= i; j++ { + m = m.Set(j, i*j) + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i*(n-1) { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } +} + +func TestMap_Delete(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewMap(nil) + other := m.Delete("foo") + if m != other { + t.Fatal("expected same map") + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewMap(nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := range rand.New(rand.NewSource(0)).Perm(n) { + m = m.Delete(i) + } + if m.Len() != 0 { + t.Fatalf("expected no elements, got %d", m.Len()) + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + const n = 1000000 + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := range rand.New(rand.NewSource(0)).Perm(n) { + m = m.Delete(i) + } + if m.Len() != 0 { + t.Fatalf("expected no elements, got %d", m.Len()) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestMap() + for i := 0; i < 100000; i++ { + switch rand.Intn(8) { + case 0: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + case 1: // delete existing key + m.Delete(m.ExistingKey(rand)) + case 2: // delete non-existent key. + m.Delete(m.NewKey(rand)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + + // Delete all and verify they are gone. + keys := make([]int, len(m.keys)) + copy(keys, m.keys) + + for _, key := range keys { + m.Delete(key) + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map works even with hash conflicts. +func TestMap_LimitedHash(t *testing.T) { + h := mockHasher{ + hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF }, + equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, + } + m := NewMap(&h) + + rand := rand.New(rand.NewSource(0)) + keys := rand.Perm(100000) + for _, i := range keys { + m = m.Set(i, i) // initial set + } + for i := range keys { + m = m.Set(i, i*2) // overwrite + } + if m.Len() != len(keys) { + t.Fatalf("unexpected len: %d", m.Len()) + } + + // Verify all key/value pairs in map. + for i := 0; i < m.Len(); i++ { + if v, ok := m.Get(i); !ok || v != i*2 { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } + + // Verify iteration. + itr := m.Iterator() + for !itr.Done() { + if k, v := itr.Next(); v != k.(int)*2 { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k.(int)*2) + } + } + + // Verify not found works. + if _, ok := m.Get(10000000); ok { + t.Fatal("expected no value") + } + + // Verify delete non-existent key works. + if other := m.Delete(10000000 + 1); m != other { + t.Fatal("expected no change") + } + + // Remove all keys. + for _, key := range keys { + m = m.Delete(key) + } + if m.Len() != 0 { + t.Fatalf("unexpected size: %d", m.Len()) + } +} + +// TestMap represents a combined immutable and stdlib map. +type TestMap struct { + im, prev *Map + std map[int]int + keys []int +} + +func NewTestMap() *TestMap { + return &TestMap{ + im: NewMap(nil), + std: make(map[int]int), + } +} + +func (m *TestMap) NewKey(rand *rand.Rand) int { + for { + k := rand.Int() + if _, ok := m.std[k]; !ok { + return k + } + } +} + +func (m *TestMap) ExistingKey(rand *rand.Rand) int { + if len(m.keys) == 0 { + return 0 + } + return m.keys[rand.Intn(len(m.keys))] +} + +func (m *TestMap) Set(k, v int) { + m.prev = m.im + m.im = m.im.Set(k, v) + + _, exists := m.std[k] + if !exists { + m.keys = append(m.keys, k) + } + m.std[k] = v +} + +func (m *TestMap) Delete(k int) { + m.prev = m.im + m.im = m.im.Delete(k) + delete(m.std, k) + + for i := range m.keys { + if m.keys[i] == k { + m.keys = append(m.keys[:i], m.keys[i+1:]...) + break + } + } +} + +func (m *TestMap) Validate() error { + for _, k := range m.keys { + if v, ok := m.im.Get(k); !ok { + return fmt.Errorf("key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + } + if err := m.validateIterator(); err != nil { + return err + } + return nil +} + +func (m *TestMap) validateIterator() error { + other := make(map[int]int) + itr := m.im.Iterator() + for !itr.Done() { + k, v := itr.Next() + other[k.(int)] = v.(int) + } + if diff := cmp.Diff(other, m.std); diff != "" { + return fmt.Errorf("map iterator mismatch: %s", diff) + } + if k, v := itr.Next(); k != nil || v != nil { + return fmt.Errorf("map iterator returned key/value after done: <%v/%v>", k, v) + } + return nil +} + +func BenchmarkMap_Set(b *testing.B) { + b.ReportAllocs() + m := NewMap(nil) + for i := 0; i < b.N; i++ { + m = m.Set(i, i) + } +} + +func BenchmarkMap_Delete(b *testing.B) { + const n = 10000 + + m := NewMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + m.Delete(i % n) // Do not update map, always operate on original + } +} + +func BenchmarkMap_Iterator(b *testing.B) { + const n = 10000 + m := NewMap(nil) + for i := 0; i < 10000; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) +} + +func ExampleMap_Set() { + m := NewMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat <nil> false +} + +func ExampleMap_Delete() { + m := NewMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + m = m.Delete("baz") + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz <nil> false +} + +func ExampleMap_Iterator() { + m := NewMap(nil) + m = m.Set("apple", 100) + m = m.Set("grape", 200) + m = m.Set("kiwi", 300) + m = m.Set("mango", 400) + m = m.Set("orange", 500) + m = m.Set("peach", 600) + m = m.Set("pear", 700) + m = m.Set("pineapple", 800) + m = m.Set("strawberry", 900) + + itr := m.Iterator() + for !itr.Done() { + k, v := itr.Next() + fmt.Println(k, v) + } + // Output: + // mango 400 + // pear 700 + // pineapple 800 + // grape 200 + // orange 500 + // strawberry 900 + // kiwi 300 + // peach 600 + // apple 100 +} + +func TestIngernalSortedMapLeafNode(t *testing.T) { + RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { + var cmpr intComparer + var node sortedMapNode = &sortedMapLeafNode{} + var keys []int + for _, i := range rand.Perm(32) { + var resized bool + var splitNode sortedMapNode + node, splitNode = node.set(i, i*10, &cmpr, &resized) + if !resized { + t.Fatal("expected resize") + } else if splitNode != nil { + t.Fatal("expected split") + } + keys = append(keys, i) + + // Verify not found at each size. + if _, ok := node.get(rand.Int()+32, &cmpr); ok { + t.Fatal("expected no value") + } + + // Verify min key is always the lowest. + sort.Ints(keys) + if got, exp := node.minKey(), keys[0]; got != exp { + t.Fatalf("minKey()=%d, expected %d", got, exp) + } + } + + // Verify all key/value pairs in node. + for i := range keys { + if v, ok := node.get(i, &cmpr); !ok || v != i*10 { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } + }) + + RunRandom(t, "Overwrite", func(t *testing.T, rand *rand.Rand) { + var cmpr intComparer + var node sortedMapNode = &sortedMapLeafNode{} + for _, i := range rand.Perm(32) { + var resized bool + node, _ = node.set(i, i*2, &cmpr, &resized) + } + for _, i := range rand.Perm(32) { + var resized bool + node, _ = node.set(i, i*3, &cmpr, &resized) + if resized { + t.Fatal("expected no resize") + } + } + + // Verify all overwritten key/value pairs in node. + for i := 0; i < 32; i++ { + if v, ok := node.get(i, &cmpr); !ok || v != i*3 { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } + }) + + t.Run("Split", func(t *testing.T) { + // Fill leaf node. + var cmpr intComparer + var node sortedMapNode = &sortedMapLeafNode{} + for i := 0; i < 32; i++ { + var resized bool + node, _ = node.set(i, i*10, &cmpr, &resized) + } + + // Add one more and expect split. + var resized bool + newNode, splitNode := node.set(32, 320, &cmpr, &resized) + + // Verify node contents. + newLeafNode, ok := newNode.(*sortedMapLeafNode) + if !ok { + t.Fatalf("unexpected node type: %T", newLeafNode) + } else if n := len(newLeafNode.entries); n != 16 { + t.Fatalf("unexpected node len: %d", n) + } + for i := range newLeafNode.entries { + if entry := newLeafNode.entries[i]; entry.key != i || entry.value != i*10 { + t.Fatalf("%d. unexpected entry: %v=%v", i, entry.key, entry.value) + } + } + + // Verify split node contents. + splitLeafNode, ok := splitNode.(*sortedMapLeafNode) + if !ok { + t.Fatalf("unexpected split node type: %T", splitLeafNode) + } else if n := len(splitLeafNode.entries); n != 17 { + t.Fatalf("unexpected split node len: %d", n) + } + for i := range splitLeafNode.entries { + if entry := splitLeafNode.entries[i]; entry.key != (i+16) || entry.value != (i+16)*10 { + t.Fatalf("%d. unexpected split node entry: %v=%v", i, entry.key, entry.value) + } + } + }) +} + +func TestIngernalSortedMapBranchNode(t *testing.T) { + RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { + keys := make([]int, 32*16) + for i := range keys { + keys[i] = rand.Intn(10000) + } + keys = uniqueIntSlice(keys) + sort.Ints(keys[:2]) // ensure first two keys are sorted for initial insert. + + // Initialize branch with two leafs. + var cmpr intComparer + leaf0 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[0], value: keys[0] * 10}}} + leaf1 := &sortedMapLeafNode{entries: []mapEntry{{key: keys[1], value: keys[1] * 10}}} + var node sortedMapNode = newSortedMapBranchNode(leaf0, leaf1) + + sort.Ints(keys) + for _, i := range rand.Perm(len(keys)) { + key := keys[i] + + var resized bool + var splitNode sortedMapNode + node, splitNode = node.set(key, key*10, &cmpr, &resized) + if key == leaf0.entries[0].key || key == leaf1.entries[0].key { + if resized { + t.Fatalf("expected no resize: key=%d", key) + } + } else { + if !resized { + t.Fatalf("expected resize: key=%d", key) + } + } + if splitNode != nil { + t.Fatal("unexpected split") + } + } + + // Verify all key/value pairs in node. + for _, key := range keys { + if v, ok := node.get(key, &cmpr); !ok || v != key*10 { + t.Fatalf("get(%d)=<%v,%v>", key, v, ok) + } + } + + // Verify min key is the lowest key. + if got, exp := node.minKey(), keys[0]; got != exp { + t.Fatalf("minKey()=%d, expected %d", got, exp) + } + }) + + t.Run("Split", func(t *testing.T) { + // Generate leaf nodes. + var cmpr intComparer + children := make([]sortedMapNode, 32) + for i := range children { + leaf := &sortedMapLeafNode{entries: make([]mapEntry, 32)} + for j := range leaf.entries { + leaf.entries[j] = mapEntry{key: (i * 32) + j, value: ((i * 32) + j) * 100} + } + children[i] = leaf + } + var node sortedMapNode = newSortedMapBranchNode(children...) + + // Add one more and expect split. + var resized bool + newNode, splitNode := node.set((32 * 32), (32*32)*100, &cmpr, &resized) + + // Verify node contents. + var idx int + newBranchNode, ok := newNode.(*sortedMapBranchNode) + if !ok { + t.Fatalf("unexpected node type: %T", newBranchNode) + } else if n := len(newBranchNode.elems); n != 16 { + t.Fatalf("unexpected child elems len: %d", n) + } + for i, elem := range newBranchNode.elems { + child, ok := elem.node.(*sortedMapLeafNode) + if !ok { + t.Fatalf("unexpected child type") + } + for j, entry := range child.entries { + if entry.key != idx || entry.value != idx*100 { + t.Fatalf("%d/%d. unexpected entry: %v=%v", i, j, entry.key, entry.value) + } + idx++ + } + } + + // Verify split node contents. + splitBranchNode, ok := splitNode.(*sortedMapBranchNode) + if !ok { + t.Fatalf("unexpected split node type: %T", splitBranchNode) + } else if n := len(splitBranchNode.elems); n != 17 { + t.Fatalf("unexpected split node elem len: %d", n) + } + for i, elem := range splitBranchNode.elems { + child, ok := elem.node.(*sortedMapLeafNode) + if !ok { + t.Fatalf("unexpected split node child type") + } + for j, entry := range child.entries { + if entry.key != idx || entry.value != idx*100 { + t.Fatalf("%d/%d. unexpected split node entry: %v=%v", i, j, entry.key, entry.value) + } + idx++ + } + } + }) +} + +func TestSortedMap_Get(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewSortedMap(nil) + if v, ok := m.Get(100); ok || v != nil { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) +} + +func TestSortedMap_Set(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + m := NewSortedMap(nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if got, exp := m.Len(), 1; got != exp { + t.Fatalf("SortedMap.Len()=%d, exp %d", got, exp) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("StringKeys", func(t *testing.T) { + m := NewSortedMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", "bat") + m = m.Set("", "EMPTY") + if v, ok := m.Get("foo"); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get("baz"); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get(""); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get("no_such_key"); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + t.Run("ByteSliceKeys", func(t *testing.T) { + m := NewSortedMap(nil) + m = m.Set([]byte("foo"), "bar") + m = m.Set([]byte("baz"), "bat") + m = m.Set([]byte(""), "EMPTY") + if v, ok := m.Get([]byte("foo")); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get([]byte("baz")); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get([]byte("")); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get([]byte("no_such_key")); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + t.Run("NoDefaultComparer", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + m := NewSortedMap(nil) + m = m.Set(uint64(100), "bar") + }() + if r != `immutable.SortedMap.Set: must set comparer for uint64 type` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestSortedMap() + for j := 0; j < 10000; j++ { + switch rand.Intn(2) { + case 1: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map can support overwrites as it expands. +func TestSortedMap_Overwrite(t *testing.T) { + const n = 1000 + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + // Set original value. + m = m.Set(i, i) + + // Overwrite every node. + for j := 0; j <= i; j++ { + m = m.Set(j, i*j) + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i*(n-1) { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } +} + +func TestSortedMap_Delete(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewSortedMap(nil) + m = m.Delete(100) + if n := m.Len(); n != 0 { + t.Fatalf("SortedMap.Len()=%d, expected 0", n) + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewSortedMap(nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + m = m.Delete(100) + if v, ok := m.Get(100); ok { + t.Fatalf("unexpected no value: <%v,%v>", v, ok) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + for i := 0; i < n; i++ { + m = m.Delete(i) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); ok { + t.Fatalf("expected no value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + for i := 0; i < n; i++ { + m = m.Delete(i) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); ok { + t.Fatalf("unexpected no value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestSortedMap() + for j := 0; j < 10000; j++ { + switch rand.Intn(8) { + case 0: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + case 1: // delete existing key + m.Delete(m.ExistingKey(rand)) + case 2: // delete non-existent key. + m.Delete(m.NewKey(rand)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +func TestSortedMap_Iterator(t *testing.T) { + t.Run("Seek", func(t *testing.T) { + const n = 100 + m := NewSortedMap(nil) + for i := 0; i < n; i += 2 { + m = m.Set(fmt.Sprintf("%04d", i), i) + } + + t.Run("Exact", func(t *testing.T) { + itr := m.Iterator() + for i := 0; i < n; i += 2 { + itr.Seek(fmt.Sprintf("%04d", i)) + for j := i; j < n; j += 2 { + if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) { + t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + } + }) + + t.Run("Miss", func(t *testing.T) { + itr := m.Iterator() + for i := 1; i < n-2; i += 2 { + itr.Seek(fmt.Sprintf("%04d", i)) + for j := i + 1; j < n; j += 2 { + if k, _ := itr.Next(); k != fmt.Sprintf("%04d", j) { + t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + } + }) + + t.Run("BeforeFirst", func(t *testing.T) { + itr := m.Iterator() + itr.Seek("") + for i := 0; i < n; i += 2 { + if k, _ := itr.Next(); k != fmt.Sprintf("%04d", i) { + t.Fatalf("%d. SortedMapIterator.Next()=%v, expected key %04d", i, k, i) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + }) + t.Run("AfterLast", func(t *testing.T) { + itr := m.Iterator() + itr.Seek("1000") + if k, _ := itr.Next(); k != nil { + t.Fatalf("0. SortedMapIterator.Next()=%v, expected nil key", k) + } else if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + }) + }) +} + +// TestSortedMap represents a combined immutable and stdlib sorted map. +type TestSortedMap struct { + im, prev *SortedMap + std map[int]int + keys []int +} + +func NewTestSortedMap() *TestSortedMap { + return &TestSortedMap{ + im: NewSortedMap(nil), + std: make(map[int]int), + } +} + +func (m *TestSortedMap) NewKey(rand *rand.Rand) int { + for { + k := rand.Int() + if _, ok := m.std[k]; !ok { + return k + } + } +} + +func (m *TestSortedMap) ExistingKey(rand *rand.Rand) int { + if len(m.keys) == 0 { + return 0 + } + return m.keys[rand.Intn(len(m.keys))] +} + +func (m *TestSortedMap) Set(k, v int) { + m.prev = m.im + m.im = m.im.Set(k, v) + + if _, ok := m.std[k]; !ok { + m.keys = append(m.keys, k) + sort.Ints(m.keys) + } + m.std[k] = v +} + +func (m *TestSortedMap) Delete(k int) { + m.prev = m.im + m.im = m.im.Delete(k) + delete(m.std, k) + + for i := range m.keys { + if m.keys[i] == k { + m.keys = append(m.keys[:i], m.keys[i+1:]...) + break + } + } +} + +func (m *TestSortedMap) Validate() error { + for _, k := range m.keys { + if v, ok := m.im.Get(k); !ok { + return fmt.Errorf("key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + } + + sort.Ints(m.keys) + if err := m.validateForwardIterator(); err != nil { + return err + } else if err := m.validateBackwardIterator(); err != nil { + return err + } + return nil +} + +func (m *TestSortedMap) validateForwardIterator() error { + itr := m.im.Iterator() + for i, k0 := range m.keys { + v0 := m.std[k0] + if k1, v1 := itr.Next(); k0 != k1 || v0 != v1 { + return fmt.Errorf("%d. SortedMapIterator.Next()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) + } + + done := i == len(m.keys)-1 + if v := itr.Done(); v != done { + return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) + } + } + if k, v := itr.Next(); k != nil || v != nil { + return fmt.Errorf("SortedMapIterator.Next()=<%v,%v>, expected nil after done", k, v) + } + return nil +} + +func (m *TestSortedMap) validateBackwardIterator() error { + itr := m.im.Iterator() + itr.Last() + for i := len(m.keys) - 1; i >= 0; i-- { + k0 := m.keys[i] + v0 := m.std[k0] + if k1, v1 := itr.Prev(); k0 != k1 || v0 != v1 { + return fmt.Errorf("%d. SortedMapIterator.Prev()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) + } + + done := i == 0 + if v := itr.Done(); v != done { + return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) + } + } + if k, v := itr.Prev(); k != nil || v != nil { + return fmt.Errorf("SortedMapIterator.Prev()=<%v,%v>, expected nil after done", k, v) + } + return nil +} + +func BenchmarkSortedMap_Set(b *testing.B) { + b.ReportAllocs() + m := NewSortedMap(nil) + for i := 0; i < b.N; i++ { + m = m.Set(i, i) + } +} + +func BenchmarkSortedMap_Delete(b *testing.B) { + const n = 10000 + + m := NewSortedMap(nil) + for i := 0; i < n; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + m.Delete(i % n) // Do not update map, always operate on original + } +} + +func BenchmarkSortedMap_Iterator(b *testing.B) { + const n = 10000 + m := NewSortedMap(nil) + for i := 0; i < 10000; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) + + b.Run("Reverse", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.Last() + } + itr.Prev() + } + }) +} + +func ExampleSortedMap_Set() { + m := NewSortedMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat <nil> false +} + +func ExampleSortedMap_Delete() { + m := NewSortedMap(nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + m = m.Delete("baz") + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz <nil> false +} + +func ExampleSortedMap_Iterator() { + m := NewSortedMap(nil) + m = m.Set("strawberry", 900) + m = m.Set("kiwi", 300) + m = m.Set("apple", 100) + m = m.Set("pear", 700) + m = m.Set("pineapple", 800) + m = m.Set("peach", 600) + m = m.Set("orange", 500) + m = m.Set("grape", 200) + m = m.Set("mango", 400) + + itr := m.Iterator() + for !itr.Done() { + k, v := itr.Next() + fmt.Println(k, v) + } + // Output: + // apple 100 + // grape 200 + // kiwi 300 + // mango 400 + // orange 500 + // peach 600 + // pear 700 + // pineapple 800 + // strawberry 900 +} + +// RunRandom executes fn multiple times with a different rand. +func RunRandom(t *testing.T, name string, fn func(t *testing.T, rand *rand.Rand)) { + if testing.Short() { + t.Skip("short mode") + } + t.Run(name, func(t *testing.T) { + for i := 0; i < *randomN; i++ { + t.Run(fmt.Sprintf("%08d", i), func(t *testing.T) { + t.Parallel() + fn(t, rand.New(rand.NewSource(int64(i)))) + }) + } + }) +} + +func uniqueIntSlice(a []int) []int { + m := make(map[int]struct{}) + other := make([]int, 0, len(a)) + for _, v := range a { + if _, ok := m[v]; ok { + continue + } + m[v] = struct{}{} + other = append(other, v) + } + return other +} + +// mockHasher represents a mock implementation of immutable.Hasher. +type mockHasher struct { + hash func(value interface{}) uint32 + equal func(a, b interface{}) bool +} + +// Hash executes the mocked HashFn function. +func (h *mockHasher) Hash(value interface{}) uint32 { + return h.hash(value) +} + +// Equal executes the mocked EqualFn function. +func (h *mockHasher) Equal(a, b interface{}) bool { + return h.equal(a, b) +} + +// mockComparer represents a mock implementation of immutable.Comparer. +type mockComparer struct { + compare func(a, b interface{}) int +} + +// Compare executes the mocked CompreFn function. +func (h *mockComparer) Compare(a, b interface{}) int { + return h.compare(a, b) +} |