aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--immutable.go453
-rw-r--r--immutable_test.go98
2 files changed, 526 insertions, 25 deletions
diff --git a/immutable.go b/immutable.go
index c1a8aef..5adb01a 100644
--- a/immutable.go
+++ b/immutable.go
@@ -45,6 +45,7 @@ import (
"bytes"
"fmt"
"math/bits"
+ "reflect"
"sort"
"strings"
)
@@ -721,16 +722,7 @@ func (m *Map) set(key, value interface{}, mutable bool) *Map {
// Set a hasher on the first value if one does not already exist.
hasher := m.hasher
if hasher == nil {
- switch key.(type) {
- case int:
- hasher = &intHasher{}
- case string:
- hasher = &stringHasher{}
- case []byte:
- hasher = &byteSliceHasher{}
- default:
- panic(fmt.Sprintf("immutable.Map.Set: must set hasher for %T type", key))
- }
+ hasher = NewHasher(key)
}
// Generate copy if necessary.
@@ -1589,16 +1581,7 @@ func (m *SortedMap) set(key, value interface{}, mutable bool) *SortedMap {
// Set a comparer on the first value if one does not already exist.
comparer := m.comparer
if comparer == nil {
- switch key.(type) {
- case int:
- comparer = &intComparer{}
- case string:
- comparer = &stringComparer{}
- case []byte:
- comparer = &byteSliceComparer{}
- default:
- panic(fmt.Sprintf("immutable.SortedMap.Set: must set comparer for %T type", key))
- }
+ comparer = NewComparer(key)
}
// Create copy, if necessary.
@@ -2205,6 +2188,52 @@ type Hasher interface {
Equal(a, b interface{}) bool
}
+// NewHasher returns the built-in hasher for a given key type.
+func NewHasher(key interface{}) Hasher {
+ // Attempt to use non-reflection based hasher first.
+ switch key.(type) {
+ case int:
+ return &intHasher{}
+ case int8:
+ return &int8Hasher{}
+ case int16:
+ return &int16Hasher{}
+ case int32:
+ return &int32Hasher{}
+ case int64:
+ return &int64Hasher{}
+ case uint:
+ return &uintHasher{}
+ case uint8:
+ return &uint8Hasher{}
+ case uint16:
+ return &uint16Hasher{}
+ case uint32:
+ return &uint32Hasher{}
+ case uint64:
+ return &uint64Hasher{}
+ case string:
+ return &stringHasher{}
+ case []byte:
+ return &byteSliceHasher{}
+ }
+
+ // Fallback to reflection-based hasher otherwise.
+ // This is used when caller wraps a type around a primitive type.
+ switch reflect.TypeOf(key).Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return &reflectIntHasher{}
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return &reflectUintHasher{}
+ case reflect.String:
+ return &reflectStringHasher{}
+ }
+
+ // If no hashers match then panic.
+ // This is a compile time issue so it should not return an error.
+ panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key))
+}
+
// intHasher implements Hasher for int keys.
type intHasher struct{}
@@ -2219,6 +2248,132 @@ func (h *intHasher) Equal(a, b interface{}) bool {
return a.(int) == b.(int)
}
+// int8Hasher implements Hasher for int8 keys.
+type int8Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *int8Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(int8)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not int8s.
+func (h *int8Hasher) Equal(a, b interface{}) bool {
+ return a.(int8) == b.(int8)
+}
+
+// int16Hasher implements Hasher for int16 keys.
+type int16Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *int16Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(int16)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not int16s.
+func (h *int16Hasher) Equal(a, b interface{}) bool {
+ return a.(int16) == b.(int16)
+}
+
+// int32Hasher implements Hasher for int32 keys.
+type int32Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *int32Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(int32)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not int32s.
+func (h *int32Hasher) Equal(a, b interface{}) bool {
+ return a.(int32) == b.(int32)
+}
+
+// int64Hasher implements Hasher for int64 keys.
+type int64Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *int64Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(int64)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not int64s.
+func (h *int64Hasher) Equal(a, b interface{}) bool {
+ return a.(int64) == b.(int64)
+}
+
+// uintHasher implements Hasher for uint keys.
+type uintHasher struct{}
+
+// Hash returns a hash for key.
+func (h *uintHasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(uint)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not uints.
+func (h *uintHasher) Equal(a, b interface{}) bool {
+ return a.(uint) == b.(uint)
+}
+
+// uint8Hasher implements Hasher for uint8 keys.
+type uint8Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *uint8Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(uint8)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not uint8s.
+func (h *uint8Hasher) Equal(a, b interface{}) bool {
+ return a.(uint8) == b.(uint8)
+}
+
+// uint16Hasher implements Hasher for uint16 keys.
+type uint16Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *uint16Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(uint16)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not uint16s.
+func (h *uint16Hasher) Equal(a, b interface{}) bool {
+ return a.(uint16) == b.(uint16)
+}
+
+// uint32Hasher implements Hasher for uint32 keys.
+type uint32Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *uint32Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(key.(uint32)))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not uint32s.
+func (h *uint32Hasher) Equal(a, b interface{}) bool {
+ return a.(uint32) == b.(uint32)
+}
+
+// uint64Hasher implements Hasher for uint64 keys.
+type uint64Hasher struct{}
+
+// Hash returns a hash for key.
+func (h *uint64Hasher) Hash(key interface{}) uint32 {
+ return hashUint64(key.(uint64))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not uint64s.
+func (h *uint64Hasher) Equal(a, b interface{}) bool {
+ return a.(uint64) == b.(uint64)
+}
+
// stringHasher implements Hasher for string keys.
type stringHasher struct{}
@@ -2237,7 +2392,7 @@ func (h *stringHasher) Equal(a, b interface{}) bool {
return a.(string) == b.(string)
}
-// byteSliceHasher implements Hasher for string keys.
+// byteSliceHasher implements Hasher for byte slice keys.
type byteSliceHasher struct{}
// Hash returns a hash for value.
@@ -2255,6 +2410,53 @@ func (h *byteSliceHasher) Equal(a, b interface{}) bool {
return bytes.Equal(a.([]byte), b.([]byte))
}
+// reflectIntHasher implements a reflection-based Hasher for int keys.
+type reflectIntHasher struct{}
+
+// Hash returns a hash for key.
+func (h *reflectIntHasher) Hash(key interface{}) uint32 {
+ return hashUint64(uint64(reflect.ValueOf(key).Int()))
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not ints.
+func (h *reflectIntHasher) Equal(a, b interface{}) bool {
+ return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int()
+}
+
+// reflectUintHasher implements a reflection-based Hasher for uint keys.
+type reflectUintHasher struct{}
+
+// Hash returns a hash for key.
+func (h *reflectUintHasher) Hash(key interface{}) uint32 {
+ return hashUint64(reflect.ValueOf(key).Uint())
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not ints.
+func (h *reflectUintHasher) Equal(a, b interface{}) bool {
+ return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint()
+}
+
+// reflectStringHasher implements a refletion-based Hasher for string keys.
+type reflectStringHasher struct{}
+
+// Hash returns a hash for value.
+func (h *reflectStringHasher) Hash(value interface{}) uint32 {
+ var hash uint32
+ s := reflect.ValueOf(value).String()
+ for i := 0; i < len(s); i++ {
+ hash = 31*hash + uint32(s[i])
+ }
+ return hash
+}
+
+// Equal returns true if a is equal to b. Otherwise returns false.
+// Panics if a and b are not strings.
+func (h *reflectStringHasher) Equal(a, b interface{}) bool {
+ return reflect.ValueOf(a).String() == reflect.ValueOf(b).String()
+}
+
// hashUint64 returns a 32-bit hash for a 64-bit value.
func hashUint64(value uint64) uint32 {
hash := value
@@ -2272,6 +2474,52 @@ type Comparer interface {
Compare(a, b interface{}) int
}
+// NewComparer returns the built-in comparer for a given key type.
+func NewComparer(key interface{}) Comparer {
+ // Attempt to use non-reflection based comparer first.
+ switch key.(type) {
+ case int:
+ return &intComparer{}
+ case int8:
+ return &int8Comparer{}
+ case int16:
+ return &int16Comparer{}
+ case int32:
+ return &int32Comparer{}
+ case int64:
+ return &int64Comparer{}
+ case uint:
+ return &uintComparer{}
+ case uint8:
+ return &uint8Comparer{}
+ case uint16:
+ return &uint16Comparer{}
+ case uint32:
+ return &uint32Comparer{}
+ case uint64:
+ return &uint64Comparer{}
+ case string:
+ return &stringComparer{}
+ case []byte:
+ return &byteSliceComparer{}
+ }
+
+ // Fallback to reflection-based comparer otherwise.
+ // This is used when caller wraps a type around a primitive type.
+ switch reflect.TypeOf(key).Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return &reflectIntComparer{}
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return &reflectUintComparer{}
+ case reflect.String:
+ return &reflectStringComparer{}
+ }
+
+ // If no comparers match then panic.
+ // This is a compile time issue so it should not return an error.
+ panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key))
+}
+
// intComparer compares two integers. Implements Comparer.
type intComparer struct{}
@@ -2286,6 +2534,132 @@ func (c *intComparer) Compare(a, b interface{}) int {
return 0
}
+// int8Comparer compares two int8 values. Implements Comparer.
+type int8Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int8.
+func (c *int8Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(int8), b.(int8); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// int16Comparer compares two int16 values. Implements Comparer.
+type int16Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int16.
+func (c *int16Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(int16), b.(int16); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// int32Comparer compares two int32 values. Implements Comparer.
+type int32Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int32.
+func (c *int32Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(int32), b.(int32); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// int64Comparer compares two int64 values. Implements Comparer.
+type int64Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int64.
+func (c *int64Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(int64), b.(int64); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// uintComparer compares two uint values. Implements Comparer.
+type uintComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an uint.
+func (c *uintComparer) Compare(a, b interface{}) int {
+ if i, j := a.(uint), b.(uint); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// uint8Comparer compares two uint8 values. Implements Comparer.
+type uint8Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an uint8.
+func (c *uint8Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(uint8), b.(uint8); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// uint16Comparer compares two uint16 values. Implements Comparer.
+type uint16Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an uint16.
+func (c *uint16Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(uint16), b.(uint16); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// uint32Comparer compares two uint32 values. Implements Comparer.
+type uint32Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an uint32.
+func (c *uint32Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(uint32), b.(uint32); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// uint64Comparer compares two uint64 values. Implements Comparer.
+type uint64Comparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an uint64.
+func (c *uint64Comparer) Compare(a, b interface{}) int {
+ if i, j := a.(uint64), b.(uint64); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
// stringComparer compares two strings. Implements Comparer.
type stringComparer struct{}
@@ -2303,3 +2677,40 @@ type byteSliceComparer struct{}
func (c *byteSliceComparer) Compare(a, b interface{}) int {
return bytes.Compare(a.([]byte), b.([]byte))
}
+
+// reflectIntComparer compares two int values using reflection. Implements Comparer.
+type reflectIntComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int.
+func (c *reflectIntComparer) Compare(a, b interface{}) int {
+ if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// reflectUintComparer compares two uint values using reflection. Implements Comparer.
+type reflectUintComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int.
+func (c *reflectUintComparer) Compare(a, b interface{}) int {
+ if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j {
+ return -1
+ } else if i > j {
+ return 1
+ }
+ return 0
+}
+
+// reflectStringComparer compares two string values using reflection. Implements Comparer.
+type reflectStringComparer struct{}
+
+// Compare returns -1 if a is less than b, returns 1 if a is greater than b, and
+// returns 0 if a is equal to b. Panic if a or b is not an int.
+func (c *reflectStringComparer) Compare(a, b interface{}) int {
+ return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String())
+}
diff --git a/immutable_test.go b/immutable_test.go
index c6dfc13..9275f1d 100644
--- a/immutable_test.go
+++ b/immutable_test.go
@@ -1010,13 +1010,14 @@ func TestMap_Set(t *testing.T) {
})
t.Run("NoDefaultHasher", func(t *testing.T) {
+ type T struct{}
var r string
func() {
defer func() { r = recover().(string) }()
m := NewMap(nil)
- m = m.Set(uint64(100), "bar")
+ m = m.Set(T{}, "bar")
}()
- if r != `immutable.Map.Set: must set hasher for uint64 type` {
+ if r != `immutable.NewHasher: must set hasher for immutable.T type` {
t.Fatalf("unexpected panic: %q", r)
}
})
@@ -1039,6 +1040,10 @@ func TestMap_Set(t *testing.T) {
// Ensure map can support overwrites as it expands.
func TestMap_Overwrite(t *testing.T) {
+ if testing.Short() {
+ t.Skip("short mode")
+ }
+
const n = 10000
m := NewMap(nil)
for i := 0; i < n; i++ {
@@ -1140,6 +1145,10 @@ func TestMap_Delete(t *testing.T) {
// Ensure map works even with hash conflicts.
func TestMap_LimitedHash(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping: short")
+ }
+
t.Run("Immutable", func(t *testing.T) {
h := mockHasher{
hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF },
@@ -1837,9 +1846,9 @@ func TestSortedMap_Set(t *testing.T) {
func() {
defer func() { r = recover().(string) }()
m := NewSortedMap(nil)
- m = m.Set(uint64(100), "bar")
+ m = m.Set(float64(100), "bar")
}()
- if r != `immutable.SortedMap.Set: must set comparer for uint64 type` {
+ if r != `immutable.NewComparer: must set comparer for float64 type` {
t.Fatalf("unexpected panic: %q", r)
}
})
@@ -2069,6 +2078,87 @@ func TestSortedMap_Iterator(t *testing.T) {
})
}
+func TestNewHasher(t *testing.T) {
+ t.Run("builtin", func(t *testing.T) {
+ t.Run("int", func(t *testing.T) { testNewHasher(t, int(100)) })
+ t.Run("int8", func(t *testing.T) { testNewHasher(t, int8(100)) })
+ t.Run("int16", func(t *testing.T) { testNewHasher(t, int16(100)) })
+ t.Run("int32", func(t *testing.T) { testNewHasher(t, int32(100)) })
+ t.Run("int64", func(t *testing.T) { testNewHasher(t, int64(100)) })
+
+ t.Run("uint", func(t *testing.T) { testNewHasher(t, uint(100)) })
+ t.Run("uint8", func(t *testing.T) { testNewHasher(t, uint8(100)) })
+ t.Run("uint16", func(t *testing.T) { testNewHasher(t, uint16(100)) })
+ t.Run("uint32", func(t *testing.T) { testNewHasher(t, uint32(100)) })
+ t.Run("uint64", func(t *testing.T) { testNewHasher(t, uint64(100)) })
+
+ t.Run("string", func(t *testing.T) { testNewHasher(t, "foo") })
+ t.Run("byteSlice", func(t *testing.T) { testNewHasher(t, []byte("foo")) })
+ })
+
+ t.Run("reflection", func(t *testing.T) {
+ type Int int
+ t.Run("int", func(t *testing.T) { testNewHasher(t, Int(100)) })
+
+ type Uint uint
+ t.Run("uint", func(t *testing.T) { testNewHasher(t, Uint(100)) })
+
+ type String string
+ t.Run("string", func(t *testing.T) { testNewHasher(t, String("foo")) })
+ })
+}
+
+func testNewHasher(t *testing.T, v interface{}) {
+ t.Helper()
+ h := NewHasher(v)
+ h.Hash(v)
+ if !h.Equal(v, v) {
+ t.Fatal("expected hash equality")
+ }
+}
+
+func TestNewComparer(t *testing.T) {
+ t.Run("builtin", func(t *testing.T) {
+ t.Run("int", func(t *testing.T) { testNewComparer(t, int(100), int(101)) })
+ t.Run("int8", func(t *testing.T) { testNewComparer(t, int8(100), int8(101)) })
+ t.Run("int16", func(t *testing.T) { testNewComparer(t, int16(100), int16(101)) })
+ t.Run("int32", func(t *testing.T) { testNewComparer(t, int32(100), int32(101)) })
+ t.Run("int64", func(t *testing.T) { testNewComparer(t, int64(100), int64(101)) })
+
+ t.Run("uint", func(t *testing.T) { testNewComparer(t, uint(100), uint(101)) })
+ t.Run("uint8", func(t *testing.T) { testNewComparer(t, uint8(100), uint8(101)) })
+ t.Run("uint16", func(t *testing.T) { testNewComparer(t, uint16(100), uint16(101)) })
+ t.Run("uint32", func(t *testing.T) { testNewComparer(t, uint32(100), uint32(101)) })
+ t.Run("uint64", func(t *testing.T) { testNewComparer(t, uint64(100), uint64(101)) })
+
+ t.Run("string", func(t *testing.T) { testNewComparer(t, "bar", "foo") })
+ t.Run("byteSlice", func(t *testing.T) { testNewComparer(t, []byte("bar"), []byte("foo")) })
+ })
+
+ t.Run("reflection", func(t *testing.T) {
+ type Int int
+ t.Run("int", func(t *testing.T) { testNewComparer(t, Int(100), Int(101)) })
+
+ type Uint uint
+ t.Run("uint", func(t *testing.T) { testNewComparer(t, Uint(100), Uint(101)) })
+
+ type String string
+ t.Run("string", func(t *testing.T) { testNewComparer(t, String("bar"), String("foo")) })
+ })
+}
+
+func testNewComparer(t *testing.T, x, y interface{}) {
+ t.Helper()
+ c := NewComparer(x)
+ if c.Compare(x, y) != -1 {
+ t.Fatal("expected comparer LT")
+ } else if c.Compare(x, x) != 0 {
+ t.Fatal("expected comparer EQ")
+ } else if c.Compare(y, x) != 1 {
+ t.Fatal("expected comparer GT")
+ }
+}
+
// TSortedMap represents a combined immutable and stdlib sorted map.
type TSortedMap struct {
im, prev *SortedMap