From 06eaf1738494a9e783ed100565911d08efaae826 Mon Sep 17 00:00:00 2001 From: EuAndreh Date: Sat, 14 Dec 2024 10:30:47 -0300 Subject: Add Makefile and move files to structured folders --- .gitignore | 15 + Makefile | 159 ++++ deps.mk | 25 + immutable.go | 2459 ------------------------------------------------ immutable_test.go | 2544 ------------------------------------------------- mkdeps.sh | 29 + sets.go | 243 ----- sets_test.go | 126 --- src/pds.go | 2700 +++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/main.go | 7 + tests/pds.go | 2669 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 11 files changed, 5604 insertions(+), 5372 deletions(-) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 deps.mk delete mode 100644 immutable.go delete mode 100644 immutable_test.go create mode 100755 mkdeps.sh delete mode 100644 sets.go delete mode 100644 sets_test.go create mode 100644 src/pds.go create mode 100644 tests/main.go create mode 100644 tests/pds.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..094db69 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +/src/version.go +/*.bin +/src/*.a +/src/*.bin +/src/*cgo* +/tests/*.a +/tests/*.bin +/tests/functional/*/*.a +/tests/functional/*/*.bin +/tests/fuzz/*/*.a +/tests/fuzz/*/*.bin +/tests/benchmarks/*/*.a +/tests/benchmarks/*/*.bin +/tests/benchmarks/*/*.txt +/tests/fuzz/corpus/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..001b63c --- /dev/null +++ b/Makefile @@ -0,0 +1,159 @@ +.POSIX: +DATE = 1970-01-01 +VERSION = 0.1.0 +NAME = pds +NAME_UC = $(NAME) +LANGUAGES = en +## Installation prefix. Defaults to "/usr". +PREFIX = /usr +BINDIR = $(PREFIX)/bin +LIBDIR = $(PREFIX)/lib +GOLIBDIR = $(LIBDIR)/go +INCLUDEDIR = $(PREFIX)/include +SRCDIR = $(PREFIX)/src/$(NAME) +SHAREDIR = $(PREFIX)/share +LOCALEDIR = $(SHAREDIR)/locale +MANDIR = $(SHAREDIR)/man +EXEC = ./ +## Where to store the installation. Empty by default. +DESTDIR = +LDLIBS = --static +GOCFLAGS = -I $(GOLIBDIR) +GOLDFLAGS = -L $(GOLIBDIR) + + + +.SUFFIXES: +.SUFFIXES: .go .a .bin .bin-check + +.go.a: + go tool compile -I $(@D) $(GOCFLAGS) -o $@ -p $(*F) \ + `find $< $$(if [ $(*F) != main ]; then \ + echo src/$(NAME).go src/version.go; fi) | uniq` + +.a.bin: + go tool link -L $(@D) $(GOLDFLAGS) -o $@ --extldflags '$(LDLIBS)' $< + + + +all: +include deps.mk + + +libs.a = $(libs.go:.go=.a) +mains.a = $(mains.go:.go=.a) +mains.bin = $(mains.go:.go=.bin) +functional-tests/lib.a = $(functional-tests/lib.go:.go=.a) +fuzz-targets/lib.a = $(fuzz-targets/lib.go:.go=.a) +benchmarks/lib.a = $(benchmarks/lib.go:.go=.a) + +sources = \ + src/$(NAME).go \ + src/version.go \ + + +derived-assets = \ + src/version.go \ + $(libs.a) \ + $(mains.a) \ + $(mains.bin) \ + +side-assets = \ + tests/fuzz/corpus/ \ + tests/benchmarks/*/main.txt \ + + + +## Default target. Builds all artifacts required for testing +## and installation. +all: $(derived-assets) + + +$(libs.a): Makefile deps.mk +$(libs.a): src/$(NAME).go src/version.go + + +$(fuzz-targets/lib.a): + go tool compile $(GOCFLAGS) -o $@ -p $(NAME) -d=libfuzzer \ + $*.go src/$(NAME).go src/version.go + +src/version.go: Makefile + echo 'package $(NAME); const Version = "$(VERSION)"' > $@ + + + +tests.bin-check = \ + tests/main.bin-check \ + $(functional-tests/main.go:.go=.bin-check) \ + +$(tests.bin-check): + $(EXEC)$*.bin + +check-unit: $(tests.bin-check) + + +integration-tests = \ + +.PRECIOUS: $(integration-tests) +$(integration-tests): $(NAME).bin +$(integration-tests): ALWAYS + sh $@ + +check-integration: $(integration-tests) +check-integration: fuzz + + +## Run all tests. Each test suite is isolated, so that a parallel +## build can run tests at the same time. The required artifacts +## are created if missing. +check: check-unit check-integration + + + +FUZZSEC=1 +fuzz-targets/main.bin-check = $(fuzz-targets/main.go:.go=.bin-check) +$(fuzz-targets/main.bin-check): + $(EXEC)$*.bin --test.fuzztime=$(FUZZSEC)s \ + --test.fuzz='.*' --test.fuzzcachedir=tests/fuzz/corpus + +fuzz: $(fuzz-targets/main.bin-check) + + + +benchmarks/main.bin-check = $(benchmarks/main.go:.go=.bin-check) +$(benchmarks/main.bin-check): + printf '%s\n' '$(EXEC)$*.bin' > $*.txt + LANG=POSIX.UTF-8 time -p $(EXEC)$*.bin 2>> $*.txt + printf '%s\n' '$*.txt' + +bench: $(benchmarks/main.bin-check) + + + +## Remove *all* derived artifacts produced during the build. +## A dedicated test asserts that this is always true. +clean: + rm -rf $(derived-assets) $(side-assets) + + +## Installs into $(DESTDIR)$(PREFIX). Its dependency target +## ensures that all installable artifacts are crafted beforehand. +install: all + mkdir -p \ + '$(DESTDIR)$(GOLIBDIR)' \ + '$(DESTDIR)$(SRCDIR)' \ + + cp src/$(NAME).a '$(DESTDIR)$(GOLIBDIR)' + cp $(sources) '$(DESTDIR)$(SRCDIR)' + +## Uninstalls from $(DESTDIR)$(PREFIX). This is a perfect mirror +## of the "install" target, and removes *all* that was installed. +## A dedicated test asserts that this is always true. +uninstall: + rm -rf \ + '$(DESTDIR)$(GOLIBDIR)'/$(NAME).a \ + '$(DESTDIR)$(SRCDIR)' \ + + + +ALWAYS: diff --git a/deps.mk b/deps.mk new file mode 100644 index 0000000..b7701af --- /dev/null +++ b/deps.mk @@ -0,0 +1,25 @@ +libs.go = \ + src/pds.go \ + tests/pds.go \ + +mains.go = \ + tests/main.go \ + +functional-tests/lib.go = \ + +functional-tests/main.go = \ + +fuzz-targets/lib.go = \ + +fuzz-targets/main.go = \ + +benchmarks/lib.go = \ + +benchmarks/main.go = \ + +src/pds.a: src/pds.go +tests/main.a: tests/main.go +tests/pds.a: tests/pds.go +tests/main.bin: tests/main.a +tests/main.bin-check: tests/main.bin +tests/main.a: tests/$(NAME).a diff --git a/immutable.go b/immutable.go deleted file mode 100644 index 1642de3..0000000 --- a/immutable.go +++ /dev/null @@ -1,2459 +0,0 @@ -// Package immutable provides immutable collection types. -// -// # Introduction -// -// Immutable collections provide an efficient, safe way to share collections -// of data while minimizing locks. The collections in this package provide -// List, Map, and SortedMap implementations. These act similarly to slices -// and maps, respectively, except that altering a collection returns a new -// copy of the collection with that change. -// -// Because collections are unable to change, they are safe for multiple -// goroutines to read from at the same time without a mutex. However, these -// types of collections come with increased CPU & memory usage as compared -// with Go's built-in collection types so please evaluate for your specific -// use. -// -// # Collection Types -// -// The List type provides an API similar to Go slices. They allow appending, -// prepending, and updating of elements. Elements can also be fetched by index -// or iterated over using a ListIterator. -// -// The Map & SortedMap types provide an API similar to Go maps. They allow -// values to be assigned to unique keys and allow for the deletion of keys. -// Values can be fetched by key and key/value pairs can be iterated over using -// the appropriate iterator type. Both map types provide the same API. The -// SortedMap, however, provides iteration over sorted keys while the Map -// provides iteration over unsorted keys. Maps improved performance and memory -// usage as compared to SortedMaps. -// -// # Hashing and Sorting -// -// Map types require the use of a Hasher implementation to calculate hashes for -// their keys and check for key equality. SortedMaps require the use of a -// Comparer implementation to sort keys in the map. -// -// These collection types automatically provide built-in hasher and comparers -// for int, string, and byte slice keys. If you are using one of these key types -// then simply pass a nil into the constructor. Otherwise you will need to -// implement a custom Hasher or Comparer type. Please see the provided -// implementations for reference. -package immutable - -import ( - "fmt" - "math/bits" - "reflect" - "sort" - "strings" - - "golang.org/x/exp/constraints" -) - -// List is a dense, ordered, indexed collections. They are analogous to slices -// in Go. They can be updated by appending to the end of the list, prepending -// values to the beginning of the list, or updating existing indexes in the -// list. -type List[T any] struct { - root listNode[T] // root node - origin int // offset to zero index element - size int // total number of elements in use -} - -// NewList returns a new empty instance of List. -func NewList[T any](values ...T) *List[T] { - l := &List[T]{ - root: &listLeafNode[T]{}, - } - for _, value := range values { - l.append(value, true) - } - return l -} - -// clone returns a copy of the list. -func (l *List[T]) clone() *List[T] { - other := *l - return &other -} - -// Len returns the number of elements in the list. -func (l *List[T]) Len() int { - return l.size -} - -// cap returns the total number of possible elements for the current depth. -func (l *List[T]) cap() int { - return 1 << (l.root.depth() * listNodeBits) -} - -// Get returns the value at the given index. Similar to slices, this method will -// panic if index is below zero or is greater than or equal to the list size. -func (l *List[T]) Get(index int) T { - if index < 0 || index >= l.size { - panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index)) - } - return l.root.get(l.origin + index) -} - -// Set returns a new list with value set at index. Similar to slices, this -// method will panic if index is below zero or if the index is greater than -// or equal to the list size. -func (l *List[T]) Set(index int, value T) *List[T] { - return l.set(index, value, false) -} - -func (l *List[T]) set(index int, value T, mutable bool) *List[T] { - if index < 0 || index >= l.size { - panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index)) - } - other := l - if !mutable { - other = l.clone() - } - other.root = other.root.set(l.origin+index, value, mutable) - return other -} - -// Append returns a new list with value added to the end of the list. -func (l *List[T]) Append(value T) *List[T] { - return l.append(value, false) -} - -func (l *List[T]) append(value T, mutable bool) *List[T] { - other := l - if !mutable { - other = l.clone() - } - - // Expand list to the right if no slots remain. - if other.size+other.origin >= l.cap() { - newRoot := &listBranchNode[T]{d: other.root.depth() + 1} - newRoot.children[0] = other.root - other.root = newRoot - } - - // Increase size and set the last element to the new value. - other.size++ - other.root = other.root.set(other.origin+other.size-1, value, mutable) - return other -} - -// Prepend returns a new list with value(s) added to the beginning of the list. -func (l *List[T]) Prepend(value T) *List[T] { - return l.prepend(value, false) -} - -func (l *List[T]) prepend(value T, mutable bool) *List[T] { - other := l - if !mutable { - other = l.clone() - } - - // Expand list to the left if no slots remain. - if other.origin == 0 { - newRoot := &listBranchNode[T]{d: other.root.depth() + 1} - newRoot.children[listNodeSize-1] = other.root - other.root = newRoot - other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits) - } - - // Increase size and move origin back. Update first element to value. - other.size++ - other.origin-- - other.root = other.root.set(other.origin, value, mutable) - return other -} - -// Slice returns a new list of elements between start index and end index. -// Similar to slices, this method will panic if start or end are below zero or -// greater than the list size. A panic will also occur if start is greater than -// end. -// -// Unlike Go slices, references to inaccessible elements will be automatically -// removed so they can be garbage collected. -func (l *List[T]) Slice(start, end int) *List[T] { - return l.slice(start, end, false) -} - -func (l *List[T]) slice(start, end int, mutable bool) *List[T] { - // Panics similar to Go slices. - if start < 0 || start > l.size { - panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start)) - } else if end < 0 || end > l.size { - panic(fmt.Sprintf("immutable.List.Slice: end index %d out of bounds", end)) - } else if start > end { - panic(fmt.Sprintf("immutable.List.Slice: invalid slice index: [%d:%d]", start, end)) - } - - // Return the same list if the start and end are the entire range. - if start == 0 && end == l.size { - return l - } - - // Create copy, if immutable. - other := l - if !mutable { - other = l.clone() - } - - // Update origin/size. - other.origin = l.origin + start - other.size = end - start - - // Contract tree while the start & end are in the same child node. - for other.root.depth() > 1 { - i := (other.origin >> (other.root.depth() * listNodeBits)) & listNodeMask - j := ((other.origin + other.size - 1) >> (other.root.depth() * listNodeBits)) & listNodeMask - if i != j { - break // branch contains at least two nodes, exit - } - - // Replace the current root with the single child & update origin offset. - other.origin -= i << (other.root.depth() * listNodeBits) - other.root = other.root.(*listBranchNode[T]).children[i] - } - - // Ensure all references are removed before start & after end. - other.root = other.root.deleteBefore(other.origin, mutable) - other.root = other.root.deleteAfter(other.origin+other.size-1, mutable) - - return other -} - -// Iterator returns a new iterator for this list positioned at the first index. -func (l *List[T]) Iterator() *ListIterator[T] { - itr := &ListIterator[T]{list: l} - itr.First() - return itr -} - -// ListBuilder represents an efficient builder for creating new Lists. -type ListBuilder[T any] struct { - list *List[T] // current state -} - -// NewListBuilder returns a new instance of ListBuilder. -func NewListBuilder[T any]() *ListBuilder[T] { - return &ListBuilder[T]{list: NewList[T]()} -} - -// List returns the current copy of the list. -// The builder should not be used again after the list after this call. -func (b *ListBuilder[T]) List() *List[T] { - assert(b.list != nil, "immutable.ListBuilder.List(): duplicate call to fetch list") - list := b.list - b.list = nil - return list -} - -// Len returns the number of elements in the underlying list. -func (b *ListBuilder[T]) Len() int { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - return b.list.Len() -} - -// Get returns the value at the given index. Similar to slices, this method will -// panic if index is below zero or is greater than or equal to the list size. -func (b *ListBuilder[T]) Get(index int) T { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - return b.list.Get(index) -} - -// Set updates the value at the given index. Similar to slices, this method will -// panic if index is below zero or if the index is greater than or equal to the -// list size. -func (b *ListBuilder[T]) Set(index int, value T) { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - b.list = b.list.set(index, value, true) -} - -// Append adds value to the end of the list. -func (b *ListBuilder[T]) Append(value T) { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - b.list = b.list.append(value, true) -} - -// Prepend adds value to the beginning of the list. -func (b *ListBuilder[T]) Prepend(value T) { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - b.list = b.list.prepend(value, true) -} - -// Slice updates the list with a sublist of elements between start and end index. -// See List.Slice() for more details. -func (b *ListBuilder[T]) Slice(start, end int) { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - b.list = b.list.slice(start, end, true) -} - -// Iterator returns a new iterator for the underlying list. -func (b *ListBuilder[T]) Iterator() *ListIterator[T] { - assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") - return b.list.Iterator() -} - -// Constants for bit shifts used for levels in the List trie. -const ( - listNodeBits = 5 - listNodeSize = 1 << listNodeBits - listNodeMask = listNodeSize - 1 -) - -// listNode represents either a branch or leaf node in a List. -type listNode[T any] interface { - depth() uint - get(index int) T - set(index int, v T, mutable bool) listNode[T] - - containsBefore(index int) bool - containsAfter(index int) bool - - deleteBefore(index int, mutable bool) listNode[T] - deleteAfter(index int, mutable bool) listNode[T] -} - -// newListNode returns a leaf node for depth zero, otherwise returns a branch node. -func newListNode[T any](depth uint) listNode[T] { - if depth == 0 { - return &listLeafNode[T]{} - } - return &listBranchNode[T]{d: depth} -} - -// listBranchNode represents a branch of a List tree at a given depth. -type listBranchNode[T any] struct { - d uint // depth - children [listNodeSize]listNode[T] -} - -// depth returns the depth of this branch node from the leaf. -func (n *listBranchNode[T]) depth() uint { return n.d } - -// get returns the child node at the segment of the index for this depth. -func (n *listBranchNode[T]) get(index int) T { - idx := (index >> (n.d * listNodeBits)) & listNodeMask - return n.children[idx].get(index) -} - -// set recursively updates the value at index for each lower depth from the node. -func (n *listBranchNode[T]) set(index int, v T, mutable bool) listNode[T] { - idx := (index >> (n.d * listNodeBits)) & listNodeMask - - // Find child for the given value in the branch. Create new if it doesn't exist. - child := n.children[idx] - if child == nil { - child = newListNode[T](n.depth() - 1) - } - - // Return a copy of this branch with the new child. - var other *listBranchNode[T] - if mutable { - other = n - } else { - tmp := *n - other = &tmp - } - other.children[idx] = child.set(index, v, mutable) - return other -} - -// containsBefore returns true if non-nil values exists between [0,index). -func (n *listBranchNode[T]) containsBefore(index int) bool { - idx := (index >> (n.d * listNodeBits)) & listNodeMask - - // Quickly check if any direct children exist before this segment of the index. - for i := 0; i < idx; i++ { - if n.children[i] != nil { - return true - } - } - - // Recursively check for children directly at the given index at this segment. - if n.children[idx] != nil && n.children[idx].containsBefore(index) { - return true - } - return false -} - -// containsAfter returns true if non-nil values exists between (index,listNodeSize). -func (n *listBranchNode[T]) containsAfter(index int) bool { - idx := (index >> (n.d * listNodeBits)) & listNodeMask - - // Quickly check if any direct children exist after this segment of the index. - for i := idx + 1; i < len(n.children); i++ { - if n.children[i] != nil { - return true - } - } - - // Recursively check for children directly at the given index at this segment. - if n.children[idx] != nil && n.children[idx].containsAfter(index) { - return true - } - return false -} - -// deleteBefore returns a new node with all elements before index removed. -func (n *listBranchNode[T]) deleteBefore(index int, mutable bool) listNode[T] { - // Ignore if no nodes exist before the given index. - if !n.containsBefore(index) { - return n - } - - // Return a copy with any nodes prior to the index removed. - idx := (index >> (n.d * listNodeBits)) & listNodeMask - - var other *listBranchNode[T] - if mutable { - other = n - for i := 0; i < idx; i++ { - n.children[i] = nil - } - } else { - other = &listBranchNode[T]{d: n.d} - copy(other.children[idx:][:], n.children[idx:][:]) - } - - if other.children[idx] != nil { - other.children[idx] = other.children[idx].deleteBefore(index, mutable) - } - return other -} - -// deleteBefore returns a new node with all elements before index removed. -func (n *listBranchNode[T]) deleteAfter(index int, mutable bool) listNode[T] { - // Ignore if no nodes exist after the given index. - if !n.containsAfter(index) { - return n - } - - // Return a copy with any nodes after the index removed. - idx := (index >> (n.d * listNodeBits)) & listNodeMask - - var other *listBranchNode[T] - if mutable { - other = n - for i := idx + 1; i < len(n.children); i++ { - n.children[i] = nil - } - } else { - other = &listBranchNode[T]{d: n.d} - copy(other.children[:idx+1], n.children[:idx+1]) - } - - if other.children[idx] != nil { - other.children[idx] = other.children[idx].deleteAfter(index, mutable) - } - return other -} - -// listLeafNode represents a leaf node in a List. -type listLeafNode[T any] struct { - children [listNodeSize]T - // bitset with ones at occupied positions, position 0 is the LSB - occupied uint32 -} - -// depth always returns 0 for leaf nodes. -func (n *listLeafNode[T]) depth() uint { return 0 } - -// get returns the value at the given index. -func (n *listLeafNode[T]) get(index int) T { - return n.children[index&listNodeMask] -} - -// set returns a copy of the node with the value at the index updated to v. -func (n *listLeafNode[T]) set(index int, v T, mutable bool) listNode[T] { - idx := index & listNodeMask - var other *listLeafNode[T] - if mutable { - other = n - } else { - tmp := *n - other = &tmp - } - other.children[idx] = v - other.occupied |= 1 << idx - return other -} - -// containsBefore returns true if non-nil values exists between [0,index). -func (n *listLeafNode[T]) containsBefore(index int) bool { - idx := index & listNodeMask - return bits.TrailingZeros32(n.occupied) < idx -} - -// containsAfter returns true if non-nil values exists between (index,listNodeSize). -func (n *listLeafNode[T]) containsAfter(index int) bool { - idx := index & listNodeMask - lastSetPos := 31 - bits.LeadingZeros32(n.occupied) - return lastSetPos > idx -} - -// deleteBefore returns a new node with all elements before index removed. -func (n *listLeafNode[T]) deleteBefore(index int, mutable bool) listNode[T] { - if !n.containsBefore(index) { - return n - } - - idx := index & listNodeMask - var other *listLeafNode[T] - if mutable { - other = n - var empty T - for i := 0; i < idx; i++ { - other.children[i] = empty - } - } else { - other = &listLeafNode[T]{occupied: n.occupied} - copy(other.children[idx:][:], n.children[idx:][:]) - } - // Set the first idx bits to 0. - other.occupied &= ^((1 << idx) - 1) - return other -} - -// deleteAfter returns a new node with all elements after index removed. -func (n *listLeafNode[T]) deleteAfter(index int, mutable bool) listNode[T] { - if !n.containsAfter(index) { - return n - } - - idx := index & listNodeMask - var other *listLeafNode[T] - if mutable { - other = n - var empty T - for i := idx + 1; i < len(n.children); i++ { - other.children[i] = empty - } - } else { - other = &listLeafNode[T]{occupied: n.occupied} - copy(other.children[:idx+1][:], n.children[:idx+1][:]) - } - // Set bits after idx to 0. idx < 31 because n.containsAfter(index) == true. - other.occupied &= (1 << (idx + 1)) - 1 - return other -} - -// ListIterator represents an ordered iterator over a list. -type ListIterator[T any] struct { - list *List[T] // source list - index int // current index position - - stack [32]listIteratorElem[T] // search stack - depth int // stack depth -} - -// Done returns true if no more elements remain in the iterator. -func (itr *ListIterator[T]) Done() bool { - return itr.index < 0 || itr.index >= itr.list.Len() -} - -// First positions the iterator on the first index. -// If source list is empty then no change is made. -func (itr *ListIterator[T]) First() { - if itr.list.Len() != 0 { - itr.Seek(0) - } -} - -// Last positions the iterator on the last index. -// If source list is empty then no change is made. -func (itr *ListIterator[T]) Last() { - if n := itr.list.Len(); n != 0 { - itr.Seek(n - 1) - } -} - -// Seek moves the iterator position to the given index in the list. -// Similar to Go slices, this method will panic if index is below zero or if -// the index is greater than or equal to the list size. -func (itr *ListIterator[T]) Seek(index int) { - // Panic similar to Go slices. - if index < 0 || index >= itr.list.Len() { - panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index)) - } - itr.index = index - - // Reset to the bottom of the stack at seek to the correct position. - itr.stack[0] = listIteratorElem[T]{node: itr.list.root} - itr.depth = 0 - itr.seek(index) -} - -// Next returns the current index and its value & moves the iterator forward. -// Returns an index of -1 if the there are no more elements to return. -func (itr *ListIterator[T]) Next() (index int, value T) { - // Exit immediately if there are no elements remaining. - var empty T - if itr.Done() { - return -1, empty - } - - // Retrieve current index & value. - elem := &itr.stack[itr.depth] - index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] - - // Increase index. If index is at the end then return immediately. - itr.index++ - if itr.Done() { - return index, value - } - - // Move up stack until we find a node that has remaining position ahead. - for ; itr.depth > 0 && itr.stack[itr.depth].index >= listNodeSize-1; itr.depth-- { - } - - // Seek to correct position from current depth. - itr.seek(itr.index) - - return index, value -} - -// Prev returns the current index and value and moves the iterator backward. -// Returns an index of -1 if the there are no more elements to return. -func (itr *ListIterator[T]) Prev() (index int, value T) { - // Exit immediately if there are no elements remaining. - var empty T - if itr.Done() { - return -1, empty - } - - // Retrieve current index & value. - elem := &itr.stack[itr.depth] - index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] - - // Decrease index. If index is past the beginning then return immediately. - itr.index-- - if itr.Done() { - return index, value - } - - // Move up stack until we find a node that has remaining position behind. - for ; itr.depth > 0 && itr.stack[itr.depth].index == 0; itr.depth-- { - } - - // Seek to correct position from current depth. - itr.seek(itr.index) - - return index, value -} - -// seek positions the stack to the given index from the current depth. -// Elements and indexes below the current depth are assumed to be correct. -func (itr *ListIterator[T]) seek(index int) { - // Iterate over each level until we reach a leaf node. - for { - elem := &itr.stack[itr.depth] - elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask - - switch node := elem.node.(type) { - case *listBranchNode[T]: - child := node.children[elem.index] - itr.stack[itr.depth+1] = listIteratorElem[T]{node: child} - itr.depth++ - case *listLeafNode[T]: - return - } - } -} - -// listIteratorElem represents the node and it's child index within the stack. -type listIteratorElem[T any] struct { - node listNode[T] - index int -} - -// Size thresholds for each type of branch node. -const ( - maxArrayMapSize = 8 - maxBitmapIndexedSize = 16 -) - -// Segment bit shifts within the map tree. -const ( - mapNodeBits = 5 - mapNodeSize = 1 << mapNodeBits - mapNodeMask = mapNodeSize - 1 -) - -// Map represents an immutable hash map implementation. The map uses a Hasher -// to generate hashes and check for equality of key values. -// -// It is implemented as an Hash Array Mapped Trie. -type Map[K, V any] struct { - size int // total number of key/value pairs - root mapNode[K, V] // root node of trie - hasher Hasher[K] // hasher implementation -} - -// NewMap returns a new instance of Map. If hasher is nil, a default hasher -// implementation will automatically be chosen based on the first key added. -// Default hasher implementations only exist for int, string, and byte slice types. -func NewMap[K, V any](hasher Hasher[K]) *Map[K, V] { - return &Map[K, V]{ - hasher: hasher, - } -} - -// NewMapOf returns a new instance of Map, containing a map of provided entries. -// -// If hasher is nil, a default hasher implementation will automatically be chosen based on the first key added. -// Default hasher implementations only exist for int, string, and byte slice types. -func NewMapOf[K comparable, V any](hasher Hasher[K], entries map[K]V) *Map[K, V] { - m := &Map[K, V]{ - hasher: hasher, - } - for k, v := range entries { - m.set(k, v, true) - } - return m -} - -// Len returns the number of elements in the map. -func (m *Map[K, V]) Len() int { - return m.size -} - -// clone returns a shallow copy of m. -func (m *Map[K, V]) clone() *Map[K, V] { - other := *m - return &other -} - -// Get returns the value for a given key and a flag indicating whether the -// key exists. This flag distinguishes a nil value set on a key versus a -// non-existent key in the map. -func (m *Map[K, V]) Get(key K) (value V, ok bool) { - var empty V - if m.root == nil { - return empty, false - } - keyHash := m.hasher.Hash(key) - return m.root.get(key, 0, keyHash, m.hasher) -} - -// Set returns a map with the key set to the new value. A nil value is allowed. -// -// This function will return a new map even if the updated value is the same as -// the existing value because Map does not track value equality. -func (m *Map[K, V]) Set(key K, value V) *Map[K, V] { - return m.set(key, value, false) -} - -func (m *Map[K, V]) set(key K, value V, mutable bool) *Map[K, V] { - // Set a hasher on the first value if one does not already exist. - hasher := m.hasher - if hasher == nil { - hasher = NewHasher(key) - } - - // Generate copy if necessary. - other := m - if !mutable { - other = m.clone() - } - other.hasher = hasher - - // If the map is empty, initialize with a simple array node. - if m.root == nil { - other.size = 1 - other.root = &mapArrayNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} - return other - } - - // Otherwise copy the map and delegate insertion to the root. - // Resized will return true if the key does not currently exist. - var resized bool - other.root = m.root.set(key, value, 0, hasher.Hash(key), hasher, mutable, &resized) - if resized { - other.size++ - } - return other -} - -// Delete returns a map with the given key removed. -// Removing a non-existent key will cause this method to return the same map. -func (m *Map[K, V]) Delete(key K) *Map[K, V] { - return m.delete(key, false) -} - -func (m *Map[K, V]) delete(key K, mutable bool) *Map[K, V] { - // Return original map if no keys exist. - if m.root == nil { - return m - } - - // If the delete did not change the node then return the original map. - var resized bool - newRoot := m.root.delete(key, 0, m.hasher.Hash(key), m.hasher, mutable, &resized) - if !resized { - return m - } - - // Generate copy if necessary. - other := m - if !mutable { - other = m.clone() - } - - // Return copy of map with new root and decreased size. - other.size = m.size - 1 - other.root = newRoot - return other -} - -// Iterator returns a new iterator for the map. -func (m *Map[K, V]) Iterator() *MapIterator[K, V] { - itr := &MapIterator[K, V]{m: m} - itr.First() - return itr -} - -// MapBuilder represents an efficient builder for creating Maps. -type MapBuilder[K, V any] struct { - m *Map[K, V] // current state -} - -// NewMapBuilder returns a new instance of MapBuilder. -func NewMapBuilder[K, V any](hasher Hasher[K]) *MapBuilder[K, V] { - return &MapBuilder[K, V]{m: NewMap[K, V](hasher)} -} - -// Map returns the underlying map. Only call once. -// Builder is invalid after call. Will panic on second invocation. -func (b *MapBuilder[K, V]) Map() *Map[K, V] { - assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") - m := b.m - b.m = nil - return m -} - -// Len returns the number of elements in the underlying map. -func (b *MapBuilder[K, V]) Len() int { - assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") - return b.m.Len() -} - -// Get returns the value for the given key. -func (b *MapBuilder[K, V]) Get(key K) (value V, ok bool) { - assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") - return b.m.Get(key) -} - -// Set sets the value of the given key. See Map.Set() for additional details. -func (b *MapBuilder[K, V]) Set(key K, value V) { - assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") - b.m = b.m.set(key, value, true) -} - -// Delete removes the given key. See Map.Delete() for additional details. -func (b *MapBuilder[K, V]) Delete(key K) { - assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") - b.m = b.m.delete(key, true) -} - -// Iterator returns a new iterator for the underlying map. -func (b *MapBuilder[K, V]) Iterator() *MapIterator[K, V] { - assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") - return b.m.Iterator() -} - -// mapNode represents any node in the map tree. -type mapNode[K, V any] interface { - get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) - set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] - delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] -} - -var _ mapNode[string, any] = (*mapArrayNode[string, any])(nil) -var _ mapNode[string, any] = (*mapBitmapIndexedNode[string, any])(nil) -var _ mapNode[string, any] = (*mapHashArrayNode[string, any])(nil) -var _ mapNode[string, any] = (*mapValueNode[string, any])(nil) -var _ mapNode[string, any] = (*mapHashCollisionNode[string, any])(nil) - -// mapLeafNode represents a node that stores a single key hash at the leaf of the map tree. -type mapLeafNode[K, V any] interface { - mapNode[K, V] - keyHashValue() uint32 -} - -var _ mapLeafNode[string, any] = (*mapValueNode[string, any])(nil) -var _ mapLeafNode[string, any] = (*mapHashCollisionNode[string, any])(nil) - -// mapArrayNode is a map node that stores key/value pairs in a slice. -// Entries are stored in insertion order. An array node expands into a bitmap -// indexed node once a given threshold size is crossed. -type mapArrayNode[K, V any] struct { - entries []mapEntry[K, V] -} - -// indexOf returns the entry index of the given key. Returns -1 if key not found. -func (n *mapArrayNode[K, V]) indexOf(key K, h Hasher[K]) int { - for i := range n.entries { - if h.Equal(n.entries[i].key, key) { - return i - } - } - return -1 -} - -// get returns the value for the given key. -func (n *mapArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { - i := n.indexOf(key, h) - if i == -1 { - return value, false - } - return n.entries[i].value, true -} - -// set inserts or updates the value for a given key. If the key is inserted and -// the new size crosses the max size threshold, a bitmap indexed node is returned. -func (n *mapArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - idx := n.indexOf(key, h) - - // Mark as resized if the key doesn't exist. - if idx == -1 { - *resized = true - } - - // If we are adding and it crosses the max size threshold, expand the node. - // We do this by continually setting the entries to a value node and expanding. - if idx == -1 && len(n.entries) >= maxArrayMapSize { - var node mapNode[K, V] = newMapValueNode(h.Hash(key), key, value) - for _, entry := range n.entries { - node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, false, resized) - } - return node - } - - // Update in-place if mutable. - if mutable { - if idx != -1 { - n.entries[idx] = mapEntry[K, V]{key, value} - } else { - n.entries = append(n.entries, mapEntry[K, V]{key, value}) - } - return n - } - - // Update existing entry if a match is found. - // Otherwise append to the end of the element list if it doesn't exist. - var other mapArrayNode[K, V] - if idx != -1 { - other.entries = make([]mapEntry[K, V], len(n.entries)) - copy(other.entries, n.entries) - other.entries[idx] = mapEntry[K, V]{key, value} - } else { - other.entries = make([]mapEntry[K, V], len(n.entries)+1) - copy(other.entries, n.entries) - other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} - } - return &other -} - -// delete removes the given key from the node. Returns the same node if key does -// not exist. Returns a nil node when removing the last entry. -func (n *mapArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - idx := n.indexOf(key, h) - - // Return original node if key does not exist. - if idx == -1 { - return n - } - *resized = true - - // Return nil if this node will contain no nodes. - if len(n.entries) == 1 { - return nil - } - - // Update in-place, if mutable. - if mutable { - copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry[K, V]{} - n.entries = n.entries[:len(n.entries)-1] - return n - } - - // Otherwise create a copy with the given entry removed. - other := &mapArrayNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} - copy(other.entries[:idx], n.entries[:idx]) - copy(other.entries[idx:], n.entries[idx+1:]) - return other -} - -// mapBitmapIndexedNode represents a map branch node with a variable number of -// node slots and indexed using a bitmap. Indexes for the node slots are -// calculated by counting the number of set bits before the target bit using popcount. -type mapBitmapIndexedNode[K, V any] struct { - bitmap uint32 - nodes []mapNode[K, V] -} - -// get returns the value for the given key. -func (n *mapBitmapIndexedNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { - bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) - if (n.bitmap & bit) == 0 { - return value, false - } - child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))] - return child.get(key, shift+mapNodeBits, keyHash, h) -} - -// set inserts or updates the value for the given key. If a new key is inserted -// and the size crosses the max size threshold then a hash array node is returned. -func (n *mapBitmapIndexedNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - // Extract the index for the bit segment of the key hash. - keyHashFrag := (keyHash >> shift) & mapNodeMask - - // Determine the bit based on the hash index. - bit := uint32(1) << keyHashFrag - exists := (n.bitmap & bit) != 0 - - // Mark as resized if the key doesn't exist. - if !exists { - *resized = true - } - - // Find index of node based on popcount of bits before it. - idx := bits.OnesCount32(n.bitmap & (bit - 1)) - - // If the node already exists, delegate set operation to it. - // If the node doesn't exist then create a simple value leaf node. - var newNode mapNode[K, V] - if exists { - newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized) - } else { - newNode = newMapValueNode(keyHash, key, value) - } - - // Convert to a hash-array node once we exceed the max bitmap size. - // Copy each node based on their bit position within the bitmap. - if !exists && len(n.nodes) > maxBitmapIndexedSize { - var other mapHashArrayNode[K, V] - for i := uint(0); i < uint(len(other.nodes)); i++ { - if n.bitmap&(uint32(1)<> shift) & mapNodeMask) - - // Return original node if key does not exist. - if (n.bitmap & bit) == 0 { - return n - } - - // Find index of node based on popcount of bits before it. - idx := bits.OnesCount32(n.bitmap & (bit - 1)) - - // Delegate delete to child node. - child := n.nodes[idx] - newChild := child.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized) - - // Return original node if key doesn't exist in child. - if !*resized { - return n - } - - // Remove if returned child has been deleted. - if newChild == nil { - // If we won't have any children then return nil. - if len(n.nodes) == 1 { - return nil - } - - // Update in-place if mutable. - if mutable { - n.bitmap ^= bit - copy(n.nodes[idx:], n.nodes[idx+1:]) - n.nodes[len(n.nodes)-1] = nil - n.nodes = n.nodes[:len(n.nodes)-1] - return n - } - - // Return copy with bit removed from bitmap and node removed from node list. - other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap ^ bit, nodes: make([]mapNode[K, V], len(n.nodes)-1)} - copy(other.nodes[:idx], n.nodes[:idx]) - copy(other.nodes[idx:], n.nodes[idx+1:]) - return other - } - - // Generate copy, if necessary. - other := n - if !mutable { - other = &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap, nodes: make([]mapNode[K, V], len(n.nodes))} - copy(other.nodes, n.nodes) - } - - // Update child. - other.nodes[idx] = newChild - return other -} - -// mapHashArrayNode is a map branch node that stores nodes in a fixed length -// array. Child nodes are indexed by their index bit segment for the current depth. -type mapHashArrayNode[K, V any] struct { - count uint // number of set nodes - nodes [mapNodeSize]mapNode[K, V] // child node slots, may contain empties -} - -// clone returns a shallow copy of n. -func (n *mapHashArrayNode[K, V]) clone() *mapHashArrayNode[K, V] { - other := *n - return &other -} - -// get returns the value for the given key. -func (n *mapHashArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { - node := n.nodes[(keyHash>>shift)&mapNodeMask] - if node == nil { - return value, false - } - return node.get(key, shift+mapNodeBits, keyHash, h) -} - -// set returns a node with the value set for the given key. -func (n *mapHashArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - idx := (keyHash >> shift) & mapNodeMask - node := n.nodes[idx] - - // If node at index doesn't exist, create a simple value leaf node. - // Otherwise delegate set to child node. - var newNode mapNode[K, V] - if node == nil { - *resized = true - newNode = newMapValueNode(keyHash, key, value) - } else { - newNode = node.set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized) - } - - // Generate copy, if necessary. - other := n - if !mutable { - other = n.clone() - } - - // Update child node (and update size, if new). - if node == nil { - other.count++ - } - other.nodes[idx] = newNode - return other -} - -// delete returns a node with the given key removed. Returns the same node if -// the key does not exist. If node shrinks to within bitmap-indexed size then -// converts to a bitmap-indexed node. -func (n *mapHashArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - idx := (keyHash >> shift) & mapNodeMask - node := n.nodes[idx] - - // Return original node if child is not found. - if node == nil { - return n - } - - // Return original node if child is unchanged. - newNode := node.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized) - if !*resized { - return n - } - - // If we remove a node and drop below a threshold, convert back to bitmap indexed node. - if newNode == nil && n.count <= maxBitmapIndexedSize { - other := &mapBitmapIndexedNode[K, V]{nodes: make([]mapNode[K, V], 0, n.count-1)} - for i, child := range n.nodes { - if child != nil && uint32(i) != idx { - other.bitmap |= 1 << uint(i) - other.nodes = append(other.nodes, child) - } - } - return other - } - - // Generate copy, if necessary. - other := n - if !mutable { - other = n.clone() - } - - // Return copy of node with child updated. - other.nodes[idx] = newNode - if newNode == nil { - other.count-- - } - return other -} - -// mapValueNode represents a leaf node with a single key/value pair. -// A value node can be converted to a hash collision leaf node if a different -// key with the same keyHash is inserted. -type mapValueNode[K, V any] struct { - keyHash uint32 - key K - value V -} - -// newMapValueNode returns a new instance of mapValueNode. -func newMapValueNode[K, V any](keyHash uint32, key K, value V) *mapValueNode[K, V] { - return &mapValueNode[K, V]{ - keyHash: keyHash, - key: key, - value: value, - } -} - -// keyHashValue returns the key hash for this node. -func (n *mapValueNode[K, V]) keyHashValue() uint32 { - return n.keyHash -} - -// get returns the value for the given key. -func (n *mapValueNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { - if !h.Equal(n.key, key) { - return value, false - } - return n.value, true -} - -// set returns a new node with the new value set for the key. If the key equals -// the node's key then a new value node is returned. If key is not equal to the -// node's key but has the same hash then a hash collision node is returned. -// Otherwise the nodes are merged into a branch node. -func (n *mapValueNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - // If the keys match then return a new value node overwriting the value. - if h.Equal(n.key, key) { - // Update in-place if mutable. - if mutable { - n.value = value - return n - } - // Otherwise return a new copy. - return newMapValueNode(n.keyHash, key, value) - } - - *resized = true - - // Recursively merge nodes together if key hashes are different. - if n.keyHash != keyHash { - return mergeIntoNode[K, V](n, shift, keyHash, key, value) - } - - // Merge into collision node if hash matches. - return &mapHashCollisionNode[K, V]{keyHash: keyHash, entries: []mapEntry[K, V]{ - {key: n.key, value: n.value}, - {key: key, value: value}, - }} -} - -// delete returns nil if the key matches the node's key. Otherwise returns the original node. -func (n *mapValueNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - // Return original node if the keys do not match. - if !h.Equal(n.key, key) { - return n - } - - // Otherwise remove the node if keys do match. - *resized = true - return nil -} - -// mapHashCollisionNode represents a leaf node that contains two or more key/value -// pairs with the same key hash. Single pairs for a hash are stored as value nodes. -type mapHashCollisionNode[K, V any] struct { - keyHash uint32 // key hash for all entries - entries []mapEntry[K, V] -} - -// keyHashValue returns the key hash for all entries on the node. -func (n *mapHashCollisionNode[K, V]) keyHashValue() uint32 { - return n.keyHash -} - -// indexOf returns the index of the entry for the given key. -// Returns -1 if the key does not exist in the node. -func (n *mapHashCollisionNode[K, V]) indexOf(key K, h Hasher[K]) int { - for i := range n.entries { - if h.Equal(n.entries[i].key, key) { - return i - } - } - return -1 -} - -// get returns the value for the given key. -func (n *mapHashCollisionNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { - for i := range n.entries { - if h.Equal(n.entries[i].key, key) { - return n.entries[i].value, true - } - } - return value, false -} - -// set returns a copy of the node with key set to the given value. -func (n *mapHashCollisionNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - // Merge node with key/value pair if this is not a hash collision. - if n.keyHash != keyHash { - *resized = true - return mergeIntoNode[K, V](n, shift, keyHash, key, value) - } - - // Update in-place if mutable. - if mutable { - if idx := n.indexOf(key, h); idx == -1 { - *resized = true - n.entries = append(n.entries, mapEntry[K, V]{key, value}) - } else { - n.entries[idx] = mapEntry[K, V]{key, value} - } - return n - } - - // Append to end of node if key doesn't exist & mark resized. - // Otherwise copy nodes and overwrite at matching key index. - other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash} - if idx := n.indexOf(key, h); idx == -1 { - *resized = true - other.entries = make([]mapEntry[K, V], len(n.entries)+1) - copy(other.entries, n.entries) - other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} - } else { - other.entries = make([]mapEntry[K, V], len(n.entries)) - copy(other.entries, n.entries) - other.entries[idx] = mapEntry[K, V]{key, value} - } - return other -} - -// delete returns a node with the given key deleted. Returns the same node if -// the key does not exist. If removing the key would shrink the node to a single -// entry then a value node is returned. -func (n *mapHashCollisionNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { - idx := n.indexOf(key, h) - - // Return original node if key is not found. - if idx == -1 { - return n - } - - // Mark as resized if key exists. - *resized = true - - // Convert to value node if we move to one entry. - if len(n.entries) == 2 { - return &mapValueNode[K, V]{ - keyHash: n.keyHash, - key: n.entries[idx^1].key, - value: n.entries[idx^1].value, - } - } - - // Remove entry in-place if mutable. - if mutable { - copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry[K, V]{} - n.entries = n.entries[:len(n.entries)-1] - return n - } - - // Return copy without entry if immutable. - other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash, entries: make([]mapEntry[K, V], len(n.entries)-1)} - copy(other.entries[:idx], n.entries[:idx]) - copy(other.entries[idx:], n.entries[idx+1:]) - return other -} - -// mergeIntoNode merges a key/value pair into an existing node. -// Caller must verify that node's keyHash is not equal to keyHash. -func mergeIntoNode[K, V any](node mapLeafNode[K, V], shift uint, keyHash uint32, key K, value V) mapNode[K, V] { - idx1 := (node.keyHashValue() >> shift) & mapNodeMask - idx2 := (keyHash >> shift) & mapNodeMask - - // Recursively build branch nodes to combine the node and its key. - other := &mapBitmapIndexedNode[K, V]{bitmap: (1 << idx1) | (1 << idx2)} - if idx1 == idx2 { - other.nodes = []mapNode[K, V]{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)} - } else { - if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 { - other.nodes = []mapNode[K, V]{node, newNode} - } else { - other.nodes = []mapNode[K, V]{newNode, node} - } - } - return other -} - -// mapEntry represents a single key/value pair. -type mapEntry[K, V any] struct { - key K - value V -} - -// MapIterator represents an iterator over a map's key/value pairs. Although -// map keys are not sorted, the iterator's order is deterministic. -type MapIterator[K, V any] struct { - m *Map[K, V] // source map - - stack [32]mapIteratorElem[K, V] // search stack - depth int // stack depth -} - -// Done returns true if no more elements remain in the iterator. -func (itr *MapIterator[K, V]) Done() bool { - return itr.depth == -1 -} - -// First resets the iterator to the first key/value pair. -func (itr *MapIterator[K, V]) First() { - // Exit immediately if the map is empty. - if itr.m.root == nil { - itr.depth = -1 - return - } - - // Initialize the stack to the left most element. - itr.stack[0] = mapIteratorElem[K, V]{node: itr.m.root} - itr.depth = 0 - itr.first() -} - -// Next returns the next key/value pair. Returns a nil key when no elements remain. -func (itr *MapIterator[K, V]) Next() (key K, value V, ok bool) { - // Return nil key if iteration is done. - if itr.Done() { - return key, value, false - } - - // Retrieve current index & value. Current node is always a leaf. - elem := &itr.stack[itr.depth] - switch node := elem.node.(type) { - case *mapArrayNode[K, V]: - entry := &node.entries[elem.index] - key, value = entry.key, entry.value - case *mapValueNode[K, V]: - key, value = node.key, node.value - case *mapHashCollisionNode[K, V]: - entry := &node.entries[elem.index] - key, value = entry.key, entry.value - } - - // Move up stack until we find a node that has remaining position ahead - // and move that element forward by one. - itr.next() - return key, value, true -} - -// next moves to the next available key. -func (itr *MapIterator[K, V]) next() { - for ; itr.depth >= 0; itr.depth-- { - elem := &itr.stack[itr.depth] - - switch node := elem.node.(type) { - case *mapArrayNode[K, V]: - if elem.index < len(node.entries)-1 { - elem.index++ - return - } - - case *mapBitmapIndexedNode[K, V]: - if elem.index < len(node.nodes)-1 { - elem.index++ - itr.stack[itr.depth+1].node = node.nodes[elem.index] - itr.depth++ - itr.first() - return - } - - case *mapHashArrayNode[K, V]: - for i := elem.index + 1; i < len(node.nodes); i++ { - if node.nodes[i] != nil { - elem.index = i - itr.stack[itr.depth+1].node = node.nodes[elem.index] - itr.depth++ - itr.first() - return - } - } - - case *mapValueNode[K, V]: - continue // always the last value, traverse up - - case *mapHashCollisionNode[K, V]: - if elem.index < len(node.entries)-1 { - elem.index++ - return - } - } - } -} - -// first positions the stack left most index. -// Elements and indexes at and below the current depth are assumed to be correct. -func (itr *MapIterator[K, V]) first() { - for ; ; itr.depth++ { - elem := &itr.stack[itr.depth] - - switch node := elem.node.(type) { - case *mapBitmapIndexedNode[K, V]: - elem.index = 0 - itr.stack[itr.depth+1].node = node.nodes[0] - - case *mapHashArrayNode[K, V]: - for i := 0; i < len(node.nodes); i++ { - if node.nodes[i] != nil { // find first node - elem.index = i - itr.stack[itr.depth+1].node = node.nodes[i] - break - } - } - - default: // *mapArrayNode, mapLeafNode - elem.index = 0 - return - } - } -} - -// mapIteratorElem represents a node/index pair in the MapIterator stack. -type mapIteratorElem[K, V any] struct { - node mapNode[K, V] - index int -} - -// Sorted map child node limit size. -const ( - sortedMapNodeSize = 32 -) - -// SortedMap represents a map of key/value pairs sorted by key. The sort order -// is determined by the Comparer used by the map. -// -// This map is implemented as a B+tree. -type SortedMap[K, V any] struct { - size int // total number of key/value pairs - root sortedMapNode[K, V] // root of b+tree - comparer Comparer[K] -} - -// NewSortedMap returns a new instance of SortedMap. If comparer is nil then -// a default comparer is set after the first key is inserted. Default comparers -// exist for int, string, and byte slice keys. -func NewSortedMap[K, V any](comparer Comparer[K]) *SortedMap[K, V] { - return &SortedMap[K, V]{ - comparer: comparer, - } -} - -// NewSortedMapOf returns a new instance of SortedMap, containing a map of provided entries. -// -// If comparer is nil then a default comparer is set after the first key is inserted. Default comparers -// exist for int, string, and byte slice keys. -func NewSortedMapOf[K comparable, V any](comparer Comparer[K], entries map[K]V) *SortedMap[K, V] { - m := &SortedMap[K, V]{ - comparer: comparer, - } - for k, v := range entries { - m.set(k, v, true) - } - return m -} - -// Len returns the number of elements in the sorted map. -func (m *SortedMap[K, V]) Len() int { - return m.size -} - -// Get returns the value for a given key and a flag indicating if the key is set. -// The flag can be used to distinguish between a nil-set key versus an unset key. -func (m *SortedMap[K, V]) Get(key K) (V, bool) { - if m.root == nil { - var v V - return v, false - } - return m.root.get(key, m.comparer) -} - -// Set returns a copy of the map with the key set to the given value. -func (m *SortedMap[K, V]) Set(key K, value V) *SortedMap[K, V] { - return m.set(key, value, false) -} - -func (m *SortedMap[K, V]) set(key K, value V, mutable bool) *SortedMap[K, V] { - // Set a comparer on the first value if one does not already exist. - comparer := m.comparer - if comparer == nil { - comparer = NewComparer(key) - } - - // Create copy, if necessary. - other := m - if !mutable { - other = m.clone() - } - other.comparer = comparer - - // If no values are set then initialize with a leaf node. - if m.root == nil { - other.size = 1 - other.root = &sortedMapLeafNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} - return other - } - - // Otherwise delegate to root node. - // If a split occurs then grow the tree from the root. - var resized bool - newRoot, splitNode := m.root.set(key, value, comparer, mutable, &resized) - if splitNode != nil { - newRoot = newSortedMapBranchNode(newRoot, splitNode) - } - - // Update root and size (if resized). - other.size = m.size - other.root = newRoot - if resized { - other.size++ - } - return other -} - -// Delete returns a copy of the map with the key removed. -// Returns the original map if key does not exist. -func (m *SortedMap[K, V]) Delete(key K) *SortedMap[K, V] { - return m.delete(key, false) -} - -func (m *SortedMap[K, V]) delete(key K, mutable bool) *SortedMap[K, V] { - // Return original map if no keys exist. - if m.root == nil { - return m - } - - // If the delete did not change the node then return the original map. - var resized bool - newRoot := m.root.delete(key, m.comparer, mutable, &resized) - if !resized { - return m - } - - // Create copy, if necessary. - other := m - if !mutable { - other = m.clone() - } - - // Update root and size. - other.size = m.size - 1 - other.root = newRoot - return other -} - -// clone returns a shallow copy of m. -func (m *SortedMap[K, V]) clone() *SortedMap[K, V] { - other := *m - return &other -} - -// Iterator returns a new iterator for this map positioned at the first key. -func (m *SortedMap[K, V]) Iterator() *SortedMapIterator[K, V] { - itr := &SortedMapIterator[K, V]{m: m} - itr.First() - return itr -} - -// SortedMapBuilder represents an efficient builder for creating sorted maps. -type SortedMapBuilder[K, V any] struct { - m *SortedMap[K, V] // current state -} - -// NewSortedMapBuilder returns a new instance of SortedMapBuilder. -func NewSortedMapBuilder[K, V any](comparer Comparer[K]) *SortedMapBuilder[K, V] { - return &SortedMapBuilder[K, V]{m: NewSortedMap[K, V](comparer)} -} - -// SortedMap returns the current copy of the map. -// The returned map is safe to use even if after the builder continues to be used. -func (b *SortedMapBuilder[K, V]) Map() *SortedMap[K, V] { - assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") - m := b.m - b.m = nil - return m -} - -// Len returns the number of elements in the underlying map. -func (b *SortedMapBuilder[K, V]) Len() int { - assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") - return b.m.Len() -} - -// Get returns the value for the given key. -func (b *SortedMapBuilder[K, V]) Get(key K) (value V, ok bool) { - assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") - return b.m.Get(key) -} - -// Set sets the value of the given key. See SortedMap.Set() for additional details. -func (b *SortedMapBuilder[K, V]) Set(key K, value V) { - assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") - b.m = b.m.set(key, value, true) -} - -// Delete removes the given key. See SortedMap.Delete() for additional details. -func (b *SortedMapBuilder[K, V]) Delete(key K) { - assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") - b.m = b.m.delete(key, true) -} - -// Iterator returns a new iterator for the underlying map positioned at the first key. -func (b *SortedMapBuilder[K, V]) Iterator() *SortedMapIterator[K, V] { - assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") - return b.m.Iterator() -} - -// sortedMapNode represents a branch or leaf node in the sorted map. -type sortedMapNode[K, V any] interface { - minKey() K - indexOf(key K, c Comparer[K]) int - get(key K, c Comparer[K]) (value V, ok bool) - set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) - delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] -} - -var _ sortedMapNode[string, any] = (*sortedMapBranchNode[string, any])(nil) -var _ sortedMapNode[string, any] = (*sortedMapLeafNode[string, any])(nil) - -// sortedMapBranchNode represents a branch in the sorted map. -type sortedMapBranchNode[K, V any] struct { - elems []sortedMapBranchElem[K, V] -} - -// newSortedMapBranchNode returns a new branch node with the given child nodes. -func newSortedMapBranchNode[K, V any](children ...sortedMapNode[K, V]) *sortedMapBranchNode[K, V] { - // Fetch min keys for every child. - elems := make([]sortedMapBranchElem[K, V], len(children)) - for i, child := range children { - elems[i] = sortedMapBranchElem[K, V]{ - key: child.minKey(), - node: child, - } - } - - return &sortedMapBranchNode[K, V]{elems: elems} -} - -// minKey returns the lowest key stored in this node's tree. -func (n *sortedMapBranchNode[K, V]) minKey() K { - return n.elems[0].node.minKey() -} - -// indexOf returns the index of the key within the child nodes. -func (n *sortedMapBranchNode[K, V]) indexOf(key K, c Comparer[K]) int { - if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 { - return idx - 1 - } - return 0 -} - -// get returns the value for the given key. -func (n *sortedMapBranchNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { - idx := n.indexOf(key, c) - return n.elems[idx].node.get(key, c) -} - -// set returns a copy of the node with the key set to the given value. -func (n *sortedMapBranchNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { - idx := n.indexOf(key, c) - - // Delegate insert to child node. - newNode, splitNode := n.elems[idx].node.set(key, value, c, mutable, resized) - - // Update in-place, if mutable. - if mutable { - n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} - if splitNode != nil { - n.elems = append(n.elems, sortedMapBranchElem[K, V]{}) - copy(n.elems[idx+1:], n.elems[idx:]) - n.elems[idx+1] = sortedMapBranchElem[K, V]{key: splitNode.minKey(), node: splitNode} - } - - // If the child splits and we have no more room then we split too. - if len(n.elems) > sortedMapNodeSize { - splitIdx := len(n.elems) / 2 - newNode := &sortedMapBranchNode[K, V]{elems: n.elems[:splitIdx:splitIdx]} - splitNode := &sortedMapBranchNode[K, V]{elems: n.elems[splitIdx:]} - return newNode, splitNode - } - return n, nil - } - - // If no split occurs, copy branch and update keys. - // If the child splits, insert new key/child into copy of branch. - var other sortedMapBranchNode[K, V] - if splitNode == nil { - other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)) - copy(other.elems, n.elems) - other.elems[idx] = sortedMapBranchElem[K, V]{ - key: newNode.minKey(), - node: newNode, - } - } else { - other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)+1) - copy(other.elems[:idx], n.elems[:idx]) - copy(other.elems[idx+1:], n.elems[idx:]) - other.elems[idx] = sortedMapBranchElem[K, V]{ - key: newNode.minKey(), - node: newNode, - } - other.elems[idx+1] = sortedMapBranchElem[K, V]{ - key: splitNode.minKey(), - node: splitNode, - } - } - - // If the child splits and we have no more room then we split too. - if len(other.elems) > sortedMapNodeSize { - splitIdx := len(other.elems) / 2 - newNode := &sortedMapBranchNode[K, V]{elems: other.elems[:splitIdx:splitIdx]} - splitNode := &sortedMapBranchNode[K, V]{elems: other.elems[splitIdx:]} - return newNode, splitNode - } - - // Otherwise return the new branch node with the updated entry. - return &other, nil -} - -// delete returns a node with the key removed. Returns the same node if the key -// does not exist. Returns nil if all child nodes are removed. -func (n *sortedMapBranchNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { - idx := n.indexOf(key, c) - - // Return original node if child has not changed. - newNode := n.elems[idx].node.delete(key, c, mutable, resized) - if !*resized { - return n - } - - // Remove child if it is now nil. - if newNode == nil { - // If this node will become empty then simply return nil. - if len(n.elems) == 1 { - return nil - } - - // If mutable, update in-place. - if mutable { - copy(n.elems[idx:], n.elems[idx+1:]) - n.elems[len(n.elems)-1] = sortedMapBranchElem[K, V]{} - n.elems = n.elems[:len(n.elems)-1] - return n - } - - // Return a copy without the given node. - other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems)-1)} - copy(other.elems[:idx], n.elems[:idx]) - copy(other.elems[idx:], n.elems[idx+1:]) - return other - } - - // If mutable, update in-place. - if mutable { - n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} - return n - } - - // Return a copy with the updated node. - other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems))} - copy(other.elems, n.elems) - other.elems[idx] = sortedMapBranchElem[K, V]{ - key: newNode.minKey(), - node: newNode, - } - return other -} - -type sortedMapBranchElem[K, V any] struct { - key K - node sortedMapNode[K, V] -} - -// sortedMapLeafNode represents a leaf node in the sorted map. -type sortedMapLeafNode[K, V any] struct { - entries []mapEntry[K, V] -} - -// minKey returns the first key stored in this node. -func (n *sortedMapLeafNode[K, V]) minKey() K { - return n.entries[0].key -} - -// indexOf returns the index of the given key. -func (n *sortedMapLeafNode[K, V]) indexOf(key K, c Comparer[K]) int { - return sort.Search(len(n.entries), func(i int) bool { - return c.Compare(n.entries[i].key, key) != -1 // GTE - }) -} - -// get returns the value of the given key. -func (n *sortedMapLeafNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { - idx := n.indexOf(key, c) - - // If the index is beyond the entry count or the key is not equal then return 'not found'. - if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { - return value, false - } - - // If the key matches then return its value. - return n.entries[idx].value, true -} - -// set returns a copy of node with the key set to the given value. If the update -// causes the node to grow beyond the maximum size then it is split in two. -func (n *sortedMapLeafNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { - // Find the insertion index for the key. - idx := n.indexOf(key, c) - exists := idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0 - - // Update in-place, if mutable. - if mutable { - if !exists { - *resized = true - n.entries = append(n.entries, mapEntry[K, V]{}) - copy(n.entries[idx+1:], n.entries[idx:]) - } - n.entries[idx] = mapEntry[K, V]{key: key, value: value} - - // If the key doesn't exist and we exceed our max allowed values then split. - if len(n.entries) > sortedMapNodeSize { - splitIdx := len(n.entries) / 2 - newNode := &sortedMapLeafNode[K, V]{entries: n.entries[:splitIdx:splitIdx]} - splitNode := &sortedMapLeafNode[K, V]{entries: n.entries[splitIdx:]} - return newNode, splitNode - } - return n, nil - } - - // If the key matches then simply return a copy with the entry overridden. - // If there is no match then insert new entry and mark as resized. - var newEntries []mapEntry[K, V] - if exists { - newEntries = make([]mapEntry[K, V], len(n.entries)) - copy(newEntries, n.entries) - newEntries[idx] = mapEntry[K, V]{key: key, value: value} - } else { - *resized = true - newEntries = make([]mapEntry[K, V], len(n.entries)+1) - copy(newEntries[:idx], n.entries[:idx]) - newEntries[idx] = mapEntry[K, V]{key: key, value: value} - copy(newEntries[idx+1:], n.entries[idx:]) - } - - // If the key doesn't exist and we exceed our max allowed values then split. - if len(newEntries) > sortedMapNodeSize { - splitIdx := len(newEntries) / 2 - newNode := &sortedMapLeafNode[K, V]{entries: newEntries[:splitIdx:splitIdx]} - splitNode := &sortedMapLeafNode[K, V]{entries: newEntries[splitIdx:]} - return newNode, splitNode - } - - // Otherwise return the new leaf node with the updated entry. - return &sortedMapLeafNode[K, V]{entries: newEntries}, nil -} - -// delete returns a copy of node with key removed. Returns the original node if -// the key does not exist. Returns nil if the removed key is the last remaining key. -func (n *sortedMapLeafNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { - idx := n.indexOf(key, c) - - // Return original node if key is not found. - if idx >= len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { - return n - } - *resized = true - - // If this is the last entry then return nil. - if len(n.entries) == 1 { - return nil - } - - // Update in-place, if mutable. - if mutable { - copy(n.entries[idx:], n.entries[idx+1:]) - n.entries[len(n.entries)-1] = mapEntry[K, V]{} - n.entries = n.entries[:len(n.entries)-1] - return n - } - - // Return copy of node with entry removed. - other := &sortedMapLeafNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} - copy(other.entries[:idx], n.entries[:idx]) - copy(other.entries[idx:], n.entries[idx+1:]) - return other -} - -// SortedMapIterator represents an iterator over a sorted map. -// Iteration can occur in natural or reverse order based on use of Next() or Prev(). -type SortedMapIterator[K, V any] struct { - m *SortedMap[K, V] // source map - - stack [32]sortedMapIteratorElem[K, V] // search stack - depth int // stack depth -} - -// Done returns true if no more key/value pairs remain in the iterator. -func (itr *SortedMapIterator[K, V]) Done() bool { - return itr.depth == -1 -} - -// First moves the iterator to the first key/value pair. -func (itr *SortedMapIterator[K, V]) First() { - if itr.m.root == nil { - itr.depth = -1 - return - } - itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} - itr.depth = 0 - itr.first() -} - -// Last moves the iterator to the last key/value pair. -func (itr *SortedMapIterator[K, V]) Last() { - if itr.m.root == nil { - itr.depth = -1 - return - } - itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} - itr.depth = 0 - itr.last() -} - -// Seek moves the iterator position to the given key in the map. -// If the key does not exist then the next key is used. If no more keys exist -// then the iteartor is marked as done. -func (itr *SortedMapIterator[K, V]) Seek(key K) { - if itr.m.root == nil { - itr.depth = -1 - return - } - itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} - itr.depth = 0 - itr.seek(key) -} - -// Next returns the current key/value pair and moves the iterator forward. -// Returns a nil key if the there are no more elements to return. -func (itr *SortedMapIterator[K, V]) Next() (key K, value V, ok bool) { - // Return nil key if iteration is complete. - if itr.Done() { - return key, value, false - } - - // Retrieve current key/value pair. - leafElem := &itr.stack[itr.depth] - leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) - leafEntry := &leafNode.entries[leafElem.index] - key, value = leafEntry.key, leafEntry.value - - // Move to the next available key/value pair. - itr.next() - - // Only occurs when iterator is done. - return key, value, true -} - -// next moves to the next key. If no keys are after then depth is set to -1. -func (itr *SortedMapIterator[K, V]) next() { - for ; itr.depth >= 0; itr.depth-- { - elem := &itr.stack[itr.depth] - - switch node := elem.node.(type) { - case *sortedMapLeafNode[K, V]: - if elem.index < len(node.entries)-1 { - elem.index++ - return - } - case *sortedMapBranchNode[K, V]: - if elem.index < len(node.elems)-1 { - elem.index++ - itr.stack[itr.depth+1].node = node.elems[elem.index].node - itr.depth++ - itr.first() - return - } - } - } -} - -// Prev returns the current key/value pair and moves the iterator backward. -// Returns a nil key if the there are no more elements to return. -func (itr *SortedMapIterator[K, V]) Prev() (key K, value V, ok bool) { - // Return nil key if iteration is complete. - if itr.Done() { - return key, value, false - } - - // Retrieve current key/value pair. - leafElem := &itr.stack[itr.depth] - leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) - leafEntry := &leafNode.entries[leafElem.index] - key, value = leafEntry.key, leafEntry.value - - itr.prev() - return key, value, true -} - -// prev moves to the previous key. If no keys are before then depth is set to -1. -func (itr *SortedMapIterator[K, V]) prev() { - for ; itr.depth >= 0; itr.depth-- { - elem := &itr.stack[itr.depth] - - switch node := elem.node.(type) { - case *sortedMapLeafNode[K, V]: - if elem.index > 0 { - elem.index-- - return - } - case *sortedMapBranchNode[K, V]: - if elem.index > 0 { - elem.index-- - itr.stack[itr.depth+1].node = node.elems[elem.index].node - itr.depth++ - itr.last() - return - } - } - } -} - -// first positions the stack to the leftmost key from the current depth. -// Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator[K, V]) first() { - for { - elem := &itr.stack[itr.depth] - elem.index = 0 - - switch node := elem.node.(type) { - case *sortedMapBranchNode[K, V]: - itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} - itr.depth++ - case *sortedMapLeafNode[K, V]: - return - } - } -} - -// last positions the stack to the rightmost key from the current depth. -// Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator[K, V]) last() { - for { - elem := &itr.stack[itr.depth] - - switch node := elem.node.(type) { - case *sortedMapBranchNode[K, V]: - elem.index = len(node.elems) - 1 - itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} - itr.depth++ - case *sortedMapLeafNode[K, V]: - elem.index = len(node.entries) - 1 - return - } - } -} - -// seek positions the stack to the given key from the current depth. -// Elements and indexes below the current depth are assumed to be correct. -func (itr *SortedMapIterator[K, V]) seek(key K) { - for { - elem := &itr.stack[itr.depth] - elem.index = elem.node.indexOf(key, itr.m.comparer) - - switch node := elem.node.(type) { - case *sortedMapBranchNode[K, V]: - itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} - itr.depth++ - case *sortedMapLeafNode[K, V]: - if elem.index == len(node.entries) { - itr.next() - } - return - } - } -} - -// sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack. -type sortedMapIteratorElem[K, V any] struct { - node sortedMapNode[K, V] - index int -} - -// Hasher hashes keys and checks them for equality. -type Hasher[K any] interface { - // Computes a hash for key. - Hash(key K) uint32 - - // Returns true if a and b are equal. - Equal(a, b K) bool -} - -// NewHasher returns the built-in hasher for a given key type. -func NewHasher[K any](key K) Hasher[K] { - // Attempt to use non-reflection based hasher first. - switch (any(key)).(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: - return &defaultHasher[K]{} - } - - // Fallback to reflection-based hasher otherwise. - // This is used when caller wraps a type around a primitive type. - switch reflect.TypeOf(key).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: - return &reflectHasher[K]{} - } - - // If no hashers match then panic. - // This is a compile time issue so it should not return an error. - panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key)) -} - -// Hash returns a hash for value. -func hashString(value string) uint32 { - var hash uint32 - for i, value := 0, value; i < len(value); i++ { - hash = 31*hash + uint32(value[i]) - } - return hash -} - -// reflectIntHasher implements a reflection-based Hasher for keys. -type reflectHasher[K any] struct{} - -// Hash returns a hash for key. -func (h *reflectHasher[K]) Hash(key K) uint32 { - switch reflect.TypeOf(key).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return hashUint64(uint64(reflect.ValueOf(key).Int())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return hashUint64(reflect.ValueOf(key).Uint()) - case reflect.String: - var hash uint32 - s := reflect.ValueOf(key).String() - for i := 0; i < len(s); i++ { - hash = 31*hash + uint32(s[i]) - } - return hash - } - panic(fmt.Sprintf("immutable.reflectHasher.Hash: reflectHasher does not support %T type", key)) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not int-ish or string-ish. -func (h *reflectHasher[K]) Equal(a, b K) bool { - switch reflect.TypeOf(a).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint() - case reflect.String: - return reflect.ValueOf(a).String() == reflect.ValueOf(b).String() - } - panic(fmt.Sprintf("immutable.reflectHasher.Equal: reflectHasher does not support %T type", a)) - -} - -// hashUint64 returns a 32-bit hash for a 64-bit value. -func hashUint64(value uint64) uint32 { - hash := value - for value > 0xffffffff { - value /= 0xffffffff - hash ^= value - } - return uint32(hash) -} - -// defaultHasher implements Hasher. -type defaultHasher[K any] struct{} - -// Hash returns a hash for key. -func (h *defaultHasher[K]) Hash(key K) uint32 { - switch x := (any(key)).(type) { - case int: - return hashUint64(uint64(x)) - case int8: - return hashUint64(uint64(x)) - case int16: - return hashUint64(uint64(x)) - case int32: - return hashUint64(uint64(x)) - case int64: - return hashUint64(uint64(x)) - case uint: - return hashUint64(uint64(x)) - case uint8: - return hashUint64(uint64(x)) - case uint16: - return hashUint64(uint64(x)) - case uint32: - return hashUint64(uint64(x)) - case uint64: - return hashUint64(uint64(x)) - case uintptr: - return hashUint64(uint64(x)) - case string: - return hashString(x) - } - panic(fmt.Sprintf("immutable.defaultHasher.Hash: must set comparer for %T type", key)) -} - -// Equal returns true if a is equal to b. Otherwise returns false. -// Panics if a and b are not comparable. -func (h *defaultHasher[K]) Equal(a, b K) bool { - return any(a) == any(b) -} - -// Comparer allows the comparison of two keys for the purpose of sorting. -type Comparer[K any] interface { - // Returns -1 if a is less than b, returns 1 if a is greater than b, - // and returns 0 if a is equal to b. - Compare(a, b K) int -} - -// NewComparer returns the built-in comparer for a given key type. -// Note that only int-ish and string-ish types are supported, despite the 'comparable' constraint. -// Attempts to use other types will result in a panic - users should define their own Comparers for these cases. -func NewComparer[K any](key K) Comparer[K] { - // Attempt to use non-reflection based comparer first. - switch (any(key)).(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: - return &defaultComparer[K]{} - } - // Fallback to reflection-based comparer otherwise. - // This is used when caller wraps a type around a primitive type. - switch reflect.TypeOf(key).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: - return &reflectComparer[K]{} - } - // If no comparers match then panic. - // This is a compile time issue so it should not return an error. - panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key)) -} - -// defaultComparer compares two values (int-ish and string-ish types are supported). Implements Comparer. -type defaultComparer[K any] struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not a string or int* type -func (c *defaultComparer[K]) Compare(i K, j K) int { - switch x := (any(i)).(type) { - case int: - return defaultCompare(x, (any(j)).(int)) - case int8: - return defaultCompare(x, (any(j)).(int8)) - case int16: - return defaultCompare(x, (any(j)).(int16)) - case int32: - return defaultCompare(x, (any(j)).(int32)) - case int64: - return defaultCompare(x, (any(j)).(int64)) - case uint: - return defaultCompare(x, (any(j)).(uint)) - case uint8: - return defaultCompare(x, (any(j)).(uint8)) - case uint16: - return defaultCompare(x, (any(j)).(uint16)) - case uint32: - return defaultCompare(x, (any(j)).(uint32)) - case uint64: - return defaultCompare(x, (any(j)).(uint64)) - case uintptr: - return defaultCompare(x, (any(j)).(uintptr)) - case string: - return defaultCompare(x, (any(j)).(string)) - } - panic(fmt.Sprintf("immutable.defaultComparer: must set comparer for %T type", i)) -} - -// defaultCompare only operates on constraints.Ordered. -// For other types, users should bring their own comparers -func defaultCompare[K constraints.Ordered](i, j K) int { - if i < j { - return -1 - } else if i > j { - return 1 - } - return 0 -} - -// reflectIntComparer compares two values using reflection. Implements Comparer. -type reflectComparer[K any] struct{} - -// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and -// returns 0 if a is equal to b. Panic if a or b is not an int-ish or string-ish type. -func (c *reflectComparer[K]) Compare(a, b K) int { - switch reflect.TypeOf(a).Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j { - return -1 - } else if i > j { - return 1 - } - return 0 - case reflect.String: - return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String()) - } - panic(fmt.Sprintf("immutable.reflectComparer.Compare: must set comparer for %T type", a)) -} - -func assert(condition bool, message string) { - if !condition { - panic(message) - } -} diff --git a/immutable_test.go b/immutable_test.go deleted file mode 100644 index 581d4d8..0000000 --- a/immutable_test.go +++ /dev/null @@ -1,2544 +0,0 @@ -package immutable - -import ( - "flag" - "fmt" - "math/rand" - "sort" - "testing" - - "golang.org/x/exp/constraints" -) - -var ( - veryVerbose = flag.Bool("vv", false, "very verbose") - randomN = flag.Int("random.n", 100, "number of RunRandom() iterations") -) - -func TestList(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - if size := NewList[string]().Len(); size != 0 { - t.Fatalf("unexpected size: %d", size) - } - }) - - t.Run("Shallow", func(t *testing.T) { - list := NewList[string]() - list = list.Append("foo") - if v := list.Get(0); v != "foo" { - t.Fatalf("unexpected value: %v", v) - } - - other := list.Append("bar") - if v := other.Get(0); v != "foo" { - t.Fatalf("unexpected value: %v", v) - } else if v := other.Get(1); v != "bar" { - t.Fatalf("unexpected value: %v", v) - } - - if v := list.Len(); v != 1 { - t.Fatalf("unexpected value: %v", v) - } - }) - - t.Run("Deep", func(t *testing.T) { - list := NewList[int]() - var array []int - for i := 0; i < 100000; i++ { - list = list.Append(i) - array = append(array, i) - } - - if got, exp := len(array), list.Len(); got != exp { - t.Fatalf("List.Len()=%d, exp %d", got, exp) - } - for j := range array { - if got, exp := list.Get(j), array[j]; got != exp { - t.Fatalf("%d. List.Get(%d)=%d, exp %d", len(array), j, got, exp) - } - } - }) - - t.Run("Set", func(t *testing.T) { - list := NewList[string]() - list = list.Append("foo") - list = list.Append("bar") - - if v := list.Get(0); v != "foo" { - t.Fatalf("unexpected value: %v", v) - } - - list = list.Set(0, "baz") - if v := list.Get(0); v != "baz" { - t.Fatalf("unexpected value: %v", v) - } else if v := list.Get(1); v != "bar" { - t.Fatalf("unexpected value: %v", v) - } - }) - - t.Run("GetBelowRange", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Get(-1) - }() - if r != `immutable.List.Get: index -1 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("GetAboveRange", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Get(1) - }() - if r != `immutable.List.Get: index 1 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("SetOutOfRange", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Set(1, "bar") - }() - if r != `immutable.List.Set: index 1 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("SliceStartOutOfRange", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Slice(2, 3) - }() - if r != `immutable.List.Slice: start index 2 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("SliceEndOutOfRange", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Slice(1, 3) - }() - if r != `immutable.List.Slice: end index 3 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("SliceInvalidIndex", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l.Slice(2, 1) - }() - if r != `immutable.List.Slice: invalid slice index: [2:1]` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("SliceBeginning", func(t *testing.T) { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Slice(1, 2) - if got, exp := l.Len(), 1; got != exp { - t.Fatalf("List.Len()=%d, exp %d", got, exp) - } else if got, exp := l.Get(0), "bar"; got != exp { - t.Fatalf("List.Get(0)=%v, exp %v", got, exp) - } - }) - - t.Run("IteratorSeekOutOfBounds", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - l := NewList[string]() - l = l.Append("foo") - l.Iterator().Seek(-1) - }() - if r != `immutable.ListIterator.Seek: index -1 out of bounds` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - t.Run("TestSliceFreesReferences", func(t *testing.T) { - /* Test that the leaf node in a sliced list contains zero'ed entries at - * the correct positions. To do this we directly access the internal - * tree structure of the list. - */ - l := NewList[*int]() - var ints [5]int - for i := 0; i < 5; i++ { - l = l.Append(&ints[i]) - } - sl := l.Slice(2, 4) - - var findLeaf func(listNode[*int]) *listLeafNode[*int] - findLeaf = func(n listNode[*int]) *listLeafNode[*int] { - switch n := n.(type) { - case *listBranchNode[*int]: - if n.children[0] == nil { - t.Fatal("Failed to find leaf node due to nil child") - } - return findLeaf(n.children[0]) - case *listLeafNode[*int]: - return n - default: - panic("Unexpected case") - } - } - - leaf := findLeaf(sl.root) - if leaf.occupied != 0b1100 { - t.Errorf("Expected occupied to be 1100, was %032b", leaf.occupied) - } - - for i := 0; i < listNodeSize; i++ { - if 2 <= i && i < 4 { - if leaf.children[i] != &ints[i] { - t.Errorf("Position %v does not contain the right pointer?", i) - } - } else if leaf.children[i] != nil { - t.Errorf("Expected position %v to be cleared, was %v", i, leaf.children[i]) - } - } - }) - - t.Run("AppendImmutable", func(t *testing.T) { - outer_l := NewList[int]() - for N := 0; N < 1_000; N++ { - l1 := outer_l.Append(0) - outer_l.Append(1) - if actual := l1.Get(N); actual != 0 { - t.Fatalf("Append mutates list with %d elements. Got %d instead of 0", N, actual) - } - - outer_l = outer_l.Append(0) - } - }) - - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { - l := NewTList() - for i := 0; i < 100000; i++ { - rnd := rand.Intn(70) - switch { - case rnd == 0: // slice - start, end := l.ChooseSliceIndices(rand) - l.Slice(start, end) - case rnd < 10: // set - if l.Len() > 0 { - l.Set(l.ChooseIndex(rand), rand.Intn(10000)) - } - case rnd < 30: // prepend - l.Prepend(rand.Intn(10000)) - default: // append - l.Append(rand.Intn(10000)) - } - } - if err := l.Validate(); err != nil { - t.Fatal(err) - } - }) -} - -// TList represents a list that operates on a standard Go slice & immutable list. -type TList struct { - im, prev *List[int] - builder *ListBuilder[int] - std []int -} - -// NewTList returns a new instance of TList. -func NewTList() *TList { - return &TList{ - im: NewList[int](), - builder: NewListBuilder[int](), - } -} - -// Len returns the size of the list. -func (l *TList) Len() int { - return len(l.std) -} - -// ChooseIndex returns a randomly chosen, valid index from the standard slice. -func (l *TList) ChooseIndex(rand *rand.Rand) int { - if len(l.std) == 0 { - return -1 - } - return rand.Intn(len(l.std)) -} - -// ChooseSliceIndices returns randomly chosen, valid indices for slicing. -func (l *TList) ChooseSliceIndices(rand *rand.Rand) (start, end int) { - if len(l.std) == 0 { - return 0, 0 - } - start = rand.Intn(len(l.std)) - end = rand.Intn(len(l.std)-start) + start - return start, end -} - -// Append adds v to the end of slice and List. -func (l *TList) Append(v int) { - l.prev = l.im - l.im = l.im.Append(v) - l.builder.Append(v) - l.std = append(l.std, v) -} - -// Prepend adds v to the beginning of the slice and List. -func (l *TList) Prepend(v int) { - l.prev = l.im - l.im = l.im.Prepend(v) - l.builder.Prepend(v) - l.std = append([]int{v}, l.std...) -} - -// Set updates the value at index i to v in the slice and List. -func (l *TList) Set(i, v int) { - l.prev = l.im - l.im = l.im.Set(i, v) - l.builder.Set(i, v) - l.std[i] = v -} - -// Slice contracts the slice and List to the range of start/end indices. -func (l *TList) Slice(start, end int) { - l.prev = l.im - l.im = l.im.Slice(start, end) - l.builder.Slice(start, end) - l.std = l.std[start:end] -} - -// Validate returns an error if the slice and List are different. -func (l *TList) Validate() error { - if got, exp := l.im.Len(), len(l.std); got != exp { - return fmt.Errorf("Len()=%v, expected %d", got, exp) - } else if got, exp := l.builder.Len(), len(l.std); got != exp { - return fmt.Errorf("Len()=%v, expected %d", got, exp) - } - - for i := range l.std { - if got, exp := l.im.Get(i), l.std[i]; got != exp { - return fmt.Errorf("Get(%d)=%v, expected %v", i, got, exp) - } else if got, exp := l.builder.Get(i), l.std[i]; got != exp { - return fmt.Errorf("Builder.List/Get(%d)=%v, expected %v", i, got, exp) - } - } - - if err := l.validateForwardIterator("basic", l.im.Iterator()); err != nil { - return err - } else if err := l.validateBackwardIterator("basic", l.im.Iterator()); err != nil { - return err - } - - if err := l.validateForwardIterator("builder", l.builder.Iterator()); err != nil { - return err - } else if err := l.validateBackwardIterator("builder", l.builder.Iterator()); err != nil { - return err - } - return nil -} - -func (l *TList) validateForwardIterator(typ string, itr *ListIterator[int]) error { - for i := range l.std { - if j, v := itr.Next(); i != j || l.std[i] != v { - return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) - } - - done := i == len(l.std)-1 - if v := itr.Done(); v != done { - return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) - } - } - if i, v := itr.Next(); i != -1 || v != 0 { - return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected DONE [%s]", i, v, typ) - } - return nil -} - -func (l *TList) validateBackwardIterator(typ string, itr *ListIterator[int]) error { - itr.Last() - for i := len(l.std) - 1; i >= 0; i-- { - if j, v := itr.Prev(); i != j || l.std[i] != v { - return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) - } - - done := i == 0 - if v := itr.Done(); v != done { - return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) - } - } - if i, v := itr.Prev(); i != -1 || v != 0 { - return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected DONE [%s]", i, v, typ) - } - return nil -} - -func BenchmarkList_Append(b *testing.B) { - b.ReportAllocs() - l := NewList[int]() - for i := 0; i < b.N; i++ { - l = l.Append(i) - } -} - -func BenchmarkList_Prepend(b *testing.B) { - b.ReportAllocs() - l := NewList[int]() - for i := 0; i < b.N; i++ { - l = l.Prepend(i) - } -} - -func BenchmarkList_Set(b *testing.B) { - const n = 10000 - - l := NewList[int]() - for i := 0; i < 10000; i++ { - l = l.Append(i) - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - l = l.Set(i%n, i*10) - } -} - -func BenchmarkList_Iterator(b *testing.B) { - const n = 10000 - l := NewList[int]() - for i := 0; i < 10000; i++ { - l = l.Append(i) - } - b.ReportAllocs() - b.ResetTimer() - - b.Run("Forward", func(b *testing.B) { - itr := l.Iterator() - for i := 0; i < b.N; i++ { - if i%n == 0 { - itr.First() - } - itr.Next() - } - }) - - b.Run("Reverse", func(b *testing.B) { - itr := l.Iterator() - for i := 0; i < b.N; i++ { - if i%n == 0 { - itr.Last() - } - itr.Prev() - } - }) -} - -func BenchmarkBuiltinSlice_Append(b *testing.B) { - b.Run("Int", func(b *testing.B) { - b.ReportAllocs() - var a []int - for i := 0; i < b.N; i++ { - a = append(a, i) - } - }) - b.Run("Interface", func(b *testing.B) { - b.ReportAllocs() - var a []interface{} - for i := 0; i < b.N; i++ { - a = append(a, i) - } - }) -} - -func BenchmarkListBuilder_Append(b *testing.B) { - b.ReportAllocs() - builder := NewListBuilder[int]() - for i := 0; i < b.N; i++ { - builder.Append(i) - } -} - -func BenchmarkListBuilder_Prepend(b *testing.B) { - b.ReportAllocs() - builder := NewListBuilder[int]() - for i := 0; i < b.N; i++ { - builder.Prepend(i) - } -} - -func BenchmarkListBuilder_Set(b *testing.B) { - const n = 10000 - - builder := NewListBuilder[int]() - for i := 0; i < 10000; i++ { - builder.Append(i) - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - builder.Set(i%n, i*10) - } -} - -func ExampleList_Append() { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Append("baz") - - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - fmt.Println(l.Get(2)) - // Output: - // foo - // bar - // baz -} - -func ExampleList_Prepend() { - l := NewList[string]() - l = l.Prepend("foo") - l = l.Prepend("bar") - l = l.Prepend("baz") - - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - fmt.Println(l.Get(2)) - // Output: - // baz - // bar - // foo -} - -func ExampleList_Set() { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Set(1, "baz") - - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - // Output: - // foo - // baz -} - -func ExampleList_Slice() { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Append("baz") - l = l.Slice(1, 3) - - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - // Output: - // bar - // baz -} - -func ExampleList_Iterator() { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Append("baz") - - itr := l.Iterator() - for !itr.Done() { - i, v := itr.Next() - fmt.Println(i, v) - } - // Output: - // 0 foo - // 1 bar - // 2 baz -} - -func ExampleList_Iterator_reverse() { - l := NewList[string]() - l = l.Append("foo") - l = l.Append("bar") - l = l.Append("baz") - - itr := l.Iterator() - itr.Last() - for !itr.Done() { - i, v := itr.Prev() - fmt.Println(i, v) - } - // Output: - // 2 baz - // 1 bar - // 0 foo -} - -func ExampleListBuilder_Append() { - b := NewListBuilder[string]() - b.Append("foo") - b.Append("bar") - b.Append("baz") - - l := b.List() - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - fmt.Println(l.Get(2)) - // Output: - // foo - // bar - // baz -} - -func ExampleListBuilder_Prepend() { - b := NewListBuilder[string]() - b.Prepend("foo") - b.Prepend("bar") - b.Prepend("baz") - - l := b.List() - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - fmt.Println(l.Get(2)) - // Output: - // baz - // bar - // foo -} - -func ExampleListBuilder_Set() { - b := NewListBuilder[string]() - b.Append("foo") - b.Append("bar") - b.Set(1, "baz") - - l := b.List() - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - // Output: - // foo - // baz -} - -func ExampleListBuilder_Slice() { - b := NewListBuilder[string]() - b.Append("foo") - b.Append("bar") - b.Append("baz") - b.Slice(1, 3) - - l := b.List() - fmt.Println(l.Get(0)) - fmt.Println(l.Get(1)) - // Output: - // bar - // baz -} - -// Ensure node can support overwrites as it expands. -func TestInternal_mapNode_Overwrite(t *testing.T) { - const n = 1000 - var h defaultHasher[int] - var node mapNode[int, int] = &mapArrayNode[int, int]{} - for i := 0; i < n; i++ { - var resized bool - node = node.set(i, i, 0, h.Hash(i), &h, false, &resized) - if !resized { - t.Fatal("expected resize") - } - - // Overwrite every node. - for j := 0; j <= i; j++ { - var resized bool - node = node.set(j, i*j, 0, h.Hash(j), &h, false, &resized) - if resized { - t.Fatalf("expected no resize: i=%d, j=%d", i, j) - } - } - - // Verify not found at each branch type. - if _, ok := node.get(1000000, 0, h.Hash(1000000), &h); ok { - t.Fatal("expected no value") - } - } - - // Verify all key/value pairs in map. - for i := 0; i < n; i++ { - if v, ok := node.get(i, 0, h.Hash(i), &h); !ok || v != i*(n-1) { - t.Fatalf("get(%d)=<%v,%v>", i, v, ok) - } - } -} - -func TestInternal_mapArrayNode(t *testing.T) { - // Ensure 8 or fewer elements stays in an array node. - t.Run("Append", func(t *testing.T) { - var h defaultHasher[int] - n := &mapArrayNode[int, int]{} - for i := 0; i < 8; i++ { - var resized bool - n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) - if !resized { - t.Fatal("expected resize") - } - - for j := 0; j < i; j++ { - if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { - t.Fatalf("get(%d)=<%v,%v>", j, v, ok) - } - } - } - }) - - // Ensure 8 or fewer elements stays in an array node when inserted in reverse. - t.Run("Prepend", func(t *testing.T) { - var h defaultHasher[int] - n := &mapArrayNode[int, int]{} - for i := 7; i >= 0; i-- { - var resized bool - n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) - if !resized { - t.Fatal("expected resize") - } - - for j := i; j <= 7; j++ { - if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { - t.Fatalf("get(%d)=<%v,%v>", j, v, ok) - } - } - } - }) - - // Ensure array can transition between node types. - t.Run("Expand", func(t *testing.T) { - var h defaultHasher[int] - var n mapNode[int, int] = &mapArrayNode[int, int]{} - for i := 0; i < 100; i++ { - var resized bool - n = n.set(i, i, 0, h.Hash(i), &h, false, &resized) - if !resized { - t.Fatal("expected resize") - } - - for j := 0; j < i; j++ { - if v, ok := n.get(j, 0, h.Hash(j), &h); !ok || v != j { - t.Fatalf("get(%d)=<%v,%v>", j, v, ok) - } - } - } - }) - - // Ensure deleting elements returns the correct new node. - RunRandom(t, "Delete", func(t *testing.T, rand *rand.Rand) { - var h defaultHasher[int] - var n mapNode[int, int] = &mapArrayNode[int, int]{} - for i := 0; i < 8; i++ { - var resized bool - n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized) - } - - for _, i := range rand.Perm(8) { - var resized bool - n = n.delete(i*10, 0, h.Hash(i*10), &h, false, &resized) - } - if n != nil { - t.Fatal("expected nil rand") - } - }) -} - -func TestInternal_mapValueNode(t *testing.T) { - t.Run("Simple", func(t *testing.T) { - var h defaultHasher[int] - n := newMapValueNode(h.Hash(2), 2, 3) - if v, ok := n.get(2, 0, h.Hash(2), &h); !ok { - t.Fatal("expected ok") - } else if v != 3 { - t.Fatalf("unexpected value: %v", v) - } - }) - - t.Run("KeyEqual", func(t *testing.T) { - var h defaultHasher[int] - var resized bool - n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(2, 4, 0, h.Hash(2), &h, false, &resized).(*mapValueNode[int, int]) - if other == n { - t.Fatal("expected new node") - } else if got, exp := other.keyHash, h.Hash(2); got != exp { - t.Fatalf("keyHash=%v, expected %v", got, exp) - } else if got, exp := other.key, 2; got != exp { - t.Fatalf("key=%v, expected %v", got, exp) - } else if got, exp := other.value, 4; got != exp { - t.Fatalf("value=%v, expected %v", got, exp) - } else if resized { - t.Fatal("unexpected resize") - } - }) - - t.Run("KeyHashEqual", func(t *testing.T) { - h := &mockHasher[int]{ - hash: func(value int) uint32 { return 1 }, - equal: func(a, b int) bool { return a == b }, - } - var resized bool - n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapHashCollisionNode[int, int]) - if got, exp := other.keyHash, h.Hash(2); got != exp { - t.Fatalf("keyHash=%v, expected %v", got, exp) - } else if got, exp := len(other.entries), 2; got != exp { - t.Fatalf("entries=%v, expected %v", got, exp) - } else if !resized { - t.Fatal("expected resize") - } - if got, exp := other.entries[0].key, 2; got != exp { - t.Fatalf("key[0]=%v, expected %v", got, exp) - } else if got, exp := other.entries[0].value, 3; got != exp { - t.Fatalf("value[0]=%v, expected %v", got, exp) - } - if got, exp := other.entries[1].key, 4; got != exp { - t.Fatalf("key[1]=%v, expected %v", got, exp) - } else if got, exp := other.entries[1].value, 5; got != exp { - t.Fatalf("value[1]=%v, expected %v", got, exp) - } - }) - - t.Run("MergeNode", func(t *testing.T) { - // Inserting into a node with a different index in the mask should split into a bitmap node. - t.Run("NoConflict", func(t *testing.T) { - var h defaultHasher[int] - var resized bool - n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) - if got, exp := other.bitmap, uint32(0x14); got != exp { - t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) - } else if got, exp := len(other.nodes), 2; got != exp { - t.Fatalf("nodes=%v, expected %v", got, exp) - } else if !resized { - t.Fatal("expected resize") - } - if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) - } else if got, exp := node.key, 2; got != exp { - t.Fatalf("key[0]=%v, expected %v", got, exp) - } else if got, exp := node.value, 3; got != exp { - t.Fatalf("value[0]=%v, expected %v", got, exp) - } - if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) - } else if got, exp := node.key, 4; got != exp { - t.Fatalf("key[1]=%v, expected %v", got, exp) - } else if got, exp := node.value, 5; got != exp { - t.Fatalf("value[1]=%v, expected %v", got, exp) - } - - // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { - t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { - t.Fatalf("Get(4)=<%v,%v>", v, ok) - } - }) - - // Reversing the nodes from NoConflict should yield the same result. - t.Run("NoConflictReverse", func(t *testing.T) { - var h defaultHasher[int] - var resized bool - n := newMapValueNode(h.Hash(4), 4, 5) - other := n.set(2, 3, 0, h.Hash(2), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) - if got, exp := other.bitmap, uint32(0x14); got != exp { - t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) - } else if got, exp := len(other.nodes), 2; got != exp { - t.Fatalf("nodes=%v, expected %v", got, exp) - } else if !resized { - t.Fatal("expected resize") - } - if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) - } else if got, exp := node.key, 2; got != exp { - t.Fatalf("key[0]=%v, expected %v", got, exp) - } else if got, exp := node.value, 3; got != exp { - t.Fatalf("value[0]=%v, expected %v", got, exp) - } - if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) - } else if got, exp := node.key, 4; got != exp { - t.Fatalf("key[1]=%v, expected %v", got, exp) - } else if got, exp := node.value, 5; got != exp { - t.Fatalf("value[1]=%v, expected %v", got, exp) - } - - // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { - t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { - t.Fatalf("Get(4)=<%v,%v>", v, ok) - } - }) - - // Inserting a node with the same mask index should nest an additional level of bitmap nodes. - t.Run("Conflict", func(t *testing.T) { - h := &mockHasher[int]{ - hash: func(value int) uint32 { return uint32(value << 5) }, - equal: func(a, b int) bool { return a == b }, - } - var resized bool - n := newMapValueNode(h.Hash(2), 2, 3) - other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapBitmapIndexedNode[int, int]) - if got, exp := other.bitmap, uint32(0x01); got != exp { // mask is zero, expect first slot. - t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) - } else if got, exp := len(other.nodes), 1; got != exp { - t.Fatalf("nodes=%v, expected %v", got, exp) - } else if !resized { - t.Fatal("expected resize") - } - child, ok := other.nodes[0].(*mapBitmapIndexedNode[int, int]) - if !ok { - t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) - } - - if node, ok := child.nodes[0].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[0]=%T, unexpected type", child.nodes[0]) - } else if got, exp := node.key, 2; got != exp { - t.Fatalf("key[0]=%v, expected %v", got, exp) - } else if got, exp := node.value, 3; got != exp { - t.Fatalf("value[0]=%v, expected %v", got, exp) - } - if node, ok := child.nodes[1].(*mapValueNode[int, int]); !ok { - t.Fatalf("node[1]=%T, unexpected type", child.nodes[1]) - } else if got, exp := node.key, 4; got != exp { - t.Fatalf("key[1]=%v, expected %v", got, exp) - } else if got, exp := node.value, 5; got != exp { - t.Fatalf("value[1]=%v, expected %v", got, exp) - } - - // Ensure both values can be read. - if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v != 3 { - t.Fatalf("Get(2)=<%v,%v>", v, ok) - } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v != 5 { - t.Fatalf("Get(4)=<%v,%v>", v, ok) - } else if v, ok := other.get(10, 0, h.Hash(10), h); ok { - t.Fatalf("Get(10)=<%v,%v>, expected no value", v, ok) - } - }) - }) -} - -func TestMap_Get(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - m := NewMap[int, string](nil) - if v, ok := m.Get(100); ok { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - }) -} - -func TestMap_Set(t *testing.T) { - t.Run("Simple", func(t *testing.T) { - m := NewMap[int, string](nil) - itr := m.Iterator() - if !itr.Done() { - t.Fatal("MapIterator.Done()=true, expected false") - } else if k, v, ok := itr.Next(); ok { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) - } - }) - - t.Run("Simple", func(t *testing.T) { - m := NewMap[int, string](nil) - m = m.Set(100, "foo") - if v, ok := m.Get(100); !ok || v != "foo" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - }) - - t.Run("Multi", func(t *testing.T) { - m := NewMapOf(nil, map[int]string{1: "foo"}) - itr := m.Iterator() - if itr.Done() { - t.Fatal("MapIterator.Done()=false, expected true") - } - if k, v, ok := itr.Next(); !ok { - t.Fatalf("MapIterator.Next()!=ok, expected ok") - } else if k != 1 || v != "foo" { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected <1, \"foo\">", k, v) - } - if k, v, ok := itr.Next(); ok { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) - } - }) - - t.Run("VerySmall", func(t *testing.T) { - const n = 6 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - - // NOTE: Array nodes store entries in insertion order. - itr := m.Iterator() - for i := 0; i < n; i++ { - if k, v, ok := itr.Next(); !ok || k != i || v != i+1 { - t.Fatalf("MapIterator.Next()=<%v,%v>, exp <%v,%v>", k, v, i, i+1) - } - } - if !itr.Done() { - t.Fatal("expected iterator done") - } - }) - - t.Run("Small", func(t *testing.T) { - const n = 1000 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - t.Run("Large", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping: short") - } - - const n = 1000000 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - t.Run("StringKeys", func(t *testing.T) { - m := NewMap[string, string](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", "bat") - m = m.Set("", "EMPTY") - if v, ok := m.Get("foo"); !ok || v != "bar" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get("baz"); !ok || v != "bat" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get(""); !ok || v != "EMPTY" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - if v, ok := m.Get("no_such_key"); ok { - t.Fatalf("expected no value: <%v,%v>", v, ok) - } - }) - - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { - m := NewTestMap() - for i := 0; i < 10000; i++ { - switch rand.Intn(2) { - case 1: // overwrite - m.Set(m.ExistingKey(rand), rand.Intn(10000)) - default: // set new key - m.Set(m.NewKey(rand), rand.Intn(10000)) - } - } - if err := m.Validate(); err != nil { - t.Fatal(err) - } - }) -} - -// Ensure map can support overwrites as it expands. -func TestMap_Overwrite(t *testing.T) { - if testing.Short() { - t.Skip("short mode") - } - - const n = 10000 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - // Set original value. - m = m.Set(i, i) - - // Overwrite every node. - for j := 0; j <= i; j++ { - m = m.Set(j, i*j) - } - } - - // Verify all key/value pairs in map. - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i*(n-1) { - t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) - } - } - - t.Run("Simple", func(t *testing.T) { - m := NewMap[int, string](nil) - itr := m.Iterator() - if !itr.Done() { - t.Fatal("MapIterator.Done()=true, expected false") - } else if k, v, ok := itr.Next(); ok { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) - } - }) -} - -func TestMap_Delete(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - m := NewMap[string, int](nil) - other := m.Delete("foo") - if m != other { - t.Fatal("expected same map") - } - }) - - t.Run("Simple", func(t *testing.T) { - m := NewMap[int, string](nil) - m = m.Set(100, "foo") - if v, ok := m.Get(100); !ok || v != "foo" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - }) - - t.Run("Small", func(t *testing.T) { - const n = 1000 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := range rand.New(rand.NewSource(0)).Perm(n) { - m = m.Delete(i) - } - if m.Len() != 0 { - t.Fatalf("expected no elements, got %d", m.Len()) - } - }) - - t.Run("Large", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping: short") - } - const n = 1000000 - m := NewMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := range rand.New(rand.NewSource(0)).Perm(n) { - m = m.Delete(i) - } - if m.Len() != 0 { - t.Fatalf("expected no elements, got %d", m.Len()) - } - }) - - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { - m := NewTestMap() - for i := 0; i < 10000; i++ { - switch rand.Intn(8) { - case 0: // overwrite - m.Set(m.ExistingKey(rand), rand.Intn(10000)) - case 1: // delete existing key - m.Delete(m.ExistingKey(rand)) - case 2: // delete non-existent key. - m.Delete(m.NewKey(rand)) - default: // set new key - m.Set(m.NewKey(rand), rand.Intn(10000)) - } - } - - // Delete all and verify they are gone. - keys := make([]int, len(m.keys)) - copy(keys, m.keys) - - for _, key := range keys { - m.Delete(key) - } - if err := m.Validate(); err != nil { - t.Fatal(err) - } - }) -} - -// Ensure map works even with hash conflicts. -func TestMap_LimitedHash(t *testing.T) { - if testing.Short() { - t.Skip("skipping: short") - } - - t.Run("Immutable", func(t *testing.T) { - h := mockHasher[int]{ - hash: func(value int) uint32 { return hashUint64(uint64(value)) % 0xFF }, - equal: func(a, b int) bool { return a == b }, - } - m := NewMap[int, int](&h) - - rand := rand.New(rand.NewSource(0)) - keys := rand.Perm(100000) - for _, i := range keys { - m = m.Set(i, i) // initial set - } - for i := range keys { - m = m.Set(i, i*2) // overwrite - } - if m.Len() != len(keys) { - t.Fatalf("unexpected len: %d", m.Len()) - } - - // Verify all key/value pairs in map. - for i := 0; i < m.Len(); i++ { - if v, ok := m.Get(i); !ok || v != i*2 { - t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) - } - } - - // Verify iteration. - itr := m.Iterator() - for !itr.Done() { - if k, v, ok := itr.Next(); !ok || v != k*2 { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) - } - } - - // Verify not found works. - if _, ok := m.Get(10000000); ok { - t.Fatal("expected no value") - } - - // Verify delete non-existent key works. - if other := m.Delete(10000000 + 1); m != other { - t.Fatal("expected no change") - } - - // Remove all keys. - for _, key := range keys { - m = m.Delete(key) - } - if m.Len() != 0 { - t.Fatalf("unexpected size: %d", m.Len()) - } - }) - - t.Run("Builder", func(t *testing.T) { - h := mockHasher[int]{ - hash: func(value int) uint32 { return hashUint64(uint64(value)) }, - equal: func(a, b int) bool { return a == b }, - } - b := NewMapBuilder[int, int](&h) - - rand := rand.New(rand.NewSource(0)) - keys := rand.Perm(100000) - for _, i := range keys { - b.Set(i, i) // initial set - } - for i := range keys { - b.Set(i, i*2) // overwrite - } - if b.Len() != len(keys) { - t.Fatalf("unexpected len: %d", b.Len()) - } - - // Verify all key/value pairs in map. - for i := 0; i < b.Len(); i++ { - if v, ok := b.Get(i); !ok || v != i*2 { - t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) - } - } - - // Verify iteration. - itr := b.Iterator() - for !itr.Done() { - if k, v, ok := itr.Next(); !ok || v != k*2 { - t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) - } - } - - // Verify not found works. - if _, ok := b.Get(10000000); ok { - t.Fatal("expected no value") - } - - // Remove all keys. - for _, key := range keys { - b.Delete(key) - } - if b.Len() != 0 { - t.Fatalf("unexpected size: %d", b.Len()) - } - }) -} - -// TMap represents a combined immutable and stdlib map. -type TMap struct { - im, prev *Map[int, int] - builder *MapBuilder[int, int] - std map[int]int - keys []int -} - -func NewTestMap() *TMap { - return &TMap{ - im: NewMap[int, int](nil), - builder: NewMapBuilder[int, int](nil), - std: make(map[int]int), - } -} - -func (m *TMap) NewKey(rand *rand.Rand) int { - for { - k := rand.Int() - if _, ok := m.std[k]; !ok { - return k - } - } -} - -func (m *TMap) ExistingKey(rand *rand.Rand) int { - if len(m.keys) == 0 { - return 0 - } - return m.keys[rand.Intn(len(m.keys))] -} - -func (m *TMap) Set(k, v int) { - m.prev = m.im - m.im = m.im.Set(k, v) - m.builder.Set(k, v) - - _, exists := m.std[k] - if !exists { - m.keys = append(m.keys, k) - } - m.std[k] = v -} - -func (m *TMap) Delete(k int) { - m.prev = m.im - m.im = m.im.Delete(k) - m.builder.Delete(k) - delete(m.std, k) - - for i := range m.keys { - if m.keys[i] == k { - m.keys = append(m.keys[:i], m.keys[i+1:]...) - break - } - } -} - -func (m *TMap) Validate() error { - for _, k := range m.keys { - if v, ok := m.im.Get(k); !ok { - return fmt.Errorf("key not found: %d", k) - } else if v != m.std[k] { - return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) - } - if v, ok := m.builder.Get(k); !ok { - return fmt.Errorf("builder key not found: %d", k) - } else if v != m.std[k] { - return fmt.Errorf("builder key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) - } - } - if err := m.validateIterator(m.im.Iterator()); err != nil { - return fmt.Errorf("basic: %s", err) - } else if err := m.validateIterator(m.builder.Iterator()); err != nil { - return fmt.Errorf("builder: %s", err) - } - return nil -} - -func (m *TMap) validateIterator(itr *MapIterator[int, int]) error { - other := make(map[int]int) - for !itr.Done() { - k, v, _ := itr.Next() - other[k] = v - } - if len(other) != len(m.std) { - return fmt.Errorf("map iterator size mismatch: %v!=%v", len(m.std), len(other)) - } - for k, v := range m.std { - if v != other[k] { - return fmt.Errorf("map iterator mismatch: key=%v, %v!=%v", k, v, other[k]) - } - } - if k, v, ok := itr.Next(); ok { - return fmt.Errorf("map iterator returned key/value after done: <%v/%v>", k, v) - } - return nil -} - -func BenchmarkBuiltinMap_Set(b *testing.B) { - b.ReportAllocs() - m := make(map[int]int) - for i := 0; i < b.N; i++ { - m[i] = i - } -} - -func BenchmarkBuiltinMap_Delete(b *testing.B) { - const n = 10000000 - - m := make(map[int]int) - for i := 0; i < n; i++ { - m[i] = i - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - delete(m, i%n) - } -} - -func BenchmarkMap_Set(b *testing.B) { - b.ReportAllocs() - m := NewMap[int, int](nil) - for i := 0; i < b.N; i++ { - m = m.Set(i, i) - } -} - -func BenchmarkMap_Delete(b *testing.B) { - const n = 10000000 - - builder := NewMapBuilder[int, int](nil) - for i := 0; i < n; i++ { - builder.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - m := builder.Map() - for i := 0; i < b.N; i++ { - m.Delete(i % n) // Do not update map, always operate on original - } -} - -func BenchmarkMap_Iterator(b *testing.B) { - const n = 10000 - m := NewMap[int, int](nil) - for i := 0; i < 10000; i++ { - m = m.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - b.Run("Forward", func(b *testing.B) { - itr := m.Iterator() - for i := 0; i < b.N; i++ { - if i%n == 0 { - itr.First() - } - itr.Next() - } - }) -} - -func BenchmarkMapBuilder_Set(b *testing.B) { - b.ReportAllocs() - builder := NewMapBuilder[int, int](nil) - for i := 0; i < b.N; i++ { - builder.Set(i, i) - } -} - -func BenchmarkMapBuilder_Delete(b *testing.B) { - const n = 10000000 - - builder := NewMapBuilder[int, int](nil) - for i := 0; i < n; i++ { - builder.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - builder.Delete(i % n) - } -} - -func ExampleMap_Set() { - m := NewMap[string, any](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", 100) - - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - - v, ok = m.Get("bat") // does not exist - fmt.Println("bat", v, ok) - // Output: - // foo bar true - // baz 100 true - // bat false -} - -func ExampleMap_Delete() { - m := NewMap[string, any](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", 100) - m = m.Delete("baz") - - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - // Output: - // foo bar true - // baz false -} - -func ExampleMap_Iterator() { - m := NewMap[string, int](nil) - m = m.Set("apple", 100) - m = m.Set("grape", 200) - m = m.Set("kiwi", 300) - m = m.Set("mango", 400) - m = m.Set("orange", 500) - m = m.Set("peach", 600) - m = m.Set("pear", 700) - m = m.Set("pineapple", 800) - m = m.Set("strawberry", 900) - - itr := m.Iterator() - for !itr.Done() { - k, v, _ := itr.Next() - fmt.Println(k, v) - } - // Output: - // mango 400 - // pear 700 - // pineapple 800 - // grape 200 - // orange 500 - // strawberry 900 - // kiwi 300 - // peach 600 - // apple 100 -} - -func ExampleMapBuilder_Set() { - b := NewMapBuilder[string, any](nil) - b.Set("foo", "bar") - b.Set("baz", 100) - - m := b.Map() - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - - v, ok = m.Get("bat") // does not exist - fmt.Println("bat", v, ok) - // Output: - // foo bar true - // baz 100 true - // bat false -} - -func ExampleMapBuilder_Delete() { - b := NewMapBuilder[string, any](nil) - b.Set("foo", "bar") - b.Set("baz", 100) - b.Delete("baz") - - m := b.Map() - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - // Output: - // foo bar true - // baz false -} - -func TestInternalSortedMapLeafNode(t *testing.T) { - RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { - var cmpr defaultComparer[int] - var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} - var keys []int - for _, i := range rand.Perm(32) { - var resized bool - var splitNode sortedMapNode[int, int] - node, splitNode = node.set(i, i*10, &cmpr, false, &resized) - if !resized { - t.Fatal("expected resize") - } else if splitNode != nil { - t.Fatal("expected split") - } - keys = append(keys, i) - - // Verify not found at each size. - if _, ok := node.get(rand.Int()+32, &cmpr); ok { - t.Fatal("expected no value") - } - - // Verify min key is always the lowest. - sort.Ints(keys) - if got, exp := node.minKey(), keys[0]; got != exp { - t.Fatalf("minKey()=%d, expected %d", got, exp) - } - } - - // Verify all key/value pairs in node. - for i := range keys { - if v, ok := node.get(i, &cmpr); !ok || v != i*10 { - t.Fatalf("get(%d)=<%v,%v>", i, v, ok) - } - } - }) - - RunRandom(t, "Overwrite", func(t *testing.T, rand *rand.Rand) { - var cmpr defaultComparer[int] - var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} - - for _, i := range rand.Perm(32) { - var resized bool - node, _ = node.set(i, i*2, &cmpr, false, &resized) - } - for _, i := range rand.Perm(32) { - var resized bool - node, _ = node.set(i, i*3, &cmpr, false, &resized) - if resized { - t.Fatal("expected no resize") - } - } - - // Verify all overwritten key/value pairs in node. - for i := 0; i < 32; i++ { - if v, ok := node.get(i, &cmpr); !ok || v != i*3 { - t.Fatalf("get(%d)=<%v,%v>", i, v, ok) - } - } - }) - - t.Run("Split", func(t *testing.T) { - // Fill leaf node. var cmpr defaultComparer[int] - var cmpr defaultComparer[int] - var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} - for i := 0; i < 32; i++ { - var resized bool - node, _ = node.set(i, i*10, &cmpr, false, &resized) - } - - // Add one more and expect split. - var resized bool - newNode, splitNode := node.set(32, 320, &cmpr, false, &resized) - - // Verify node contents. - newLeafNode, ok := newNode.(*sortedMapLeafNode[int, int]) - if !ok { - t.Fatalf("unexpected node type: %T", newLeafNode) - } else if n := len(newLeafNode.entries); n != 16 { - t.Fatalf("unexpected node len: %d", n) - } - for i := range newLeafNode.entries { - if entry := newLeafNode.entries[i]; entry.key != i || entry.value != i*10 { - t.Fatalf("%d. unexpected entry: %v=%v", i, entry.key, entry.value) - } - } - - // Verify split node contents. - splitLeafNode, ok := splitNode.(*sortedMapLeafNode[int, int]) - if !ok { - t.Fatalf("unexpected split node type: %T", splitLeafNode) - } else if n := len(splitLeafNode.entries); n != 17 { - t.Fatalf("unexpected split node len: %d", n) - } - for i := range splitLeafNode.entries { - if entry := splitLeafNode.entries[i]; entry.key != (i+16) || entry.value != (i+16)*10 { - t.Fatalf("%d. unexpected split node entry: %v=%v", i, entry.key, entry.value) - } - } - }) -} - -func TestInternalSortedMapBranchNode(t *testing.T) { - RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { - keys := make([]int, 32*16) - for i := range keys { - keys[i] = rand.Intn(10000) - } - keys = uniqueIntSlice(keys) - sort.Ints(keys[:2]) // ensure first two keys are sorted for initial insert. - - // Initialize branch with two leafs. - var cmpr defaultComparer[int] - leaf0 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[0], value: keys[0] * 10}}} - leaf1 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[1], value: keys[1] * 10}}} - var node sortedMapNode[int, int] = newSortedMapBranchNode[int, int](leaf0, leaf1) - - sort.Ints(keys) - for _, i := range rand.Perm(len(keys)) { - key := keys[i] - - var resized bool - var splitNode sortedMapNode[int, int] - node, splitNode = node.set(key, key*10, &cmpr, false, &resized) - if key == leaf0.entries[0].key || key == leaf1.entries[0].key { - if resized { - t.Fatalf("expected no resize: key=%d", key) - } - } else { - if !resized { - t.Fatalf("expected resize: key=%d", key) - } - } - if splitNode != nil { - t.Fatal("unexpected split") - } - } - - // Verify all key/value pairs in node. - for _, key := range keys { - if v, ok := node.get(key, &cmpr); !ok || v != key*10 { - t.Fatalf("get(%d)=<%v,%v>", key, v, ok) - } - } - - // Verify min key is the lowest key. - if got, exp := node.minKey(), keys[0]; got != exp { - t.Fatalf("minKey()=%d, expected %d", got, exp) - } - }) - - t.Run("Split", func(t *testing.T) { - // Generate leaf nodes. - var cmpr defaultComparer[int] - children := make([]sortedMapNode[int, int], 32) - for i := range children { - leaf := &sortedMapLeafNode[int, int]{entries: make([]mapEntry[int, int], 32)} - for j := range leaf.entries { - leaf.entries[j] = mapEntry[int, int]{key: (i * 32) + j, value: ((i * 32) + j) * 100} - } - children[i] = leaf - } - var node sortedMapNode[int, int] = newSortedMapBranchNode(children...) - - // Add one more and expect split. - var resized bool - newNode, splitNode := node.set((32 * 32), (32*32)*100, &cmpr, false, &resized) - - // Verify node contents. - var idx int - newBranchNode, ok := newNode.(*sortedMapBranchNode[int, int]) - if !ok { - t.Fatalf("unexpected node type: %T", newBranchNode) - } else if n := len(newBranchNode.elems); n != 16 { - t.Fatalf("unexpected child elems len: %d", n) - } - for i, elem := range newBranchNode.elems { - child, ok := elem.node.(*sortedMapLeafNode[int, int]) - if !ok { - t.Fatalf("unexpected child type") - } - for j, entry := range child.entries { - if entry.key != idx || entry.value != idx*100 { - t.Fatalf("%d/%d. unexpected entry: %v=%v", i, j, entry.key, entry.value) - } - idx++ - } - } - - // Verify split node contents. - splitBranchNode, ok := splitNode.(*sortedMapBranchNode[int, int]) - if !ok { - t.Fatalf("unexpected split node type: %T", splitBranchNode) - } else if n := len(splitBranchNode.elems); n != 17 { - t.Fatalf("unexpected split node elem len: %d", n) - } - for i, elem := range splitBranchNode.elems { - child, ok := elem.node.(*sortedMapLeafNode[int, int]) - if !ok { - t.Fatalf("unexpected split node child type") - } - for j, entry := range child.entries { - if entry.key != idx || entry.value != idx*100 { - t.Fatalf("%d/%d. unexpected split node entry: %v=%v", i, j, entry.key, entry.value) - } - idx++ - } - } - }) -} - -func TestSortedMap_Get(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - m := NewSortedMap[int, int](nil) - if v, ok := m.Get(100); ok { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - }) -} - -func TestSortedMap_Set(t *testing.T) { - t.Run("Simple", func(t *testing.T) { - m := NewSortedMap[int, string](nil) - m = m.Set(100, "foo") - if v, ok := m.Get(100); !ok || v != "foo" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if got, exp := m.Len(), 1; got != exp { - t.Fatalf("SortedMap.Len()=%d, exp %d", got, exp) - } - }) - - t.Run("Small", func(t *testing.T) { - const n = 1000 - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - t.Run("Large", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping: short") - } - - const n = 1000000 - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - t.Run("StringKeys", func(t *testing.T) { - m := NewSortedMap[string, string](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", "bat") - m = m.Set("", "EMPTY") - if v, ok := m.Get("foo"); !ok || v != "bar" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get("baz"); !ok || v != "bat" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } else if v, ok := m.Get(""); !ok || v != "EMPTY" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - if v, ok := m.Get("no_such_key"); ok { - t.Fatalf("expected no value: <%v,%v>", v, ok) - } - }) - - t.Run("NoDefaultComparer", func(t *testing.T) { - var r string - func() { - defer func() { r = recover().(string) }() - m := NewSortedMap[float64, string](nil) - m = m.Set(float64(100), "bar") - }() - if r != `immutable.NewComparer: must set comparer for float64 type` { - t.Fatalf("unexpected panic: %q", r) - } - }) - - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { - m := NewTSortedMap() - for j := 0; j < 10000; j++ { - switch rand.Intn(2) { - case 1: // overwrite - m.Set(m.ExistingKey(rand), rand.Intn(10000)) - default: // set new key - m.Set(m.NewKey(rand), rand.Intn(10000)) - } - } - if err := m.Validate(); err != nil { - t.Fatal(err) - } - }) -} - -// Ensure map can support overwrites as it expands. -func TestSortedMap_Overwrite(t *testing.T) { - const n = 1000 - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - // Set original value. - m = m.Set(i, i) - - // Overwrite every node. - for j := 0; j <= i; j++ { - m = m.Set(j, i*j) - } - } - - // Verify all key/value pairs in map. - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i*(n-1) { - t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) - } - } -} - -func TestSortedMap_Delete(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - m := NewSortedMap[int, int](nil) - m = m.Delete(100) - if n := m.Len(); n != 0 { - t.Fatalf("SortedMap.Len()=%d, expected 0", n) - } - }) - - t.Run("Simple", func(t *testing.T) { - m := NewSortedMap[int, string](nil) - m = m.Set(100, "foo") - if v, ok := m.Get(100); !ok || v != "foo" { - t.Fatalf("unexpected value: <%v,%v>", v, ok) - } - m = m.Delete(100) - if v, ok := m.Get(100); ok { - t.Fatalf("unexpected no value: <%v,%v>", v, ok) - } - }) - - t.Run("Small", func(t *testing.T) { - const n = 1000 - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - - for i := 0; i < n; i++ { - m = m.Delete(i) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); ok { - t.Fatalf("expected no value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - t.Run("Large", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping: short") - } - - const n = 1000000 - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i+1) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); !ok || v != i+1 { - t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) - } - } - - for i := 0; i < n; i++ { - m = m.Delete(i) - } - for i := 0; i < n; i++ { - if v, ok := m.Get(i); ok { - t.Fatalf("unexpected no value for key=%v: <%v,%v>", i, v, ok) - } - } - }) - - RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { - m := NewTSortedMap() - for j := 0; j < 10000; j++ { - switch rand.Intn(8) { - case 0: // overwrite - m.Set(m.ExistingKey(rand), rand.Intn(10000)) - case 1: // delete existing key - m.Delete(m.ExistingKey(rand)) - case 2: // delete non-existent key. - m.Delete(m.NewKey(rand)) - default: // set new key - m.Set(m.NewKey(rand), rand.Intn(10000)) - } - } - if err := m.Validate(); err != nil { - t.Fatal(err) - } - - // Delete all keys. - keys := make([]int, len(m.keys)) - copy(keys, m.keys) - for _, k := range keys { - m.Delete(k) - } - if err := m.Validate(); err != nil { - t.Fatal(err) - } - }) -} - -func TestSortedMap_Iterator(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - t.Run("First", func(t *testing.T) { - itr := NewSortedMap[int, int](nil).Iterator() - itr.First() - if k, v, ok := itr.Next(); ok { - t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) - } - }) - - t.Run("Last", func(t *testing.T) { - itr := NewSortedMap[int, int](nil).Iterator() - itr.Last() - if k, v, ok := itr.Prev(); ok { - t.Fatalf("SortedMapIterator.Prev()=<%v,%v>, expected nil", k, v) - } - }) - - t.Run("Seek", func(t *testing.T) { - itr := NewSortedMap[string, int](nil).Iterator() - itr.Seek("foo") - if k, v, ok := itr.Next(); ok { - t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) - } - }) - }) - - t.Run("Seek", func(t *testing.T) { - const n = 100 - m := NewSortedMap[string, int](nil) - for i := 0; i < n; i += 2 { - m = m.Set(fmt.Sprintf("%04d", i), i) - } - - t.Run("Exact", func(t *testing.T) { - itr := m.Iterator() - for i := 0; i < n; i += 2 { - itr.Seek(fmt.Sprintf("%04d", i)) - for j := i; j < n; j += 2 { - if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { - t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) - } - } - if !itr.Done() { - t.Fatalf("SortedMapIterator.Done()=true, expected false") - } - } - }) - - t.Run("Miss", func(t *testing.T) { - itr := m.Iterator() - for i := 1; i < n-2; i += 2 { - itr.Seek(fmt.Sprintf("%04d", i)) - for j := i + 1; j < n; j += 2 { - if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { - t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) - } - } - if !itr.Done() { - t.Fatalf("SortedMapIterator.Done()=true, expected false") - } - } - }) - - t.Run("BeforeFirst", func(t *testing.T) { - itr := m.Iterator() - itr.Seek("") - for i := 0; i < n; i += 2 { - if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", i) { - t.Fatalf("%d. SortedMapIterator.Next()=%v, expected key %04d", i, k, i) - } - } - if !itr.Done() { - t.Fatalf("SortedMapIterator.Done()=true, expected false") - } - }) - t.Run("AfterLast", func(t *testing.T) { - itr := m.Iterator() - itr.Seek("1000") - if k, _, ok := itr.Next(); ok { - t.Fatalf("0. SortedMapIterator.Next()=%v, expected nil key", k) - } else if !itr.Done() { - t.Fatalf("SortedMapIterator.Done()=true, expected false") - } - }) - }) -} - -func TestNewHasher(t *testing.T) { - t.Run("builtin", func(t *testing.T) { - t.Run("int", func(t *testing.T) { testNewHasher(t, int(100)) }) - t.Run("int8", func(t *testing.T) { testNewHasher(t, int8(100)) }) - t.Run("int16", func(t *testing.T) { testNewHasher(t, int16(100)) }) - t.Run("int32", func(t *testing.T) { testNewHasher(t, int32(100)) }) - t.Run("int64", func(t *testing.T) { testNewHasher(t, int64(100)) }) - - t.Run("uint", func(t *testing.T) { testNewHasher(t, uint(100)) }) - t.Run("uint8", func(t *testing.T) { testNewHasher(t, uint8(100)) }) - t.Run("uint16", func(t *testing.T) { testNewHasher(t, uint16(100)) }) - t.Run("uint32", func(t *testing.T) { testNewHasher(t, uint32(100)) }) - t.Run("uint64", func(t *testing.T) { testNewHasher(t, uint64(100)) }) - - t.Run("string", func(t *testing.T) { testNewHasher(t, "foo") }) - //t.Run("byteSlice", func(t *testing.T) { testNewHasher(t, []byte("foo")) }) - }) - - t.Run("reflection", func(t *testing.T) { - type Int int - t.Run("int", func(t *testing.T) { testNewHasher(t, Int(100)) }) - - type Uint uint - t.Run("uint", func(t *testing.T) { testNewHasher(t, Uint(100)) }) - - type String string - t.Run("string", func(t *testing.T) { testNewHasher(t, String("foo")) }) - }) -} - -func testNewHasher[V constraints.Ordered](t *testing.T, v V) { - t.Helper() - h := NewHasher(v) - h.Hash(v) - if !h.Equal(v, v) { - t.Fatal("expected hash equality") - } -} - -func TestNewComparer(t *testing.T) { - t.Run("builtin", func(t *testing.T) { - t.Run("int", func(t *testing.T) { testNewComparer(t, int(100), int(101)) }) - t.Run("int8", func(t *testing.T) { testNewComparer(t, int8(100), int8(101)) }) - t.Run("int16", func(t *testing.T) { testNewComparer(t, int16(100), int16(101)) }) - t.Run("int32", func(t *testing.T) { testNewComparer(t, int32(100), int32(101)) }) - t.Run("int64", func(t *testing.T) { testNewComparer(t, int64(100), int64(101)) }) - - t.Run("uint", func(t *testing.T) { testNewComparer(t, uint(100), uint(101)) }) - t.Run("uint8", func(t *testing.T) { testNewComparer(t, uint8(100), uint8(101)) }) - t.Run("uint16", func(t *testing.T) { testNewComparer(t, uint16(100), uint16(101)) }) - t.Run("uint32", func(t *testing.T) { testNewComparer(t, uint32(100), uint32(101)) }) - t.Run("uint64", func(t *testing.T) { testNewComparer(t, uint64(100), uint64(101)) }) - - t.Run("string", func(t *testing.T) { testNewComparer(t, "bar", "foo") }) - //t.Run("byteSlice", func(t *testing.T) { testNewComparer(t, []byte("bar"), []byte("foo")) }) - }) - - t.Run("reflection", func(t *testing.T) { - type Int int - t.Run("int", func(t *testing.T) { testNewComparer(t, Int(100), Int(101)) }) - - type Uint uint - t.Run("uint", func(t *testing.T) { testNewComparer(t, Uint(100), Uint(101)) }) - - type String string - t.Run("string", func(t *testing.T) { testNewComparer(t, String("bar"), String("foo")) }) - }) -} - -func testNewComparer[T constraints.Ordered](t *testing.T, x, y T) { - t.Helper() - c := NewComparer(x) - if c.Compare(x, y) != -1 { - t.Fatal("expected comparer LT") - } else if c.Compare(x, x) != 0 { - t.Fatal("expected comparer EQ") - } else if c.Compare(y, x) != 1 { - t.Fatal("expected comparer GT") - } -} - -// TSortedMap represents a combined immutable and stdlib sorted map. -type TSortedMap struct { - im, prev *SortedMap[int, int] - builder *SortedMapBuilder[int, int] - std map[int]int - keys []int -} - -func NewTSortedMap() *TSortedMap { - return &TSortedMap{ - im: NewSortedMap[int, int](nil), - builder: NewSortedMapBuilder[int, int](nil), - std: make(map[int]int), - } -} - -func (m *TSortedMap) NewKey(rand *rand.Rand) int { - for { - k := rand.Int() - if _, ok := m.std[k]; !ok { - return k - } - } -} - -func (m *TSortedMap) ExistingKey(rand *rand.Rand) int { - if len(m.keys) == 0 { - return 0 - } - return m.keys[rand.Intn(len(m.keys))] -} - -func (m *TSortedMap) Set(k, v int) { - m.prev = m.im - m.im = m.im.Set(k, v) - m.builder.Set(k, v) - - if _, ok := m.std[k]; !ok { - m.keys = append(m.keys, k) - sort.Ints(m.keys) - } - m.std[k] = v -} - -func (m *TSortedMap) Delete(k int) { - m.prev = m.im - m.im = m.im.Delete(k) - m.builder.Delete(k) - delete(m.std, k) - - for i := range m.keys { - if m.keys[i] == k { - m.keys = append(m.keys[:i], m.keys[i+1:]...) - break - } - } -} - -func (m *TSortedMap) Validate() error { - for _, k := range m.keys { - if v, ok := m.im.Get(k); !ok { - return fmt.Errorf("key not found: %d", k) - } else if v != m.std[k] { - return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) - } - if v, ok := m.builder.Get(k); !ok { - return fmt.Errorf("builder key not found: %d", k) - } else if v != m.std[k] { - return fmt.Errorf("builder key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) - } - } - - if got, exp := m.builder.Len(), len(m.std); got != exp { - return fmt.Errorf("SortedMapBuilder.Len()=%d, expected %d", got, exp) - } - - sort.Ints(m.keys) - if err := m.validateForwardIterator(m.im.Iterator()); err != nil { - return fmt.Errorf("basic: %s", err) - } else if err := m.validateBackwardIterator(m.im.Iterator()); err != nil { - return fmt.Errorf("basic: %s", err) - } - - if err := m.validateForwardIterator(m.builder.Iterator()); err != nil { - return fmt.Errorf("basic: %s", err) - } else if err := m.validateBackwardIterator(m.builder.Iterator()); err != nil { - return fmt.Errorf("basic: %s", err) - } - return nil -} - -func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator[int, int]) error { - for i, k0 := range m.keys { - v0 := m.std[k0] - if k1, v1, ok := itr.Next(); !ok || k0 != k1 || v0 != v1 { - return fmt.Errorf("%d. SortedMapIterator.Next()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) - } - - done := i == len(m.keys)-1 - if v := itr.Done(); v != done { - return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) - } - } - if k, v, ok := itr.Next(); ok { - return fmt.Errorf("SortedMapIterator.Next()=<%v,%v>, expected nil after done", k, v) - } - return nil -} - -func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator[int, int]) error { - itr.Last() - for i := len(m.keys) - 1; i >= 0; i-- { - k0 := m.keys[i] - v0 := m.std[k0] - if k1, v1, ok := itr.Prev(); !ok || k0 != k1 || v0 != v1 { - return fmt.Errorf("%d. SortedMapIterator.Prev()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) - } - - done := i == 0 - if v := itr.Done(); v != done { - return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) - } - } - if k, v, ok := itr.Prev(); ok { - return fmt.Errorf("SortedMapIterator.Prev()=<%v,%v>, expected nil after done", k, v) - } - return nil -} - -func BenchmarkSortedMap_Set(b *testing.B) { - b.ReportAllocs() - m := NewSortedMap[int, int](nil) - for i := 0; i < b.N; i++ { - m = m.Set(i, i) - } -} - -func BenchmarkSortedMap_Delete(b *testing.B) { - const n = 10000 - - m := NewSortedMap[int, int](nil) - for i := 0; i < n; i++ { - m = m.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - m.Delete(i % n) // Do not update map, always operate on original - } -} - -func BenchmarkSortedMap_Iterator(b *testing.B) { - const n = 10000 - m := NewSortedMap[int, int](nil) - for i := 0; i < 10000; i++ { - m = m.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - b.Run("Forward", func(b *testing.B) { - itr := m.Iterator() - for i := 0; i < b.N; i++ { - if i%n == 0 { - itr.First() - } - itr.Next() - } - }) - - b.Run("Reverse", func(b *testing.B) { - itr := m.Iterator() - for i := 0; i < b.N; i++ { - if i%n == 0 { - itr.Last() - } - itr.Prev() - } - }) -} - -func BenchmarkSortedMapBuilder_Set(b *testing.B) { - b.ReportAllocs() - builder := NewSortedMapBuilder[int, int](nil) - for i := 0; i < b.N; i++ { - builder.Set(i, i) - } -} - -func BenchmarkSortedMapBuilder_Delete(b *testing.B) { - const n = 1000000 - - builder := NewSortedMapBuilder[int, int](nil) - for i := 0; i < n; i++ { - builder.Set(i, i) - } - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - builder.Delete(i % n) - } -} - -func ExampleSortedMap_Set() { - m := NewSortedMap[string, any](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", 100) - - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - - v, ok = m.Get("bat") // does not exist - fmt.Println("bat", v, ok) - // Output: - // foo bar true - // baz 100 true - // bat false -} - -func ExampleSortedMap_Delete() { - m := NewSortedMap[string, any](nil) - m = m.Set("foo", "bar") - m = m.Set("baz", 100) - m = m.Delete("baz") - - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - // Output: - // foo bar true - // baz false -} - -func ExampleSortedMap_Iterator() { - m := NewSortedMap[string, any](nil) - m = m.Set("strawberry", 900) - m = m.Set("kiwi", 300) - m = m.Set("apple", 100) - m = m.Set("pear", 700) - m = m.Set("pineapple", 800) - m = m.Set("peach", 600) - m = m.Set("orange", 500) - m = m.Set("grape", 200) - m = m.Set("mango", 400) - - itr := m.Iterator() - for !itr.Done() { - k, v, _ := itr.Next() - fmt.Println(k, v) - } - // Output: - // apple 100 - // grape 200 - // kiwi 300 - // mango 400 - // orange 500 - // peach 600 - // pear 700 - // pineapple 800 - // strawberry 900 -} - -func ExampleSortedMapBuilder_Set() { - b := NewSortedMapBuilder[string, any](nil) - b.Set("foo", "bar") - b.Set("baz", 100) - - m := b.Map() - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - - v, ok = m.Get("bat") // does not exist - fmt.Println("bat", v, ok) - // Output: - // foo bar true - // baz 100 true - // bat false -} - -func ExampleSortedMapBuilder_Delete() { - b := NewSortedMapBuilder[string, any](nil) - b.Set("foo", "bar") - b.Set("baz", 100) - b.Delete("baz") - - m := b.Map() - v, ok := m.Get("foo") - fmt.Println("foo", v, ok) - - v, ok = m.Get("baz") - fmt.Println("baz", v, ok) - // Output: - // foo bar true - // baz false -} - -// RunRandom executes fn multiple times with a different rand. -func RunRandom(t *testing.T, name string, fn func(t *testing.T, rand *rand.Rand)) { - if testing.Short() { - t.Skip("short mode") - } - t.Run(name, func(t *testing.T) { - for i := 0; i < *randomN; i++ { - i := i - t.Run(fmt.Sprintf("%08d", i), func(t *testing.T) { - t.Parallel() - fn(t, rand.New(rand.NewSource(int64(i)))) - }) - } - }) -} - -func uniqueIntSlice(a []int) []int { - m := make(map[int]struct{}) - other := make([]int, 0, len(a)) - for _, v := range a { - if _, ok := m[v]; ok { - continue - } - m[v] = struct{}{} - other = append(other, v) - } - return other -} - -// mockHasher represents a mock implementation of immutable.Hasher. -type mockHasher[K constraints.Ordered] struct { - hash func(value K) uint32 - equal func(a, b K) bool -} - -// Hash executes the mocked HashFn function. -func (h *mockHasher[K]) Hash(value K) uint32 { - return h.hash(value) -} - -// Equal executes the mocked EqualFn function. -func (h *mockHasher[K]) Equal(a, b K) bool { - return h.equal(a, b) -} - -// mockComparer represents a mock implementation of immutable.Comparer. -type mockComparer[K constraints.Ordered] struct { - compare func(a, b K) int -} - -// Compare executes the mocked CompreFn function. -func (h *mockComparer[K]) Compare(a, b K) int { - return h.compare(a, b) -} diff --git a/mkdeps.sh b/mkdeps.sh new file mode 100755 index 0000000..e8da8a4 --- /dev/null +++ b/mkdeps.sh @@ -0,0 +1,29 @@ +#!/bin/sh +set -eu + +export LANG=POSIX.UTF-8 + + +libs() { + find src tests -name '*.go' | grep -v '/main\.go$' | + grep -v '/version\.go$' +} + +mains() { + find src tests -name '*.go' | grep '/main\.go$' +} + +libs | varlist 'libs.go' +mains | varlist 'mains.go' + +find tests/functional/*/*.go -not -name main.go | varlist 'functional-tests/lib.go' +find tests/functional/*/main.go | varlist 'functional-tests/main.go' +find tests/fuzz/*/*.go -not -name main.go | varlist 'fuzz-targets/lib.go' +find tests/fuzz/*/main.go | varlist 'fuzz-targets/main.go' +find tests/benchmarks/*/*.go -not -name main.go | varlist 'benchmarks/lib.go' +find tests/benchmarks/*/main.go | varlist 'benchmarks/main.go' + +{ libs; mains; } | sort | sed 's/^\(.*\)\.go$/\1.a:\t\1.go/' +mains | sort | sed 's/^\(.*\)\.go$/\1.bin:\t\1.a/' +mains | sort | sed 's/^\(.*\)\.go$/\1.bin-check:\t\1.bin/' +mains | sort | sed 's|^\(.*\)/main\.go$|\1/main.a:\t\1/$(NAME).a|' diff --git a/sets.go b/sets.go deleted file mode 100644 index b41bd37..0000000 --- a/sets.go +++ /dev/null @@ -1,243 +0,0 @@ -package immutable - -// Set represents a collection of unique values. The set uses a Hasher -// to generate hashes and check for equality of key values. -// -// Internally, the Set stores values as keys of a Map[T,struct{}] -type Set[T any] struct { - m *Map[T, struct{}] -} - -// NewSet returns a new instance of Set. -// -// If hasher is nil, a default hasher implementation will automatically be chosen based on the first key added. -// Default hasher implementations only exist for int, string, and byte slice types. -// NewSet can also take some initial values as varargs. -func NewSet[T any](hasher Hasher[T], values ...T) Set[T] { - m := NewMap[T, struct{}](hasher) - for _, value := range values { - m = m.set(value, struct{}{}, true) - } - return Set[T]{m} -} - -// Add returns a set containing the new value. -// -// This function will return a new set even if the set already contains the value. -func (s Set[T]) Add(value T) Set[T] { - return Set[T]{s.m.Set(value, struct{}{})} -} - -// Delete returns a set with the given key removed. -func (s Set[T]) Delete(value T) Set[T] { - return Set[T]{s.m.Delete(value)} -} - -// Has returns true when the set contains the given value -func (s Set[T]) Has(val T) bool { - _, ok := s.m.Get(val) - return ok -} - -// Len returns the number of elements in the underlying map. -func (s Set[K]) Len() int { - return s.m.Len() -} - -// Items returns a slice of the items inside the set -func (s Set[T]) Items() []T { - r := make([]T, 0, s.Len()) - itr := s.Iterator() - for !itr.Done() { - v, _ := itr.Next() - r = append(r, v) - } - return r -} - -// Iterator returns a new iterator for this set positioned at the first value. -func (s Set[T]) Iterator() *SetIterator[T] { - itr := &SetIterator[T]{mi: s.m.Iterator()} - itr.mi.First() - return itr -} - -// SetIterator represents an iterator over a set. -// Iteration can occur in natural or reverse order based on use of Next() or Prev(). -type SetIterator[T any] struct { - mi *MapIterator[T, struct{}] -} - -// Done returns true if no more values remain in the iterator. -func (itr *SetIterator[T]) Done() bool { - return itr.mi.Done() -} - -// First moves the iterator to the first value. -func (itr *SetIterator[T]) First() { - itr.mi.First() -} - -// Next moves the iterator to the next value. -func (itr *SetIterator[T]) Next() (val T, ok bool) { - val, _, ok = itr.mi.Next() - return -} - -type SetBuilder[T any] struct { - s Set[T] -} - -func NewSetBuilder[T any](hasher Hasher[T]) *SetBuilder[T] { - return &SetBuilder[T]{s: NewSet(hasher)} -} - -func (s SetBuilder[T]) Set(val T) { - s.s.m = s.s.m.set(val, struct{}{}, true) -} - -func (s SetBuilder[T]) Delete(val T) { - s.s.m = s.s.m.delete(val, true) -} - -func (s SetBuilder[T]) Has(val T) bool { - return s.s.Has(val) -} - -func (s SetBuilder[T]) Len() int { - return s.s.Len() -} - -type SortedSet[T any] struct { - m *SortedMap[T, struct{}] -} - -// NewSortedSet returns a new instance of SortedSet. -// -// If comparer is nil then -// a default comparer is set after the first key is inserted. Default comparers -// exist for int, string, and byte slice keys. -// NewSortedSet can also take some initial values as varargs. -func NewSortedSet[T any](comparer Comparer[T], values ...T) SortedSet[T] { - m := NewSortedMap[T, struct{}](comparer) - for _, value := range values { - m = m.set(value, struct{}{}, true) - } - return SortedSet[T]{m} -} - -// Add returns a set containing the new value. -// -// This function will return a new set even if the set already contains the value. -func (s SortedSet[T]) Add(value T) SortedSet[T] { - return SortedSet[T]{s.m.Set(value, struct{}{})} -} - -// Delete returns a set with the given key removed. -func (s SortedSet[T]) Delete(value T) SortedSet[T] { - return SortedSet[T]{s.m.Delete(value)} -} - -// Has returns true when the set contains the given value -func (s SortedSet[T]) Has(val T) bool { - _, ok := s.m.Get(val) - return ok -} - -// Len returns the number of elements in the underlying map. -func (s SortedSet[K]) Len() int { - return s.m.Len() -} - -// Items returns a slice of the items inside the set -func (s SortedSet[T]) Items() []T { - r := make([]T, 0, s.Len()) - itr := s.Iterator() - for !itr.Done() { - v, _ := itr.Next() - r = append(r, v) - } - return r -} - -// Iterator returns a new iterator for this set positioned at the first value. -func (s SortedSet[T]) Iterator() *SortedSetIterator[T] { - itr := &SortedSetIterator[T]{mi: s.m.Iterator()} - itr.mi.First() - return itr -} - -// SortedSetIterator represents an iterator over a sorted set. -// Iteration can occur in natural or reverse order based on use of Next() or Prev(). -type SortedSetIterator[T any] struct { - mi *SortedMapIterator[T, struct{}] -} - -// Done returns true if no more values remain in the iterator. -func (itr *SortedSetIterator[T]) Done() bool { - return itr.mi.Done() -} - -// First moves the iterator to the first value. -func (itr *SortedSetIterator[T]) First() { - itr.mi.First() -} - -// Last moves the iterator to the last value. -func (itr *SortedSetIterator[T]) Last() { - itr.mi.Last() -} - -// Next moves the iterator to the next value. -func (itr *SortedSetIterator[T]) Next() (val T, ok bool) { - val, _, ok = itr.mi.Next() - return -} - -// Prev moves the iterator to the previous value. -func (itr *SortedSetIterator[T]) Prev() (val T, ok bool) { - val, _, ok = itr.mi.Prev() - return -} - -// Seek moves the iterator to the given value. -// -// If the value does not exist then the next value is used. If no more keys exist -// then the iterator is marked as done. -func (itr *SortedSetIterator[T]) Seek(val T) { - itr.mi.Seek(val) -} - -type SortedSetBuilder[T any] struct { - s *SortedSet[T] -} - -func NewSortedSetBuilder[T any](comparer Comparer[T]) *SortedSetBuilder[T] { - s := NewSortedSet(comparer) - return &SortedSetBuilder[T]{s: &s} -} - -func (s SortedSetBuilder[T]) Set(val T) { - s.s.m = s.s.m.set(val, struct{}{}, true) -} - -func (s SortedSetBuilder[T]) Delete(val T) { - s.s.m = s.s.m.delete(val, true) -} - -func (s SortedSetBuilder[T]) Has(val T) bool { - return s.s.Has(val) -} - -func (s SortedSetBuilder[T]) Len() int { - return s.s.Len() -} - -// SortedSet returns the current copy of the set. -// The builder should not be used again after the list after this call. -func (s SortedSetBuilder[T]) SortedSet() SortedSet[T] { - assert(s.s != nil, "immutable.SortedSetBuilder.SortedSet(): duplicate call to fetch sorted set") - set := s.s - s.s = nil - return *set -} diff --git a/sets_test.go b/sets_test.go deleted file mode 100644 index 6612cba..0000000 --- a/sets_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package immutable - -import ( - "testing" -) - -func TestSetsPut(t *testing.T) { - s := NewSet[string](nil) - s2 := s.Add("1").Add("1") - s2.Add("2") - if s.Len() != 0 { - t.Fatalf("Unexpected mutation of set") - } - if s.Has("1") { - t.Fatalf("Unexpected set element") - } - if s2.Len() != 1 { - t.Fatalf("Unexpected non-mutation of set") - } - if !s2.Has("1") { - t.Fatalf("Set element missing") - } - itr := s2.Iterator() - counter := 0 - for !itr.Done() { - i, v := itr.Next() - t.Log(i, v) - counter++ - } - if counter != 1 { - t.Fatalf("iterator wrong length") - } -} - -func TestSetsDelete(t *testing.T) { - s := NewSet[string](nil) - s2 := s.Add("1") - s3 := s.Delete("1") - if s2.Len() != 1 { - t.Fatalf("Unexpected non-mutation of set") - } - if !s2.Has("1") { - t.Fatalf("Set element missing") - } - if s3.Len() != 0 { - t.Fatalf("Unexpected set length after delete") - } - if s3.Has("1") { - t.Fatalf("Unexpected set element after delete") - } -} - -func TestSortedSetsPut(t *testing.T) { - s := NewSortedSet[string](nil) - s2 := s.Add("1").Add("1").Add("0") - if s.Len() != 0 { - t.Fatalf("Unexpected mutation of set") - } - if s.Has("1") { - t.Fatalf("Unexpected set element") - } - if s2.Len() != 2 { - t.Fatalf("Unexpected non-mutation of set") - } - if !s2.Has("1") { - t.Fatalf("Set element missing") - } - - itr := s2.Iterator() - counter := 0 - for !itr.Done() { - i, v := itr.Next() - t.Log(i, v) - if counter == 0 && i != "0" { - t.Fatalf("sort did not work for first el") - } - if counter == 1 && i != "1" { - t.Fatalf("sort did not work for second el") - } - counter++ - } - if counter != 2 { - t.Fatalf("iterator wrong length") - } -} - -func TestSortedSetsDelete(t *testing.T) { - s := NewSortedSet[string](nil) - s2 := s.Add("1") - s3 := s.Delete("1") - if s2.Len() != 1 { - t.Fatalf("Unexpected non-mutation of set") - } - if !s2.Has("1") { - t.Fatalf("Set element missing") - } - if s3.Len() != 0 { - t.Fatalf("Unexpected set length after delete") - } - if s3.Has("1") { - t.Fatalf("Unexpected set element after delete") - } -} - -func TestSortedSetBuilder(t *testing.T) { - b := NewSortedSetBuilder[string](nil) - b.Set("test3") - b.Set("test1") - b.Set("test2") - - s := b.SortedSet() - items := s.Items() - - if len(items) != 3 { - t.Fatalf("Set has wrong number of items") - } - if items[0] != "test1" { - t.Fatalf("First item incorrectly sorted") - } - if items[1] != "test2" { - t.Fatalf("Second item incorrectly sorted") - } - if items[2] != "test3" { - t.Fatalf("Third item incorrectly sorted") - } -} diff --git a/src/pds.go b/src/pds.go new file mode 100644 index 0000000..0ce9ade --- /dev/null +++ b/src/pds.go @@ -0,0 +1,2700 @@ +// Package immutable provides immutable collection types. +// +// # Introduction +// +// Immutable collections provide an efficient, safe way to share collections +// of data while minimizing locks. The collections in this package provide +// List, Map, and SortedMap implementations. These act similarly to slices +// and maps, respectively, except that altering a collection returns a new +// copy of the collection with that change. +// +// Because collections are unable to change, they are safe for multiple +// goroutines to read from at the same time without a mutex. However, these +// types of collections come with increased CPU & memory usage as compared +// with Go's built-in collection types so please evaluate for your specific +// use. +// +// # Collection Types +// +// The List type provides an API similar to Go slices. They allow appending, +// prepending, and updating of elements. Elements can also be fetched by index +// or iterated over using a ListIterator. +// +// The Map & SortedMap types provide an API similar to Go maps. They allow +// values to be assigned to unique keys and allow for the deletion of keys. +// Values can be fetched by key and key/value pairs can be iterated over using +// the appropriate iterator type. Both map types provide the same API. The +// SortedMap, however, provides iteration over sorted keys while the Map +// provides iteration over unsorted keys. Maps improved performance and memory +// usage as compared to SortedMaps. +// +// # Hashing and Sorting +// +// Map types require the use of a Hasher implementation to calculate hashes for +// their keys and check for key equality. SortedMaps require the use of a +// Comparer implementation to sort keys in the map. +// +// These collection types automatically provide built-in hasher and comparers +// for int, string, and byte slice keys. If you are using one of these key types +// then simply pass a nil into the constructor. Otherwise you will need to +// implement a custom Hasher or Comparer type. Please see the provided +// implementations for reference. +package pds + +import ( + "cmp" + "fmt" + "math/bits" + "reflect" + "sort" + "strings" +) + +// List is a dense, ordered, indexed collections. They are analogous to slices +// in Go. They can be updated by appending to the end of the list, prepending +// values to the beginning of the list, or updating existing indexes in the +// list. +type List[T any] struct { + root listNode[T] // root node + origin int // offset to zero index element + size int // total number of elements in use +} + +// NewList returns a new empty instance of List. +func NewList[T any](values ...T) *List[T] { + l := &List[T]{ + root: &listLeafNode[T]{}, + } + for _, value := range values { + l.append(value, true) + } + return l +} + +// clone returns a copy of the list. +func (l *List[T]) clone() *List[T] { + other := *l + return &other +} + +// Len returns the number of elements in the list. +func (l *List[T]) Len() int { + return l.size +} + +// cap returns the total number of possible elements for the current depth. +func (l *List[T]) cap() int { + return 1 << (l.root.depth() * listNodeBits) +} + +// Get returns the value at the given index. Similar to slices, this method will +// panic if index is below zero or is greater than or equal to the list size. +func (l *List[T]) Get(index int) T { + if index < 0 || index >= l.size { + panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index)) + } + return l.root.get(l.origin + index) +} + +// Set returns a new list with value set at index. Similar to slices, this +// method will panic if index is below zero or if the index is greater than +// or equal to the list size. +func (l *List[T]) Set(index int, value T) *List[T] { + return l.set(index, value, false) +} + +func (l *List[T]) set(index int, value T, mutable bool) *List[T] { + if index < 0 || index >= l.size { + panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index)) + } + other := l + if !mutable { + other = l.clone() + } + other.root = other.root.set(l.origin+index, value, mutable) + return other +} + +// Append returns a new list with value added to the end of the list. +func (l *List[T]) Append(value T) *List[T] { + return l.append(value, false) +} + +func (l *List[T]) append(value T, mutable bool) *List[T] { + other := l + if !mutable { + other = l.clone() + } + + // Expand list to the right if no slots remain. + if other.size+other.origin >= l.cap() { + newRoot := &listBranchNode[T]{d: other.root.depth() + 1} + newRoot.children[0] = other.root + other.root = newRoot + } + + // Increase size and set the last element to the new value. + other.size++ + other.root = other.root.set(other.origin+other.size-1, value, mutable) + return other +} + +// Prepend returns a new list with value(s) added to the beginning of the list. +func (l *List[T]) Prepend(value T) *List[T] { + return l.prepend(value, false) +} + +func (l *List[T]) prepend(value T, mutable bool) *List[T] { + other := l + if !mutable { + other = l.clone() + } + + // Expand list to the left if no slots remain. + if other.origin == 0 { + newRoot := &listBranchNode[T]{d: other.root.depth() + 1} + newRoot.children[listNodeSize-1] = other.root + other.root = newRoot + other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits) + } + + // Increase size and move origin back. Update first element to value. + other.size++ + other.origin-- + other.root = other.root.set(other.origin, value, mutable) + return other +} + +// Slice returns a new list of elements between start index and end index. +// Similar to slices, this method will panic if start or end are below zero or +// greater than the list size. A panic will also occur if start is greater than +// end. +// +// Unlike Go slices, references to inaccessible elements will be automatically +// removed so they can be garbage collected. +func (l *List[T]) Slice(start, end int) *List[T] { + return l.slice(start, end, false) +} + +func (l *List[T]) slice(start, end int, mutable bool) *List[T] { + // Panics similar to Go slices. + if start < 0 || start > l.size { + panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start)) + } else if end < 0 || end > l.size { + panic(fmt.Sprintf("immutable.List.Slice: end index %d out of bounds", end)) + } else if start > end { + panic(fmt.Sprintf("immutable.List.Slice: invalid slice index: [%d:%d]", start, end)) + } + + // Return the same list if the start and end are the entire range. + if start == 0 && end == l.size { + return l + } + + // Create copy, if immutable. + other := l + if !mutable { + other = l.clone() + } + + // Update origin/size. + other.origin = l.origin + start + other.size = end - start + + // Contract tree while the start & end are in the same child node. + for other.root.depth() > 1 { + i := (other.origin >> (other.root.depth() * listNodeBits)) & listNodeMask + j := ((other.origin + other.size - 1) >> (other.root.depth() * listNodeBits)) & listNodeMask + if i != j { + break // branch contains at least two nodes, exit + } + + // Replace the current root with the single child & update origin offset. + other.origin -= i << (other.root.depth() * listNodeBits) + other.root = other.root.(*listBranchNode[T]).children[i] + } + + // Ensure all references are removed before start & after end. + other.root = other.root.deleteBefore(other.origin, mutable) + other.root = other.root.deleteAfter(other.origin+other.size-1, mutable) + + return other +} + +// Iterator returns a new iterator for this list positioned at the first index. +func (l *List[T]) Iterator() *ListIterator[T] { + itr := &ListIterator[T]{list: l} + itr.First() + return itr +} + +// ListBuilder represents an efficient builder for creating new Lists. +type ListBuilder[T any] struct { + list *List[T] // current state +} + +// NewListBuilder returns a new instance of ListBuilder. +func NewListBuilder[T any]() *ListBuilder[T] { + return &ListBuilder[T]{list: NewList[T]()} +} + +// List returns the current copy of the list. +// The builder should not be used again after the list after this call. +func (b *ListBuilder[T]) List() *List[T] { + assert(b.list != nil, "immutable.ListBuilder.List(): duplicate call to fetch list") + list := b.list + b.list = nil + return list +} + +// Len returns the number of elements in the underlying list. +func (b *ListBuilder[T]) Len() int { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + return b.list.Len() +} + +// Get returns the value at the given index. Similar to slices, this method will +// panic if index is below zero or is greater than or equal to the list size. +func (b *ListBuilder[T]) Get(index int) T { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + return b.list.Get(index) +} + +// Set updates the value at the given index. Similar to slices, this method will +// panic if index is below zero or if the index is greater than or equal to the +// list size. +func (b *ListBuilder[T]) Set(index int, value T) { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.set(index, value, true) +} + +// Append adds value to the end of the list. +func (b *ListBuilder[T]) Append(value T) { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.append(value, true) +} + +// Prepend adds value to the beginning of the list. +func (b *ListBuilder[T]) Prepend(value T) { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.prepend(value, true) +} + +// Slice updates the list with a sublist of elements between start and end index. +// See List.Slice() for more details. +func (b *ListBuilder[T]) Slice(start, end int) { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.slice(start, end, true) +} + +// Iterator returns a new iterator for the underlying list. +func (b *ListBuilder[T]) Iterator() *ListIterator[T] { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + return b.list.Iterator() +} + +// Constants for bit shifts used for levels in the List trie. +const ( + listNodeBits = 5 + listNodeSize = 1 << listNodeBits + listNodeMask = listNodeSize - 1 +) + +// listNode represents either a branch or leaf node in a List. +type listNode[T any] interface { + depth() uint + get(index int) T + set(index int, v T, mutable bool) listNode[T] + + containsBefore(index int) bool + containsAfter(index int) bool + + deleteBefore(index int, mutable bool) listNode[T] + deleteAfter(index int, mutable bool) listNode[T] +} + +// newListNode returns a leaf node for depth zero, otherwise returns a branch node. +func newListNode[T any](depth uint) listNode[T] { + if depth == 0 { + return &listLeafNode[T]{} + } + return &listBranchNode[T]{d: depth} +} + +// listBranchNode represents a branch of a List tree at a given depth. +type listBranchNode[T any] struct { + d uint // depth + children [listNodeSize]listNode[T] +} + +// depth returns the depth of this branch node from the leaf. +func (n *listBranchNode[T]) depth() uint { return n.d } + +// get returns the child node at the segment of the index for this depth. +func (n *listBranchNode[T]) get(index int) T { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + return n.children[idx].get(index) +} + +// set recursively updates the value at index for each lower depth from the node. +func (n *listBranchNode[T]) set(index int, v T, mutable bool) listNode[T] { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Find child for the given value in the branch. Create new if it doesn't exist. + child := n.children[idx] + if child == nil { + child = newListNode[T](n.depth() - 1) + } + + // Return a copy of this branch with the new child. + var other *listBranchNode[T] + if mutable { + other = n + } else { + tmp := *n + other = &tmp + } + other.children[idx] = child.set(index, v, mutable) + return other +} + +// containsBefore returns true if non-nil values exists between [0,index). +func (n *listBranchNode[T]) containsBefore(index int) bool { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Quickly check if any direct children exist before this segment of the index. + for i := 0; i < idx; i++ { + if n.children[i] != nil { + return true + } + } + + // Recursively check for children directly at the given index at this segment. + if n.children[idx] != nil && n.children[idx].containsBefore(index) { + return true + } + return false +} + +// containsAfter returns true if non-nil values exists between (index,listNodeSize). +func (n *listBranchNode[T]) containsAfter(index int) bool { + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + // Quickly check if any direct children exist after this segment of the index. + for i := idx + 1; i < len(n.children); i++ { + if n.children[i] != nil { + return true + } + } + + // Recursively check for children directly at the given index at this segment. + if n.children[idx] != nil && n.children[idx].containsAfter(index) { + return true + } + return false +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listBranchNode[T]) deleteBefore(index int, mutable bool) listNode[T] { + // Ignore if no nodes exist before the given index. + if !n.containsBefore(index) { + return n + } + + // Return a copy with any nodes prior to the index removed. + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + var other *listBranchNode[T] + if mutable { + other = n + for i := 0; i < idx; i++ { + n.children[i] = nil + } + } else { + other = &listBranchNode[T]{d: n.d} + copy(other.children[idx:][:], n.children[idx:][:]) + } + + if other.children[idx] != nil { + other.children[idx] = other.children[idx].deleteBefore(index, mutable) + } + return other +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listBranchNode[T]) deleteAfter(index int, mutable bool) listNode[T] { + // Ignore if no nodes exist after the given index. + if !n.containsAfter(index) { + return n + } + + // Return a copy with any nodes after the index removed. + idx := (index >> (n.d * listNodeBits)) & listNodeMask + + var other *listBranchNode[T] + if mutable { + other = n + for i := idx + 1; i < len(n.children); i++ { + n.children[i] = nil + } + } else { + other = &listBranchNode[T]{d: n.d} + copy(other.children[:idx+1], n.children[:idx+1]) + } + + if other.children[idx] != nil { + other.children[idx] = other.children[idx].deleteAfter(index, mutable) + } + return other +} + +// listLeafNode represents a leaf node in a List. +type listLeafNode[T any] struct { + children [listNodeSize]T + // bitset with ones at occupied positions, position 0 is the LSB + occupied uint32 +} + +// depth always returns 0 for leaf nodes. +func (n *listLeafNode[T]) depth() uint { return 0 } + +// get returns the value at the given index. +func (n *listLeafNode[T]) get(index int) T { + return n.children[index&listNodeMask] +} + +// set returns a copy of the node with the value at the index updated to v. +func (n *listLeafNode[T]) set(index int, v T, mutable bool) listNode[T] { + idx := index & listNodeMask + var other *listLeafNode[T] + if mutable { + other = n + } else { + tmp := *n + other = &tmp + } + other.children[idx] = v + other.occupied |= 1 << idx + return other +} + +// containsBefore returns true if non-nil values exists between [0,index). +func (n *listLeafNode[T]) containsBefore(index int) bool { + idx := index & listNodeMask + return bits.TrailingZeros32(n.occupied) < idx +} + +// containsAfter returns true if non-nil values exists between (index,listNodeSize). +func (n *listLeafNode[T]) containsAfter(index int) bool { + idx := index & listNodeMask + lastSetPos := 31 - bits.LeadingZeros32(n.occupied) + return lastSetPos > idx +} + +// deleteBefore returns a new node with all elements before index removed. +func (n *listLeafNode[T]) deleteBefore(index int, mutable bool) listNode[T] { + if !n.containsBefore(index) { + return n + } + + idx := index & listNodeMask + var other *listLeafNode[T] + if mutable { + other = n + var empty T + for i := 0; i < idx; i++ { + other.children[i] = empty + } + } else { + other = &listLeafNode[T]{occupied: n.occupied} + copy(other.children[idx:][:], n.children[idx:][:]) + } + // Set the first idx bits to 0. + other.occupied &= ^((1 << idx) - 1) + return other +} + +// deleteAfter returns a new node with all elements after index removed. +func (n *listLeafNode[T]) deleteAfter(index int, mutable bool) listNode[T] { + if !n.containsAfter(index) { + return n + } + + idx := index & listNodeMask + var other *listLeafNode[T] + if mutable { + other = n + var empty T + for i := idx + 1; i < len(n.children); i++ { + other.children[i] = empty + } + } else { + other = &listLeafNode[T]{occupied: n.occupied} + copy(other.children[:idx+1][:], n.children[:idx+1][:]) + } + // Set bits after idx to 0. idx < 31 because n.containsAfter(index) == true. + other.occupied &= (1 << (idx + 1)) - 1 + return other +} + +// ListIterator represents an ordered iterator over a list. +type ListIterator[T any] struct { + list *List[T] // source list + index int // current index position + + stack [32]listIteratorElem[T] // search stack + depth int // stack depth +} + +// Done returns true if no more elements remain in the iterator. +func (itr *ListIterator[T]) Done() bool { + return itr.index < 0 || itr.index >= itr.list.Len() +} + +// First positions the iterator on the first index. +// If source list is empty then no change is made. +func (itr *ListIterator[T]) First() { + if itr.list.Len() != 0 { + itr.Seek(0) + } +} + +// Last positions the iterator on the last index. +// If source list is empty then no change is made. +func (itr *ListIterator[T]) Last() { + if n := itr.list.Len(); n != 0 { + itr.Seek(n - 1) + } +} + +// Seek moves the iterator position to the given index in the list. +// Similar to Go slices, this method will panic if index is below zero or if +// the index is greater than or equal to the list size. +func (itr *ListIterator[T]) Seek(index int) { + // Panic similar to Go slices. + if index < 0 || index >= itr.list.Len() { + panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index)) + } + itr.index = index + + // Reset to the bottom of the stack at seek to the correct position. + itr.stack[0] = listIteratorElem[T]{node: itr.list.root} + itr.depth = 0 + itr.seek(index) +} + +// Next returns the current index and its value & moves the iterator forward. +// Returns an index of -1 if the there are no more elements to return. +func (itr *ListIterator[T]) Next() (index int, value T) { + // Exit immediately if there are no elements remaining. + var empty T + if itr.Done() { + return -1, empty + } + + // Retrieve current index & value. + elem := &itr.stack[itr.depth] + index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] + + // Increase index. If index is at the end then return immediately. + itr.index++ + if itr.Done() { + return index, value + } + + // Move up stack until we find a node that has remaining position ahead. + for ; itr.depth > 0 && itr.stack[itr.depth].index >= listNodeSize-1; itr.depth-- { + } + + // Seek to correct position from current depth. + itr.seek(itr.index) + + return index, value +} + +// Prev returns the current index and value and moves the iterator backward. +// Returns an index of -1 if the there are no more elements to return. +func (itr *ListIterator[T]) Prev() (index int, value T) { + // Exit immediately if there are no elements remaining. + var empty T + if itr.Done() { + return -1, empty + } + + // Retrieve current index & value. + elem := &itr.stack[itr.depth] + index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index] + + // Decrease index. If index is past the beginning then return immediately. + itr.index-- + if itr.Done() { + return index, value + } + + // Move up stack until we find a node that has remaining position behind. + for ; itr.depth > 0 && itr.stack[itr.depth].index == 0; itr.depth-- { + } + + // Seek to correct position from current depth. + itr.seek(itr.index) + + return index, value +} + +// seek positions the stack to the given index from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *ListIterator[T]) seek(index int) { + // Iterate over each level until we reach a leaf node. + for { + elem := &itr.stack[itr.depth] + elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask + + switch node := elem.node.(type) { + case *listBranchNode[T]: + child := node.children[elem.index] + itr.stack[itr.depth+1] = listIteratorElem[T]{node: child} + itr.depth++ + case *listLeafNode[T]: + return + } + } +} + +// listIteratorElem represents the node and it's child index within the stack. +type listIteratorElem[T any] struct { + node listNode[T] + index int +} + +// Size thresholds for each type of branch node. +const ( + maxArrayMapSize = 8 + maxBitmapIndexedSize = 16 +) + +// Segment bit shifts within the map tree. +const ( + mapNodeBits = 5 + mapNodeSize = 1 << mapNodeBits + mapNodeMask = mapNodeSize - 1 +) + +// Map represents an immutable hash map implementation. The map uses a Hasher +// to generate hashes and check for equality of key values. +// +// It is implemented as an Hash Array Mapped Trie. +type Map[K, V any] struct { + size int // total number of key/value pairs + root mapNode[K, V] // root node of trie + hasher Hasher[K] // hasher implementation +} + +// NewMap returns a new instance of Map. If hasher is nil, a default hasher +// implementation will automatically be chosen based on the first key added. +// Default hasher implementations only exist for int, string, and byte slice types. +func NewMap[K, V any](hasher Hasher[K]) *Map[K, V] { + return &Map[K, V]{ + hasher: hasher, + } +} + +// NewMapOf returns a new instance of Map, containing a map of provided entries. +// +// If hasher is nil, a default hasher implementation will automatically be chosen based on the first key added. +// Default hasher implementations only exist for int, string, and byte slice types. +func NewMapOf[K comparable, V any](hasher Hasher[K], entries map[K]V) *Map[K, V] { + m := &Map[K, V]{ + hasher: hasher, + } + for k, v := range entries { + m.set(k, v, true) + } + return m +} + +// Len returns the number of elements in the map. +func (m *Map[K, V]) Len() int { + return m.size +} + +// clone returns a shallow copy of m. +func (m *Map[K, V]) clone() *Map[K, V] { + other := *m + return &other +} + +// Get returns the value for a given key and a flag indicating whether the +// key exists. This flag distinguishes a nil value set on a key versus a +// non-existent key in the map. +func (m *Map[K, V]) Get(key K) (value V, ok bool) { + var empty V + if m.root == nil { + return empty, false + } + keyHash := m.hasher.Hash(key) + return m.root.get(key, 0, keyHash, m.hasher) +} + +// Set returns a map with the key set to the new value. A nil value is allowed. +// +// This function will return a new map even if the updated value is the same as +// the existing value because Map does not track value equality. +func (m *Map[K, V]) Set(key K, value V) *Map[K, V] { + return m.set(key, value, false) +} + +func (m *Map[K, V]) set(key K, value V, mutable bool) *Map[K, V] { + // Set a hasher on the first value if one does not already exist. + hasher := m.hasher + if hasher == nil { + hasher = NewHasher(key) + } + + // Generate copy if necessary. + other := m + if !mutable { + other = m.clone() + } + other.hasher = hasher + + // If the map is empty, initialize with a simple array node. + if m.root == nil { + other.size = 1 + other.root = &mapArrayNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} + return other + } + + // Otherwise copy the map and delegate insertion to the root. + // Resized will return true if the key does not currently exist. + var resized bool + other.root = m.root.set(key, value, 0, hasher.Hash(key), hasher, mutable, &resized) + if resized { + other.size++ + } + return other +} + +// Delete returns a map with the given key removed. +// Removing a non-existent key will cause this method to return the same map. +func (m *Map[K, V]) Delete(key K) *Map[K, V] { + return m.delete(key, false) +} + +func (m *Map[K, V]) delete(key K, mutable bool) *Map[K, V] { + // Return original map if no keys exist. + if m.root == nil { + return m + } + + // If the delete did not change the node then return the original map. + var resized bool + newRoot := m.root.delete(key, 0, m.hasher.Hash(key), m.hasher, mutable, &resized) + if !resized { + return m + } + + // Generate copy if necessary. + other := m + if !mutable { + other = m.clone() + } + + // Return copy of map with new root and decreased size. + other.size = m.size - 1 + other.root = newRoot + return other +} + +// Iterator returns a new iterator for the map. +func (m *Map[K, V]) Iterator() *MapIterator[K, V] { + itr := &MapIterator[K, V]{m: m} + itr.First() + return itr +} + +// MapBuilder represents an efficient builder for creating Maps. +type MapBuilder[K, V any] struct { + m *Map[K, V] // current state +} + +// NewMapBuilder returns a new instance of MapBuilder. +func NewMapBuilder[K, V any](hasher Hasher[K]) *MapBuilder[K, V] { + return &MapBuilder[K, V]{m: NewMap[K, V](hasher)} +} + +// Map returns the underlying map. Only call once. +// Builder is invalid after call. Will panic on second invocation. +func (b *MapBuilder[K, V]) Map() *Map[K, V] { + assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") + m := b.m + b.m = nil + return m +} + +// Len returns the number of elements in the underlying map. +func (b *MapBuilder[K, V]) Len() int { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + return b.m.Len() +} + +// Get returns the value for the given key. +func (b *MapBuilder[K, V]) Get(key K) (value V, ok bool) { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + return b.m.Get(key) +} + +// Set sets the value of the given key. See Map.Set() for additional details. +func (b *MapBuilder[K, V]) Set(key K, value V) { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + b.m = b.m.set(key, value, true) +} + +// Delete removes the given key. See Map.Delete() for additional details. +func (b *MapBuilder[K, V]) Delete(key K) { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + b.m = b.m.delete(key, true) +} + +// Iterator returns a new iterator for the underlying map. +func (b *MapBuilder[K, V]) Iterator() *MapIterator[K, V] { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + return b.m.Iterator() +} + +// mapNode represents any node in the map tree. +type mapNode[K, V any] interface { + get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) + set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] + delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] +} + +var _ mapNode[string, any] = (*mapArrayNode[string, any])(nil) +var _ mapNode[string, any] = (*mapBitmapIndexedNode[string, any])(nil) +var _ mapNode[string, any] = (*mapHashArrayNode[string, any])(nil) +var _ mapNode[string, any] = (*mapValueNode[string, any])(nil) +var _ mapNode[string, any] = (*mapHashCollisionNode[string, any])(nil) + +// mapLeafNode represents a node that stores a single key hash at the leaf of the map tree. +type mapLeafNode[K, V any] interface { + mapNode[K, V] + keyHashValue() uint32 +} + +var _ mapLeafNode[string, any] = (*mapValueNode[string, any])(nil) +var _ mapLeafNode[string, any] = (*mapHashCollisionNode[string, any])(nil) + +// mapArrayNode is a map node that stores key/value pairs in a slice. +// Entries are stored in insertion order. An array node expands into a bitmap +// indexed node once a given threshold size is crossed. +type mapArrayNode[K, V any] struct { + entries []mapEntry[K, V] +} + +// indexOf returns the entry index of the given key. Returns -1 if key not found. +func (n *mapArrayNode[K, V]) indexOf(key K, h Hasher[K]) int { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return i + } + } + return -1 +} + +// get returns the value for the given key. +func (n *mapArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { + i := n.indexOf(key, h) + if i == -1 { + return value, false + } + return n.entries[i].value, true +} + +// set inserts or updates the value for a given key. If the key is inserted and +// the new size crosses the max size threshold, a bitmap indexed node is returned. +func (n *mapArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + idx := n.indexOf(key, h) + + // Mark as resized if the key doesn't exist. + if idx == -1 { + *resized = true + } + + // If we are adding and it crosses the max size threshold, expand the node. + // We do this by continually setting the entries to a value node and expanding. + if idx == -1 && len(n.entries) >= maxArrayMapSize { + var node mapNode[K, V] = newMapValueNode(h.Hash(key), key, value) + for _, entry := range n.entries { + node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, false, resized) + } + return node + } + + // Update in-place if mutable. + if mutable { + if idx != -1 { + n.entries[idx] = mapEntry[K, V]{key, value} + } else { + n.entries = append(n.entries, mapEntry[K, V]{key, value}) + } + return n + } + + // Update existing entry if a match is found. + // Otherwise append to the end of the element list if it doesn't exist. + var other mapArrayNode[K, V] + if idx != -1 { + other.entries = make([]mapEntry[K, V], len(n.entries)) + copy(other.entries, n.entries) + other.entries[idx] = mapEntry[K, V]{key, value} + } else { + other.entries = make([]mapEntry[K, V], len(n.entries)+1) + copy(other.entries, n.entries) + other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} + } + return &other +} + +// delete removes the given key from the node. Returns the same node if key does +// not exist. Returns a nil node when removing the last entry. +func (n *mapArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + idx := n.indexOf(key, h) + + // Return original node if key does not exist. + if idx == -1 { + return n + } + *resized = true + + // Return nil if this node will contain no nodes. + if len(n.entries) == 1 { + return nil + } + + // Update in-place, if mutable. + if mutable { + copy(n.entries[idx:], n.entries[idx+1:]) + n.entries[len(n.entries)-1] = mapEntry[K, V]{} + n.entries = n.entries[:len(n.entries)-1] + return n + } + + // Otherwise create a copy with the given entry removed. + other := &mapArrayNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// mapBitmapIndexedNode represents a map branch node with a variable number of +// node slots and indexed using a bitmap. Indexes for the node slots are +// calculated by counting the number of set bits before the target bit using popcount. +type mapBitmapIndexedNode[K, V any] struct { + bitmap uint32 + nodes []mapNode[K, V] +} + +// get returns the value for the given key. +func (n *mapBitmapIndexedNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { + bit := uint32(1) << ((keyHash >> shift) & mapNodeMask) + if (n.bitmap & bit) == 0 { + return value, false + } + child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))] + return child.get(key, shift+mapNodeBits, keyHash, h) +} + +// set inserts or updates the value for the given key. If a new key is inserted +// and the size crosses the max size threshold then a hash array node is returned. +func (n *mapBitmapIndexedNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + // Extract the index for the bit segment of the key hash. + keyHashFrag := (keyHash >> shift) & mapNodeMask + + // Determine the bit based on the hash index. + bit := uint32(1) << keyHashFrag + exists := (n.bitmap & bit) != 0 + + // Mark as resized if the key doesn't exist. + if !exists { + *resized = true + } + + // Find index of node based on popcount of bits before it. + idx := bits.OnesCount32(n.bitmap & (bit - 1)) + + // If the node already exists, delegate set operation to it. + // If the node doesn't exist then create a simple value leaf node. + var newNode mapNode[K, V] + if exists { + newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized) + } else { + newNode = newMapValueNode(keyHash, key, value) + } + + // Convert to a hash-array node once we exceed the max bitmap size. + // Copy each node based on their bit position within the bitmap. + if !exists && len(n.nodes) > maxBitmapIndexedSize { + var other mapHashArrayNode[K, V] + for i := uint(0); i < uint(len(other.nodes)); i++ { + if n.bitmap&(uint32(1)<> shift) & mapNodeMask) + + // Return original node if key does not exist. + if (n.bitmap & bit) == 0 { + return n + } + + // Find index of node based on popcount of bits before it. + idx := bits.OnesCount32(n.bitmap & (bit - 1)) + + // Delegate delete to child node. + child := n.nodes[idx] + newChild := child.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized) + + // Return original node if key doesn't exist in child. + if !*resized { + return n + } + + // Remove if returned child has been deleted. + if newChild == nil { + // If we won't have any children then return nil. + if len(n.nodes) == 1 { + return nil + } + + // Update in-place if mutable. + if mutable { + n.bitmap ^= bit + copy(n.nodes[idx:], n.nodes[idx+1:]) + n.nodes[len(n.nodes)-1] = nil + n.nodes = n.nodes[:len(n.nodes)-1] + return n + } + + // Return copy with bit removed from bitmap and node removed from node list. + other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap ^ bit, nodes: make([]mapNode[K, V], len(n.nodes)-1)} + copy(other.nodes[:idx], n.nodes[:idx]) + copy(other.nodes[idx:], n.nodes[idx+1:]) + return other + } + + // Generate copy, if necessary. + other := n + if !mutable { + other = &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap, nodes: make([]mapNode[K, V], len(n.nodes))} + copy(other.nodes, n.nodes) + } + + // Update child. + other.nodes[idx] = newChild + return other +} + +// mapHashArrayNode is a map branch node that stores nodes in a fixed length +// array. Child nodes are indexed by their index bit segment for the current depth. +type mapHashArrayNode[K, V any] struct { + count uint // number of set nodes + nodes [mapNodeSize]mapNode[K, V] // child node slots, may contain empties +} + +// clone returns a shallow copy of n. +func (n *mapHashArrayNode[K, V]) clone() *mapHashArrayNode[K, V] { + other := *n + return &other +} + +// get returns the value for the given key. +func (n *mapHashArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { + node := n.nodes[(keyHash>>shift)&mapNodeMask] + if node == nil { + return value, false + } + return node.get(key, shift+mapNodeBits, keyHash, h) +} + +// set returns a node with the value set for the given key. +func (n *mapHashArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + idx := (keyHash >> shift) & mapNodeMask + node := n.nodes[idx] + + // If node at index doesn't exist, create a simple value leaf node. + // Otherwise delegate set to child node. + var newNode mapNode[K, V] + if node == nil { + *resized = true + newNode = newMapValueNode(keyHash, key, value) + } else { + newNode = node.set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized) + } + + // Generate copy, if necessary. + other := n + if !mutable { + other = n.clone() + } + + // Update child node (and update size, if new). + if node == nil { + other.count++ + } + other.nodes[idx] = newNode + return other +} + +// delete returns a node with the given key removed. Returns the same node if +// the key does not exist. If node shrinks to within bitmap-indexed size then +// converts to a bitmap-indexed node. +func (n *mapHashArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + idx := (keyHash >> shift) & mapNodeMask + node := n.nodes[idx] + + // Return original node if child is not found. + if node == nil { + return n + } + + // Return original node if child is unchanged. + newNode := node.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized) + if !*resized { + return n + } + + // If we remove a node and drop below a threshold, convert back to bitmap indexed node. + if newNode == nil && n.count <= maxBitmapIndexedSize { + other := &mapBitmapIndexedNode[K, V]{nodes: make([]mapNode[K, V], 0, n.count-1)} + for i, child := range n.nodes { + if child != nil && uint32(i) != idx { + other.bitmap |= 1 << uint(i) + other.nodes = append(other.nodes, child) + } + } + return other + } + + // Generate copy, if necessary. + other := n + if !mutable { + other = n.clone() + } + + // Return copy of node with child updated. + other.nodes[idx] = newNode + if newNode == nil { + other.count-- + } + return other +} + +// mapValueNode represents a leaf node with a single key/value pair. +// A value node can be converted to a hash collision leaf node if a different +// key with the same keyHash is inserted. +type mapValueNode[K, V any] struct { + keyHash uint32 + key K + value V +} + +// newMapValueNode returns a new instance of mapValueNode. +func newMapValueNode[K, V any](keyHash uint32, key K, value V) *mapValueNode[K, V] { + return &mapValueNode[K, V]{ + keyHash: keyHash, + key: key, + value: value, + } +} + +// keyHashValue returns the key hash for this node. +func (n *mapValueNode[K, V]) keyHashValue() uint32 { + return n.keyHash +} + +// get returns the value for the given key. +func (n *mapValueNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { + if !h.Equal(n.key, key) { + return value, false + } + return n.value, true +} + +// set returns a new node with the new value set for the key. If the key equals +// the node's key then a new value node is returned. If key is not equal to the +// node's key but has the same hash then a hash collision node is returned. +// Otherwise the nodes are merged into a branch node. +func (n *mapValueNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + // If the keys match then return a new value node overwriting the value. + if h.Equal(n.key, key) { + // Update in-place if mutable. + if mutable { + n.value = value + return n + } + // Otherwise return a new copy. + return newMapValueNode(n.keyHash, key, value) + } + + *resized = true + + // Recursively merge nodes together if key hashes are different. + if n.keyHash != keyHash { + return mergeIntoNode[K, V](n, shift, keyHash, key, value) + } + + // Merge into collision node if hash matches. + return &mapHashCollisionNode[K, V]{keyHash: keyHash, entries: []mapEntry[K, V]{ + {key: n.key, value: n.value}, + {key: key, value: value}, + }} +} + +// delete returns nil if the key matches the node's key. Otherwise returns the original node. +func (n *mapValueNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + // Return original node if the keys do not match. + if !h.Equal(n.key, key) { + return n + } + + // Otherwise remove the node if keys do match. + *resized = true + return nil +} + +// mapHashCollisionNode represents a leaf node that contains two or more key/value +// pairs with the same key hash. Single pairs for a hash are stored as value nodes. +type mapHashCollisionNode[K, V any] struct { + keyHash uint32 // key hash for all entries + entries []mapEntry[K, V] +} + +// keyHashValue returns the key hash for all entries on the node. +func (n *mapHashCollisionNode[K, V]) keyHashValue() uint32 { + return n.keyHash +} + +// indexOf returns the index of the entry for the given key. +// Returns -1 if the key does not exist in the node. +func (n *mapHashCollisionNode[K, V]) indexOf(key K, h Hasher[K]) int { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return i + } + } + return -1 +} + +// get returns the value for the given key. +func (n *mapHashCollisionNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) { + for i := range n.entries { + if h.Equal(n.entries[i].key, key) { + return n.entries[i].value, true + } + } + return value, false +} + +// set returns a copy of the node with key set to the given value. +func (n *mapHashCollisionNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + // Merge node with key/value pair if this is not a hash collision. + if n.keyHash != keyHash { + *resized = true + return mergeIntoNode[K, V](n, shift, keyHash, key, value) + } + + // Update in-place if mutable. + if mutable { + if idx := n.indexOf(key, h); idx == -1 { + *resized = true + n.entries = append(n.entries, mapEntry[K, V]{key, value}) + } else { + n.entries[idx] = mapEntry[K, V]{key, value} + } + return n + } + + // Append to end of node if key doesn't exist & mark resized. + // Otherwise copy nodes and overwrite at matching key index. + other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash} + if idx := n.indexOf(key, h); idx == -1 { + *resized = true + other.entries = make([]mapEntry[K, V], len(n.entries)+1) + copy(other.entries, n.entries) + other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value} + } else { + other.entries = make([]mapEntry[K, V], len(n.entries)) + copy(other.entries, n.entries) + other.entries[idx] = mapEntry[K, V]{key, value} + } + return other +} + +// delete returns a node with the given key deleted. Returns the same node if +// the key does not exist. If removing the key would shrink the node to a single +// entry then a value node is returned. +func (n *mapHashCollisionNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] { + idx := n.indexOf(key, h) + + // Return original node if key is not found. + if idx == -1 { + return n + } + + // Mark as resized if key exists. + *resized = true + + // Convert to value node if we move to one entry. + if len(n.entries) == 2 { + return &mapValueNode[K, V]{ + keyHash: n.keyHash, + key: n.entries[idx^1].key, + value: n.entries[idx^1].value, + } + } + + // Remove entry in-place if mutable. + if mutable { + copy(n.entries[idx:], n.entries[idx+1:]) + n.entries[len(n.entries)-1] = mapEntry[K, V]{} + n.entries = n.entries[:len(n.entries)-1] + return n + } + + // Return copy without entry if immutable. + other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash, entries: make([]mapEntry[K, V], len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// mergeIntoNode merges a key/value pair into an existing node. +// Caller must verify that node's keyHash is not equal to keyHash. +func mergeIntoNode[K, V any](node mapLeafNode[K, V], shift uint, keyHash uint32, key K, value V) mapNode[K, V] { + idx1 := (node.keyHashValue() >> shift) & mapNodeMask + idx2 := (keyHash >> shift) & mapNodeMask + + // Recursively build branch nodes to combine the node and its key. + other := &mapBitmapIndexedNode[K, V]{bitmap: (1 << idx1) | (1 << idx2)} + if idx1 == idx2 { + other.nodes = []mapNode[K, V]{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)} + } else { + if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 { + other.nodes = []mapNode[K, V]{node, newNode} + } else { + other.nodes = []mapNode[K, V]{newNode, node} + } + } + return other +} + +// mapEntry represents a single key/value pair. +type mapEntry[K, V any] struct { + key K + value V +} + +// MapIterator represents an iterator over a map's key/value pairs. Although +// map keys are not sorted, the iterator's order is deterministic. +type MapIterator[K, V any] struct { + m *Map[K, V] // source map + + stack [32]mapIteratorElem[K, V] // search stack + depth int // stack depth +} + +// Done returns true if no more elements remain in the iterator. +func (itr *MapIterator[K, V]) Done() bool { + return itr.depth == -1 +} + +// First resets the iterator to the first key/value pair. +func (itr *MapIterator[K, V]) First() { + // Exit immediately if the map is empty. + if itr.m.root == nil { + itr.depth = -1 + return + } + + // Initialize the stack to the left most element. + itr.stack[0] = mapIteratorElem[K, V]{node: itr.m.root} + itr.depth = 0 + itr.first() +} + +// Next returns the next key/value pair. Returns a nil key when no elements remain. +func (itr *MapIterator[K, V]) Next() (key K, value V, ok bool) { + // Return nil key if iteration is done. + if itr.Done() { + return key, value, false + } + + // Retrieve current index & value. Current node is always a leaf. + elem := &itr.stack[itr.depth] + switch node := elem.node.(type) { + case *mapArrayNode[K, V]: + entry := &node.entries[elem.index] + key, value = entry.key, entry.value + case *mapValueNode[K, V]: + key, value = node.key, node.value + case *mapHashCollisionNode[K, V]: + entry := &node.entries[elem.index] + key, value = entry.key, entry.value + } + + // Move up stack until we find a node that has remaining position ahead + // and move that element forward by one. + itr.next() + return key, value, true +} + +// next moves to the next available key. +func (itr *MapIterator[K, V]) next() { + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *mapArrayNode[K, V]: + if elem.index < len(node.entries)-1 { + elem.index++ + return + } + + case *mapBitmapIndexedNode[K, V]: + if elem.index < len(node.nodes)-1 { + elem.index++ + itr.stack[itr.depth+1].node = node.nodes[elem.index] + itr.depth++ + itr.first() + return + } + + case *mapHashArrayNode[K, V]: + for i := elem.index + 1; i < len(node.nodes); i++ { + if node.nodes[i] != nil { + elem.index = i + itr.stack[itr.depth+1].node = node.nodes[elem.index] + itr.depth++ + itr.first() + return + } + } + + case *mapValueNode[K, V]: + continue // always the last value, traverse up + + case *mapHashCollisionNode[K, V]: + if elem.index < len(node.entries)-1 { + elem.index++ + return + } + } + } +} + +// first positions the stack left most index. +// Elements and indexes at and below the current depth are assumed to be correct. +func (itr *MapIterator[K, V]) first() { + for ; ; itr.depth++ { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *mapBitmapIndexedNode[K, V]: + elem.index = 0 + itr.stack[itr.depth+1].node = node.nodes[0] + + case *mapHashArrayNode[K, V]: + for i := 0; i < len(node.nodes); i++ { + if node.nodes[i] != nil { // find first node + elem.index = i + itr.stack[itr.depth+1].node = node.nodes[i] + break + } + } + + default: // *mapArrayNode, mapLeafNode + elem.index = 0 + return + } + } +} + +// mapIteratorElem represents a node/index pair in the MapIterator stack. +type mapIteratorElem[K, V any] struct { + node mapNode[K, V] + index int +} + +// Sorted map child node limit size. +const ( + sortedMapNodeSize = 32 +) + +// SortedMap represents a map of key/value pairs sorted by key. The sort order +// is determined by the Comparer used by the map. +// +// This map is implemented as a B+tree. +type SortedMap[K, V any] struct { + size int // total number of key/value pairs + root sortedMapNode[K, V] // root of b+tree + comparer Comparer[K] +} + +// NewSortedMap returns a new instance of SortedMap. If comparer is nil then +// a default comparer is set after the first key is inserted. Default comparers +// exist for int, string, and byte slice keys. +func NewSortedMap[K, V any](comparer Comparer[K]) *SortedMap[K, V] { + return &SortedMap[K, V]{ + comparer: comparer, + } +} + +// NewSortedMapOf returns a new instance of SortedMap, containing a map of provided entries. +// +// If comparer is nil then a default comparer is set after the first key is inserted. Default comparers +// exist for int, string, and byte slice keys. +func NewSortedMapOf[K comparable, V any](comparer Comparer[K], entries map[K]V) *SortedMap[K, V] { + m := &SortedMap[K, V]{ + comparer: comparer, + } + for k, v := range entries { + m.set(k, v, true) + } + return m +} + +// Len returns the number of elements in the sorted map. +func (m *SortedMap[K, V]) Len() int { + return m.size +} + +// Get returns the value for a given key and a flag indicating if the key is set. +// The flag can be used to distinguish between a nil-set key versus an unset key. +func (m *SortedMap[K, V]) Get(key K) (V, bool) { + if m.root == nil { + var v V + return v, false + } + return m.root.get(key, m.comparer) +} + +// Set returns a copy of the map with the key set to the given value. +func (m *SortedMap[K, V]) Set(key K, value V) *SortedMap[K, V] { + return m.set(key, value, false) +} + +func (m *SortedMap[K, V]) set(key K, value V, mutable bool) *SortedMap[K, V] { + // Set a comparer on the first value if one does not already exist. + comparer := m.comparer + if comparer == nil { + comparer = NewComparer(key) + } + + // Create copy, if necessary. + other := m + if !mutable { + other = m.clone() + } + other.comparer = comparer + + // If no values are set then initialize with a leaf node. + if m.root == nil { + other.size = 1 + other.root = &sortedMapLeafNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}} + return other + } + + // Otherwise delegate to root node. + // If a split occurs then grow the tree from the root. + var resized bool + newRoot, splitNode := m.root.set(key, value, comparer, mutable, &resized) + if splitNode != nil { + newRoot = newSortedMapBranchNode(newRoot, splitNode) + } + + // Update root and size (if resized). + other.size = m.size + other.root = newRoot + if resized { + other.size++ + } + return other +} + +// Delete returns a copy of the map with the key removed. +// Returns the original map if key does not exist. +func (m *SortedMap[K, V]) Delete(key K) *SortedMap[K, V] { + return m.delete(key, false) +} + +func (m *SortedMap[K, V]) delete(key K, mutable bool) *SortedMap[K, V] { + // Return original map if no keys exist. + if m.root == nil { + return m + } + + // If the delete did not change the node then return the original map. + var resized bool + newRoot := m.root.delete(key, m.comparer, mutable, &resized) + if !resized { + return m + } + + // Create copy, if necessary. + other := m + if !mutable { + other = m.clone() + } + + // Update root and size. + other.size = m.size - 1 + other.root = newRoot + return other +} + +// clone returns a shallow copy of m. +func (m *SortedMap[K, V]) clone() *SortedMap[K, V] { + other := *m + return &other +} + +// Iterator returns a new iterator for this map positioned at the first key. +func (m *SortedMap[K, V]) Iterator() *SortedMapIterator[K, V] { + itr := &SortedMapIterator[K, V]{m: m} + itr.First() + return itr +} + +// SortedMapBuilder represents an efficient builder for creating sorted maps. +type SortedMapBuilder[K, V any] struct { + m *SortedMap[K, V] // current state +} + +// NewSortedMapBuilder returns a new instance of SortedMapBuilder. +func NewSortedMapBuilder[K, V any](comparer Comparer[K]) *SortedMapBuilder[K, V] { + return &SortedMapBuilder[K, V]{m: NewSortedMap[K, V](comparer)} +} + +// SortedMap returns the current copy of the map. +// The returned map is safe to use even if after the builder continues to be used. +func (b *SortedMapBuilder[K, V]) Map() *SortedMap[K, V] { + assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") + m := b.m + b.m = nil + return m +} + +// Len returns the number of elements in the underlying map. +func (b *SortedMapBuilder[K, V]) Len() int { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + return b.m.Len() +} + +// Get returns the value for the given key. +func (b *SortedMapBuilder[K, V]) Get(key K) (value V, ok bool) { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + return b.m.Get(key) +} + +// Set sets the value of the given key. See SortedMap.Set() for additional details. +func (b *SortedMapBuilder[K, V]) Set(key K, value V) { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + b.m = b.m.set(key, value, true) +} + +// Delete removes the given key. See SortedMap.Delete() for additional details. +func (b *SortedMapBuilder[K, V]) Delete(key K) { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + b.m = b.m.delete(key, true) +} + +// Iterator returns a new iterator for the underlying map positioned at the first key. +func (b *SortedMapBuilder[K, V]) Iterator() *SortedMapIterator[K, V] { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + return b.m.Iterator() +} + +// sortedMapNode represents a branch or leaf node in the sorted map. +type sortedMapNode[K, V any] interface { + minKey() K + indexOf(key K, c Comparer[K]) int + get(key K, c Comparer[K]) (value V, ok bool) + set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) + delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] +} + +var _ sortedMapNode[string, any] = (*sortedMapBranchNode[string, any])(nil) +var _ sortedMapNode[string, any] = (*sortedMapLeafNode[string, any])(nil) + +// sortedMapBranchNode represents a branch in the sorted map. +type sortedMapBranchNode[K, V any] struct { + elems []sortedMapBranchElem[K, V] +} + +// newSortedMapBranchNode returns a new branch node with the given child nodes. +func newSortedMapBranchNode[K, V any](children ...sortedMapNode[K, V]) *sortedMapBranchNode[K, V] { + // Fetch min keys for every child. + elems := make([]sortedMapBranchElem[K, V], len(children)) + for i, child := range children { + elems[i] = sortedMapBranchElem[K, V]{ + key: child.minKey(), + node: child, + } + } + + return &sortedMapBranchNode[K, V]{elems: elems} +} + +// minKey returns the lowest key stored in this node's tree. +func (n *sortedMapBranchNode[K, V]) minKey() K { + return n.elems[0].node.minKey() +} + +// indexOf returns the index of the key within the child nodes. +func (n *sortedMapBranchNode[K, V]) indexOf(key K, c Comparer[K]) int { + if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 { + return idx - 1 + } + return 0 +} + +// get returns the value for the given key. +func (n *sortedMapBranchNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { + idx := n.indexOf(key, c) + return n.elems[idx].node.get(key, c) +} + +// set returns a copy of the node with the key set to the given value. +func (n *sortedMapBranchNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { + idx := n.indexOf(key, c) + + // Delegate insert to child node. + newNode, splitNode := n.elems[idx].node.set(key, value, c, mutable, resized) + + // Update in-place, if mutable. + if mutable { + n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} + if splitNode != nil { + n.elems = append(n.elems, sortedMapBranchElem[K, V]{}) + copy(n.elems[idx+1:], n.elems[idx:]) + n.elems[idx+1] = sortedMapBranchElem[K, V]{key: splitNode.minKey(), node: splitNode} + } + + // If the child splits and we have no more room then we split too. + if len(n.elems) > sortedMapNodeSize { + splitIdx := len(n.elems) / 2 + newNode := &sortedMapBranchNode[K, V]{elems: n.elems[:splitIdx:splitIdx]} + splitNode := &sortedMapBranchNode[K, V]{elems: n.elems[splitIdx:]} + return newNode, splitNode + } + return n, nil + } + + // If no split occurs, copy branch and update keys. + // If the child splits, insert new key/child into copy of branch. + var other sortedMapBranchNode[K, V] + if splitNode == nil { + other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)) + copy(other.elems, n.elems) + other.elems[idx] = sortedMapBranchElem[K, V]{ + key: newNode.minKey(), + node: newNode, + } + } else { + other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)+1) + copy(other.elems[:idx], n.elems[:idx]) + copy(other.elems[idx+1:], n.elems[idx:]) + other.elems[idx] = sortedMapBranchElem[K, V]{ + key: newNode.minKey(), + node: newNode, + } + other.elems[idx+1] = sortedMapBranchElem[K, V]{ + key: splitNode.minKey(), + node: splitNode, + } + } + + // If the child splits and we have no more room then we split too. + if len(other.elems) > sortedMapNodeSize { + splitIdx := len(other.elems) / 2 + newNode := &sortedMapBranchNode[K, V]{elems: other.elems[:splitIdx:splitIdx]} + splitNode := &sortedMapBranchNode[K, V]{elems: other.elems[splitIdx:]} + return newNode, splitNode + } + + // Otherwise return the new branch node with the updated entry. + return &other, nil +} + +// delete returns a node with the key removed. Returns the same node if the key +// does not exist. Returns nil if all child nodes are removed. +func (n *sortedMapBranchNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { + idx := n.indexOf(key, c) + + // Return original node if child has not changed. + newNode := n.elems[idx].node.delete(key, c, mutable, resized) + if !*resized { + return n + } + + // Remove child if it is now nil. + if newNode == nil { + // If this node will become empty then simply return nil. + if len(n.elems) == 1 { + return nil + } + + // If mutable, update in-place. + if mutable { + copy(n.elems[idx:], n.elems[idx+1:]) + n.elems[len(n.elems)-1] = sortedMapBranchElem[K, V]{} + n.elems = n.elems[:len(n.elems)-1] + return n + } + + // Return a copy without the given node. + other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems)-1)} + copy(other.elems[:idx], n.elems[:idx]) + copy(other.elems[idx:], n.elems[idx+1:]) + return other + } + + // If mutable, update in-place. + if mutable { + n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode} + return n + } + + // Return a copy with the updated node. + other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems))} + copy(other.elems, n.elems) + other.elems[idx] = sortedMapBranchElem[K, V]{ + key: newNode.minKey(), + node: newNode, + } + return other +} + +type sortedMapBranchElem[K, V any] struct { + key K + node sortedMapNode[K, V] +} + +// sortedMapLeafNode represents a leaf node in the sorted map. +type sortedMapLeafNode[K, V any] struct { + entries []mapEntry[K, V] +} + +// minKey returns the first key stored in this node. +func (n *sortedMapLeafNode[K, V]) minKey() K { + return n.entries[0].key +} + +// indexOf returns the index of the given key. +func (n *sortedMapLeafNode[K, V]) indexOf(key K, c Comparer[K]) int { + return sort.Search(len(n.entries), func(i int) bool { + return c.Compare(n.entries[i].key, key) != -1 // GTE + }) +} + +// get returns the value of the given key. +func (n *sortedMapLeafNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) { + idx := n.indexOf(key, c) + + // If the index is beyond the entry count or the key is not equal then return 'not found'. + if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { + return value, false + } + + // If the key matches then return its value. + return n.entries[idx].value, true +} + +// set returns a copy of node with the key set to the given value. If the update +// causes the node to grow beyond the maximum size then it is split in two. +func (n *sortedMapLeafNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) { + // Find the insertion index for the key. + idx := n.indexOf(key, c) + exists := idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0 + + // Update in-place, if mutable. + if mutable { + if !exists { + *resized = true + n.entries = append(n.entries, mapEntry[K, V]{}) + copy(n.entries[idx+1:], n.entries[idx:]) + } + n.entries[idx] = mapEntry[K, V]{key: key, value: value} + + // If the key doesn't exist and we exceed our max allowed values then split. + if len(n.entries) > sortedMapNodeSize { + splitIdx := len(n.entries) / 2 + newNode := &sortedMapLeafNode[K, V]{entries: n.entries[:splitIdx:splitIdx]} + splitNode := &sortedMapLeafNode[K, V]{entries: n.entries[splitIdx:]} + return newNode, splitNode + } + return n, nil + } + + // If the key matches then simply return a copy with the entry overridden. + // If there is no match then insert new entry and mark as resized. + var newEntries []mapEntry[K, V] + if exists { + newEntries = make([]mapEntry[K, V], len(n.entries)) + copy(newEntries, n.entries) + newEntries[idx] = mapEntry[K, V]{key: key, value: value} + } else { + *resized = true + newEntries = make([]mapEntry[K, V], len(n.entries)+1) + copy(newEntries[:idx], n.entries[:idx]) + newEntries[idx] = mapEntry[K, V]{key: key, value: value} + copy(newEntries[idx+1:], n.entries[idx:]) + } + + // If the key doesn't exist and we exceed our max allowed values then split. + if len(newEntries) > sortedMapNodeSize { + splitIdx := len(newEntries) / 2 + newNode := &sortedMapLeafNode[K, V]{entries: newEntries[:splitIdx:splitIdx]} + splitNode := &sortedMapLeafNode[K, V]{entries: newEntries[splitIdx:]} + return newNode, splitNode + } + + // Otherwise return the new leaf node with the updated entry. + return &sortedMapLeafNode[K, V]{entries: newEntries}, nil +} + +// delete returns a copy of node with key removed. Returns the original node if +// the key does not exist. Returns nil if the removed key is the last remaining key. +func (n *sortedMapLeafNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] { + idx := n.indexOf(key, c) + + // Return original node if key is not found. + if idx >= len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 { + return n + } + *resized = true + + // If this is the last entry then return nil. + if len(n.entries) == 1 { + return nil + } + + // Update in-place, if mutable. + if mutable { + copy(n.entries[idx:], n.entries[idx+1:]) + n.entries[len(n.entries)-1] = mapEntry[K, V]{} + n.entries = n.entries[:len(n.entries)-1] + return n + } + + // Return copy of node with entry removed. + other := &sortedMapLeafNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)} + copy(other.entries[:idx], n.entries[:idx]) + copy(other.entries[idx:], n.entries[idx+1:]) + return other +} + +// SortedMapIterator represents an iterator over a sorted map. +// Iteration can occur in natural or reverse order based on use of Next() or Prev(). +type SortedMapIterator[K, V any] struct { + m *SortedMap[K, V] // source map + + stack [32]sortedMapIteratorElem[K, V] // search stack + depth int // stack depth +} + +// Done returns true if no more key/value pairs remain in the iterator. +func (itr *SortedMapIterator[K, V]) Done() bool { + return itr.depth == -1 +} + +// First moves the iterator to the first key/value pair. +func (itr *SortedMapIterator[K, V]) First() { + if itr.m.root == nil { + itr.depth = -1 + return + } + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} + itr.depth = 0 + itr.first() +} + +// Last moves the iterator to the last key/value pair. +func (itr *SortedMapIterator[K, V]) Last() { + if itr.m.root == nil { + itr.depth = -1 + return + } + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} + itr.depth = 0 + itr.last() +} + +// Seek moves the iterator position to the given key in the map. +// If the key does not exist then the next key is used. If no more keys exist +// then the iteartor is marked as done. +func (itr *SortedMapIterator[K, V]) Seek(key K) { + if itr.m.root == nil { + itr.depth = -1 + return + } + itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root} + itr.depth = 0 + itr.seek(key) +} + +// Next returns the current key/value pair and moves the iterator forward. +// Returns a nil key if the there are no more elements to return. +func (itr *SortedMapIterator[K, V]) Next() (key K, value V, ok bool) { + // Return nil key if iteration is complete. + if itr.Done() { + return key, value, false + } + + // Retrieve current key/value pair. + leafElem := &itr.stack[itr.depth] + leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) + leafEntry := &leafNode.entries[leafElem.index] + key, value = leafEntry.key, leafEntry.value + + // Move to the next available key/value pair. + itr.next() + + // Only occurs when iterator is done. + return key, value, true +} + +// next moves to the next key. If no keys are after then depth is set to -1. +func (itr *SortedMapIterator[K, V]) next() { + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapLeafNode[K, V]: + if elem.index < len(node.entries)-1 { + elem.index++ + return + } + case *sortedMapBranchNode[K, V]: + if elem.index < len(node.elems)-1 { + elem.index++ + itr.stack[itr.depth+1].node = node.elems[elem.index].node + itr.depth++ + itr.first() + return + } + } + } +} + +// Prev returns the current key/value pair and moves the iterator backward. +// Returns a nil key if the there are no more elements to return. +func (itr *SortedMapIterator[K, V]) Prev() (key K, value V, ok bool) { + // Return nil key if iteration is complete. + if itr.Done() { + return key, value, false + } + + // Retrieve current key/value pair. + leafElem := &itr.stack[itr.depth] + leafNode := leafElem.node.(*sortedMapLeafNode[K, V]) + leafEntry := &leafNode.entries[leafElem.index] + key, value = leafEntry.key, leafEntry.value + + itr.prev() + return key, value, true +} + +// prev moves to the previous key. If no keys are before then depth is set to -1. +func (itr *SortedMapIterator[K, V]) prev() { + for ; itr.depth >= 0; itr.depth-- { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapLeafNode[K, V]: + if elem.index > 0 { + elem.index-- + return + } + case *sortedMapBranchNode[K, V]: + if elem.index > 0 { + elem.index-- + itr.stack[itr.depth+1].node = node.elems[elem.index].node + itr.depth++ + itr.last() + return + } + } + } +} + +// first positions the stack to the leftmost key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator[K, V]) first() { + for { + elem := &itr.stack[itr.depth] + elem.index = 0 + + switch node := elem.node.(type) { + case *sortedMapBranchNode[K, V]: + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode[K, V]: + return + } + } +} + +// last positions the stack to the rightmost key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator[K, V]) last() { + for { + elem := &itr.stack[itr.depth] + + switch node := elem.node.(type) { + case *sortedMapBranchNode[K, V]: + elem.index = len(node.elems) - 1 + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode[K, V]: + elem.index = len(node.entries) - 1 + return + } + } +} + +// seek positions the stack to the given key from the current depth. +// Elements and indexes below the current depth are assumed to be correct. +func (itr *SortedMapIterator[K, V]) seek(key K) { + for { + elem := &itr.stack[itr.depth] + elem.index = elem.node.indexOf(key, itr.m.comparer) + + switch node := elem.node.(type) { + case *sortedMapBranchNode[K, V]: + itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node} + itr.depth++ + case *sortedMapLeafNode[K, V]: + if elem.index == len(node.entries) { + itr.next() + } + return + } + } +} + +// sortedMapIteratorElem represents node/index pair in the SortedMapIterator stack. +type sortedMapIteratorElem[K, V any] struct { + node sortedMapNode[K, V] + index int +} + +// Hasher hashes keys and checks them for equality. +type Hasher[K any] interface { + // Computes a hash for key. + Hash(key K) uint32 + + // Returns true if a and b are equal. + Equal(a, b K) bool +} + +// NewHasher returns the built-in hasher for a given key type. +func NewHasher[K any](key K) Hasher[K] { + // Attempt to use non-reflection based hasher first. + switch (any(key)).(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: + return &defaultHasher[K]{} + } + + // Fallback to reflection-based hasher otherwise. + // This is used when caller wraps a type around a primitive type. + switch reflect.TypeOf(key).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + return &reflectHasher[K]{} + } + + // If no hashers match then panic. + // This is a compile time issue so it should not return an error. + panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key)) +} + +// Hash returns a hash for value. +func hashString(value string) uint32 { + var hash uint32 + for i, value := 0, value; i < len(value); i++ { + hash = 31*hash + uint32(value[i]) + } + return hash +} + +// reflectIntHasher implements a reflection-based Hasher for keys. +type reflectHasher[K any] struct{} + +// Hash returns a hash for key. +func (h *reflectHasher[K]) Hash(key K) uint32 { + switch reflect.TypeOf(key).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return hashUint64(uint64(reflect.ValueOf(key).Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return hashUint64(reflect.ValueOf(key).Uint()) + case reflect.String: + var hash uint32 + s := reflect.ValueOf(key).String() + for i := 0; i < len(s); i++ { + hash = 31*hash + uint32(s[i]) + } + return hash + } + panic(fmt.Sprintf("immutable.reflectHasher.Hash: reflectHasher does not support %T type", key)) +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not int-ish or string-ish. +func (h *reflectHasher[K]) Equal(a, b K) bool { + switch reflect.TypeOf(a).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint() + case reflect.String: + return reflect.ValueOf(a).String() == reflect.ValueOf(b).String() + } + panic(fmt.Sprintf("immutable.reflectHasher.Equal: reflectHasher does not support %T type", a)) + +} + +// hashUint64 returns a 32-bit hash for a 64-bit value. +func hashUint64(value uint64) uint32 { + hash := value + for value > 0xffffffff { + value /= 0xffffffff + hash ^= value + } + return uint32(hash) +} + +// defaultHasher implements Hasher. +type defaultHasher[K any] struct{} + +// Hash returns a hash for key. +func (h *defaultHasher[K]) Hash(key K) uint32 { + switch x := (any(key)).(type) { + case int: + return hashUint64(uint64(x)) + case int8: + return hashUint64(uint64(x)) + case int16: + return hashUint64(uint64(x)) + case int32: + return hashUint64(uint64(x)) + case int64: + return hashUint64(uint64(x)) + case uint: + return hashUint64(uint64(x)) + case uint8: + return hashUint64(uint64(x)) + case uint16: + return hashUint64(uint64(x)) + case uint32: + return hashUint64(uint64(x)) + case uint64: + return hashUint64(uint64(x)) + case uintptr: + return hashUint64(uint64(x)) + case string: + return hashString(x) + } + panic(fmt.Sprintf("immutable.defaultHasher.Hash: must set comparer for %T type", key)) +} + +// Equal returns true if a is equal to b. Otherwise returns false. +// Panics if a and b are not comparable. +func (h *defaultHasher[K]) Equal(a, b K) bool { + return any(a) == any(b) +} + +// Comparer allows the comparison of two keys for the purpose of sorting. +type Comparer[K any] interface { + // Returns -1 if a is less than b, returns 1 if a is greater than b, + // and returns 0 if a is equal to b. + Compare(a, b K) int +} + +// NewComparer returns the built-in comparer for a given key type. +// Note that only int-ish and string-ish types are supported, despite the 'comparable' constraint. +// Attempts to use other types will result in a panic - users should define their own Comparers for these cases. +func NewComparer[K any](key K) Comparer[K] { + // Attempt to use non-reflection based comparer first. + switch (any(key)).(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string: + return &defaultComparer[K]{} + } + // Fallback to reflection-based comparer otherwise. + // This is used when caller wraps a type around a primitive type. + switch reflect.TypeOf(key).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + return &reflectComparer[K]{} + } + // If no comparers match then panic. + // This is a compile time issue so it should not return an error. + panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key)) +} + +// defaultComparer compares two values (int-ish and string-ish types are supported). Implements Comparer. +type defaultComparer[K any] struct{} + +// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and +// returns 0 if a is equal to b. Panic if a or b is not a string or int* type +func (c *defaultComparer[K]) Compare(i K, j K) int { + switch x := (any(i)).(type) { + case int: + return defaultCompare(x, (any(j)).(int)) + case int8: + return defaultCompare(x, (any(j)).(int8)) + case int16: + return defaultCompare(x, (any(j)).(int16)) + case int32: + return defaultCompare(x, (any(j)).(int32)) + case int64: + return defaultCompare(x, (any(j)).(int64)) + case uint: + return defaultCompare(x, (any(j)).(uint)) + case uint8: + return defaultCompare(x, (any(j)).(uint8)) + case uint16: + return defaultCompare(x, (any(j)).(uint16)) + case uint32: + return defaultCompare(x, (any(j)).(uint32)) + case uint64: + return defaultCompare(x, (any(j)).(uint64)) + case uintptr: + return defaultCompare(x, (any(j)).(uintptr)) + case string: + return defaultCompare(x, (any(j)).(string)) + } + panic(fmt.Sprintf("immutable.defaultComparer: must set comparer for %T type", i)) +} + +// defaultCompare only operates on constraints.Ordered. +// For other types, users should bring their own comparers +func defaultCompare[K cmp.Ordered](i, j K) int { + if i < j { + return -1 + } else if i > j { + return 1 + } + return 0 +} + +// reflectIntComparer compares two values using reflection. Implements Comparer. +type reflectComparer[K any] struct{} + +// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and +// returns 0 if a is equal to b. Panic if a or b is not an int-ish or string-ish type. +func (c *reflectComparer[K]) Compare(a, b K) int { + switch reflect.TypeOf(a).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j { + return -1 + } else if i > j { + return 1 + } + return 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j { + return -1 + } else if i > j { + return 1 + } + return 0 + case reflect.String: + return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String()) + } + panic(fmt.Sprintf("immutable.reflectComparer.Compare: must set comparer for %T type", a)) +} + +func assert(condition bool, message string) { + if !condition { + panic(message) + } +} + +// Set represents a collection of unique values. The set uses a Hasher +// to generate hashes and check for equality of key values. +// +// Internally, the Set stores values as keys of a Map[T,struct{}] +type Set[T any] struct { + m *Map[T, struct{}] +} + +// NewSet returns a new instance of Set. +// +// If hasher is nil, a default hasher implementation will automatically be chosen based on the first key added. +// Default hasher implementations only exist for int, string, and byte slice types. +// NewSet can also take some initial values as varargs. +func NewSet[T any](hasher Hasher[T], values ...T) Set[T] { + m := NewMap[T, struct{}](hasher) + for _, value := range values { + m = m.set(value, struct{}{}, true) + } + return Set[T]{m} +} + +// Add returns a set containing the new value. +// +// This function will return a new set even if the set already contains the value. +func (s Set[T]) Add(value T) Set[T] { + return Set[T]{s.m.Set(value, struct{}{})} +} + +// Delete returns a set with the given key removed. +func (s Set[T]) Delete(value T) Set[T] { + return Set[T]{s.m.Delete(value)} +} + +// Has returns true when the set contains the given value +func (s Set[T]) Has(val T) bool { + _, ok := s.m.Get(val) + return ok +} + +// Len returns the number of elements in the underlying map. +func (s Set[K]) Len() int { + return s.m.Len() +} + +// Items returns a slice of the items inside the set +func (s Set[T]) Items() []T { + r := make([]T, 0, s.Len()) + itr := s.Iterator() + for !itr.Done() { + v, _ := itr.Next() + r = append(r, v) + } + return r +} + +// Iterator returns a new iterator for this set positioned at the first value. +func (s Set[T]) Iterator() *SetIterator[T] { + itr := &SetIterator[T]{mi: s.m.Iterator()} + itr.mi.First() + return itr +} + +// SetIterator represents an iterator over a set. +// Iteration can occur in natural or reverse order based on use of Next() or Prev(). +type SetIterator[T any] struct { + mi *MapIterator[T, struct{}] +} + +// Done returns true if no more values remain in the iterator. +func (itr *SetIterator[T]) Done() bool { + return itr.mi.Done() +} + +// First moves the iterator to the first value. +func (itr *SetIterator[T]) First() { + itr.mi.First() +} + +// Next moves the iterator to the next value. +func (itr *SetIterator[T]) Next() (val T, ok bool) { + val, _, ok = itr.mi.Next() + return +} + +type SetBuilder[T any] struct { + s Set[T] +} + +func NewSetBuilder[T any](hasher Hasher[T]) *SetBuilder[T] { + return &SetBuilder[T]{s: NewSet(hasher)} +} + +func (s SetBuilder[T]) Set(val T) { + s.s.m = s.s.m.set(val, struct{}{}, true) +} + +func (s SetBuilder[T]) Delete(val T) { + s.s.m = s.s.m.delete(val, true) +} + +func (s SetBuilder[T]) Has(val T) bool { + return s.s.Has(val) +} + +func (s SetBuilder[T]) Len() int { + return s.s.Len() +} + +type SortedSet[T any] struct { + m *SortedMap[T, struct{}] +} + +// NewSortedSet returns a new instance of SortedSet. +// +// If comparer is nil then +// a default comparer is set after the first key is inserted. Default comparers +// exist for int, string, and byte slice keys. +// NewSortedSet can also take some initial values as varargs. +func NewSortedSet[T any](comparer Comparer[T], values ...T) SortedSet[T] { + m := NewSortedMap[T, struct{}](comparer) + for _, value := range values { + m = m.set(value, struct{}{}, true) + } + return SortedSet[T]{m} +} + +// Add returns a set containing the new value. +// +// This function will return a new set even if the set already contains the value. +func (s SortedSet[T]) Add(value T) SortedSet[T] { + return SortedSet[T]{s.m.Set(value, struct{}{})} +} + +// Delete returns a set with the given key removed. +func (s SortedSet[T]) Delete(value T) SortedSet[T] { + return SortedSet[T]{s.m.Delete(value)} +} + +// Has returns true when the set contains the given value +func (s SortedSet[T]) Has(val T) bool { + _, ok := s.m.Get(val) + return ok +} + +// Len returns the number of elements in the underlying map. +func (s SortedSet[K]) Len() int { + return s.m.Len() +} + +// Items returns a slice of the items inside the set +func (s SortedSet[T]) Items() []T { + r := make([]T, 0, s.Len()) + itr := s.Iterator() + for !itr.Done() { + v, _ := itr.Next() + r = append(r, v) + } + return r +} + +// Iterator returns a new iterator for this set positioned at the first value. +func (s SortedSet[T]) Iterator() *SortedSetIterator[T] { + itr := &SortedSetIterator[T]{mi: s.m.Iterator()} + itr.mi.First() + return itr +} + +// SortedSetIterator represents an iterator over a sorted set. +// Iteration can occur in natural or reverse order based on use of Next() or Prev(). +type SortedSetIterator[T any] struct { + mi *SortedMapIterator[T, struct{}] +} + +// Done returns true if no more values remain in the iterator. +func (itr *SortedSetIterator[T]) Done() bool { + return itr.mi.Done() +} + +// First moves the iterator to the first value. +func (itr *SortedSetIterator[T]) First() { + itr.mi.First() +} + +// Last moves the iterator to the last value. +func (itr *SortedSetIterator[T]) Last() { + itr.mi.Last() +} + +// Next moves the iterator to the next value. +func (itr *SortedSetIterator[T]) Next() (val T, ok bool) { + val, _, ok = itr.mi.Next() + return +} + +// Prev moves the iterator to the previous value. +func (itr *SortedSetIterator[T]) Prev() (val T, ok bool) { + val, _, ok = itr.mi.Prev() + return +} + +// Seek moves the iterator to the given value. +// +// If the value does not exist then the next value is used. If no more keys exist +// then the iterator is marked as done. +func (itr *SortedSetIterator[T]) Seek(val T) { + itr.mi.Seek(val) +} + +type SortedSetBuilder[T any] struct { + s *SortedSet[T] +} + +func NewSortedSetBuilder[T any](comparer Comparer[T]) *SortedSetBuilder[T] { + s := NewSortedSet(comparer) + return &SortedSetBuilder[T]{s: &s} +} + +func (s SortedSetBuilder[T]) Set(val T) { + s.s.m = s.s.m.set(val, struct{}{}, true) +} + +func (s SortedSetBuilder[T]) Delete(val T) { + s.s.m = s.s.m.delete(val, true) +} + +func (s SortedSetBuilder[T]) Has(val T) bool { + return s.s.Has(val) +} + +func (s SortedSetBuilder[T]) Len() int { + return s.s.Len() +} + +// SortedSet returns the current copy of the set. +// The builder should not be used again after the list after this call. +func (s SortedSetBuilder[T]) SortedSet() SortedSet[T] { + assert(s.s != nil, "immutable.SortedSetBuilder.SortedSet(): duplicate call to fetch sorted set") + set := s.s + s.s = nil + return *set +} diff --git a/tests/main.go b/tests/main.go new file mode 100644 index 0000000..db923fd --- /dev/null +++ b/tests/main.go @@ -0,0 +1,7 @@ +package main + +import "pds" + +func main() { + pds.MainTest() +} diff --git a/tests/pds.go b/tests/pds.go new file mode 100644 index 0000000..ad74531 --- /dev/null +++ b/tests/pds.go @@ -0,0 +1,2669 @@ +package pds + +import ( + "cmp" + "flag" + "fmt" + "math/rand" + "sort" + "testing" +) + +var ( + veryVerbose = flag.Bool("vv", false, "very verbose") + randomN = flag.Int("random.n", 100, "number of RunRandom() iterations") +) + +func TestList(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + if size := NewList[string]().Len(); size != 0 { + t.Fatalf("unexpected size: %d", size) + } + }) + + t.Run("Shallow", func(t *testing.T) { + list := NewList[string]() + list = list.Append("foo") + if v := list.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } + + other := list.Append("bar") + if v := other.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } else if v := other.Get(1); v != "bar" { + t.Fatalf("unexpected value: %v", v) + } + + if v := list.Len(); v != 1 { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("Deep", func(t *testing.T) { + list := NewList[int]() + var array []int + for i := 0; i < 100000; i++ { + list = list.Append(i) + array = append(array, i) + } + + if got, exp := len(array), list.Len(); got != exp { + t.Fatalf("List.Len()=%d, exp %d", got, exp) + } + for j := range array { + if got, exp := list.Get(j), array[j]; got != exp { + t.Fatalf("%d. List.Get(%d)=%d, exp %d", len(array), j, got, exp) + } + } + }) + + t.Run("Set", func(t *testing.T) { + list := NewList[string]() + list = list.Append("foo") + list = list.Append("bar") + + if v := list.Get(0); v != "foo" { + t.Fatalf("unexpected value: %v", v) + } + + list = list.Set(0, "baz") + if v := list.Get(0); v != "baz" { + t.Fatalf("unexpected value: %v", v) + } else if v := list.Get(1); v != "bar" { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("GetBelowRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Get(-1) + }() + if r != `immutable.List.Get: index -1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("GetAboveRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Get(1) + }() + if r != `immutable.List.Get: index 1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SetOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Set(1, "bar") + }() + if r != `immutable.List.Set: index 1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceStartOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Slice(2, 3) + }() + if r != `immutable.List.Slice: start index 2 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceEndOutOfRange", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Slice(1, 3) + }() + if r != `immutable.List.Slice: end index 3 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceInvalidIndex", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l.Slice(2, 1) + }() + if r != `immutable.List.Slice: invalid slice index: [2:1]` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("SliceBeginning", func(t *testing.T) { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Slice(1, 2) + if got, exp := l.Len(), 1; got != exp { + t.Fatalf("List.Len()=%d, exp %d", got, exp) + } else if got, exp := l.Get(0), "bar"; got != exp { + t.Fatalf("List.Get(0)=%v, exp %v", got, exp) + } + }) + + t.Run("IteratorSeekOutOfBounds", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + l := NewList[string]() + l = l.Append("foo") + l.Iterator().Seek(-1) + }() + if r != `immutable.ListIterator.Seek: index -1 out of bounds` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + t.Run("TestSliceFreesReferences", func(t *testing.T) { + /* Test that the leaf node in a sliced list contains zero'ed entries at + * the correct positions. To do this we directly access the internal + * tree structure of the list. + */ + l := NewList[*int]() + var ints [5]int + for i := 0; i < 5; i++ { + l = l.Append(&ints[i]) + } + sl := l.Slice(2, 4) + + var findLeaf func(listNode[*int]) *listLeafNode[*int] + findLeaf = func(n listNode[*int]) *listLeafNode[*int] { + switch n := n.(type) { + case *listBranchNode[*int]: + if n.children[0] == nil { + t.Fatal("Failed to find leaf node due to nil child") + } + return findLeaf(n.children[0]) + case *listLeafNode[*int]: + return n + default: + panic("Unexpected case") + } + } + + leaf := findLeaf(sl.root) + if leaf.occupied != 0b1100 { + t.Errorf("Expected occupied to be 1100, was %032b", leaf.occupied) + } + + for i := 0; i < listNodeSize; i++ { + if 2 <= i && i < 4 { + if leaf.children[i] != &ints[i] { + t.Errorf("Position %v does not contain the right pointer?", i) + } + } else if leaf.children[i] != nil { + t.Errorf("Expected position %v to be cleared, was %v", i, leaf.children[i]) + } + } + }) + + t.Run("AppendImmutable", func(t *testing.T) { + outer_l := NewList[int]() + for N := 0; N < 1_000; N++ { + l1 := outer_l.Append(0) + outer_l.Append(1) + if actual := l1.Get(N); actual != 0 { + t.Fatalf("Append mutates list with %d elements. Got %d instead of 0", N, actual) + } + + outer_l = outer_l.Append(0) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + l := NewTList() + for i := 0; i < 100000; i++ { + rnd := rand.Intn(70) + switch { + case rnd == 0: // slice + start, end := l.ChooseSliceIndices(rand) + l.Slice(start, end) + case rnd < 10: // set + if l.Len() > 0 { + l.Set(l.ChooseIndex(rand), rand.Intn(10000)) + } + case rnd < 30: // prepend + l.Prepend(rand.Intn(10000)) + default: // append + l.Append(rand.Intn(10000)) + } + } + if err := l.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// TList represents a list that operates on a standard Go slice & immutable list. +type TList struct { + im, prev *List[int] + builder *ListBuilder[int] + std []int +} + +// NewTList returns a new instance of TList. +func NewTList() *TList { + return &TList{ + im: NewList[int](), + builder: NewListBuilder[int](), + } +} + +// Len returns the size of the list. +func (l *TList) Len() int { + return len(l.std) +} + +// ChooseIndex returns a randomly chosen, valid index from the standard slice. +func (l *TList) ChooseIndex(rand *rand.Rand) int { + if len(l.std) == 0 { + return -1 + } + return rand.Intn(len(l.std)) +} + +// ChooseSliceIndices returns randomly chosen, valid indices for slicing. +func (l *TList) ChooseSliceIndices(rand *rand.Rand) (start, end int) { + if len(l.std) == 0 { + return 0, 0 + } + start = rand.Intn(len(l.std)) + end = rand.Intn(len(l.std)-start) + start + return start, end +} + +// Append adds v to the end of slice and List. +func (l *TList) Append(v int) { + l.prev = l.im + l.im = l.im.Append(v) + l.builder.Append(v) + l.std = append(l.std, v) +} + +// Prepend adds v to the beginning of the slice and List. +func (l *TList) Prepend(v int) { + l.prev = l.im + l.im = l.im.Prepend(v) + l.builder.Prepend(v) + l.std = append([]int{v}, l.std...) +} + +// Set updates the value at index i to v in the slice and List. +func (l *TList) Set(i, v int) { + l.prev = l.im + l.im = l.im.Set(i, v) + l.builder.Set(i, v) + l.std[i] = v +} + +// Slice contracts the slice and List to the range of start/end indices. +func (l *TList) Slice(start, end int) { + l.prev = l.im + l.im = l.im.Slice(start, end) + l.builder.Slice(start, end) + l.std = l.std[start:end] +} + +// Validate returns an error if the slice and List are different. +func (l *TList) Validate() error { + if got, exp := l.im.Len(), len(l.std); got != exp { + return fmt.Errorf("Len()=%v, expected %d", got, exp) + } else if got, exp := l.builder.Len(), len(l.std); got != exp { + return fmt.Errorf("Len()=%v, expected %d", got, exp) + } + + for i := range l.std { + if got, exp := l.im.Get(i), l.std[i]; got != exp { + return fmt.Errorf("Get(%d)=%v, expected %v", i, got, exp) + } else if got, exp := l.builder.Get(i), l.std[i]; got != exp { + return fmt.Errorf("Builder.List/Get(%d)=%v, expected %v", i, got, exp) + } + } + + if err := l.validateForwardIterator("basic", l.im.Iterator()); err != nil { + return err + } else if err := l.validateBackwardIterator("basic", l.im.Iterator()); err != nil { + return err + } + + if err := l.validateForwardIterator("builder", l.builder.Iterator()); err != nil { + return err + } else if err := l.validateBackwardIterator("builder", l.builder.Iterator()); err != nil { + return err + } + return nil +} + +func (l *TList) validateForwardIterator(typ string, itr *ListIterator[int]) error { + for i := range l.std { + if j, v := itr.Next(); i != j || l.std[i] != v { + return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) + } + + done := i == len(l.std)-1 + if v := itr.Done(); v != done { + return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) + } + } + if i, v := itr.Next(); i != -1 || v != 0 { + return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected DONE [%s]", i, v, typ) + } + return nil +} + +func (l *TList) validateBackwardIterator(typ string, itr *ListIterator[int]) error { + itr.Last() + for i := len(l.std) - 1; i >= 0; i-- { + if j, v := itr.Prev(); i != j || l.std[i] != v { + return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) + } + + done := i == 0 + if v := itr.Done(); v != done { + return fmt.Errorf("ListIterator.Done()=%v, expected %v [%s]", v, done, typ) + } + } + if i, v := itr.Prev(); i != -1 || v != 0 { + return fmt.Errorf("ListIterator.Prev()=<%v,%v>, expected DONE [%s]", i, v, typ) + } + return nil +} + +func BenchmarkList_Append(b *testing.B) { + b.ReportAllocs() + l := NewList[int]() + for i := 0; i < b.N; i++ { + l = l.Append(i) + } +} + +func BenchmarkList_Prepend(b *testing.B) { + b.ReportAllocs() + l := NewList[int]() + for i := 0; i < b.N; i++ { + l = l.Prepend(i) + } +} + +func BenchmarkList_Set(b *testing.B) { + const n = 10000 + + l := NewList[int]() + for i := 0; i < 10000; i++ { + l = l.Append(i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + l = l.Set(i%n, i*10) + } +} + +func BenchmarkList_Iterator(b *testing.B) { + const n = 10000 + l := NewList[int]() + for i := 0; i < 10000; i++ { + l = l.Append(i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := l.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) + + b.Run("Reverse", func(b *testing.B) { + itr := l.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.Last() + } + itr.Prev() + } + }) +} + +func BenchmarkBuiltinSlice_Append(b *testing.B) { + b.Run("Int", func(b *testing.B) { + b.ReportAllocs() + var a []int + for i := 0; i < b.N; i++ { + a = append(a, i) + } + }) + b.Run("Interface", func(b *testing.B) { + b.ReportAllocs() + var a []interface{} + for i := 0; i < b.N; i++ { + a = append(a, i) + } + }) +} + +func BenchmarkListBuilder_Append(b *testing.B) { + b.ReportAllocs() + builder := NewListBuilder[int]() + for i := 0; i < b.N; i++ { + builder.Append(i) + } +} + +func BenchmarkListBuilder_Prepend(b *testing.B) { + b.ReportAllocs() + builder := NewListBuilder[int]() + for i := 0; i < b.N; i++ { + builder.Prepend(i) + } +} + +func BenchmarkListBuilder_Set(b *testing.B) { + const n = 10000 + + builder := NewListBuilder[int]() + for i := 0; i < 10000; i++ { + builder.Append(i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + builder.Set(i%n, i*10) + } +} + +func ExampleList_Append() { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // foo + // bar + // baz +} + +func ExampleList_Prepend() { + l := NewList[string]() + l = l.Prepend("foo") + l = l.Prepend("bar") + l = l.Prepend("baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // baz + // bar + // foo +} + +func ExampleList_Set() { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Set(1, "baz") + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // foo + // baz +} + +func ExampleList_Slice() { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + l = l.Slice(1, 3) + + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // bar + // baz +} + +func ExampleList_Iterator() { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + itr := l.Iterator() + for !itr.Done() { + i, v := itr.Next() + fmt.Println(i, v) + } + // Output: + // 0 foo + // 1 bar + // 2 baz +} + +func ExampleList_Iterator_reverse() { + l := NewList[string]() + l = l.Append("foo") + l = l.Append("bar") + l = l.Append("baz") + + itr := l.Iterator() + itr.Last() + for !itr.Done() { + i, v := itr.Prev() + fmt.Println(i, v) + } + // Output: + // 2 baz + // 1 bar + // 0 foo +} + +func ExampleListBuilder_Append() { + b := NewListBuilder[string]() + b.Append("foo") + b.Append("bar") + b.Append("baz") + + l := b.List() + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // foo + // bar + // baz +} + +func ExampleListBuilder_Prepend() { + b := NewListBuilder[string]() + b.Prepend("foo") + b.Prepend("bar") + b.Prepend("baz") + + l := b.List() + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + fmt.Println(l.Get(2)) + // Output: + // baz + // bar + // foo +} + +func ExampleListBuilder_Set() { + b := NewListBuilder[string]() + b.Append("foo") + b.Append("bar") + b.Set(1, "baz") + + l := b.List() + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // foo + // baz +} + +func ExampleListBuilder_Slice() { + b := NewListBuilder[string]() + b.Append("foo") + b.Append("bar") + b.Append("baz") + b.Slice(1, 3) + + l := b.List() + fmt.Println(l.Get(0)) + fmt.Println(l.Get(1)) + // Output: + // bar + // baz +} + +// Ensure node can support overwrites as it expands. +func TestInternal_mapNode_Overwrite(t *testing.T) { + const n = 1000 + var h defaultHasher[int] + var node mapNode[int, int] = &mapArrayNode[int, int]{} + for i := 0; i < n; i++ { + var resized bool + node = node.set(i, i, 0, h.Hash(i), &h, false, &resized) + if !resized { + t.Fatal("expected resize") + } + + // Overwrite every node. + for j := 0; j <= i; j++ { + var resized bool + node = node.set(j, i*j, 0, h.Hash(j), &h, false, &resized) + if resized { + t.Fatalf("expected no resize: i=%d, j=%d", i, j) + } + } + + // Verify not found at each branch type. + if _, ok := node.get(1000000, 0, h.Hash(1000000), &h); ok { + t.Fatal("expected no value") + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := node.get(i, 0, h.Hash(i), &h); !ok || v != i*(n-1) { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } +} + +func TestInternal_mapArrayNode(t *testing.T) { + // Ensure 8 or fewer elements stays in an array node. + t.Run("Append", func(t *testing.T) { + var h defaultHasher[int] + n := &mapArrayNode[int, int]{} + for i := 0; i < 8; i++ { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) + if !resized { + t.Fatal("expected resize") + } + + for j := 0; j < i; j++ { + if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure 8 or fewer elements stays in an array node when inserted in reverse. + t.Run("Prepend", func(t *testing.T) { + var h defaultHasher[int] + n := &mapArrayNode[int, int]{} + for i := 7; i >= 0; i-- { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized).(*mapArrayNode[int, int]) + if !resized { + t.Fatal("expected resize") + } + + for j := i; j <= 7; j++ { + if v, ok := n.get(j*10, 0, h.Hash(j*10), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure array can transition between node types. + t.Run("Expand", func(t *testing.T) { + var h defaultHasher[int] + var n mapNode[int, int] = &mapArrayNode[int, int]{} + for i := 0; i < 100; i++ { + var resized bool + n = n.set(i, i, 0, h.Hash(i), &h, false, &resized) + if !resized { + t.Fatal("expected resize") + } + + for j := 0; j < i; j++ { + if v, ok := n.get(j, 0, h.Hash(j), &h); !ok || v != j { + t.Fatalf("get(%d)=<%v,%v>", j, v, ok) + } + } + } + }) + + // Ensure deleting elements returns the correct new node. + RunRandom(t, "Delete", func(t *testing.T, rand *rand.Rand) { + var h defaultHasher[int] + var n mapNode[int, int] = &mapArrayNode[int, int]{} + for i := 0; i < 8; i++ { + var resized bool + n = n.set(i*10, i, 0, h.Hash(i*10), &h, false, &resized) + } + + for _, i := range rand.Perm(8) { + var resized bool + n = n.delete(i*10, 0, h.Hash(i*10), &h, false, &resized) + } + if n != nil { + t.Fatal("expected nil rand") + } + }) +} + +func TestInternal_mapValueNode(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + var h defaultHasher[int] + n := newMapValueNode(h.Hash(2), 2, 3) + if v, ok := n.get(2, 0, h.Hash(2), &h); !ok { + t.Fatal("expected ok") + } else if v != 3 { + t.Fatalf("unexpected value: %v", v) + } + }) + + t.Run("KeyEqual", func(t *testing.T) { + var h defaultHasher[int] + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(2, 4, 0, h.Hash(2), &h, false, &resized).(*mapValueNode[int, int]) + if other == n { + t.Fatal("expected new node") + } else if got, exp := other.keyHash, h.Hash(2); got != exp { + t.Fatalf("keyHash=%v, expected %v", got, exp) + } else if got, exp := other.key, 2; got != exp { + t.Fatalf("key=%v, expected %v", got, exp) + } else if got, exp := other.value, 4; got != exp { + t.Fatalf("value=%v, expected %v", got, exp) + } else if resized { + t.Fatal("unexpected resize") + } + }) + + t.Run("KeyHashEqual", func(t *testing.T) { + h := &mockHasher[int]{ + hash: func(value int) uint32 { return 1 }, + equal: func(a, b int) bool { return a == b }, + } + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapHashCollisionNode[int, int]) + if got, exp := other.keyHash, h.Hash(2); got != exp { + t.Fatalf("keyHash=%v, expected %v", got, exp) + } else if got, exp := len(other.entries), 2; got != exp { + t.Fatalf("entries=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if got, exp := other.entries[0].key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := other.entries[0].value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if got, exp := other.entries[1].key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := other.entries[1].value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + }) + + t.Run("MergeNode", func(t *testing.T) { + // Inserting into a node with a different index in the mask should split into a bitmap node. + t.Run("NoConflict", func(t *testing.T) { + var h defaultHasher[int] + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) + if got, exp := other.bitmap, uint32(0x14); got != exp { + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 2; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } + }) + + // Reversing the nodes from NoConflict should yield the same result. + t.Run("NoConflictReverse", func(t *testing.T) { + var h defaultHasher[int] + var resized bool + n := newMapValueNode(h.Hash(4), 4, 5) + other := n.set(2, 3, 0, h.Hash(2), &h, false, &resized).(*mapBitmapIndexedNode[int, int]) + if got, exp := other.bitmap, uint32(0x14); got != exp { + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 2; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + if node, ok := other.nodes[0].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := other.nodes[1].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[1]=%T, unexpected type", other.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), &h); !ok || v != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), &h); !ok || v != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } + }) + + // Inserting a node with the same mask index should nest an additional level of bitmap nodes. + t.Run("Conflict", func(t *testing.T) { + h := &mockHasher[int]{ + hash: func(value int) uint32 { return uint32(value << 5) }, + equal: func(a, b int) bool { return a == b }, + } + var resized bool + n := newMapValueNode(h.Hash(2), 2, 3) + other := n.set(4, 5, 0, h.Hash(4), h, false, &resized).(*mapBitmapIndexedNode[int, int]) + if got, exp := other.bitmap, uint32(0x01); got != exp { // mask is zero, expect first slot. + t.Fatalf("bitmap=0x%02x, expected 0x%02x", got, exp) + } else if got, exp := len(other.nodes), 1; got != exp { + t.Fatalf("nodes=%v, expected %v", got, exp) + } else if !resized { + t.Fatal("expected resize") + } + child, ok := other.nodes[0].(*mapBitmapIndexedNode[int, int]) + if !ok { + t.Fatalf("node[0]=%T, unexpected type", other.nodes[0]) + } + + if node, ok := child.nodes[0].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[0]=%T, unexpected type", child.nodes[0]) + } else if got, exp := node.key, 2; got != exp { + t.Fatalf("key[0]=%v, expected %v", got, exp) + } else if got, exp := node.value, 3; got != exp { + t.Fatalf("value[0]=%v, expected %v", got, exp) + } + if node, ok := child.nodes[1].(*mapValueNode[int, int]); !ok { + t.Fatalf("node[1]=%T, unexpected type", child.nodes[1]) + } else if got, exp := node.key, 4; got != exp { + t.Fatalf("key[1]=%v, expected %v", got, exp) + } else if got, exp := node.value, 5; got != exp { + t.Fatalf("value[1]=%v, expected %v", got, exp) + } + + // Ensure both values can be read. + if v, ok := other.get(2, 0, h.Hash(2), h); !ok || v != 3 { + t.Fatalf("Get(2)=<%v,%v>", v, ok) + } else if v, ok := other.get(4, 0, h.Hash(4), h); !ok || v != 5 { + t.Fatalf("Get(4)=<%v,%v>", v, ok) + } else if v, ok := other.get(10, 0, h.Hash(10), h); ok { + t.Fatalf("Get(10)=<%v,%v>, expected no value", v, ok) + } + }) + }) +} + +func TestMap_Get(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewMap[int, string](nil) + if v, ok := m.Get(100); ok { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) +} + +func TestMap_Set(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + m := NewMap[int, string](nil) + itr := m.Iterator() + if !itr.Done() { + t.Fatal("MapIterator.Done()=true, expected false") + } else if k, v, ok := itr.Next(); ok { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewMap[int, string](nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) + + t.Run("Multi", func(t *testing.T) { + m := NewMapOf(nil, map[int]string{1: "foo"}) + itr := m.Iterator() + if itr.Done() { + t.Fatal("MapIterator.Done()=false, expected true") + } + if k, v, ok := itr.Next(); !ok { + t.Fatalf("MapIterator.Next()!=ok, expected ok") + } else if k != 1 || v != "foo" { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected <1, \"foo\">", k, v) + } + if k, v, ok := itr.Next(); ok { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) + + t.Run("VerySmall", func(t *testing.T) { + const n = 6 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + // NOTE: Array nodes store entries in insertion order. + itr := m.Iterator() + for i := 0; i < n; i++ { + if k, v, ok := itr.Next(); !ok || k != i || v != i+1 { + t.Fatalf("MapIterator.Next()=<%v,%v>, exp <%v,%v>", k, v, i, i+1) + } + } + if !itr.Done() { + t.Fatal("expected iterator done") + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("StringKeys", func(t *testing.T) { + m := NewMap[string, string](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", "bat") + m = m.Set("", "EMPTY") + if v, ok := m.Get("foo"); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get("baz"); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get(""); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get("no_such_key"); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestMap() + for i := 0; i < 10000; i++ { + switch rand.Intn(2) { + case 1: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map can support overwrites as it expands. +func TestMap_Overwrite(t *testing.T) { + if testing.Short() { + t.Skip("short mode") + } + + const n = 10000 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + // Set original value. + m = m.Set(i, i) + + // Overwrite every node. + for j := 0; j <= i; j++ { + m = m.Set(j, i*j) + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i*(n-1) { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } + + t.Run("Simple", func(t *testing.T) { + m := NewMap[int, string](nil) + itr := m.Iterator() + if !itr.Done() { + t.Fatal("MapIterator.Done()=true, expected false") + } else if k, v, ok := itr.Next(); ok { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) +} + +func TestMap_Delete(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewMap[string, int](nil) + other := m.Delete("foo") + if m != other { + t.Fatal("expected same map") + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewMap[int, string](nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := range rand.New(rand.NewSource(0)).Perm(n) { + m = m.Delete(i) + } + if m.Len() != 0 { + t.Fatalf("expected no elements, got %d", m.Len()) + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + const n = 1000000 + m := NewMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := range rand.New(rand.NewSource(0)).Perm(n) { + m = m.Delete(i) + } + if m.Len() != 0 { + t.Fatalf("expected no elements, got %d", m.Len()) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTestMap() + for i := 0; i < 10000; i++ { + switch rand.Intn(8) { + case 0: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + case 1: // delete existing key + m.Delete(m.ExistingKey(rand)) + case 2: // delete non-existent key. + m.Delete(m.NewKey(rand)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + + // Delete all and verify they are gone. + keys := make([]int, len(m.keys)) + copy(keys, m.keys) + + for _, key := range keys { + m.Delete(key) + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map works even with hash conflicts. +func TestMap_LimitedHash(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + t.Run("Immutable", func(t *testing.T) { + h := mockHasher[int]{ + hash: func(value int) uint32 { return hashUint64(uint64(value)) % 0xFF }, + equal: func(a, b int) bool { return a == b }, + } + m := NewMap[int, int](&h) + + rand := rand.New(rand.NewSource(0)) + keys := rand.Perm(100000) + for _, i := range keys { + m = m.Set(i, i) // initial set + } + for i := range keys { + m = m.Set(i, i*2) // overwrite + } + if m.Len() != len(keys) { + t.Fatalf("unexpected len: %d", m.Len()) + } + + // Verify all key/value pairs in map. + for i := 0; i < m.Len(); i++ { + if v, ok := m.Get(i); !ok || v != i*2 { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } + + // Verify iteration. + itr := m.Iterator() + for !itr.Done() { + if k, v, ok := itr.Next(); !ok || v != k*2 { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) + } + } + + // Verify not found works. + if _, ok := m.Get(10000000); ok { + t.Fatal("expected no value") + } + + // Verify delete non-existent key works. + if other := m.Delete(10000000 + 1); m != other { + t.Fatal("expected no change") + } + + // Remove all keys. + for _, key := range keys { + m = m.Delete(key) + } + if m.Len() != 0 { + t.Fatalf("unexpected size: %d", m.Len()) + } + }) + + t.Run("Builder", func(t *testing.T) { + h := mockHasher[int]{ + hash: func(value int) uint32 { return hashUint64(uint64(value)) }, + equal: func(a, b int) bool { return a == b }, + } + b := NewMapBuilder[int, int](&h) + + rand := rand.New(rand.NewSource(0)) + keys := rand.Perm(100000) + for _, i := range keys { + b.Set(i, i) // initial set + } + for i := range keys { + b.Set(i, i*2) // overwrite + } + if b.Len() != len(keys) { + t.Fatalf("unexpected len: %d", b.Len()) + } + + // Verify all key/value pairs in map. + for i := 0; i < b.Len(); i++ { + if v, ok := b.Get(i); !ok || v != i*2 { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } + + // Verify iteration. + itr := b.Iterator() + for !itr.Done() { + if k, v, ok := itr.Next(); !ok || v != k*2 { + t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k*2) + } + } + + // Verify not found works. + if _, ok := b.Get(10000000); ok { + t.Fatal("expected no value") + } + + // Remove all keys. + for _, key := range keys { + b.Delete(key) + } + if b.Len() != 0 { + t.Fatalf("unexpected size: %d", b.Len()) + } + }) +} + +// TMap represents a combined immutable and stdlib map. +type TMap struct { + im, prev *Map[int, int] + builder *MapBuilder[int, int] + std map[int]int + keys []int +} + +func NewTestMap() *TMap { + return &TMap{ + im: NewMap[int, int](nil), + builder: NewMapBuilder[int, int](nil), + std: make(map[int]int), + } +} + +func (m *TMap) NewKey(rand *rand.Rand) int { + for { + k := rand.Int() + if _, ok := m.std[k]; !ok { + return k + } + } +} + +func (m *TMap) ExistingKey(rand *rand.Rand) int { + if len(m.keys) == 0 { + return 0 + } + return m.keys[rand.Intn(len(m.keys))] +} + +func (m *TMap) Set(k, v int) { + m.prev = m.im + m.im = m.im.Set(k, v) + m.builder.Set(k, v) + + _, exists := m.std[k] + if !exists { + m.keys = append(m.keys, k) + } + m.std[k] = v +} + +func (m *TMap) Delete(k int) { + m.prev = m.im + m.im = m.im.Delete(k) + m.builder.Delete(k) + delete(m.std, k) + + for i := range m.keys { + if m.keys[i] == k { + m.keys = append(m.keys[:i], m.keys[i+1:]...) + break + } + } +} + +func (m *TMap) Validate() error { + for _, k := range m.keys { + if v, ok := m.im.Get(k); !ok { + return fmt.Errorf("key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + if v, ok := m.builder.Get(k); !ok { + return fmt.Errorf("builder key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("builder key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + } + if err := m.validateIterator(m.im.Iterator()); err != nil { + return fmt.Errorf("basic: %s", err) + } else if err := m.validateIterator(m.builder.Iterator()); err != nil { + return fmt.Errorf("builder: %s", err) + } + return nil +} + +func (m *TMap) validateIterator(itr *MapIterator[int, int]) error { + other := make(map[int]int) + for !itr.Done() { + k, v, _ := itr.Next() + other[k] = v + } + if len(other) != len(m.std) { + return fmt.Errorf("map iterator size mismatch: %v!=%v", len(m.std), len(other)) + } + for k, v := range m.std { + if v != other[k] { + return fmt.Errorf("map iterator mismatch: key=%v, %v!=%v", k, v, other[k]) + } + } + if k, v, ok := itr.Next(); ok { + return fmt.Errorf("map iterator returned key/value after done: <%v/%v>", k, v) + } + return nil +} + +func BenchmarkBuiltinMap_Set(b *testing.B) { + b.ReportAllocs() + m := make(map[int]int) + for i := 0; i < b.N; i++ { + m[i] = i + } +} + +func BenchmarkBuiltinMap_Delete(b *testing.B) { + const n = 10000000 + + m := make(map[int]int) + for i := 0; i < n; i++ { + m[i] = i + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + delete(m, i%n) + } +} + +func BenchmarkMap_Set(b *testing.B) { + b.ReportAllocs() + m := NewMap[int, int](nil) + for i := 0; i < b.N; i++ { + m = m.Set(i, i) + } +} + +func BenchmarkMap_Delete(b *testing.B) { + const n = 10000000 + + builder := NewMapBuilder[int, int](nil) + for i := 0; i < n; i++ { + builder.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + m := builder.Map() + for i := 0; i < b.N; i++ { + m.Delete(i % n) // Do not update map, always operate on original + } +} + +func BenchmarkMap_Iterator(b *testing.B) { + const n = 10000 + m := NewMap[int, int](nil) + for i := 0; i < 10000; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) +} + +func BenchmarkMapBuilder_Set(b *testing.B) { + b.ReportAllocs() + builder := NewMapBuilder[int, int](nil) + for i := 0; i < b.N; i++ { + builder.Set(i, i) + } +} + +func BenchmarkMapBuilder_Delete(b *testing.B) { + const n = 10000000 + + builder := NewMapBuilder[int, int](nil) + for i := 0; i < n; i++ { + builder.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + builder.Delete(i % n) + } +} + +func ExampleMap_Set() { + m := NewMap[string, any](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat false +} + +func ExampleMap_Delete() { + m := NewMap[string, any](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + m = m.Delete("baz") + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz false +} + +func ExampleMap_Iterator() { + m := NewMap[string, int](nil) + m = m.Set("apple", 100) + m = m.Set("grape", 200) + m = m.Set("kiwi", 300) + m = m.Set("mango", 400) + m = m.Set("orange", 500) + m = m.Set("peach", 600) + m = m.Set("pear", 700) + m = m.Set("pineapple", 800) + m = m.Set("strawberry", 900) + + itr := m.Iterator() + for !itr.Done() { + k, v, _ := itr.Next() + fmt.Println(k, v) + } + // Output: + // mango 400 + // pear 700 + // pineapple 800 + // grape 200 + // orange 500 + // strawberry 900 + // kiwi 300 + // peach 600 + // apple 100 +} + +func ExampleMapBuilder_Set() { + b := NewMapBuilder[string, any](nil) + b.Set("foo", "bar") + b.Set("baz", 100) + + m := b.Map() + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat false +} + +func ExampleMapBuilder_Delete() { + b := NewMapBuilder[string, any](nil) + b.Set("foo", "bar") + b.Set("baz", 100) + b.Delete("baz") + + m := b.Map() + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz false +} + +func TestInternalSortedMapLeafNode(t *testing.T) { + RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} + var keys []int + for _, i := range rand.Perm(32) { + var resized bool + var splitNode sortedMapNode[int, int] + node, splitNode = node.set(i, i*10, &cmpr, false, &resized) + if !resized { + t.Fatal("expected resize") + } else if splitNode != nil { + t.Fatal("expected split") + } + keys = append(keys, i) + + // Verify not found at each size. + if _, ok := node.get(rand.Int()+32, &cmpr); ok { + t.Fatal("expected no value") + } + + // Verify min key is always the lowest. + sort.Ints(keys) + if got, exp := node.minKey(), keys[0]; got != exp { + t.Fatalf("minKey()=%d, expected %d", got, exp) + } + } + + // Verify all key/value pairs in node. + for i := range keys { + if v, ok := node.get(i, &cmpr); !ok || v != i*10 { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } + }) + + RunRandom(t, "Overwrite", func(t *testing.T, rand *rand.Rand) { + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} + + for _, i := range rand.Perm(32) { + var resized bool + node, _ = node.set(i, i*2, &cmpr, false, &resized) + } + for _, i := range rand.Perm(32) { + var resized bool + node, _ = node.set(i, i*3, &cmpr, false, &resized) + if resized { + t.Fatal("expected no resize") + } + } + + // Verify all overwritten key/value pairs in node. + for i := 0; i < 32; i++ { + if v, ok := node.get(i, &cmpr); !ok || v != i*3 { + t.Fatalf("get(%d)=<%v,%v>", i, v, ok) + } + } + }) + + t.Run("Split", func(t *testing.T) { + // Fill leaf node. var cmpr defaultComparer[int] + var cmpr defaultComparer[int] + var node sortedMapNode[int, int] = &sortedMapLeafNode[int, int]{} + for i := 0; i < 32; i++ { + var resized bool + node, _ = node.set(i, i*10, &cmpr, false, &resized) + } + + // Add one more and expect split. + var resized bool + newNode, splitNode := node.set(32, 320, &cmpr, false, &resized) + + // Verify node contents. + newLeafNode, ok := newNode.(*sortedMapLeafNode[int, int]) + if !ok { + t.Fatalf("unexpected node type: %T", newLeafNode) + } else if n := len(newLeafNode.entries); n != 16 { + t.Fatalf("unexpected node len: %d", n) + } + for i := range newLeafNode.entries { + if entry := newLeafNode.entries[i]; entry.key != i || entry.value != i*10 { + t.Fatalf("%d. unexpected entry: %v=%v", i, entry.key, entry.value) + } + } + + // Verify split node contents. + splitLeafNode, ok := splitNode.(*sortedMapLeafNode[int, int]) + if !ok { + t.Fatalf("unexpected split node type: %T", splitLeafNode) + } else if n := len(splitLeafNode.entries); n != 17 { + t.Fatalf("unexpected split node len: %d", n) + } + for i := range splitLeafNode.entries { + if entry := splitLeafNode.entries[i]; entry.key != (i+16) || entry.value != (i+16)*10 { + t.Fatalf("%d. unexpected split node entry: %v=%v", i, entry.key, entry.value) + } + } + }) +} + +func TestInternalSortedMapBranchNode(t *testing.T) { + RunRandom(t, "NoSplit", func(t *testing.T, rand *rand.Rand) { + keys := make([]int, 32*16) + for i := range keys { + keys[i] = rand.Intn(10000) + } + keys = uniqueIntSlice(keys) + sort.Ints(keys[:2]) // ensure first two keys are sorted for initial insert. + + // Initialize branch with two leafs. + var cmpr defaultComparer[int] + leaf0 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[0], value: keys[0] * 10}}} + leaf1 := &sortedMapLeafNode[int, int]{entries: []mapEntry[int, int]{{key: keys[1], value: keys[1] * 10}}} + var node sortedMapNode[int, int] = newSortedMapBranchNode[int, int](leaf0, leaf1) + + sort.Ints(keys) + for _, i := range rand.Perm(len(keys)) { + key := keys[i] + + var resized bool + var splitNode sortedMapNode[int, int] + node, splitNode = node.set(key, key*10, &cmpr, false, &resized) + if key == leaf0.entries[0].key || key == leaf1.entries[0].key { + if resized { + t.Fatalf("expected no resize: key=%d", key) + } + } else { + if !resized { + t.Fatalf("expected resize: key=%d", key) + } + } + if splitNode != nil { + t.Fatal("unexpected split") + } + } + + // Verify all key/value pairs in node. + for _, key := range keys { + if v, ok := node.get(key, &cmpr); !ok || v != key*10 { + t.Fatalf("get(%d)=<%v,%v>", key, v, ok) + } + } + + // Verify min key is the lowest key. + if got, exp := node.minKey(), keys[0]; got != exp { + t.Fatalf("minKey()=%d, expected %d", got, exp) + } + }) + + t.Run("Split", func(t *testing.T) { + // Generate leaf nodes. + var cmpr defaultComparer[int] + children := make([]sortedMapNode[int, int], 32) + for i := range children { + leaf := &sortedMapLeafNode[int, int]{entries: make([]mapEntry[int, int], 32)} + for j := range leaf.entries { + leaf.entries[j] = mapEntry[int, int]{key: (i * 32) + j, value: ((i * 32) + j) * 100} + } + children[i] = leaf + } + var node sortedMapNode[int, int] = newSortedMapBranchNode(children...) + + // Add one more and expect split. + var resized bool + newNode, splitNode := node.set((32 * 32), (32*32)*100, &cmpr, false, &resized) + + // Verify node contents. + var idx int + newBranchNode, ok := newNode.(*sortedMapBranchNode[int, int]) + if !ok { + t.Fatalf("unexpected node type: %T", newBranchNode) + } else if n := len(newBranchNode.elems); n != 16 { + t.Fatalf("unexpected child elems len: %d", n) + } + for i, elem := range newBranchNode.elems { + child, ok := elem.node.(*sortedMapLeafNode[int, int]) + if !ok { + t.Fatalf("unexpected child type") + } + for j, entry := range child.entries { + if entry.key != idx || entry.value != idx*100 { + t.Fatalf("%d/%d. unexpected entry: %v=%v", i, j, entry.key, entry.value) + } + idx++ + } + } + + // Verify split node contents. + splitBranchNode, ok := splitNode.(*sortedMapBranchNode[int, int]) + if !ok { + t.Fatalf("unexpected split node type: %T", splitBranchNode) + } else if n := len(splitBranchNode.elems); n != 17 { + t.Fatalf("unexpected split node elem len: %d", n) + } + for i, elem := range splitBranchNode.elems { + child, ok := elem.node.(*sortedMapLeafNode[int, int]) + if !ok { + t.Fatalf("unexpected split node child type") + } + for j, entry := range child.entries { + if entry.key != idx || entry.value != idx*100 { + t.Fatalf("%d/%d. unexpected split node entry: %v=%v", i, j, entry.key, entry.value) + } + idx++ + } + } + }) +} + +func TestSortedMap_Get(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewSortedMap[int, int](nil) + if v, ok := m.Get(100); ok { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + }) +} + +func TestSortedMap_Set(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + m := NewSortedMap[int, string](nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if got, exp := m.Len(), 1; got != exp { + t.Fatalf("SortedMap.Len()=%d, exp %d", got, exp) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("StringKeys", func(t *testing.T) { + m := NewSortedMap[string, string](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", "bat") + m = m.Set("", "EMPTY") + if v, ok := m.Get("foo"); !ok || v != "bar" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get("baz"); !ok || v != "bat" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } else if v, ok := m.Get(""); !ok || v != "EMPTY" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + if v, ok := m.Get("no_such_key"); ok { + t.Fatalf("expected no value: <%v,%v>", v, ok) + } + }) + + t.Run("NoDefaultComparer", func(t *testing.T) { + var r string + func() { + defer func() { r = recover().(string) }() + m := NewSortedMap[float64, string](nil) + m = m.Set(float64(100), "bar") + }() + if r != `immutable.NewComparer: must set comparer for float64 type` { + t.Fatalf("unexpected panic: %q", r) + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTSortedMap() + for j := 0; j < 10000; j++ { + switch rand.Intn(2) { + case 1: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +// Ensure map can support overwrites as it expands. +func TestSortedMap_Overwrite(t *testing.T) { + const n = 1000 + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + // Set original value. + m = m.Set(i, i) + + // Overwrite every node. + for j := 0; j <= i; j++ { + m = m.Set(j, i*j) + } + } + + // Verify all key/value pairs in map. + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i*(n-1) { + t.Fatalf("Get(%d)=<%v,%v>", i, v, ok) + } + } +} + +func TestSortedMap_Delete(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + m := NewSortedMap[int, int](nil) + m = m.Delete(100) + if n := m.Len(); n != 0 { + t.Fatalf("SortedMap.Len()=%d, expected 0", n) + } + }) + + t.Run("Simple", func(t *testing.T) { + m := NewSortedMap[int, string](nil) + m = m.Set(100, "foo") + if v, ok := m.Get(100); !ok || v != "foo" { + t.Fatalf("unexpected value: <%v,%v>", v, ok) + } + m = m.Delete(100) + if v, ok := m.Get(100); ok { + t.Fatalf("unexpected no value: <%v,%v>", v, ok) + } + }) + + t.Run("Small", func(t *testing.T) { + const n = 1000 + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + for i := 0; i < n; i++ { + m = m.Delete(i) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); ok { + t.Fatalf("expected no value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + t.Run("Large", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping: short") + } + + const n = 1000000 + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i+1) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); !ok || v != i+1 { + t.Fatalf("unexpected value for key=%v: <%v,%v>", i, v, ok) + } + } + + for i := 0; i < n; i++ { + m = m.Delete(i) + } + for i := 0; i < n; i++ { + if v, ok := m.Get(i); ok { + t.Fatalf("unexpected no value for key=%v: <%v,%v>", i, v, ok) + } + } + }) + + RunRandom(t, "Random", func(t *testing.T, rand *rand.Rand) { + m := NewTSortedMap() + for j := 0; j < 10000; j++ { + switch rand.Intn(8) { + case 0: // overwrite + m.Set(m.ExistingKey(rand), rand.Intn(10000)) + case 1: // delete existing key + m.Delete(m.ExistingKey(rand)) + case 2: // delete non-existent key. + m.Delete(m.NewKey(rand)) + default: // set new key + m.Set(m.NewKey(rand), rand.Intn(10000)) + } + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + + // Delete all keys. + keys := make([]int, len(m.keys)) + copy(keys, m.keys) + for _, k := range keys { + m.Delete(k) + } + if err := m.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +func TestSortedMap_Iterator(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + t.Run("First", func(t *testing.T) { + itr := NewSortedMap[int, int](nil).Iterator() + itr.First() + if k, v, ok := itr.Next(); ok { + t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) + + t.Run("Last", func(t *testing.T) { + itr := NewSortedMap[int, int](nil).Iterator() + itr.Last() + if k, v, ok := itr.Prev(); ok { + t.Fatalf("SortedMapIterator.Prev()=<%v,%v>, expected nil", k, v) + } + }) + + t.Run("Seek", func(t *testing.T) { + itr := NewSortedMap[string, int](nil).Iterator() + itr.Seek("foo") + if k, v, ok := itr.Next(); ok { + t.Fatalf("SortedMapIterator.Next()=<%v,%v>, expected nil", k, v) + } + }) + }) + + t.Run("Seek", func(t *testing.T) { + const n = 100 + m := NewSortedMap[string, int](nil) + for i := 0; i < n; i += 2 { + m = m.Set(fmt.Sprintf("%04d", i), i) + } + + t.Run("Exact", func(t *testing.T) { + itr := m.Iterator() + for i := 0; i < n; i += 2 { + itr.Seek(fmt.Sprintf("%04d", i)) + for j := i; j < n; j += 2 { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { + t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + } + }) + + t.Run("Miss", func(t *testing.T) { + itr := m.Iterator() + for i := 1; i < n-2; i += 2 { + itr.Seek(fmt.Sprintf("%04d", i)) + for j := i + 1; j < n; j += 2 { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", j) { + t.Fatalf("%d/%d. SortedMapIterator.Next()=%v, expected key %04d", i, j, k, j) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + } + }) + + t.Run("BeforeFirst", func(t *testing.T) { + itr := m.Iterator() + itr.Seek("") + for i := 0; i < n; i += 2 { + if k, _, ok := itr.Next(); !ok || k != fmt.Sprintf("%04d", i) { + t.Fatalf("%d. SortedMapIterator.Next()=%v, expected key %04d", i, k, i) + } + } + if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + }) + t.Run("AfterLast", func(t *testing.T) { + itr := m.Iterator() + itr.Seek("1000") + if k, _, ok := itr.Next(); ok { + t.Fatalf("0. SortedMapIterator.Next()=%v, expected nil key", k) + } else if !itr.Done() { + t.Fatalf("SortedMapIterator.Done()=true, expected false") + } + }) + }) +} + +func TestNewHasher(t *testing.T) { + t.Run("builtin", func(t *testing.T) { + t.Run("int", func(t *testing.T) { testNewHasher(t, int(100)) }) + t.Run("int8", func(t *testing.T) { testNewHasher(t, int8(100)) }) + t.Run("int16", func(t *testing.T) { testNewHasher(t, int16(100)) }) + t.Run("int32", func(t *testing.T) { testNewHasher(t, int32(100)) }) + t.Run("int64", func(t *testing.T) { testNewHasher(t, int64(100)) }) + + t.Run("uint", func(t *testing.T) { testNewHasher(t, uint(100)) }) + t.Run("uint8", func(t *testing.T) { testNewHasher(t, uint8(100)) }) + t.Run("uint16", func(t *testing.T) { testNewHasher(t, uint16(100)) }) + t.Run("uint32", func(t *testing.T) { testNewHasher(t, uint32(100)) }) + t.Run("uint64", func(t *testing.T) { testNewHasher(t, uint64(100)) }) + + t.Run("string", func(t *testing.T) { testNewHasher(t, "foo") }) + //t.Run("byteSlice", func(t *testing.T) { testNewHasher(t, []byte("foo")) }) + }) + + t.Run("reflection", func(t *testing.T) { + type Int int + t.Run("int", func(t *testing.T) { testNewHasher(t, Int(100)) }) + + type Uint uint + t.Run("uint", func(t *testing.T) { testNewHasher(t, Uint(100)) }) + + type String string + t.Run("string", func(t *testing.T) { testNewHasher(t, String("foo")) }) + }) +} + +func testNewHasher[V cmp.Ordered](t *testing.T, v V) { + t.Helper() + h := NewHasher(v) + h.Hash(v) + if !h.Equal(v, v) { + t.Fatal("expected hash equality") + } +} + +func TestNewComparer(t *testing.T) { + t.Run("builtin", func(t *testing.T) { + t.Run("int", func(t *testing.T) { testNewComparer(t, int(100), int(101)) }) + t.Run("int8", func(t *testing.T) { testNewComparer(t, int8(100), int8(101)) }) + t.Run("int16", func(t *testing.T) { testNewComparer(t, int16(100), int16(101)) }) + t.Run("int32", func(t *testing.T) { testNewComparer(t, int32(100), int32(101)) }) + t.Run("int64", func(t *testing.T) { testNewComparer(t, int64(100), int64(101)) }) + + t.Run("uint", func(t *testing.T) { testNewComparer(t, uint(100), uint(101)) }) + t.Run("uint8", func(t *testing.T) { testNewComparer(t, uint8(100), uint8(101)) }) + t.Run("uint16", func(t *testing.T) { testNewComparer(t, uint16(100), uint16(101)) }) + t.Run("uint32", func(t *testing.T) { testNewComparer(t, uint32(100), uint32(101)) }) + t.Run("uint64", func(t *testing.T) { testNewComparer(t, uint64(100), uint64(101)) }) + + t.Run("string", func(t *testing.T) { testNewComparer(t, "bar", "foo") }) + //t.Run("byteSlice", func(t *testing.T) { testNewComparer(t, []byte("bar"), []byte("foo")) }) + }) + + t.Run("reflection", func(t *testing.T) { + type Int int + t.Run("int", func(t *testing.T) { testNewComparer(t, Int(100), Int(101)) }) + + type Uint uint + t.Run("uint", func(t *testing.T) { testNewComparer(t, Uint(100), Uint(101)) }) + + type String string + t.Run("string", func(t *testing.T) { testNewComparer(t, String("bar"), String("foo")) }) + }) +} + +func testNewComparer[T cmp.Ordered](t *testing.T, x, y T) { + t.Helper() + c := NewComparer(x) + if c.Compare(x, y) != -1 { + t.Fatal("expected comparer LT") + } else if c.Compare(x, x) != 0 { + t.Fatal("expected comparer EQ") + } else if c.Compare(y, x) != 1 { + t.Fatal("expected comparer GT") + } +} + +// TSortedMap represents a combined immutable and stdlib sorted map. +type TSortedMap struct { + im, prev *SortedMap[int, int] + builder *SortedMapBuilder[int, int] + std map[int]int + keys []int +} + +func NewTSortedMap() *TSortedMap { + return &TSortedMap{ + im: NewSortedMap[int, int](nil), + builder: NewSortedMapBuilder[int, int](nil), + std: make(map[int]int), + } +} + +func (m *TSortedMap) NewKey(rand *rand.Rand) int { + for { + k := rand.Int() + if _, ok := m.std[k]; !ok { + return k + } + } +} + +func (m *TSortedMap) ExistingKey(rand *rand.Rand) int { + if len(m.keys) == 0 { + return 0 + } + return m.keys[rand.Intn(len(m.keys))] +} + +func (m *TSortedMap) Set(k, v int) { + m.prev = m.im + m.im = m.im.Set(k, v) + m.builder.Set(k, v) + + if _, ok := m.std[k]; !ok { + m.keys = append(m.keys, k) + sort.Ints(m.keys) + } + m.std[k] = v +} + +func (m *TSortedMap) Delete(k int) { + m.prev = m.im + m.im = m.im.Delete(k) + m.builder.Delete(k) + delete(m.std, k) + + for i := range m.keys { + if m.keys[i] == k { + m.keys = append(m.keys[:i], m.keys[i+1:]...) + break + } + } +} + +func (m *TSortedMap) Validate() error { + for _, k := range m.keys { + if v, ok := m.im.Get(k); !ok { + return fmt.Errorf("key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + if v, ok := m.builder.Get(k); !ok { + return fmt.Errorf("builder key not found: %d", k) + } else if v != m.std[k] { + return fmt.Errorf("builder key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) + } + } + + if got, exp := m.builder.Len(), len(m.std); got != exp { + return fmt.Errorf("SortedMapBuilder.Len()=%d, expected %d", got, exp) + } + + sort.Ints(m.keys) + if err := m.validateForwardIterator(m.im.Iterator()); err != nil { + return fmt.Errorf("basic: %s", err) + } else if err := m.validateBackwardIterator(m.im.Iterator()); err != nil { + return fmt.Errorf("basic: %s", err) + } + + if err := m.validateForwardIterator(m.builder.Iterator()); err != nil { + return fmt.Errorf("basic: %s", err) + } else if err := m.validateBackwardIterator(m.builder.Iterator()); err != nil { + return fmt.Errorf("basic: %s", err) + } + return nil +} + +func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator[int, int]) error { + for i, k0 := range m.keys { + v0 := m.std[k0] + if k1, v1, ok := itr.Next(); !ok || k0 != k1 || v0 != v1 { + return fmt.Errorf("%d. SortedMapIterator.Next()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) + } + + done := i == len(m.keys)-1 + if v := itr.Done(); v != done { + return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) + } + } + if k, v, ok := itr.Next(); ok { + return fmt.Errorf("SortedMapIterator.Next()=<%v,%v>, expected nil after done", k, v) + } + return nil +} + +func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator[int, int]) error { + itr.Last() + for i := len(m.keys) - 1; i >= 0; i-- { + k0 := m.keys[i] + v0 := m.std[k0] + if k1, v1, ok := itr.Prev(); !ok || k0 != k1 || v0 != v1 { + return fmt.Errorf("%d. SortedMapIterator.Prev()=<%v,%v>, expected <%v,%v>", i, k1, v1, k0, v0) + } + + done := i == 0 + if v := itr.Done(); v != done { + return fmt.Errorf("%d. SortedMapIterator.Done()=%v, expected %v", i, v, done) + } + } + if k, v, ok := itr.Prev(); ok { + return fmt.Errorf("SortedMapIterator.Prev()=<%v,%v>, expected nil after done", k, v) + } + return nil +} + +func BenchmarkSortedMap_Set(b *testing.B) { + b.ReportAllocs() + m := NewSortedMap[int, int](nil) + for i := 0; i < b.N; i++ { + m = m.Set(i, i) + } +} + +func BenchmarkSortedMap_Delete(b *testing.B) { + const n = 10000 + + m := NewSortedMap[int, int](nil) + for i := 0; i < n; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + m.Delete(i % n) // Do not update map, always operate on original + } +} + +func BenchmarkSortedMap_Iterator(b *testing.B) { + const n = 10000 + m := NewSortedMap[int, int](nil) + for i := 0; i < 10000; i++ { + m = m.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + b.Run("Forward", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.First() + } + itr.Next() + } + }) + + b.Run("Reverse", func(b *testing.B) { + itr := m.Iterator() + for i := 0; i < b.N; i++ { + if i%n == 0 { + itr.Last() + } + itr.Prev() + } + }) +} + +func BenchmarkSortedMapBuilder_Set(b *testing.B) { + b.ReportAllocs() + builder := NewSortedMapBuilder[int, int](nil) + for i := 0; i < b.N; i++ { + builder.Set(i, i) + } +} + +func BenchmarkSortedMapBuilder_Delete(b *testing.B) { + const n = 1000000 + + builder := NewSortedMapBuilder[int, int](nil) + for i := 0; i < n; i++ { + builder.Set(i, i) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + builder.Delete(i % n) + } +} + +func ExampleSortedMap_Set() { + m := NewSortedMap[string, any](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat false +} + +func ExampleSortedMap_Delete() { + m := NewSortedMap[string, any](nil) + m = m.Set("foo", "bar") + m = m.Set("baz", 100) + m = m.Delete("baz") + + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz false +} + +func ExampleSortedMap_Iterator() { + m := NewSortedMap[string, any](nil) + m = m.Set("strawberry", 900) + m = m.Set("kiwi", 300) + m = m.Set("apple", 100) + m = m.Set("pear", 700) + m = m.Set("pineapple", 800) + m = m.Set("peach", 600) + m = m.Set("orange", 500) + m = m.Set("grape", 200) + m = m.Set("mango", 400) + + itr := m.Iterator() + for !itr.Done() { + k, v, _ := itr.Next() + fmt.Println(k, v) + } + // Output: + // apple 100 + // grape 200 + // kiwi 300 + // mango 400 + // orange 500 + // peach 600 + // pear 700 + // pineapple 800 + // strawberry 900 +} + +func ExampleSortedMapBuilder_Set() { + b := NewSortedMapBuilder[string, any](nil) + b.Set("foo", "bar") + b.Set("baz", 100) + + m := b.Map() + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + + v, ok = m.Get("bat") // does not exist + fmt.Println("bat", v, ok) + // Output: + // foo bar true + // baz 100 true + // bat false +} + +func ExampleSortedMapBuilder_Delete() { + b := NewSortedMapBuilder[string, any](nil) + b.Set("foo", "bar") + b.Set("baz", 100) + b.Delete("baz") + + m := b.Map() + v, ok := m.Get("foo") + fmt.Println("foo", v, ok) + + v, ok = m.Get("baz") + fmt.Println("baz", v, ok) + // Output: + // foo bar true + // baz false +} + +// RunRandom executes fn multiple times with a different rand. +func RunRandom(t *testing.T, name string, fn func(t *testing.T, rand *rand.Rand)) { + if testing.Short() { + t.Skip("short mode") + } + t.Run(name, func(t *testing.T) { + for i := 0; i < *randomN; i++ { + i := i + t.Run(fmt.Sprintf("%08d", i), func(t *testing.T) { + t.Parallel() + fn(t, rand.New(rand.NewSource(int64(i)))) + }) + } + }) +} + +func uniqueIntSlice(a []int) []int { + m := make(map[int]struct{}) + other := make([]int, 0, len(a)) + for _, v := range a { + if _, ok := m[v]; ok { + continue + } + m[v] = struct{}{} + other = append(other, v) + } + return other +} + +// mockHasher represents a mock implementation of immutable.Hasher. +type mockHasher[K cmp.Ordered] struct { + hash func(value K) uint32 + equal func(a, b K) bool +} + +// Hash executes the mocked HashFn function. +func (h *mockHasher[K]) Hash(value K) uint32 { + return h.hash(value) +} + +// Equal executes the mocked EqualFn function. +func (h *mockHasher[K]) Equal(a, b K) bool { + return h.equal(a, b) +} + +// mockComparer represents a mock implementation of immutable.Comparer. +type mockComparer[K cmp.Ordered] struct { + compare func(a, b K) int +} + +// Compare executes the mocked CompreFn function. +func (h *mockComparer[K]) Compare(a, b K) int { + return h.compare(a, b) +} + +func TestSetsPut(t *testing.T) { + s := NewSet[string](nil) + s2 := s.Add("1").Add("1") + s2.Add("2") + if s.Len() != 0 { + t.Fatalf("Unexpected mutation of set") + } + if s.Has("1") { + t.Fatalf("Unexpected set element") + } + if s2.Len() != 1 { + t.Fatalf("Unexpected non-mutation of set") + } + if !s2.Has("1") { + t.Fatalf("Set element missing") + } + itr := s2.Iterator() + counter := 0 + for !itr.Done() { + i, v := itr.Next() + t.Log(i, v) + counter++ + } + if counter != 1 { + t.Fatalf("iterator wrong length") + } +} + +func TestSetsDelete(t *testing.T) { + s := NewSet[string](nil) + s2 := s.Add("1") + s3 := s.Delete("1") + if s2.Len() != 1 { + t.Fatalf("Unexpected non-mutation of set") + } + if !s2.Has("1") { + t.Fatalf("Set element missing") + } + if s3.Len() != 0 { + t.Fatalf("Unexpected set length after delete") + } + if s3.Has("1") { + t.Fatalf("Unexpected set element after delete") + } +} + +func TestSortedSetsPut(t *testing.T) { + s := NewSortedSet[string](nil) + s2 := s.Add("1").Add("1").Add("0") + if s.Len() != 0 { + t.Fatalf("Unexpected mutation of set") + } + if s.Has("1") { + t.Fatalf("Unexpected set element") + } + if s2.Len() != 2 { + t.Fatalf("Unexpected non-mutation of set") + } + if !s2.Has("1") { + t.Fatalf("Set element missing") + } + + itr := s2.Iterator() + counter := 0 + for !itr.Done() { + i, v := itr.Next() + t.Log(i, v) + if counter == 0 && i != "0" { + t.Fatalf("sort did not work for first el") + } + if counter == 1 && i != "1" { + t.Fatalf("sort did not work for second el") + } + counter++ + } + if counter != 2 { + t.Fatalf("iterator wrong length") + } +} + +func TestSortedSetsDelete(t *testing.T) { + s := NewSortedSet[string](nil) + s2 := s.Add("1") + s3 := s.Delete("1") + if s2.Len() != 1 { + t.Fatalf("Unexpected non-mutation of set") + } + if !s2.Has("1") { + t.Fatalf("Set element missing") + } + if s3.Len() != 0 { + t.Fatalf("Unexpected set length after delete") + } + if s3.Has("1") { + t.Fatalf("Unexpected set element after delete") + } +} + +func TestSortedSetBuilder(t *testing.T) { + b := NewSortedSetBuilder[string](nil) + b.Set("test3") + b.Set("test1") + b.Set("test2") + + s := b.SortedSet() + items := s.Items() + + if len(items) != 3 { + t.Fatalf("Set has wrong number of items") + } + if items[0] != "test1" { + t.Fatalf("First item incorrectly sorted") + } + if items[1] != "test2" { + t.Fatalf("Second item incorrectly sorted") + } + if items[2] != "test3" { + t.Fatalf("Third item incorrectly sorted") + } +} + + + +func MainTest() { +} -- cgit v1.2.3