onebeyond/onebeyond-studio-core

View on GitHub
src/OneBeyond.Studio.Application.SharedKernel/Authorization/AuthorizationRequirementBehavior.cs

Summary

Maintainability
A
3 hrs
Test Coverage
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Autofac;
using EnsureThat;
using MediatR;
using Microsoft.Extensions.Logging;
using OneBeyond.Studio.Application.SharedKernel.Exceptions;
using OneBeyond.Studio.Crosscuts.Exceptions;
using OneBeyond.Studio.Crosscuts.Logging;
using OneBeyond.Studio.Domain.SharedKernel.Authorization;

namespace OneBeyond.Studio.Application.SharedKernel.Authorization;

public class AuthorizationRequirementBehavior<TRequest, TResponse>
    : AuthorizationRequirementBehavior
    , IPipelineBehavior<TRequest, TResponse>
    where TRequest : class, IBaseRequest
{
    private readonly ILifetimeScope _container;
    private readonly AuthorizationOptions _authorizationOptions;

    public AuthorizationRequirementBehavior(
        ILifetimeScope container,
        AuthorizationOptions authorizationOptions)
    {
        EnsureArg.IsNotNull(container, nameof(container));
        EnsureArg.IsNotNull(authorizationOptions, nameof(authorizationOptions));

        _container = container;
        _authorizationOptions = authorizationOptions;
    }

    private static readonly ILogger Logger = LogManager.CreateLogger<AuthorizationRequirementBehavior<TRequest, TResponse>>();
    private static readonly ConcurrentDictionary<Type, AuthorizationRequirementHandler> AuthorizationRequirementHandlerWrappers = new();

    public async Task<TResponse> Handle(
        TRequest request,
        RequestHandlerDelegate<TResponse> next,
        CancellationToken cancellationToken)
    {
        EnsureArg.IsNotNull(request, nameof(request));
        EnsureArg.IsNotNull(next, nameof(next));

        var requestType = request.GetType();

        var policies = (AuthorizationPolicyAttribute[])
            Attribute.GetCustomAttributes(requestType, typeof(AuthorizationPolicyAttribute));

        if (!_authorizationOptions.AllowUnattributedRequests
            && policies.Length == 0)
        {
            throw new AuthorizationPolicyMissingException(requestType);
        }

        foreach (var policy in policies)
        {
            var isPolicyMet = false;
            var requirementExceptions = new List<Exception>();

            foreach (var requirementType in policy.RequirementTypes)
            {
                Logger.LogInformation(
                    "Validating authorization requirement {AuthorizationRequirementType} on request {RequestType}",
                    requirementType.Key.FullName,
                    requestType.FullName);

                try
                {
                    var requirementHandlerType = typeof(IAuthorizationRequirementHandler<,>)
                        .MakeGenericType(requirementType.Key, requestType);
                    var requirementKey = new AuthorizationRequirementKey(requirementType.Key, requirementType.Value);
                    var requirement = AuthorizationRequirements.GetOrAdd(
                        requirementKey,
                        (_) => (AuthorizationRequirement)Activator.CreateInstance(
                            requirementType.Key,
                            requirementType.Value.ToArray())!);
                    var requirementHandlerWrapper = AuthorizationRequirementHandlerWrappers.GetOrAdd(
                        requirementType.Key,
                        (_) =>
                        {
                            var requirementHandlerWrapperType = typeof(AuthorizationRequirementHandler<>)
                                .MakeGenericType(typeof(TRequest), typeof(TResponse), requirementType.Key);
                            return (AuthorizationRequirementHandler)Activator.CreateInstance(
                                requirementHandlerWrapperType)!;
                        });

                    var requirementHandler = _container.Resolve(requirementHandlerType);

                    await requirementHandlerWrapper.HandleAsync(
                        requirementHandler,
                        requirement,
                        request,
                        cancellationToken).ConfigureAwait(false);

                    Logger.LogInformation(
                        "Authorization requirement {AuthorizationRequirementType} is met on request {RequestType}",
                        requirementType.Key.FullName,
                        requestType.FullName);

                    isPolicyMet = true;
                    break;
                }
                catch (Exception exception) when (!exception.IsCritical())
                {
                    requirementExceptions.Add(exception);
                }
            }

            if (!isPolicyMet)
            {
                throw new AuthorizationPolicyFailedException(policy, requestType, requirementExceptions);
            }
        }

        return await next().ConfigureAwait(false);
    }

    private abstract class AuthorizationRequirementHandler
    {
        public abstract Task HandleAsync(
            object requirementHandler,
            AuthorizationRequirement requirement,
            TRequest request,
            CancellationToken cancellationToken);
    }

    private sealed class AuthorizationRequirementHandler<TRequirement> : AuthorizationRequirementHandler
        where TRequirement : AuthorizationRequirement
    {
        public override Task HandleAsync(
            object requirementHandler,
            AuthorizationRequirement requirement,
            TRequest request,
            CancellationToken cancellationToken)
        {
            return ((IAuthorizationRequirementHandler<TRequirement, TRequest>)requirementHandler)
                .HandleAsync((TRequirement)requirement, request, cancellationToken);
        }
    }
}

public abstract class AuthorizationRequirementBehavior
{
    protected static ConcurrentDictionary<AuthorizationRequirementKey, AuthorizationRequirement> AuthorizationRequirements { get; } = new();

    protected struct AuthorizationRequirementKey : IEquatable<AuthorizationRequirementKey>
    {
        private int _hashCode;
        private readonly Type _type;
        private readonly IReadOnlyCollection<object> _args;

        public AuthorizationRequirementKey(
            Type authorizationRequirementType,
            IReadOnlyCollection<object> authorizationRequirementArgs)
        {
            _hashCode = 0;
            _type = authorizationRequirementType;
            _args = authorizationRequirementArgs;
        }

        public override int GetHashCode()
        {
            if (_hashCode == 0)
            {
                _hashCode = _type.GetHashCode();
                foreach (var arg in _args)
                {
                    _hashCode = _hashCode * 31 + arg.GetHashCode();
                }
            }
            return _hashCode;
        }

        public bool Equals(AuthorizationRequirementKey other)
            => _type.Equals(other._type)
            && _args.SequenceEqual(other._args);

        public override bool Equals(object? obj)
            => obj is AuthorizationRequirementKey key && Equals(key);

        public static bool operator ==(AuthorizationRequirementKey left, AuthorizationRequirementKey right)
            => left.Equals(right);

        public static bool operator !=(AuthorizationRequirementKey left, AuthorizationRequirementKey right)
            => !(left == right);
    }
}