// Package hashmap implements persistent hashmap.
package hashmap

import (

const (
    chunkBits = 5
    nodeCap   = 1 << chunkBits
    chunkMask = nodeCap - 1

// Equal is the type of a function that reports whether two keys are equal.
type Equal func(k1, k2 any) bool

// Hash is the type of a function that returns the hash code of a key.
type Hash func(k any) uint32

// New takes an equality function and a hash function, and returns an empty
// Map.
func New(e Equal, h Hash) Map {
    return &hashMap{0, emptyBitmapNode, nil, e, h}

type hashMap struct {
    count int
    root  node
    nilV  *any
    equal Equal
    hash  Hash

func (m *hashMap) Len() int {
    return m.count

func (m *hashMap) Index(k any) (any, bool) {
    if k == nil {
        if m.nilV == nil {
            return nil, false
        return *m.nilV, true
    return m.root.find(0, m.hash(k), k, m.equal)

func (m *hashMap) Assoc(k, v any) Map {
    if k == nil {
        newCount := m.count
        if m.nilV == nil {
        return &hashMap{newCount, m.root, &v, m.equal, m.hash}
    newRoot, added := m.root.assoc(0, m.hash(k), k, v, m.hash, m.equal)
    newCount := m.count
    if added {
    return &hashMap{newCount, newRoot, m.nilV, m.equal, m.hash}

func (m *hashMap) Dissoc(k any) Map {
    if k == nil {
        newCount := m.count
        if m.nilV != nil {
        return &hashMap{newCount, m.root, nil, m.equal, m.hash}
    newRoot, deleted := m.root.without(0, m.hash(k), k, m.equal)
    newCount := m.count
    if deleted {
    return &hashMap{newCount, newRoot, m.nilV, m.equal, m.hash}

func (m *hashMap) Iterator() Iterator {
    if m.nilV != nil {
        return &nilVIterator{true, *m.nilV, m.root.iterator()}
    return m.root.iterator()

type nilVIterator struct {
    atNil bool
    nilV  any
    tail  Iterator

func (it *nilVIterator) Elem() (any, any) {
    if it.atNil {
        return nil, it.nilV
    return it.tail.Elem()

func (it *nilVIterator) HasElem() bool {
    return it.atNil || it.tail.HasElem()

func (it *nilVIterator) Next() {
    if it.atNil {
        it.atNil = false
    } else {

func (m *hashMap) MarshalJSON() ([]byte, error) {
    var buf bytes.Buffer
    first := true
    for it := m.Iterator(); it.HasElem(); it.Next() {
        if first {
            first = false
        } else {
        k, v := it.Elem()
        kString, err := convertKey(k)
        if err != nil {
            return nil, err
        kBytes, err := json.Marshal(kString)
        if err != nil {
            return nil, err
        vBytes, err := json.Marshal(v)
        if err != nil {
            return nil, err
    return buf.Bytes(), nil

// convertKey converts a map key to a string. The implementation matches the
// behavior of how json.Marshal encodes keys of the builtin map type.
func convertKey(k any) (string, error) {
    kref := reflect.ValueOf(k)
    if kref.Kind() == reflect.String {
        return kref.String(), nil
    if t, ok := k.(encoding.TextMarshaler); ok {
        b2, err := t.MarshalText()
        if err != nil {
            return "", err
        return string(b2), nil
    switch kref.Kind() {
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return strconv.FormatInt(kref.Int(), 10), nil
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
        return strconv.FormatUint(kref.Uint(), 10), nil
    return "", fmt.Errorf("unsupported key type %T", k)

// node is an interface for all nodes in the hash map tree.
type node interface {
    // assoc adds a new pair of key and value. It returns the new node, and
    // whether the key did not exist before (i.e. a new pair has been added,
    // instead of replaced).
    assoc(shift, hash uint32, k, v any, h Hash, eq Equal) (node, bool)
    // without removes a key. It returns the new node and whether the key did
    // not exist before (i.e. a key was indeed removed).
    without(shift, hash uint32, k any, eq Equal) (node, bool)
    // find finds the value for a key. It returns the found value (if any) and
    // whether such a pair exists.
    find(shift, hash uint32, k any, eq Equal) (any, bool)
    // iterator returns an iterator.
    iterator() Iterator

// arrayNode stores all of its children in an array. The array is always at
// least 1/4 full, otherwise it will be packed into a bitmapNode.
type arrayNode struct {
    nChildren int
    children  [nodeCap]node

func (n *arrayNode) withNewChild(i uint32, newChild node, d int) *arrayNode {
    newChildren := n.children
    newChildren[i] = newChild
    return &arrayNode{n.nChildren + d, newChildren}

func (n *arrayNode) assoc(shift, hash uint32, k, v any, h Hash, eq Equal) (node, bool) {
    idx := chunk(shift, hash)
    child := n.children[idx]
    if child == nil {
        newChild, _ := emptyBitmapNode.assoc(shift+chunkBits, hash, k, v, h, eq)
        return n.withNewChild(idx, newChild, 1), true
    newChild, added := child.assoc(shift+chunkBits, hash, k, v, h, eq)
    return n.withNewChild(idx, newChild, 0), added

func (n *arrayNode) without(shift, hash uint32, k any, eq Equal) (node, bool) {
    idx := chunk(shift, hash)
    child := n.children[idx]
    if child == nil {
        return n, false
    newChild, _ := child.without(shift+chunkBits, hash, k, eq)
    if newChild == child {
        return n, false
    if newChild == emptyBitmapNode {
        if n.nChildren <= nodeCap/4 {
            // less than 1/4 full; shrink
            return n.pack(int(idx)), true
        return n.withNewChild(idx, nil, -1), true
    return n.withNewChild(idx, newChild, 0), true

func (n *arrayNode) pack(skip int) *bitmapNode {
    newNode := bitmapNode{0, make([]mapEntry, n.nChildren-1)}
    j := 0
    for i, child := range n.children {
        // TODO(xiaq): Benchmark performance difference after unrolling this
        // into two loops without the if
        if i != skip && child != nil {
            newNode.bitmap |= 1 << uint(i)
            newNode.entries[j].value = child
    return &newNode

func (n *arrayNode) find(shift, hash uint32, k any, eq Equal) (any, bool) {
    idx := chunk(shift, hash)
    child := n.children[idx]
    if child == nil {
        return nil, false
    return child.find(shift+chunkBits, hash, k, eq)

func (n *arrayNode) iterator() Iterator {
    it := &arrayNodeIterator{n, 0, nil}
    return it

type arrayNodeIterator struct {
    n       *arrayNode
    index   int
    current Iterator

func (it *arrayNodeIterator) fixCurrent() {
    for ; it.index < nodeCap && it.n.children[it.index] == nil; it.index++ {
    if it.index < nodeCap {
        it.current = it.n.children[it.index].iterator()
    } else {
        it.current = nil

func (it *arrayNodeIterator) Elem() (any, any) {
    return it.current.Elem()

func (it *arrayNodeIterator) HasElem() bool {
    return it.current != nil

func (it *arrayNodeIterator) Next() {
    if !it.current.HasElem() {

var emptyBitmapNode = &bitmapNode{}

type bitmapNode struct {
    bitmap  uint32
    entries []mapEntry

// mapEntry is a map entry. When used in a collisionNode, it is also an entry
// with non-nil key. When used in a bitmapNode, it is also abused to represent
// children when the key is nil.
type mapEntry struct {
    key   any
    value any

func chunk(shift, hash uint32) uint32 {
    return (hash >> shift) & chunkMask

func bitpos(shift, hash uint32) uint32 {
    return 1 << chunk(shift, hash)

func index(bitmap, bit uint32) uint32 {
    return popCount(bitmap & (bit - 1))

const (
    m1  uint32 = 0x55555555
    m2  uint32 = 0x33333333
    m4  uint32 = 0x0f0f0f0f
    m8  uint32 = 0x00ff00ff
    m16 uint32 = 0x0000ffff

// TODO(xiaq): Use an optimized implementation.
func popCount(u uint32) uint32 {
    u = (u & m1) + ((u >> 1) & m1)
    u = (u & m2) + ((u >> 2) & m2)
    u = (u & m4) + ((u >> 4) & m4)
    u = (u & m8) + ((u >> 8) & m8)
    u = (u & m16) + ((u >> 16) & m16)
    return u

func createNode(shift uint32, k1 any, v1 any, h2 uint32, k2 any, v2 any, h Hash, eq Equal) node {
    h1 := h(k1)
    if h1 == h2 {
        return &collisionNode{h1, []mapEntry{{k1, v1}, {k2, v2}}}
    n, _ := emptyBitmapNode.assoc(shift, h1, k1, v1, h, eq)
    n, _ = n.assoc(shift, h2, k2, v2, h, eq)
    return n

func (n *bitmapNode) unpack(shift, idx uint32, newChild node, h Hash, eq Equal) *arrayNode {
    var newNode arrayNode
    newNode.nChildren = len(n.entries) + 1
    newNode.children[idx] = newChild
    j := 0
    for i := uint(0); i < nodeCap; i++ {
        if (n.bitmap>>i)&1 != 0 {
            entry := n.entries[j]
            if entry.key == nil {
                newNode.children[i] = entry.value.(node)
            } else {
                newNode.children[i], _ = emptyBitmapNode.assoc(
                    shift+chunkBits, h(entry.key), entry.key, entry.value, h, eq)
    return &newNode

func (n *bitmapNode) withoutEntry(bit, idx uint32) *bitmapNode {
    if n.bitmap == bit {
        return emptyBitmapNode
    return &bitmapNode{n.bitmap ^ bit, withoutEntry(n.entries, idx)}

func withoutEntry(entries []mapEntry, idx uint32) []mapEntry {
    newEntries := make([]mapEntry, len(entries)-1)
    copy(newEntries[:idx], entries[:idx])
    copy(newEntries[idx:], entries[idx+1:])
    return newEntries

func (n *bitmapNode) withReplacedEntry(i uint32, entry mapEntry) *bitmapNode {
    return &bitmapNode{n.bitmap, replaceEntry(n.entries, i, entry.key, entry.value)}

func replaceEntry(entries []mapEntry, i uint32, k, v any) []mapEntry {
    newEntries := append([]mapEntry(nil), entries...)
    newEntries[i] = mapEntry{k, v}
    return newEntries

func (n *bitmapNode) assoc(shift, hash uint32, k, v any, h Hash, eq Equal) (node, bool) {
    bit := bitpos(shift, hash)
    idx := index(n.bitmap, bit)
    if n.bitmap&bit == 0 {
        // Entry does not exist yet
        nEntries := len(n.entries)
        if nEntries >= nodeCap/2 {
            // Unpack into an arrayNode
            newNode, _ := emptyBitmapNode.assoc(shift+chunkBits, hash, k, v, h, eq)
            return n.unpack(shift, chunk(shift, hash), newNode, h, eq), true
        // Add a new entry
        newEntries := make([]mapEntry, len(n.entries)+1)
        copy(newEntries[:idx], n.entries[:idx])
        newEntries[idx] = mapEntry{k, v}
        copy(newEntries[idx+1:], n.entries[idx:])
        return &bitmapNode{n.bitmap | bit, newEntries}, true
    // Entry exists
    entry := n.entries[idx]
    if entry.key == nil {
        // Non-leaf child
        child := entry.value.(node)
        newChild, added := child.assoc(shift+chunkBits, hash, k, v, h, eq)
        return n.withReplacedEntry(idx, mapEntry{nil, newChild}), added
    // Leaf
    if eq(k, entry.key) {
        // Identical key, replace
        return n.withReplacedEntry(idx, mapEntry{k, v}), false
    // Create and insert new inner node
    newNode := createNode(shift+chunkBits, entry.key, entry.value, hash, k, v, h, eq)
    return n.withReplacedEntry(idx, mapEntry{nil, newNode}), true

func (n *bitmapNode) without(shift, hash uint32, k any, eq Equal) (node, bool) {
    bit := bitpos(shift, hash)
    if n.bitmap&bit == 0 {
        return n, false
    idx := index(n.bitmap, bit)
    entry := n.entries[idx]
    if entry.key == nil {
        // Non-leaf child
        child := entry.value.(node)
        newChild, deleted := child.without(shift+chunkBits, hash, k, eq)
        if newChild == child {
            return n, false
        if newChild == emptyBitmapNode {
            return n.withoutEntry(bit, idx), true
        return n.withReplacedEntry(idx, mapEntry{nil, newChild}), deleted
    } else if eq(entry.key, k) {
        // Leaf, and this is the entry to delete.
        return n.withoutEntry(bit, idx), true
    // Nothing to delete.
    return n, false

func (n *bitmapNode) find(shift, hash uint32, k any, eq Equal) (any, bool) {
    bit := bitpos(shift, hash)
    if n.bitmap&bit == 0 {
        return nil, false
    idx := index(n.bitmap, bit)
    entry := n.entries[idx]
    if entry.key == nil {
        child := entry.value.(node)
        return child.find(shift+chunkBits, hash, k, eq)
    } else if eq(entry.key, k) {
        return entry.value, true
    return nil, false

func (n *bitmapNode) iterator() Iterator {
    it := &bitmapNodeIterator{n, 0, nil}
    return it

type bitmapNodeIterator struct {
    n       *bitmapNode
    index   int
    current Iterator

func (it *bitmapNodeIterator) fixCurrent() {
    if it.index < len(it.n.entries) {
        entry := it.n.entries[it.index]
        if entry.key == nil {
            it.current = entry.value.(node).iterator()
        } else {
            it.current = nil
    } else {
        it.current = nil

func (it *bitmapNodeIterator) Elem() (any, any) {
    if it.current != nil {
        return it.current.Elem()
    entry := it.n.entries[it.index]
    return entry.key, entry.value

func (it *bitmapNodeIterator) HasElem() bool {
    return it.index < len(it.n.entries)

func (it *bitmapNodeIterator) Next() {
    if it.current != nil {
    if it.current == nil || !it.current.HasElem() {

type collisionNode struct {
    hash    uint32
    entries []mapEntry

func (n *collisionNode) assoc(shift, hash uint32, k, v any, h Hash, eq Equal) (node, bool) {
    if hash == n.hash {
        idx := n.findIndex(k, eq)
        if idx != -1 {
            return &collisionNode{
                n.hash, replaceEntry(n.entries, uint32(idx), k, v)}, false
        newEntries := make([]mapEntry, len(n.entries)+1)
        copy(newEntries[:len(n.entries)], n.entries[:])
        newEntries[len(n.entries)] = mapEntry{k, v}
        return &collisionNode{n.hash, newEntries}, true
    // Wrap in a bitmapNode and add the entry
    wrap := bitmapNode{bitpos(shift, n.hash), []mapEntry{{nil, n}}}
    return wrap.assoc(shift, hash, k, v, h, eq)

func (n *collisionNode) without(shift, hash uint32, k any, eq Equal) (node, bool) {
    idx := n.findIndex(k, eq)
    if idx == -1 {
        return n, false
    if len(n.entries) == 1 {
        return emptyBitmapNode, true
    return &collisionNode{n.hash, withoutEntry(n.entries, uint32(idx))}, true

func (n *collisionNode) find(shift, hash uint32, k any, eq Equal) (any, bool) {
    idx := n.findIndex(k, eq)
    if idx == -1 {
        return nil, false
    return n.entries[idx].value, true

func (n *collisionNode) findIndex(k any, eq Equal) int {
    for i, entry := range n.entries {
        if eq(k, entry.key) {
            return i
    return -1

func (n *collisionNode) iterator() Iterator {
    return &collisionNodeIterator{n, 0}

type collisionNodeIterator struct {
    n     *collisionNode
    index int

func (it *collisionNodeIterator) Elem() (any, any) {
    entry := it.n.entries[it.index]
    return entry.key, entry.value

func (it *collisionNodeIterator) HasElem() bool {
    return it.index < len(it.n.entries)

func (it *collisionNodeIterator) Next() {