waku/v2/protocol/peer_exchange/shard_lru.go
package peer_exchange
import (
"container/list"
"fmt"
"math/rand"
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/waku-org/go-waku/waku/v2/protocol"
wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/utils"
)
type ShardInfo struct {
clusterID uint16
shard uint16
}
type shardLRU struct {
size int // number of nodes allowed per shard
idToNode map[enode.ID][]*list.Element
shardNodes map[ShardInfo]*list.List
rng *rand.Rand
mu sync.RWMutex
}
func newShardLRU(size int) *shardLRU {
return &shardLRU{
idToNode: map[enode.ID][]*list.Element{},
shardNodes: map[ShardInfo]*list.List{},
size: size,
rng: rand.New(rand.NewSource(rand.Int63())),
}
}
type nodeWithShardInfo struct {
key ShardInfo
node *enode.Node
}
// time complexity: O(number of previous indexes present for node.ID)
func (l *shardLRU) remove(node *enode.Node) {
elements := l.idToNode[node.ID()]
for _, element := range elements {
key := element.Value.(nodeWithShardInfo).key
l.shardNodes[key].Remove(element)
}
delete(l.idToNode, node.ID())
}
// if a node is removed for a list, remove it from idToNode too
func (l *shardLRU) removeFromIdToNode(ele *list.Element) {
nodeID := ele.Value.(nodeWithShardInfo).node.ID()
for ind, entries := range l.idToNode[nodeID] {
if entries == ele {
l.idToNode[nodeID] = append(l.idToNode[nodeID][:ind], l.idToNode[nodeID][ind+1:]...)
break
}
}
if len(l.idToNode[nodeID]) == 0 {
delete(l.idToNode, nodeID)
}
}
func nodeToRelayShard(node *enode.Node) (*protocol.RelayShards, error) {
shard, err := wenr.RelaySharding(node.Record())
if err != nil {
return nil, err
}
if shard == nil { // if no shard info, then add to node to Cluster 0, Index 0
shard = &protocol.RelayShards{
ClusterID: 0,
ShardIDs: []uint16{0},
}
}
return shard, nil
}
// time complexity: O(new number of indexes in node's shard)
func (l *shardLRU) add(node *enode.Node) error {
shard, err := nodeToRelayShard(node)
if err != nil {
return err
}
elements := []*list.Element{}
for _, index := range shard.ShardIDs {
key := ShardInfo{
shard.ClusterID,
index,
}
if l.shardNodes[key] == nil {
l.shardNodes[key] = list.New()
}
if l.shardNodes[key].Len() >= l.size {
oldest := l.shardNodes[key].Back()
l.removeFromIdToNode(oldest)
l.shardNodes[key].Remove(oldest)
}
entry := l.shardNodes[key].PushFront(nodeWithShardInfo{
key: key,
node: node,
})
elements = append(elements, entry)
}
l.idToNode[node.ID()] = elements
return nil
}
// this will be called when the seq number of node is more than the one in cache
func (l *shardLRU) Add(node *enode.Node) error {
l.mu.Lock()
defer l.mu.Unlock()
// removing bcz previous node might be subscribed to different shards, we need to remove node from those shards
l.remove(node)
return l.add(node)
}
// clusterIndex is nil when peers for no specific shard are requested
func (l *shardLRU) GetRandomNodes(clusterIndex *ShardInfo, neededPeers int) (nodes []*enode.Node) {
l.mu.Lock()
defer l.mu.Unlock()
availablePeers := l.len(clusterIndex)
if availablePeers < neededPeers {
neededPeers = availablePeers
}
// if clusterIndex is nil, then return all nodes
var elements []*list.Element
if clusterIndex == nil {
elements = make([]*list.Element, 0, len(l.idToNode))
for _, entries := range l.idToNode {
elements = append(elements, entries[0])
}
} else if entries := l.shardNodes[*clusterIndex]; entries != nil && entries.Len() != 0 {
elements = make([]*list.Element, 0, entries.Len())
for ent := entries.Back(); ent != nil; ent = ent.Prev() {
elements = append(elements, ent)
}
}
utils.Logger().Info(fmt.Sprintf("%d", len(elements)))
indexes := l.rng.Perm(len(elements))[0:neededPeers]
for _, ind := range indexes {
node := elements[ind].Value.(nodeWithShardInfo).node
nodes = append(nodes, node)
// this removes the node from all list (all cluster/shard pair that the node has) and adds it to the front
l.remove(node)
_ = l.add(node)
}
return nodes
}
// if clusterIndex is not nil, return len of nodes maintained for a given shard
// if clusterIndex is nil, return count of all nodes maintained
func (l *shardLRU) len(clusterIndex *ShardInfo) int {
if clusterIndex == nil {
return len(l.idToNode)
}
if entries := l.shardNodes[*clusterIndex]; entries != nil {
return entries.Len()
}
return 0
}
// get the node with the given id, if it is present in cache
func (l *shardLRU) Get(id enode.ID) *enode.Node {
l.mu.RLock()
defer l.mu.RUnlock()
if elements, ok := l.idToNode[id]; ok && len(elements) > 0 {
return elements[0].Value.(nodeWithShardInfo).node
}
return nil
}