dolittle/DotNET.SDK

View on GitHub
Source/Services/ReverseCallClient.cs

Summary

Maintainability
C
1 day
Test Coverage
F
0%
// Copyright (c) Dolittle. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading;
using System.Threading.Tasks;
using Dolittle.SDK.DependencyInversion;
using Dolittle.SDK.Execution;
using Dolittle.SDK.Protobuf;
using Dolittle.SDK.Tenancy;
using Dolittle.Services.Contracts;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
using Microsoft.Extensions.Logging;
using ExecutionContext = Dolittle.SDK.Execution.ExecutionContext;

namespace Dolittle.SDK.Services;

/// <summary>
/// Represents an implementation of <see cref="IReverseCallClient{TConnectArguments, TConnectResponse, TRequest, TResponse}"/>.
/// </summary>
/// <typeparam name="TClientMessage">Type of the <see cref="IMessage">messages</see> that is sent from the client to the server.</typeparam>
/// <typeparam name="TServerMessage">Type of the <see cref="IMessage">messages</see> that is sent from the server to the client.</typeparam>
/// <typeparam name="TConnectArguments">Type of the arguments that are sent along with the initial Connect call.</typeparam>
/// <typeparam name="TConnectResponse">Type of the response that is received after the initial Connect call.</typeparam>
/// <typeparam name="TRequest">Type of the requests sent from the server to the client using.</typeparam>
/// <typeparam name="TResponse">Type of the responses received from the client using.</typeparam>
public class ReverseCallClient<TClientMessage, TServerMessage, TConnectArguments, TConnectResponse, TRequest, TResponse>
    : IDisposable, IReverseCallClient<TConnectArguments, TConnectResponse, TRequest, TResponse, TClientMessage>
    where TClientMessage : IMessage
    where TServerMessage : IMessage
    where TConnectArguments : class
    where TConnectResponse : class
    where TRequest : class
    where TResponse : class
{
    readonly IAmAReverseCallProtocol<TClientMessage, TServerMessage, TConnectArguments, TConnectResponse, TRequest, TResponse> _protocol;
    readonly TimeSpan _pingInterval;
    readonly IPerformMethodCalls _caller;
    readonly ExecutionContext _executionContext;
    readonly ITenantScopedProviders _tenantScopedProviders;
    readonly ILogger _logger;
    readonly SemaphoreSlim _writeResponseSemaphore = new(1);
    readonly object _connectLock = new();
    readonly object _handleLock = new();
    IClientStreamWriter<TClientMessage>? _clientToServer;
    IAsyncStreamReader<TServerMessage>? _serverToClient;
    bool _connecting;
    bool _connectionEstablished;
    bool _startedHandling;
    bool _disposed;

    /// <summary>
    /// Initializes a new instance of the <see cref="ReverseCallClient{TClientMessage, TServerMessage, TConnectArguments, TConnectResponse, TRequest, TResponse}"/> class.
    /// </summary>/// <param name="protocol">The the reverse call protocol that will be used to connect to the server.</param>
    /// <param name="pingInterval">The interval at which to request pings from the server to keep the reverse call alive.</param>
    /// <param name="caller">The caller that will be used to perform the method call.</param>
    /// <param name="executionContext">The execution context to use while initiating the reverse call.</param>
    /// <param name="tenantScopedProviders">The <see cref="ITenantScopedProviders"/> for resolving a <see cref="IServiceProvider"/> for a specific <see cref="TenantId"/>.</param>
    /// <param name="logger">The <see cref="ILogger" />.</param>
    public ReverseCallClient(
        IAmAReverseCallProtocol<TClientMessage, TServerMessage, TConnectArguments, TConnectResponse, TRequest, TResponse> protocol,
        TimeSpan pingInterval,
        IPerformMethodCalls caller,
        ExecutionContext executionContext,
        ITenantScopedProviders tenantScopedProviders,
        ILogger logger)
    {
        ThrowIfInvalidPingInterval(pingInterval);
        _protocol = protocol;
        _pingInterval = pingInterval;
        _caller = caller;
        _executionContext = executionContext;
        _tenantScopedProviders = tenantScopedProviders;
        _logger = logger;
    }

    /// <inheritdoc/>
    public TConnectResponse ConnectResponse { get; private set; }

    /// <inheritdoc/>
    public async Task<bool> Connect(TConnectArguments connectArguments, CancellationToken cancellationToken)
    {
        ThrowIfConnecting();
        lock (_connectLock)
        {
            ThrowIfConnecting();
            _connecting = true;
        }

        var streamingCall = _caller.Call(_protocol, cancellationToken);
        _clientToServer = streamingCall.RequestStream;
        _serverToClient = streamingCall.ResponseStream;
        var callContext = CreateReverseCallArgumentsContext();
        _protocol.SetConnectArgumentsContextIn(callContext, connectArguments);
        var connectMessage = _protocol.CreateMessageFrom(connectArguments);

        // ReSharper disable once MethodSupportsCancellation
        await _clientToServer.WriteAsync(connectMessage).ConfigureAwait(false);

        using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
        linkedCts.CancelAfter(_pingInterval.Multiply(3));

        try
        {
            while (await _serverToClient.MoveNext(linkedCts.Token).ConfigureAwait(false))
            {
                var message = _serverToClient.Current;
                var ping = _protocol.GetPingFrom(message);
                var response = _protocol.GetConnectResponseFrom(message);
                if (ping != null)
                {
                    _logger.ReceivedPing();
                    await WritePong(cancellationToken).ConfigureAwait(false);
                }
                else if (response != null)
                {
                    _logger.ReceivedConnectResponse();
                    ConnectResponse = response;
                    _connectionEstablished = true;
                    return true;
                }
                else
                {
                    _logger.ReceivedNonPingOrResponseDuringConnect();
                }

                linkedCts.CancelAfter(_pingInterval.Multiply(3));
            }

            _logger.TimedOutDuringConnect();
            await _clientToServer.CompleteAsync().ConfigureAwait(false);
            return false;
        }
        catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled)
        {
            if (cancellationToken.IsCancellationRequested)
            {
                _logger.CancelledByClientDuringConnect();
            }
            else
            {
                _logger.CancelledByServerDuringConnect();
            }

            return false;
        }
    }

    public Task WriteMessage(TClientMessage message, CancellationToken token)
    {
        // ReSharper disable once MethodSupportsCancellation
        return _clientToServer!.WriteAsync(message);
    }


    /// <inheritdoc/>
    public async Task Handle(IReverseCallHandler<TRequest, TResponse> handler, CancellationToken cancellationToken)
    {
        ThrowIfConnectionNotEstablished();
        ThrowIfAlreadyStartedHandling();
        lock (_handleLock)
        {
            ThrowIfAlreadyStartedHandling();
            _startedHandling = true;
        }

        using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
        linkedCts.CancelAfter(_pingInterval.Multiply(3));
        try
        {
            while (await _serverToClient.MoveNext(linkedCts.Token).ConfigureAwait(false))
            {
                var message = _serverToClient.Current;
                var ping = _protocol.GetPingFrom(message);
                var request = _protocol.GetRequestFrom(message);
                if (ping != null)
                {
                    _logger.ReceivedPing();
                    await WritePong(cancellationToken).ConfigureAwait(false);
                }
                else if (request != null)
                {
                    _ = Task.Run(() => OnReceivedRequest(handler, request, cancellationToken), CancellationToken.None);
                }
                else if (_protocol.IsDisconnectAck(message))
                {
                    // The server has acknowledged our disconnect request, and completed in-flight requests.
                    // We can now safely complete the stream.
                    _logger.ReceivedDisconnectAck();
                    await _clientToServer!.CompleteAsync().ConfigureAwait(false);
                    return;
                }
                else
                {
                    _logger.ReceivedNonPingOrRequestDuringHandling();
                }

                linkedCts.CancelAfter(_pingInterval.Multiply(3));
            }
        }
        catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled)
        {
            if (cancellationToken.IsCancellationRequested)
            {
                _logger.CancelledByClientDuringHandling();
                return;
            }

            if (linkedCts.IsCancellationRequested || cancellationToken.IsCancellationRequested)
            {
                throw new PingTimedOut(_pingInterval);
            }

            _logger.CancelledByServerDuringHandling();
        }
    }

    /// <inheritdoc/>
    public void Dispose()
    {
        Dispose(true);
        GC.SuppressFinalize(this);
    }

    /// <summary>
    /// Dispose the managed and unmanaged resources.
    /// </summary>
    /// <param name="disposing">Whether to dispose.</param>
    protected virtual void Dispose(bool disposing)
    {
        if (_disposed)
        {
            return;
        }

        if (disposing)
        {
            _writeResponseSemaphore.Dispose();
        }

        _disposed = true;
    }

    ReverseCallArgumentsContext CreateReverseCallArgumentsContext()
        => new()
        {
            HeadId = Guid.NewGuid().ToProtobuf(),
            ExecutionContext = _executionContext.ToProtobuf(),
            PingInterval = Duration.FromTimeSpan(_pingInterval),
        };

    async Task WritePong(CancellationToken cancellationToken)
    {
        await _writeResponseSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
        try
        {
            if (cancellationToken.IsCancellationRequested)
            {
                _logger.CancelledBeforePongCouldBeSent();
                return;
            }

            var message = _protocol.CreateMessageFrom(new Pong());

            _logger.WritingPong();
            // ReSharper disable once MethodSupportsCancellation
            await _clientToServer.WriteAsync(message).ConfigureAwait(false);
        }
        finally
        {
            _writeResponseSemaphore.Release();
        }
    }

    async Task OnReceivedRequest(IReverseCallHandler<TRequest, TResponse> handler, TRequest request, CancellationToken cancellationToken)
    {
        try
        {
            var requestContext = _protocol.GetRequestContextFrom(request);
            var callId = requestContext.CallId.ToGuid();

            TResponse response;
            try
            {
                _logger.HandlingRequest(callId);

                var executionContext = _executionContext
                    .ForTenant(requestContext.ExecutionContext.TenantId.To<TenantId>())
                    .ForCorrelation(requestContext.ExecutionContext.CorrelationId.To<CorrelationId>());

                response = await handler.Handle(request, executionContext, _tenantScopedProviders.ForTenant(executionContext.Tenant), cancellationToken)
                    .ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                _logger.ErrorWhileInvokingHandlerFor(callId, ex);
                return;
            }

            await _writeResponseSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
            try
            {
                await WriteResponse(response, callId, cancellationToken).ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                _logger.ErrorWhileWritingResponseFor(callId, ex);
            }
            finally
            {
                _writeResponseSemaphore.Release();
            }
        }
        catch (Exception ex)
        {
            _logger.ErrorWhileHandlingRequest(ex);
        }
    }

    Task WriteResponse(TResponse response, Guid callId, CancellationToken cancellationToken)
    {
        var responseContext = new ReverseCallResponseContext { CallId = callId.ToProtobuf() };
        _protocol.SetResponseContextIn(responseContext, response);
        var message = _protocol.CreateMessageFrom(response);
        if (!cancellationToken.IsCancellationRequested)
        {
            _logger.WritingResponseFor(callId);
            return _clientToServer.WriteAsync(message);
        }

        _logger.CancelledWhileWritingResponseFor(callId);

        return Task.CompletedTask;
    }

    static void ThrowIfInvalidPingInterval(TimeSpan pingInterval)
    {
        if (pingInterval.TotalMilliseconds <= 0)
        {
            throw new PingIntervalNotGreaterThanZero();
        }
    }

    void ThrowIfConnecting()
    {
        if (_connecting)
        {
            throw new ReverseCallClientAlreadyCalledConnect();
        }
    }

    void ThrowIfAlreadyStartedHandling()
    {
        if (_startedHandling)
        {
            throw new ReverseCallClientAlreadyStartedHandling();
        }
    }

    void ThrowIfConnectionNotEstablished()
    {
        if (!_connectionEstablished)
        {
            throw new ReverseCallClientNotConnected();
        }
    }
}