Ilshidur/express-socket.io-jwt

View on GitHub
src/index.js

Summary

Maintainability
A
0 mins
Test Coverage
const jwt = require('jsonwebtoken');
const { isEqual: ipEquals } = require('ip');

// Default input functions
const defaultJwtFromRequest = ({ cookie, queryParams }) => cookie.jwt || queryParams.token;
const defaultConnectionNameFromRequest = () => null;
const defaultOnSocketParseError = (err, socket, next) => next(new Error('Could not decode socket token'));

const indexSockets = sockets => sockets.reduce((acc, socket) => ({
  ...acc, [socket.connectionName]: socket,
}), {});

const authenticateSocketMiddleware =
  io =>
    ({
      matchSocket = 'ip',
      required = true,
    }) =>
      async (req, res, next) => {
        function filterSockets(request, namespace, filterFunc) {
          return Object.values(namespace.connected).filter(socket => filterFunc(request, socket));
        }

        let sockets;
        try {
          const namespace = io.of('/');

          switch (matchSocket) {
            case 'ip':
              sockets = filterSockets(
                req,
                namespace,
                (request, socket) => ipEquals(request.ip, socket.handshake.address),
              );
              break;
            case 'cookie':
              sockets = filterSockets(
                req,
                namespace,
                (request, socket) => (request.cookies ? request.cookies.io === socket.id : false),
              );
              break;
            default:
              sockets = filterSockets(
                req,
                namespace,
                matchSocket,
              );
          }
        } catch (err) {
          return next(err);
        }

        if (sockets.length === 0 && required) {
          return next(new Error('Unauthorized'));
        }

        const indexedSockets = indexSockets(sockets);

        const validateSockets = sockets.map(async (socket) => {
          if (!required) {
            return;
          }

          if (socket && !socket.token && required) {
            throw new Error('Unauthorized');
          }

          if (socket && !socket.payload && required) {
            throw new Error('Unauthorized');
          }
        });

        try {
          await Promise.all(validateSockets);
        } catch (err) {
          return next(err);
        }

        req.getSocket = connection => (connection ? indexedSockets[connection] : sockets[0]);

        return next();
      };

async function parseSocketJwt(socket, { jwtFromRequest, secret, connectionNameFromRequest }) {
  const cookie = socket.request;
  const queryParams = socket.handshake.query;

  // TODO: Option for the request to wait for the socket to prepare

  socket.token = await jwtFromRequest({ cookie, queryParams }, socket);
  socket.payload = socket.token ? jwt.verify(socket.token, secret) : null;
  socket.connectionName = await connectionNameFromRequest({ cookie, queryParams }, socket);
}

const socketMiddleware = ({
  secret,
  jwtFromRequest = defaultJwtFromRequest,
  connectionNameFromRequest = defaultConnectionNameFromRequest,
  onSocketParseError = defaultOnSocketParseError,
} = {}) => async (socket, next) => {
  if (!secret) {
    const err = new Error('Cannot decode socket JWT because secret is missing in the server');
    onSocketParseError(err, socket, next);
    return;
  }

  try {
    await parseSocketJwt(socket, { secret, jwtFromRequest, connectionNameFromRequest });
  } catch (err) {
    onSocketParseError(err, socket, next);
    return;
  }
  next();
};

// TODO: Allow choice of namespaces to '.use' (string or function)
const createMiddleware = (io, opts = {}) => {
  if (!io) {
    throw new Error('Missing socket.io server');
  }

  // TODO: Remove this side effect and make the function pure
  io.use(socketMiddleware(opts));
  return authenticateSocketMiddleware(io);
};

module.exports.createMiddleware = createMiddleware;