NeuraLegion/sectester-net

View on GitHub
src/SecTester.Bus/Dispatchers/RmqEventBus.cs

Summary

Maintainability
A
2 hrs
Test Coverage
A
99%
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using SecTester.Bus.Exceptions;
using SecTester.Bus.Extensions;
using SecTester.Core;
using SecTester.Core.Bus;
using SecTester.Core.Utils;

namespace SecTester.Bus.Dispatchers;

public class RmqEventBus : IEventBus
{
  private const string ReplyQueueName = "amq.rabbitmq.reply-to";

  private readonly IRmqConnectionManager _connectionManager;
  private readonly List<Type> _eventTypes = new();
  private readonly Dictionary<string, List<Type>> _handlers = new();
  private readonly ILogger _logger;
  private readonly RmqEventBusOptions _options;
  private readonly ConcurrentDictionary<string, TaskCompletionSource<string>> _pendingMessages = new();
  private readonly IServiceScopeFactory _scopeFactory;
  private IModel _channel;

  public RmqEventBus(RmqEventBusOptions options, IRmqConnectionManager connectionManager, ILogger<RmqEventBus> logger,
    IServiceScopeFactory scopeFactory)
  {
    _options = options ?? throw new ArgumentNullException(nameof(options));
    _connectionManager = connectionManager ?? throw new ArgumentNullException(nameof(connectionManager));
    _logger = logger ?? throw new ArgumentNullException(nameof(logger));
    _scopeFactory = scopeFactory ?? throw new ArgumentNullException(nameof(scopeFactory));
    _channel = CreateConsumerChannel();
  }

  public Task Publish<TEvent>(TEvent message) where TEvent : Event
  {
    _connectionManager.TryConnect();

    SendMessage(new MessageParams<TEvent>
    {
      Payload = message,
      Type = message.Type,
      RoutingKey = message.Type,
      Exchange = _options.Exchange,
      CorrelationId = message.CorrelationId,
      CreatedAt = message.CreatedAt
    });

    return Task.CompletedTask;
  }

  public async Task<TResult?> Execute<TResult>(Command<TResult> message)
  {
    var tcs = new TaskCompletionSource<string>();
    _pendingMessages[message.CorrelationId] = tcs;
    var ct = new CancellationTokenSource(message.Ttl);
    using var _ = ct.Token.Register(() => tcs.TrySetCanceled(), false);

    SendMessage(_channel, new MessageParams<object>
    {
      Payload = message,
      Type = message.Type,
      ReplyTo = ReplyQueueName,
      RoutingKey = _options.AppQueue,
      CorrelationId = message.CorrelationId,
      CreatedAt = message.CreatedAt
    });

    if (!message.ExpectReply)
    {
      return default;
    }

    var result = await tcs.Task.ConfigureAwait(false);

    return MessageSerializer.Deserialize<TResult>(result);
  }

  public void Register<THandler, TEvent, TResult>() where THandler : IEventListener<TEvent, TResult> where TEvent : Event
  {
    var eventName = MessageUtils.GetMessageType<TEvent>();
    var handlerType = typeof(THandler);
    var eventType = typeof(TEvent);

    if (!_handlers.ContainsKey(eventName))
    {
      _eventTypes.Add(eventType);
      _handlers.Add(eventName, new List<Type>());
      BindQueue(eventName);
    }

    _handlers[eventName].Add(handlerType);
  }

  public void Unregister<THandler, TEvent, TResult>() where THandler : IEventListener<TEvent, TResult> where TEvent : Event
  {
    var eventName = MessageUtils.GetMessageType<TEvent>();
    var handlerType = typeof(THandler);
    var eventType = typeof(TEvent);

    if (!_handlers.ContainsKey(eventName))
    {
      throw new NoSubscriptionFoundException(eventName);
    }

    _handlers[eventName].Remove(handlerType);

    if (_handlers[eventName] is { Count: 0 })
    {
      _eventTypes.Remove(eventType);
      _handlers.Remove(eventName);
      UnBindQueue(eventName);
    }
  }

  public void Register<THandler, TEvent>() where THandler : IEventListener<TEvent> where TEvent : Event =>
    Register<THandler, TEvent, Unit>();

  public void Unregister<THandler, TEvent>() where THandler : IEventListener<TEvent> where TEvent : Event =>
    Unregister<THandler, TEvent, Unit>();

  public void Dispose()
  {
    _connectionManager.Dispose();
    GC.SuppressFinalize(this);
  }

  private Task ReplyReceiverHandler(BasicDeliverEventArgs args)
  {
    var data = Encoding.UTF8.GetString(args.Body.ToArray());

    if (!string.IsNullOrEmpty(args.BasicProperties.CorrelationId))
    {
      _logger.LogDebug(
        "Received a reply ({CorrelationId}) with following payload: {Payload}",
        args.BasicProperties.CorrelationId,
        data
      );

      _pendingMessages.TryRemove(args.BasicProperties.CorrelationId!, out var tcs);
      tcs?.SetResult(data);
    }

    return Task.CompletedTask;
  }

  private IModel CreateConsumerChannel()
  {
    _connectionManager.TryConnect();

    var channel = _connectionManager.CreateChannel();
    channel.CallbackException += (_, _) =>
    {
      _channel.Dispose();
      _channel = CreateConsumerChannel();
    };

    BindQueueToExchange(channel);
    StartBasicConsume(channel);
    StartReplyQueueConsume(channel);

    return channel;
  }

  private void StartBasicConsume(IModel channel)
  {
    var consumer = _connectionManager.CreateConsumer(channel);
    consumer.Received += ReceiverHandler;
    channel.BasicConsume(_options.ClientQueue, true, consumer);
  }

  private void StartReplyQueueConsume(IModel channel)
  {
    var consumer = _connectionManager.CreateConsumer(channel);
    consumer.Received += (_, args) => ReplyReceiverHandler(args);
    channel.BasicConsume(ReplyQueueName, true, consumer);
  }

  private void BindQueueToExchange(IModel channel)
  {
    channel.ExchangeDeclare(_options.Exchange, "direct", true);
    channel.QueueDeclare(_options.ClientQueue, exclusive: false, autoDelete: true, durable: true);
    channel.BasicQos(0, _options.PrefetchCount, false);
  }

  private async Task ReceiverHandler(object sender, BasicDeliverEventArgs args)
  {
    if (args.Redelivered)
    {
      return;
    }

    var name = string.IsNullOrEmpty(args.BasicProperties.Type) ? args.RoutingKey : args.BasicProperties.Type;
    var handlers = GetHandlers(name);
    var body = Encoding.UTF8.GetString(args.Body.ToArray());
    var consumedMessage = new ConsumedMessage
    {
      Name = name,
      Payload = body,
      ReplyTo = args.BasicProperties.ReplyTo,
      CorrelationId = args.BasicProperties.CorrelationId
    };

    _logger.LogDebug(
      "Received a event ({Name}) with following payload: {Body}", consumedMessage.Name,
      body
    );

    foreach (var handler in handlers)
    {
      await HandleEvent(handler, consumedMessage).ConfigureAwait(false);
    }
  }

  private List<Type> GetHandlers(string eventName)
  {
    if (!_handlers.ContainsKey(eventName))
    {
      throw new NoSubscriptionFoundException(eventName);
    }

    if (_handlers[eventName] is null or { Count: 0 })
    {
      throw new EventHandlerNotFoundException(eventName);
    }

    return _handlers[eventName];
  }

  private async Task HandleEvent(Type eventHandler, ConsumedMessage consumedMessage)
  {
    try
    {
      var scope = _scopeFactory.CreateAsyncScope();
      await using var _ = scope.ConfigureAwait(false);
      var instance = scope.ServiceProvider.GetService(eventHandler);
      var eventType = GetEventType(consumedMessage.Name!);

      if (instance == null || eventType == null)
      {
        return;
      }

      var concreteType = eventHandler.GetConcreteEventListenerType();
      var payload = MessageSerializer.Deserialize(consumedMessage.Payload!, eventType);
      var task = (Task)concreteType.InvokeMember(nameof(IEventListener<Event>.Handle), BindingFlags.InvokeMethod, null, instance, new[]
      {
        payload
      }, CultureInfo.InvariantCulture);
      var response = await task.Cast<object?>().ConfigureAwait(false);

      if (response != null && !string.IsNullOrEmpty(consumedMessage.ReplyTo))
      {
        SendReplyOnEvent(consumedMessage, response);
      }
    }
    catch (Exception err)
    {
      _logger.LogDebug(err, "Error while processing a message ({CorrelationId}) due to error occurred. Event: {Payload}",
        consumedMessage.CorrelationId, consumedMessage.Payload);
    }
  }

  private void SendReplyOnEvent<T>(ConsumedMessage consumedMessage, T response)
  {
    _logger.LogDebug(
      "Sending a reply ({Event}) back with following payload: {Json}",
      consumedMessage.Name,
      response
    );

    SendMessage(new MessageParams<T>
    {
      Payload = response,
      RoutingKey = consumedMessage.ReplyTo!,
      CorrelationId = consumedMessage.CorrelationId
    });
  }

  private static IBasicProperties CreateMessageProperties(IModel channel, DateTime? createdAt = default)
  {
    var properties = channel.CreateBasicProperties();
    var timestamp = new DateTimeOffset(createdAt ?? DateTime.UtcNow);
    properties.Timestamp = new AmqpTimestamp(timestamp.ToUnixTimeMilliseconds());
    properties.Persistent = true;
    properties.ContentType = "application/json";
    return properties;
  }

  private void SendMessage<T>(IModel channel, MessageParams<T> messageParams)
  {
    var properties = CreateMessageProperties(channel, messageParams.CreatedAt);
    properties.CorrelationId = messageParams.CorrelationId;
    properties.Type = messageParams.Type;
    properties.ReplyTo = messageParams.ReplyTo;

    _logger.LogDebug("Send a message with following parameters: {Params}", messageParams);

    channel.BasicPublish(messageParams.Exchange ?? "",
      messageParams.RoutingKey,
      true,
      properties,
      messageParams.ToBytes());
  }

  private void SendMessage<T>(MessageParams<T> messageParams)
  {
    using var channel = _connectionManager.CreateChannel();
    SendMessage(channel, messageParams);
  }

  private void BindQueue(string eventName)
  {
    _connectionManager.TryConnect();
    using var channel = _connectionManager.CreateChannel();
    channel.QueueBind(_options.ClientQueue,
      _options.Exchange,
      eventName);
  }

  private void UnBindQueue(string eventName)
  {
    _connectionManager.TryConnect();
    using var channel = _connectionManager.CreateChannel();
    channel.QueueUnbind(_options.ClientQueue,
      _options.Exchange,
      eventName);
  }

  private Type? GetEventType(string eventName) => _eventTypes.SingleOrDefault(x => MessageUtils.GetMessageType(x) == eventName);
}