waku-org/go-waku

View on GitHub
waku/v2/protocol/peer_exchange/shard_lru.go

Summary

Maintainability
A
0 mins
Test Coverage
A
95%
package peer_exchange

import (
    "container/list"
    "fmt"
    "math/rand"
    "sync"

    "github.com/ethereum/go-ethereum/p2p/enode"
    "github.com/waku-org/go-waku/waku/v2/protocol"
    wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
    "github.com/waku-org/go-waku/waku/v2/utils"
)

type ShardInfo struct {
    clusterID uint16
    shard     uint16
}
type shardLRU struct {
    size       int // number of nodes allowed per shard
    idToNode   map[enode.ID][]*list.Element
    shardNodes map[ShardInfo]*list.List
    rng        *rand.Rand
    mu         sync.RWMutex
}

func newShardLRU(size int) *shardLRU {
    return &shardLRU{
        idToNode:   map[enode.ID][]*list.Element{},
        shardNodes: map[ShardInfo]*list.List{},
        size:       size,
        rng:        rand.New(rand.NewSource(rand.Int63())),
    }
}

type nodeWithShardInfo struct {
    key  ShardInfo
    node *enode.Node
}

// time complexity: O(number of previous indexes present for node.ID)
func (l *shardLRU) remove(node *enode.Node) {
    elements := l.idToNode[node.ID()]
    for _, element := range elements {
        key := element.Value.(nodeWithShardInfo).key
        l.shardNodes[key].Remove(element)
    }
    delete(l.idToNode, node.ID())
}

// if a node is removed for a list, remove it from idToNode too
func (l *shardLRU) removeFromIdToNode(ele *list.Element) {
    nodeID := ele.Value.(nodeWithShardInfo).node.ID()
    for ind, entries := range l.idToNode[nodeID] {
        if entries == ele {
            l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...)
            break
        }
    }
    if len(l.idToNode[nodeID]) == 0 {
        delete(l.idToNode, nodeID)
    }
}

func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) {
    shard, err := wenr.RelaySharding(node.Record())
    if err != nil {
        return nil, err
    }

    if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0
        shard = &protocol.RelayShards{
            ClusterID: 0,
            ShardIDs:  []uint16{0},
        }
    }

    return shard, nil
}

// time complexity: O(new number of indexes in node's shard)
func (l *shardLRU) add(node *enode.Node) error {
    shard, err := nodeToRelayShard(node)
    if err != nil {
        return err
    }

    elements := []*list.Element{}
    for _, index := range shard.ShardIDs {
        key := ShardInfo{
            shard.ClusterID,
            index,
        }
        if l.shardNodes[key] == nil {
            l.shardNodes[key] = list.New()
        }
        if l.shardNodes[key].Len() >= l.size {
            oldest := l.shardNodes[key].Back()
            l.removeFromIdToNode(oldest)
            l.shardNodes[key].Remove(oldest)
        }
        entry := l.shardNodes[key].PushFront(nodeWithShardInfo{
            key:  key,
            node: node,
        })
        elements = append(elements, entry)

    }
    l.idToNode[node.ID()] = elements

    return nil
}

// this will be called when the seq number of node is more than the one in cache
func (l *shardLRU) Add(node *enode.Node) error {
    l.mu.Lock()
    defer l.mu.Unlock()
    // removing bcz previous node might be subscribed to different shards, we need to remove node from those shards
    l.remove(node)
    return l.add(node)
}

// clusterIndex is nil when peers for no specific shard are requested
func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) {
    l.mu.Lock()
    defer l.mu.Unlock()

    availablePeers := l.len(clusterIndex)
    if availablePeers < neededPeers {
        neededPeers = availablePeers
    }
    // if clusterIndex is nil, then return all nodes
    var elements []*list.Element
    if clusterIndex == nil {
        elements = make([]*list.Element, 0, len(l.idToNode))
        for _, entries := range l.idToNode {
            elements = append(elements, entries[0])
        }
    } else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 {
        elements = make([]*list.Element, 0, entries.Len())
        for ent := entries.Back(); ent != nil; ent = ent.Prev() {
            elements = append(elements, ent)
        }
    }
    utils.Logger().Info(fmt.Sprintf("%d", len(elements)))
    indexes := l.rng.Perm(len(elements))[0:neededPeers]
    for _, ind := range indexes {
        node := elements[ind].Value.(nodeWithShardInfo).node
        nodes = append(nodes, node)
        // this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front
        l.remove(node)
        _ = l.add(node)
    }
    return nodes
}

// if clusterIndex is not nil, return len of nodes maintained for a given shard
// if clusterIndex is nil, return count of all nodes maintained
func (l *shardLRU) len(clusterIndex *ShardInfo) int {
    if clusterIndex == nil {
        return len(l.idToNode)
    }
    if entries := l.shardNodes[*clusterIndex]; entries != nil {
        return entries.Len()
    }
    return 0
}

// get the node with the given id, if it is present in cache
func (l *shardLRU) Get(id enode.ID) *enode.Node {
    l.mu.RLock()
    defer l.mu.RUnlock()

    if elements, ok := l.idToNode[id]; ok && len(elements) > 0 {
        return elements[0].Value.(nodeWithShardInfo).node
    }
    return nil
}