Mordil/RediStack

View on GitHub
Sources/RediStack/ChannelHandlers/RedisPubSubHandler.swift

Summary

Maintainability
A
0 mins
Test Coverage
//===----------------------------------------------------------------------===//
//
// This source file is part of the RediStack open source project
//
// Copyright (c) 2020-2022 RediStack project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of RediStack project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore

/// The possible events that are received from Redis Pub/Sub channels.
public enum RedisPubSubEvent {
    /// The available sources of Pub/Sub unsubscribe events.
    public enum UnsubscribeEventSource {
        /// The client sent an unsubscribe command either as UNSUBSCRIBE or PUNSUBSCRIBE.
        case userInitiated
        /// The client encountered an error and had to unsubscribe.
        /// - Parameter _: The error the client encountered.
        case clientError(_ error: Error)
    }

    /// The connection has been subscribed to a channel.
    ///
    /// This event should only be received once, before receiving messages.
    /// - Parameters:
    ///     - key: The subscribed channel or pattern that was subscribed to.
    ///     - currentSubscriptionCount: The current total number of subscriptions the connection has after subscribing.
    case subscribed(key: String, currentSubscriptionCount: Int)
    /// The connection has been unsubscribed from a channel.
    ///
    /// This event should only be received once, after all messages received have been processed, with no further messages being received.
    /// - Parameters:
    ///     - key: The subscribed channel or pattern that was unsubscribed from.
    ///     - currentSubscriptionCount: The current total number of subscriptions the connection has after unsubscribing.
    ///     - source: The source of the unsubscribe event.
    case unsubscribed(key: String, currentSubscriptionCount: Int, source: UnsubscribeEventSource)
    /// The connection has received a message on the given channel.
    ///
    /// This event can be received an infinite number of times, until the connection has unsubscribed from the channel.
    /// - Parameters:
    ///     - publisher: The name of the channel that published the message.
    ///     - message: The message data that was received from the `publisher`.
    case message(publisher: RedisChannelName, message: RESPValue)
}

/// A closure receiver of individual Pub/Sub events from Redis subscriptions to channels and patterns.
/// - Warning: The receiver is called on the same `NIO.EventLoop` that processed the message (the `EventLoop` of the `NIO.ChannelPipeline`).
///
/// If you are doing non-trivial work in response to PubSub messages, it is **highly recommended** that the work
/// be dispatched to another thread, so as to not block further messages from being processed.
/// - Parameter event: The event that the connection is responding to.
public typealias RedisPubSubEventReceiver = (_ event: RedisPubSubEvent) -> Void

/// A list of patterns or channels that a Pub/Sub subscription change is targetting.
///
/// See ``RedisChannelName`` or the Redis documentation on [PSUBSCRIBE](https://redis.io/commands/psubscribe) and [SUBSCRIBE](https://redis.io/commands/subscribe).
///
/// Use the `values` property to quickly access the underlying list of the target for any purpose that requires a the `String` values.
public enum RedisSubscriptionTarget: Equatable, CustomDebugStringConvertible {
    case channels([RedisChannelName])
    case patterns([String])

    public var values: [String] {
        switch self {
        case let .channels(names): return names.map { $0.rawValue }
        case let .patterns(values): return values
        }
    }
    
    public var debugDescription: String {
        let values = self.values.joined(separator: ", ")
        switch self {
        case .channels: return "Channels '\(values)'"
        case .patterns: return "Patterns '\(values)'"
        }
    }
    
    public static func ==(lhs: RedisSubscriptionTarget, rhs: RedisSubscriptionTarget) -> Bool {
        switch (lhs, rhs) {
        case let (.channels(left), .channels(right)): return left == right
        case let (.patterns(left), .patterns(right)): return left == right
        default: return false
        }
    }
}

/// A channel handler that stores a map of closures and channel or pattern names subscribed to in Redis using Pub/Sub.
public final class RedisPubSubHandler {
    private var state: State = .default

    // each key in the following maps _must_ be prefixed as there can be clashes between patterns and channel names

    /// A map of channel names or patterns and their respective event registration.
    private var subscriptions: [String: Subscription]
    /// A queue of subscribe changes awaiting notification of completion.
    private var pendingSubscribes: PendingSubscriptionChangeQueue
    /// A queue of unsubscribe changes awaiting notification of completion.
    private var pendingUnsubscribes: PendingSubscriptionChangeQueue
    
    private let eventLoop: EventLoop
    
    // we need to be extra careful not to use this context before we know we've initialized
    private var context: ChannelHandlerContext!
    
    /// - Parameters:
    ///     - eventLoop: The event loop the `NIO.Channel` that this handler was added to is bound to.
    ///     - queueCapacity: The initial capacity of queues used for processing subscription changes. The initial value is `3`.
    ///
    ///         Unless you are subscribing and unsubscribing from a large volume of channels or patterns at a single time,
    ///         such as a single SUBSCRIBE call, you do not need to modify this value.
    public init(eventLoop: EventLoop, initialSubscriptionQueueCapacity queueCapacity: Int = 3) {
        self.eventLoop = eventLoop
        self.subscriptions = [:]
        self.pendingSubscribes = [:]
        self.pendingUnsubscribes = [:]
        
        self.pendingSubscribes.reserveCapacity(queueCapacity)
        self.pendingUnsubscribes.reserveCapacity(queueCapacity)
    }
}

// MARK: PubSub Message Handling

extension RedisPubSubHandler {
    private func handleSubscribeMessage(
        withSubscriptionKey subscriptionKey: String,
        reportedSubscriptionCount subscriptionCount: Int,
        keyPrefix: String
    ) {
        let prefixedKey = self.prefixKey(subscriptionKey, with: keyPrefix)
        
        defer { self.pendingSubscribes.removeValue(forKey: prefixedKey)?.succeed(subscriptionCount) }
        
        guard let subscription = self.subscriptions[prefixedKey] else { return }

        subscription.onEvent(.subscribed(key: subscriptionKey, currentSubscriptionCount: subscriptionCount))
        self.subscriptions[prefixedKey] = subscription
        
        subscription.type.gauge.increment()
    }
    
    private func handleUnsubscribeMessage(
        withSubscriptionKey subscriptionKey: String,
        reportedSubscriptionCount subscriptionCount: Int,
        unsubscribeFromAllKey: String,
        keyPrefix: String
    ) {
        let prefixedKey = self.prefixKey(subscriptionKey, with: keyPrefix)
        guard let subscription = self.subscriptions.removeValue(forKey: prefixedKey) else { return }

        subscription.onEvent(.unsubscribed(
            key: subscriptionKey,
            currentSubscriptionCount: subscriptionCount,
            source: .userInitiated
        ))
        subscription.type.gauge.decrement()

        switch self.pendingUnsubscribes.removeValue(forKey: prefixedKey) {
        // we found a specific pattern/channel was being removed, so just fulfill the notification
        case let .some(promise):
            promise.succeed(subscriptionCount)
            
        // if one wasn't found, this means a [p]unsubscribe all was issued
        case .none:
            // and we want to wait for the subscription count to be 0 before we resolve it's notification
            // this count may be from what Redis reports, or the count of subscriptions for this particular type
            guard
                subscriptionCount == 0 || self.subscriptions.count(where: { $0.type == subscription.type }) == 0
            else { return }
            // always report back the count according to Redis, it is the source of truth
            self.pendingUnsubscribes.removeValue(forKey: unsubscribeFromAllKey)?.succeed(subscriptionCount)
        }
    }
    
    private func handleMessage(
        _ message: RESPValue,
        from channel: RedisChannelName,
        withSubscriptionKey subscriptionKey: String,
        keyPrefix: String
    ) {
        guard let subscription = self.subscriptions[self.prefixKey(subscriptionKey, with: keyPrefix)] else { return }
        subscription.onEvent(.message(publisher: channel, message: message))
        RedisMetrics.subscriptionMessagesReceivedCount.increment()
    }
}

// MARK: Subscription Management

extension RedisPubSubHandler {
    /// Registers the provided subscription event handler to receive events from the specified subscription target.
    /// - Important: Any previously registered receiver will be replaced and not notified.
    /// - Parameters:
    ///     - target: The channels or patterns that the receiver should receive messages for.
    ///     - receiver: The closure that receives any future pub/sub events.
    /// - Returns: A `NIO.EventLoopFuture` that resolves the number of subscriptions the client has after the subscription has been added.
    public func addSubscription(
        for target: RedisSubscriptionTarget,
        receiver: @escaping RedisPubSubEventReceiver
    ) -> EventLoopFuture<Int> {
        guard self.eventLoop.inEventLoop else {
            return self.eventLoop.flatSubmit {
                return self.addSubscription(for: target, receiver: receiver)
            }
        }

        switch self.state {
        case .removed: return self.eventLoop.makeFailedFuture(RedisClientError.subscriptionModeRaceCondition)

        case let .error(e): return self.eventLoop.makeFailedFuture(e)
            
        case .default:
            // go through all the target patterns/names and update the map with the new receiver if it's already registered
            // if it was a new registration, not an update, we keep that name to send to Redis
            // we do this so that we save on data transfer bandwidth

            let newSubscriptionTargets = target.values
                .compactMap { (targetKey) -> String? in
                    let subscription = Subscription(type: target.subscriptionType, eventReceiver: receiver)
                    let prefixedKey = self.prefixKey(targetKey, with: target.keyPrefix)
                    guard self.subscriptions.updateValue(subscription, forKey: prefixedKey) == nil else { return nil }
                    return targetKey
                }

            // if there aren't any new actual subscriptions,
            // then we just short circuit and return our local count of subscriptions
            guard !newSubscriptionTargets.isEmpty else {
                return self.eventLoop.makeSucceededFuture(self.subscriptions.count)
            }

            return self.sendSubscriptionChange(
                subscriptionChangeKeyword: target.subscribeKeyword,
                subscriptionTargets: newSubscriptionTargets,
                queue: \.pendingSubscribes,
                keyPrefix: target.keyPrefix
            )
        }
    }

    /// Removes the provided target as a subscription, stopping future messages from being received.
    /// - Parameter target: The channel or pattern that a receiver should be removed for.
    /// - Returns: A `NIO.EventLoopFuture` that resolves the number of subscriptions the client has after the subscription has been removed.
    public func removeSubscription(for target: RedisSubscriptionTarget) -> EventLoopFuture<Int> {
        guard self.eventLoop.inEventLoop else {
            return self.eventLoop.flatSubmit { self.removeSubscription(for: target) }
        }

        // if we're not in our default state,
        // this essentially is a no-op because an error triggers all receivers to be removed
        guard case .default = self.state else { return self.eventLoop.makeSucceededFuture(0) }

        // we send the UNSUBSCRIBE message to Redis,
        // and in the response we handle the actual removal of the receiver closure

        // if there are no channels / patterns specified,
        // then this is a special case of unsubscribing from all patterns / channels
        guard !target.values.isEmpty else {
            return self.unsubscribeAll(for: target)
        }
        
        return self.sendSubscriptionChange(
            subscriptionChangeKeyword: target.unsubscribeKeyword,
            subscriptionTargets: target.values,
            queue: \.pendingUnsubscribes,
            keyPrefix: target.keyPrefix
        )
    }
    
    private func sendSubscriptionChange(
        subscriptionChangeKeyword keyword: String,
        subscriptionTargets targets: [String],
        queue pendingQueue: ReferenceWritableKeyPath<RedisPubSubHandler, PendingSubscriptionChangeQueue>,
        keyPrefix: String
    ) -> EventLoopFuture<Int> {
        self.eventLoop.assertInEventLoop()
        
        var command = [RESPValue(bulk: keyword)]
        command.append(convertingContentsOf: targets)
        
        // the command does not respond in a normal command response fashion of the end count of subscriptions
        // after all of them have been established (or removed)
        //
        // instead, it replies with a subscribe/unsubscribe message for each channel/pattern that was sent
        //
        // so we have to create a top-level future that synchronizes all of the responses
        // where we take the last response from Redis as the count of active subscriptions
        
        // create them
        let pendingSubscriptions: [(String, EventLoopPromise<Int>)] = targets.map {
            return (self.prefixKey($0, with: keyPrefix), self.eventLoop.makePromise())
        }
        // add the subscription change handler to the appropriate queue for each individual subscription target
        pendingSubscriptions.forEach { self[keyPath: pendingQueue].updateValue($1, forKey: $0) }

        // synchronize all of the individual subscription changes
        let subscriptionCountFuture = EventLoopFuture<Int>
            .whenAllComplete(
                pendingSubscriptions.map { $0.1.futureResult },
                on: self.eventLoop
            )
            .flatMapThrowing { (results) -> Int in
                // trust the last success response as the most current count
                guard let latestSubscriptionCount = results
                    .lazy
                    .reversed() // reverse to save time-complexity, as we just need the last (first) successful value
                    .compactMap({ try? $0.get() })
                    .first
                // if we have no success cases, we will still have at least one response that we can
                // rely on the 'get' method to throw the error for us, rather than unwrapping it ourselves
                else { return try results.first!.get() }

                return latestSubscriptionCount
            }
        
        return self.context
            .writeAndFlush(self.wrapOutboundOut(.array(command)))
            .flatMap { return subscriptionCountFuture }
    }

    private func unsubscribeAll(for target: RedisSubscriptionTarget) -> EventLoopFuture<Int> {
        let command = [RESPValue(bulk: target.unsubscribeKeyword)]

        let promise = self.context.eventLoop.makePromise(of: Int.self)
        self.pendingUnsubscribes.updateValue(promise, forKey: target.unsubscribeAllKey)

        return self.context
            .writeAndFlush(self.wrapOutboundOut(.array(command)))
            .flatMap { promise.futureResult }
    }
}

// MARK: ChannelHandler

extension RedisPubSubHandler {
    public func handlerAdded(context: ChannelHandlerContext) {
        self.context = context
    }

    public func handlerRemoved(context: ChannelHandlerContext) {
        self.context = nil // break ref cycles
    }
}

// MARK: RemoveableChannelHandler

extension RedisPubSubHandler: RemovableChannelHandler {
    public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
        // update our state and leave immediately so we don't get any more subscription requests
        self.state = .removed
        context.leavePipeline(removalToken: removalToken)
        // "close" all subscription handlers
        self.removeAllReceivers()
    }
}

// MARK: ChannelInboundHandler

extension RedisPubSubHandler: ChannelInboundHandler {
    public typealias InboundIn = RESPValue
    public typealias InboundOut = RESPValue
    
    public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
        let value = self.unwrapInboundIn(data)

        // check to see if the value is in the expected PubSub message format
        // if it isn't, then we forward on to the next handler to be treated as a normal command response
        // if it is, we handle it here

        // Redis defines the format as [messageKeyword: String, channelName: String, message: RESPValue]
        // unless the messageType is 'pmessage', in which case it's [messageKeyword, pattern: String, channelName, message]

        // these guards extract some of the basic details of a pubsub message
        guard
            let array = value.array,
            array.count >= 3,
            let channelOrPattern = array[1].string,
            let messageKeyword = array[0].string
        else {
            context.fireChannelRead(data)
            return
        }
        
        // safe because the array is guaranteed from the guard above to have at least 3 elements
        // and it is NOT to be used until we match the PubSub message keyword
        let message = array.last!
        
        // the last check is to match one of the known pubsub message keywords
        // if we have a match, we're definitely in a pubsub message and we should handle it

        switch messageKeyword {
        case "message":
            self.handleMessage(
                message,
                from: .init(channelOrPattern),
                withSubscriptionKey: channelOrPattern,
                keyPrefix: kSubscriptionKeyPrefixChannel
            )

        
        case "pmessage":
            self.handleMessage(
                message,
                from: .init(array[2].string!), // the channel name is stored as the 3rd element in the array in 'pmessage' streams
                withSubscriptionKey: channelOrPattern,
                keyPrefix: kSubscriptionKeyPrefixPattern
            )

        // if the message keyword is for subscribing or unsubscribing,
        // the message is guaranteed to be the count of subscriptions the connection still has
        case "subscribe":
            self.handleSubscribeMessage(
                withSubscriptionKey: channelOrPattern,
                reportedSubscriptionCount: message.int!,
                keyPrefix: kSubscriptionKeyPrefixChannel
            )
            
        case "psubscribe":
            self.handleSubscribeMessage(
                withSubscriptionKey: channelOrPattern,
                reportedSubscriptionCount: message.int!,
                keyPrefix: kSubscriptionKeyPrefixPattern
            )

        case "unsubscribe":
            self.handleUnsubscribeMessage(
                withSubscriptionKey: channelOrPattern,
                reportedSubscriptionCount: message.int!,
                unsubscribeFromAllKey: kUnsubscribeAllChannelsKey,
                keyPrefix: kSubscriptionKeyPrefixChannel
            )
            
        case "punsubscribe":
            self.handleUnsubscribeMessage(
                withSubscriptionKey: channelOrPattern,
                reportedSubscriptionCount: message.int!,
                unsubscribeFromAllKey: kUnsubscribeAllPatternsKey,
                keyPrefix: kSubscriptionKeyPrefixPattern
            )
            
        // if we don't have a match, fire a channel read to forward to the next handler
        default: context.fireChannelRead(data)
        }
    }
    
    public func errorCaught(context: ChannelHandlerContext, error: Error) {
        self.removeAllReceivers(because: error)
        context.fireErrorCaught(error)
    }
    
    public func channelInactive(context: ChannelHandlerContext) {
        self.removeAllReceivers(because: RedisClientError.connectionClosed)
        context.fireChannelInactive()
    }
    
    private func removeAllReceivers(because error: Error? = nil) {
        error.map { self.state = .error($0) }
        
        let receivers = self.subscriptions
        self.subscriptions.removeAll()
        receivers.forEach {
            let source: RedisPubSubEvent.UnsubscribeEventSource = error.map { .clientError($0) } ?? .userInitiated
            $0.value.onEvent(.unsubscribed(key: $0.key, currentSubscriptionCount: 0, source: source))
            $0.value.type.gauge.decrement()
        }
    }
}

// MARK: ChannelOutboundHandler

extension RedisPubSubHandler: ChannelOutboundHandler {
    public typealias OutboundIn = RESPValue
    public typealias OutboundOut = RESPValue
    
    // the pub/sub handler is a transparent outbound handler
    // we only conform to the protocol so we're appropriately placed in the pipeline
    // to bypass the command handler for pub/sub subscription changes
}

// MARK: Private Types

// keys used for the pendingUnsubscribes
private let kUnsubscribeAllChannelsKey = "__RS_ALL_CHS"
private let kUnsubscribeAllPatternsKey = "__RS_ALL_PNS"

fileprivate enum SubscriptionType {
    case channel, pattern
    
    var gauge: RedisMetrics.IncrementalGauge {
        switch self {
        case .channel: return RedisMetrics.activeChannelSubscriptions
        case .pattern: return RedisMetrics.activePatternSubscriptions
        }
    }
}

extension RedisPubSubHandler {
    private typealias PendingSubscriptionChangeQueue = [String: EventLoopPromise<Int>]

    fileprivate final class Subscription {
        let type: SubscriptionType
        let onEvent: RedisPubSubEventReceiver
        
        init(type: SubscriptionType, eventReceiver: @escaping RedisPubSubEventReceiver) {
            self.type = type
            self.onEvent = eventReceiver
        }
    }

    private enum State {
        case `default`, removed, error(Error)
    }
}

// MARK: Subscription Management Helpers

private let kSubscriptionKeyPrefixChannel = "__RS_CS"
private let kSubscriptionKeyPrefixPattern = "__RS_PS"

extension RedisPubSubHandler {
    private func prefixKey(_ key: String, with prefix: String) -> String { "\(prefix)_\(key)" }
}

extension RedisSubscriptionTarget {
    fileprivate var unsubscribeAllKey: String {
        switch self {
        case .channels: return kUnsubscribeAllChannelsKey
        case .patterns: return kUnsubscribeAllPatternsKey
        }
    }

    fileprivate var keyPrefix: String {
        switch self {
        case .channels: return kSubscriptionKeyPrefixChannel
        case .patterns: return kSubscriptionKeyPrefixPattern
        }
    }

    fileprivate var subscriptionType: SubscriptionType {
        switch self {
        case .channels: return .channel
        case .patterns: return .pattern
        }
    }
    
    fileprivate var subscribeKeyword: String {
        switch self {
        case .channels: return "SUBSCRIBE"
        case .patterns: return "PSUBSCRIBE"
        }
    }
    fileprivate var unsubscribeKeyword: String {
        switch self {
        case .channels: return "UNSUBSCRIBE"
        case .patterns: return "PUNSUBSCRIBE"
        }
    }
}

extension Dictionary where Key == String, Value == RedisPubSubHandler.Subscription {
    func count(where isIncluded: (Value) -> Bool) -> Int {
        self.reduce(into: 0) {
            guard isIncluded($1.value) else { return }
            $0 += 1
        }
    }
}