thomasdashney/primus-rooms-metroplex-adapter

View on GitHub
lib/primus-rooms-metroplex-adapter.js

Summary

Maintainability
C
7 hrs
Test Coverage
const PrimusRoomsAdapter = require('primus-rooms-adapter')
const Redis = require('ioredis')
const { compact, chunk, flatten, uniq, without } = require('lodash')
const async = require('async')

const MAX_SPARK_FORWARDS_PER_BATCH = 50000

const DEFAULT_ROOM_REFRESH_INTERVAL_MS = 300000 // metroplex interval default

// some libraries - namely engine.io - may take a while
// to determine that the connection is closed
// (ie. if using long polling instead of websockets).
//
// this number is set defensively high in order to have a
// high guarantee that the client is actually disconnected
const DEFAULT_HEARTBEAT_PING_INTERVAL_MS = 300000

// Key TTLs are a factor greater than amount of time we plan to refresh the key
const TTL_REFRESH_DRIFT_FACTOR = 1.2

const DEFAULT_KEYS_MATCH_SCAN_COUNT = 100
const MEMBERS_SSCAN_COUNT = 10000

module.exports = class PrimusRoomsMetroplexAdapter extends PrimusRoomsAdapter {
  constructor (redis, primus, opts = {}) {
    super(opts)

    if (redis instanceof Redis) {
      this.redis = redis
    } else {
      throw new Error('redis object is not an instance of ioredis')
    }

    if (primus) {
      this.primus = primus
    } else {
      throw new Error('primus object is not an instance of Primus')
    }

    if (!primus.metroplex) {
      throw new Error('PrimusRoomsMetroplexAdapter must be instantiated after metroplex')
    }

    this._namespace = opts.namespace || 'room_manager'
    this._identifier = opts.identifier || new Date().getTime()

    this._roomSparkSetTTLSeconds = Math.ceil(
      (this.primus.metroplex.interval || DEFAULT_ROOM_REFRESH_INTERVAL_MS) / 1000 * TTL_REFRESH_DRIFT_FACTOR
    )
    this._sparkRoomSetTTLSeconds = Math.ceil(
      DEFAULT_HEARTBEAT_PING_INTERVAL_MS / 1000 * TTL_REFRESH_DRIFT_FACTOR
    )
  }

  /**
  * Initializes redis key TTL refreshers
  */
  initialize () {
    this._initializeRoomSetTTLRefresher()
    this._initializeSparkSetsTTLRefresher()
  }

  /**
  * Adds a socket to a room
  * @param {String} id - Socket id
  * @param {String} room - Room name
  * @param {Function} callback - Callback
  */
  add (id, room, callback) {
    const roomSparkSetKey = this._roomSparkSetKey(room)
    const sparkRoomSetKey = this._sparkRoomSetKey(id)
    this.redis.multi()
      .sadd(roomSparkSetKey, id)
      .expire(roomSparkSetKey, this._roomSparkSetTTLSeconds)
      .sadd(sparkRoomSetKey, room)
      .expire(sparkRoomSetKey, this._sparkRoomSetTTLSeconds)
      .exec((err, results) => {
        if (err) return callback(err)
        this._validateMultiResults(results, err => callback(err))
      })
  }

  /**
  * Get rooms the socket is in or get all rooms if no socket ID is provided
  * returns Array of room names
  * @param {String} [id] Socket id
  * @param {Function} callback Callback
  */
  get (id, callback) {
    if (id) {
      this.redis.smembers(this._sparkRoomSetKey(id), callback)
    } else {
      this._keysMatchingPattern(`${this._namespace}:rooms:*:*`, (err, keys) => {
        if (err) return callback(err)

        const rooms = uniq(keys.map(key => {
          const parts = key.split(':')
          return parts[parts.length - 1]
        }))

        callback(null, rooms)
      })
    }
  }

  /**
  * Remove a socket from a room or all rooms if a room name is not passed.
  * @param {String} id - Socket id
  * @param {String} [room] - Room name
  * @param {Function} callback Callback
  */
  del (id, room, callback) {
    if (room) {
      this.redis.multi()
        .srem(this._roomSparkSetKey(room), id)
        .srem(this._sparkRoomSetKey(id), room)
        .exec((err, results) => {
          if (err) return callback(err)
          this._validateMultiResults(results, err => callback(err))
        })
    } else {
      this.redis.smembers(this._sparkRoomSetKey(id), (err, roomsForUser) => {
        if (err) return callback(err)

        const multi = this.redis.multi()

        multi.srem(this._roomSparkSetKey(room), id)

        roomsForUser.forEach(room => {
          multi.srem(this._roomSparkSetKey(room), id)
        })

        multi
          .del(this._sparkRoomSetKey(id))
          .exec((err, results) => {
            if (err) return callback(err)
            this._validateMultiResults(results, err => callback(err))
          })
      })
    }
  }

  /**
  * Broadcast a packet. If no rooms are provided, will broadcast to all rooms.
  * @param {*} data - Data to broadcast
  * @param {Object} [opts] - Broadcast options
  * @param {Array} [opts.except=[]] - Socket ids to exclude
  * @param {Array} [opts.rooms=[]] - List of rooms to broadcast to
  * @param {Function} [opts.transformer] - Message transformer
  * @param {Object} clients - Connected clients
  * @param {Function} [callback] - Optional callback
  */
  broadcast (data, opts, clients, callback) {
    opts = opts || {}
    opts.rooms = opts.rooms || []
    opts.except = opts.except || []
    opts.transformer = opts.transformer || (data => data[0])
    callback = callback || (err => err && console.error(err))

    const transformedData = opts.transformer(data)

    if (opts.rooms.length > 0) {
      async.waterfall([
        callback => {
          async.map(opts.rooms, this.clients.bind(this), (err, results) => {
            if (err) return callback(err)
            callback(null, flatten(results))
          })
        },

        (sparkIds, callback) => {
          const withoutExcluded = without(sparkIds, ...opts.except)
          this._sendToSparks(withoutExcluded, transformedData, callback)
        }
      ], callback)
    } else {
      // this is currently very inefficient for a large number of sparks,
      // as it simply scans all spark-room keys and parses the spark ids from the keys

      // ideally, we store set `room_manager:${namespace}_${identifier}:sparks` to quickly
      // look up all of the sparks across servers
      // however, this doesn't seem to get called in practice, so i'm avoiding over-optimizing.
      this._keysMatchingPattern(`${this._namespace}:sparks:*`, 5000, (err, keys) => {
        if (err) return
        const sparkIds = keys.map(key => key.substring(`${this._namespace}:sparks:`.length))
        this._sendToSparks(sparkIds, transformedData, callback)
      })
    }
  }

  /**
    * Get client ids connected to a room.
    * returns Array of spark ids
    * @param {String} room - Room name
    * @param {Function} callback - Callback
    */
  clients (room, callback) {
    async.waterfall([
      callback => {
        this._keysMatchingPattern(`${this._namespace}:rooms:*:${room}`, callback)
      },

      (roomKeys, callback) => {
        this._setMembersForKeys(roomKeys, callback)
      }
    ], callback)
  }

  /**
  * Remove all sockets from a room.
  * @param {String|Array} room - Room name
  * @param {Function} callback - Callback
  */
  empty (room, callback) {
    async.autoInject({
      roomKeys: callback => {
        this._keysMatchingPattern(`${this._namespace}:rooms:*:${room}`, callback)
      },

      sparkIds: (roomKeys, callback) => {
        this._setMembersForKeys(roomKeys, callback)
      },

      removeData: (roomKeys, sparkIds, callback) => {
        const multi = this.redis.multi().del(roomKeys)

        sparkIds.forEach(sparkId => {
          multi.srem(this._sparkRoomSetKey(sparkId), room)
        })

        multi.exec((err, results) => {
          if (err) return callback(err)
          this._validateMultiResults(results, err => callback(err))
        })
      }
    }, callback)
  }

  /**
  * Check if a room is empty.
  * returns `true` if the room is empty, else `false`
  * @param {String} room - Room name
  * @param {Function} callback - Callback
  */
  isEmpty (room, callback) {
    this._keysMatchingPattern(`${this._namespace}:rooms:*:${room}`, (err, roomKeys) => {
      if (err) return callback(err)
      callback(null, roomKeys.length === 0)
    })
  }

  /**
  * Reset the store. Will remove everything including all socket data from other adapter in the same cluster
  * @param {Function} callback - Callback
  */
  clear (callback) {
    this._keysMatchingPattern(`${this._namespace}:*`, (err, keys) => {
      if (err) return callback(err)
      this.redis.del(keys, err => callback(err))
    })
  }

  _initializeRoomSetTTLRefresher () {
    setInterval(() => {
      this._refreshRoomSetsTTL(this._roomSparkSetTTLSeconds, err => {
        if (err) console.error(new Error(`Error refreshing room->spark set TTL: ${err}`))
      })
    }, Math.floor(this._roomSparkSetTTLSeconds / TTL_REFRESH_DRIFT_FACTOR * 1000))
  }

  _refreshRoomSetsTTL (ttl, callback) {
    this._keysMatchingPattern(`${this._namespace}:rooms:${this._serverInstance}:*`, (err, roomKeys) => {
      if (err) return callback(err)

      const multi = this.redis.multi()
      roomKeys.forEach(roomKey => {
        multi.expire(roomKey, ttl)
      })
      multi.exec((err, results) => {
        if (err) return callback(err)
        this._validateMultiResults(results, err => callback(err))
      })
    })
  }

  _initializeSparkSetsTTLRefresher () {
    this.primus.on('connection', spark => {
      spark.on('heartbeat', () => {
        this.redis.expire(this._sparkRoomSetKey(spark.id), this._sparkRoomSetTTLSeconds, err => {
          if (err) console.error(new Error(`Error refreshing spark->room set TTL: ${err}`))
        })
      })
    })
  }

  get _serverInstance () {
    // primus.metroplex.address is set asynchronously,
    // which is why we can't access & set this value in the constructor

    // concatenate with unique identifier so server instances
    // are unique between crashes, preventing the re-usage
    // of past room lists
    if (!this._serverInstanceValue) {
      this._serverInstanceValue = `${this.primus.metroplex.address}_${this._identifier}`
    }

    return this._serverInstanceValue
  }

  _roomSparkSetKey (room) {
    return `${this._namespace}:rooms:${this._serverInstance}:${room}`
  }

  _sparkRoomSetKey (sparkId) {
    return `${this._namespace}:sparks:${sparkId}`
  }

  // redis recommends we use `scan` instead of `keys`
  // when iterating large sets of keys.
  // see https://redis.io/commands/scan
  _keysMatchingPattern (pattern, count, callback) {
    if (!callback) {
      callback = count
      count = DEFAULT_KEYS_MATCH_SCAN_COUNT
    }

    let cursor = 0
    const keys = []

    async.doWhilst(cb => {
      this.redis.scan(cursor, 'MATCH', pattern, 'COUNT', count, (err, result) => {
        if (err) return cb(err)
        const [newCursor, newKeys] = result
        cursor = newCursor
        keys.push(...newKeys)
        cb()
      })
    }, () => cursor !== '0', err => {
      if (err) return callback(err)
      callback(null, keys)
    })
  }

  // returns all members of n sets using sscan instead of ssmembers
  // useful in situations where a set can contain many members
  // and we want to avoid blocking redis, which is single-threaded
  // see https://redis.io/commands/scan
  _setMembersForKeys (keys, callback) {
    async.map(keys, (roomKey, callback) => {
      let cursor = 0
      const members = []

      async.doWhilst(cb => {
        this.redis.sscan(roomKey, cursor, 'COUNT', MEMBERS_SSCAN_COUNT, (err, result) => {
          if (err) return cb(err)
          const [newCursor, newMembers] = result
          cursor = newCursor
          members.push(...newMembers)
          cb()
        })
      }, () => cursor !== '0', err => {
        if (err) return callback(err)
        callback(null, members)
      })
    }, (err, memberGroups) => {
      if (err) return callback(err)
      callback(null, flatten(memberGroups))
    })
  }

  // redis commands grouped using `multi` or `pipeline`
  // may yield errors per result.
  // this helper checks each result for errors and combines
  // them into a single error that can be handled
  _validateMultiResults (results, callback) {
    const errors = results.map(result => result[0])

    if (errors.some(error => !!error)) {
      callback(new Error(`${compact(errors)}`))
    } else {
      callback()
    }
  }

  _sendToSparks (sparkIds, data, callback) {
    // primus.forward.sparks runs into memory issues when called
    // with too many spark ids. to avoid this, we call it
    // using sub-groups of spark ids
    const sparkIdChunks = chunk(sparkIds, MAX_SPARK_FORWARDS_PER_BATCH)
    async.each(sparkIdChunks, (sparkIds, callback) => {
      this.primus.forward.sparks(sparkIds, data, callback)
    }, callback)
  }
}