socketio/socket.io

View on GitHub
lib/socket.ts

Summary

Maintainability
F
3 days
Test Coverage
import { Packet, PacketType } from "socket.io-parser";
import debugModule from "debug";
import type { Server } from "./index";
import {
  EventParams,
  EventNames,
  EventsMap,
  StrictEventEmitter,
  DefaultEventsMap,
} from "./typed-events";
import type { Client } from "./client";
import type { Namespace, NamespaceReservedEventsMap } from "./namespace";
import type { IncomingMessage, IncomingHttpHeaders } from "http";
import type {
  Adapter,
  BroadcastFlags,
  Room,
  SocketId,
} from "socket.io-adapter";
import base64id from "base64id";
import type { ParsedUrlQuery } from "querystring";
import { BroadcastOperator } from "./broadcast-operator";

const debug = debugModule("socket.io:socket");

type ClientReservedEvents = "connect_error";

export interface SocketReservedEventsMap {
  disconnect: (reason: string) => void;
  disconnecting: (reason: string) => void;
  error: (err: Error) => void;
}

// EventEmitter reserved events: https://nodejs.org/api/events.html#events_event_newlistener
export interface EventEmitterReservedEventsMap {
  newListener: (
    eventName: string | Symbol,
    listener: (...args: any[]) => void
  ) => void;
  removeListener: (
    eventName: string | Symbol,
    listener: (...args: any[]) => void
  ) => void;
}

export const RESERVED_EVENTS: ReadonlySet<string | Symbol> = new Set<
  | ClientReservedEvents
  | keyof NamespaceReservedEventsMap<never, never, never, never>
  | keyof SocketReservedEventsMap
  | keyof EventEmitterReservedEventsMap
>(<const>[
  "connect",
  "connect_error",
  "disconnect",
  "disconnecting",
  "newListener",
  "removeListener",
]);

/**
 * The handshake details
 */
export interface Handshake {
  /**
   * The headers sent as part of the handshake
   */
  headers: IncomingHttpHeaders;

  /**
   * The date of creation (as string)
   */
  time: string;

  /**
   * The ip of the client
   */
  address: string;

  /**
   * Whether the connection is cross-domain
   */
  xdomain: boolean;

  /**
   * Whether the connection is secure
   */
  secure: boolean;

  /**
   * The date of creation (as unix timestamp)
   */
  issued: number;

  /**
   * The request URL string
   */
  url: string;

  /**
   * The query object
   */
  query: ParsedUrlQuery;

  /**
   * The auth object
   */
  auth: { [key: string]: any };
}

/**
 * `[eventName, ...args]`
 */
export type Event = [string, ...any[]];

export class Socket<
  ListenEvents extends EventsMap = DefaultEventsMap,
  EmitEvents extends EventsMap = ListenEvents,
  ServerSideEvents extends EventsMap = DefaultEventsMap,
  SocketData = any
> extends StrictEventEmitter<
  ListenEvents,
  EmitEvents,
  SocketReservedEventsMap
> {
  public readonly id: SocketId;
  public readonly handshake: Handshake;
  /**
   * Additional information that can be attached to the Socket instance and which will be used in the fetchSockets method
   */
  public data: Partial<SocketData> = {};

  public connected: boolean = false;

  private readonly server: Server<
    ListenEvents,
    EmitEvents,
    ServerSideEvents,
    SocketData
  >;
  private readonly adapter: Adapter;
  private acks: Map<number, () => void> = new Map();
  private fns: Array<(event: Event, next: (err?: Error) => void) => void> = [];
  private flags: BroadcastFlags & { timeout?: number } = {};
  private _anyListeners?: Array<(...args: any[]) => void>;

  /**
   * Interface to a `Client` for a given `Namespace`.
   *
   * @param {Namespace} nsp
   * @param {Client} client
   * @param {Object} auth
   * @package
   */
  constructor(
    readonly nsp: Namespace<ListenEvents, EmitEvents, ServerSideEvents>,
    readonly client: Client<ListenEvents, EmitEvents, ServerSideEvents>,
    auth: object
  ) {
    super();
    this.server = nsp.server;
    this.adapter = this.nsp.adapter;
    if (client.conn.protocol === 3) {
      // @ts-ignore
      this.id = nsp.name !== "/" ? nsp.name + "#" + client.id : client.id;
    } else {
      this.id = base64id.generateId(); // don't reuse the Engine.IO id because it's sensitive information
    }
    this.handshake = this.buildHandshake(auth);
  }

  /**
   * Builds the `handshake` BC object
   *
   * @private
   */
  private buildHandshake(auth: object): Handshake {
    return {
      headers: this.request.headers,
      time: new Date() + "",
      address: this.conn.remoteAddress,
      xdomain: !!this.request.headers.origin,
      // @ts-ignore
      secure: !!this.request.connection.encrypted,
      issued: +new Date(),
      url: this.request.url!,
      // @ts-ignore
      query: this.request._query,
      auth,
    };
  }

  /**
   * Emits to this client.
   *
   * @return Always returns `true`.
   * @public
   */
  public emit<Ev extends EventNames<EmitEvents>>(
    ev: Ev,
    ...args: EventParams<EmitEvents, Ev>
  ): boolean {
    if (RESERVED_EVENTS.has(ev)) {
      throw new Error(`"${ev}" is a reserved event name`);
    }
    const data: any[] = [ev, ...args];
    const packet: any = {
      type: PacketType.EVENT,
      data: data,
    };

    // access last argument to see if it's an ACK callback
    if (typeof data[data.length - 1] === "function") {
      const id = this.nsp._ids++;
      debug("emitting packet with ack id %d", id);

      this.registerAckCallback(id, data.pop());
      packet.id = id;
    }

    const flags = Object.assign({}, this.flags);
    this.flags = {};

    this.packet(packet, flags);

    return true;
  }

  /**
   * @private
   */
  private registerAckCallback(id: number, ack: (...args: any[]) => void): void {
    const timeout = this.flags.timeout;
    if (timeout === undefined) {
      this.acks.set(id, ack);
      return;
    }

    const timer = setTimeout(() => {
      debug("event with ack id %d has timed out after %d ms", id, timeout);
      this.acks.delete(id);
      ack.call(this, new Error("operation has timed out"));
    }, timeout);

    this.acks.set(id, (...args) => {
      clearTimeout(timer);
      ack.apply(this, [null, ...args]);
    });
  }

  /**
   * Targets a room when broadcasting.
   *
   * @param room
   * @return self
   * @public
   */
  public to(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> {
    return this.newBroadcastOperator().to(room);
  }

  /**
   * Targets a room when broadcasting.
   *
   * @param room
   * @return self
   * @public
   */
  public in(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> {
    return this.newBroadcastOperator().in(room);
  }

  /**
   * Excludes a room when broadcasting.
   *
   * @param room
   * @return self
   * @public
   */
  public except(
    room: Room | Room[]
  ): BroadcastOperator<EmitEvents, SocketData> {
    return this.newBroadcastOperator().except(room);
  }

  /**
   * Sends a `message` event.
   *
   * @return self
   * @public
   */
  public send(...args: EventParams<EmitEvents, "message">): this {
    this.emit("message", ...args);
    return this;
  }

  /**
   * Sends a `message` event.
   *
   * @return self
   * @public
   */
  public write(...args: EventParams<EmitEvents, "message">): this {
    this.emit("message", ...args);
    return this;
  }

  /**
   * Writes a packet.
   *
   * @param {Object} packet - packet object
   * @param {Object} opts - options
   * @private
   */
  private packet(
    packet: Omit<Packet, "nsp"> & Partial<Pick<Packet, "nsp">>,
    opts: any = {}
  ): void {
    packet.nsp = this.nsp.name;
    opts.compress = false !== opts.compress;
    this.client._packet(packet as Packet, opts);
  }

  /**
   * Joins a room.
   *
   * @param {String|Array} rooms - room or array of rooms
   * @return a Promise or nothing, depending on the adapter
   * @public
   */
  public join(rooms: Room | Array<Room>): Promise<void> | void {
    debug("join room %s", rooms);

    return this.adapter.addAll(
      this.id,
      new Set(Array.isArray(rooms) ? rooms : [rooms])
    );
  }

  /**
   * Leaves a room.
   *
   * @param {String} room
   * @return a Promise or nothing, depending on the adapter
   * @public
   */
  public leave(room: string): Promise<void> | void {
    debug("leave room %s", room);

    return this.adapter.del(this.id, room);
  }

  /**
   * Leave all rooms.
   *
   * @private
   */
  private leaveAll(): void {
    this.adapter.delAll(this.id);
  }

  /**
   * Called by `Namespace` upon successful
   * middleware execution (ie: authorization).
   * Socket is added to namespace array before
   * call to join, so adapters can access it.
   *
   * @private
   */
  _onconnect(): void {
    debug("socket connected - writing packet");
    this.connected = true;
    this.join(this.id);
    if (this.conn.protocol === 3) {
      this.packet({ type: PacketType.CONNECT });
    } else {
      this.packet({ type: PacketType.CONNECT, data: { sid: this.id } });
    }
  }

  /**
   * Called with each packet. Called by `Client`.
   *
   * @param {Object} packet
   * @private
   */
  _onpacket(packet: Packet): void {
    debug("got packet %j", packet);
    switch (packet.type) {
      case PacketType.EVENT:
        this.onevent(packet);
        break;

      case PacketType.BINARY_EVENT:
        this.onevent(packet);
        break;

      case PacketType.ACK:
        this.onack(packet);
        break;

      case PacketType.BINARY_ACK:
        this.onack(packet);
        break;

      case PacketType.DISCONNECT:
        this.ondisconnect();
        break;

      case PacketType.CONNECT_ERROR:
        this._onerror(new Error(packet.data));
    }
  }

  /**
   * Called upon event packet.
   *
   * @param {Packet} packet - packet object
   * @private
   */
  private onevent(packet: Packet): void {
    const args = packet.data || [];
    debug("emitting event %j", args);

    if (null != packet.id) {
      debug("attaching ack callback to event");
      args.push(this.ack(packet.id));
    }

    if (this._anyListeners && this._anyListeners.length) {
      const listeners = this._anyListeners.slice();
      for (const listener of listeners) {
        listener.apply(this, args);
      }
    }
    this.dispatch(args);
  }

  /**
   * Produces an ack callback to emit with an event.
   *
   * @param {Number} id - packet id
   * @private
   */
  private ack(id: number): () => void {
    const self = this;
    let sent = false;
    return function () {
      // prevent double callbacks
      if (sent) return;
      const args = Array.prototype.slice.call(arguments);
      debug("sending ack %j", args);

      self.packet({
        id: id,
        type: PacketType.ACK,
        data: args,
      });

      sent = true;
    };
  }

  /**
   * Called upon ack packet.
   *
   * @private
   */
  private onack(packet: Packet): void {
    const ack = this.acks.get(packet.id!);
    if ("function" == typeof ack) {
      debug("calling ack %s with %j", packet.id, packet.data);
      ack.apply(this, packet.data);
      this.acks.delete(packet.id!);
    } else {
      debug("bad ack %s", packet.id);
    }
  }

  /**
   * Called upon client disconnect packet.
   *
   * @private
   */
  private ondisconnect(): void {
    debug("got disconnect packet");
    this._onclose("client namespace disconnect");
  }

  /**
   * Handles a client error.
   *
   * @private
   */
  _onerror(err: Error): void {
    if (this.listeners("error").length) {
      this.emitReserved("error", err);
    } else {
      console.error("Missing error handler on `socket`.");
      console.error(err.stack);
    }
  }

  /**
   * Called upon closing. Called by `Client`.
   *
   * @param {String} reason
   * @throw {Error} optional error object
   *
   * @private
   */
  _onclose(reason: string): this | undefined {
    if (!this.connected) return this;
    debug("closing socket - reason %s", reason);
    this.emitReserved("disconnecting", reason);
    this.leaveAll();
    this.nsp._remove(this);
    this.client._remove(this);
    this.connected = false;
    this.emitReserved("disconnect", reason);
    return;
  }

  /**
   * Produces an `error` packet.
   *
   * @param {Object} err - error object
   *
   * @private
   */
  _error(err): void {
    this.packet({ type: PacketType.CONNECT_ERROR, data: err });
  }

  /**
   * Disconnects this client.
   *
   * @param {Boolean} close - if `true`, closes the underlying connection
   * @return {Socket} self
   *
   * @public
   */
  public disconnect(close = false): this {
    if (!this.connected) return this;
    if (close) {
      this.client._disconnect();
    } else {
      this.packet({ type: PacketType.DISCONNECT });
      this._onclose("server namespace disconnect");
    }
    return this;
  }

  /**
   * Sets the compress flag.
   *
   * @param {Boolean} compress - if `true`, compresses the sending data
   * @return {Socket} self
   * @public
   */
  public compress(compress: boolean): this {
    this.flags.compress = compress;
    return this;
  }

  /**
   * Sets a modifier for a subsequent event emission that the event data may be lost if the client is not ready to
   * receive messages (because of network slowness or other issues, or because they’re connected through long polling
   * and is in the middle of a request-response cycle).
   *
   * @return {Socket} self
   * @public
   */
  public get volatile(): this {
    this.flags.volatile = true;
    return this;
  }

  /**
   * Sets a modifier for a subsequent event emission that the event data will only be broadcast to every sockets but the
   * sender.
   *
   * @return {Socket} self
   * @public
   */
  public get broadcast(): BroadcastOperator<EmitEvents, SocketData> {
    return this.newBroadcastOperator();
  }

  /**
   * Sets a modifier for a subsequent event emission that the event data will only be broadcast to the current node.
   *
   * @return {Socket} self
   * @public
   */
  public get local(): BroadcastOperator<EmitEvents, SocketData> {
    return this.newBroadcastOperator().local;
  }

  /**
   * Sets a modifier for a subsequent event emission that the callback will be called with an error when the
   * given number of milliseconds have elapsed without an acknowledgement from the client:
   *
   * ```
   * socket.timeout(5000).emit("my-event", (err) => {
   *   if (err) {
   *     // the client did not acknowledge the event in the given delay
   *   }
   * });
   * ```
   *
   * @returns self
   * @public
   */
  public timeout(timeout: number): this {
    this.flags.timeout = timeout;
    return this;
  }

  /**
   * Dispatch incoming event to socket listeners.
   *
   * @param {Array} event - event that will get emitted
   * @private
   */
  private dispatch(event: Event): void {
    debug("dispatching an event %j", event);
    this.run(event, (err) => {
      process.nextTick(() => {
        if (err) {
          return this._onerror(err);
        }
        if (this.connected) {
          super.emitUntyped.apply(this, event);
        } else {
          debug("ignore packet received after disconnection");
        }
      });
    });
  }

  /**
   * Sets up socket middleware.
   *
   * @param {Function} fn - middleware function (event, next)
   * @return {Socket} self
   * @public
   */
  public use(fn: (event: Event, next: (err?: Error) => void) => void): this {
    this.fns.push(fn);
    return this;
  }

  /**
   * Executes the middleware for an incoming event.
   *
   * @param {Array} event - event that will get emitted
   * @param {Function} fn - last fn call in the middleware
   * @private
   */
  private run(event: Event, fn: (err: Error | null) => void): void {
    const fns = this.fns.slice(0);
    if (!fns.length) return fn(null);

    function run(i: number) {
      fns[i](event, function (err) {
        // upon error, short-circuit
        if (err) return fn(err);

        // if no middleware left, summon callback
        if (!fns[i + 1]) return fn(null);

        // go on to next
        run(i + 1);
      });
    }

    run(0);
  }

  /**
   * Whether the socket is currently disconnected
   */
  public get disconnected() {
    return !this.connected;
  }

  /**
   * A reference to the request that originated the underlying Engine.IO Socket.
   *
   * @public
   */
  public get request(): IncomingMessage {
    return this.client.request;
  }

  /**
   * A reference to the underlying Client transport connection (Engine.IO Socket object).
   *
   * @public
   */
  public get conn() {
    return this.client.conn;
  }

  /**
   * @public
   */
  public get rooms(): Set<Room> {
    return this.adapter.socketRooms(this.id) || new Set();
  }

  /**
   * Adds a listener that will be fired when any event is emitted. The event name is passed as the first argument to the
   * callback.
   *
   * @param listener
   * @public
   */
  public onAny(listener: (...args: any[]) => void): this {
    this._anyListeners = this._anyListeners || [];
    this._anyListeners.push(listener);
    return this;
  }

  /**
   * Adds a listener that will be fired when any event is emitted. The event name is passed as the first argument to the
   * callback. The listener is added to the beginning of the listeners array.
   *
   * @param listener
   * @public
   */
  public prependAny(listener: (...args: any[]) => void): this {
    this._anyListeners = this._anyListeners || [];
    this._anyListeners.unshift(listener);
    return this;
  }

  /**
   * Removes the listener that will be fired when any event is emitted.
   *
   * @param listener
   * @public
   */
  public offAny(listener?: (...args: any[]) => void): this {
    if (!this._anyListeners) {
      return this;
    }
    if (listener) {
      const listeners = this._anyListeners;
      for (let i = 0; i < listeners.length; i++) {
        if (listener === listeners[i]) {
          listeners.splice(i, 1);
          return this;
        }
      }
    } else {
      this._anyListeners = [];
    }
    return this;
  }

  /**
   * Returns an array of listeners that are listening for any event that is specified. This array can be manipulated,
   * e.g. to remove listeners.
   *
   * @public
   */
  public listenersAny() {
    return this._anyListeners || [];
  }

  private newBroadcastOperator(): BroadcastOperator<EmitEvents, SocketData> {
    const flags = Object.assign({}, this.flags);
    this.flags = {};
    return new BroadcastOperator(
      this.adapter,
      new Set<Room>(),
      new Set<Room>([this.id]),
      flags
    );
  }
}