onebeyond/onebeyond-studio-core

View on GitHub
src/OneBeyond.Studio.DataAccess.EFCore/Repositories/BaseRWRepository.cs

Summary

Maintainability
B
6 hrs
Test Coverage
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using DelegateDecompiler.EntityFrameworkCore;
using EnsureThat;
using Microsoft.EntityFrameworkCore;
using Nito.AsyncEx;
using OneBeyond.Studio.Application.SharedKernel.DataAccessPolicies;
using OneBeyond.Studio.Application.SharedKernel.Repositories;
using OneBeyond.Studio.Application.SharedKernel.Repositories.Exceptions;
using OneBeyond.Studio.Application.SharedKernel.Specifications;
using OneBeyond.Studio.DataAccess.EFCore.Projections;
using OneBeyond.Studio.Domain.SharedKernel.Entities;

namespace OneBeyond.Studio.DataAccess.EFCore.Repositories;

/// <summary>
/// Implements base read-write repository for the <typeparamref name="TAggregateRoot"/> type using EF Core.
/// All the entities returned by this repository are trackable.
/// </summary>
/// <typeparam name="TDbContext"></typeparam>
/// <typeparam name="TAggregateRoot"></typeparam>
/// <typeparam name="TAggregateRootId"></typeparam>
public class BaseRWRepository<TDbContext, TAggregateRoot, TAggregateRootId>
    : BaseRORepository<TDbContext, TAggregateRoot, TAggregateRootId>
    , IRWRepository<TAggregateRoot, TAggregateRootId>
    where TDbContext : DbContext
    where TAggregateRoot : AggregateRoot<TAggregateRootId>
    where TAggregateRootId : notnull
{
    public BaseRWRepository(
        TDbContext dbContext,
        IRWDataAccessPolicyProvider<TAggregateRoot> rwDataAccessPolicyProvider,
        IEntityTypeProjections<TAggregateRoot> entityTypeProjections)
        : base(dbContext, rwDataAccessPolicyProvider, entityTypeProjections)
    {
        EnsureCreateDataAccessPolicy = new AsyncLazy<Action<TAggregateRoot>>(
            async () =>
                CompileDataAccessFunction(
                    await rwDataAccessPolicyProvider
                        .GetCreateDataAccessPolicyAsync()
                        .ConfigureAwait(false)),
            AsyncLazyFlags.RetryOnFailure);
        EnsureUpdateDataAccessPolicy = new AsyncLazy<Action<TAggregateRoot>>(
            async () =>
                CompileDataAccessFunction(
                    await rwDataAccessPolicyProvider
                        .GetUpdateDataAccessPolicyAsync()
                        .ConfigureAwait(false)),
            AsyncLazyFlags.RetryOnFailure);
        EnsureDeleteDataAccessPolicy = new AsyncLazy<Action<TAggregateRoot>>(
            async () =>
                CompileDataAccessFunction(
                    await rwDataAccessPolicyProvider
                        .GetDeleteDataAccessPolicyAsync()
                        .ConfigureAwait(false)),
            AsyncLazyFlags.RetryOnFailure);
    }

    protected AsyncLazy<Action<TAggregateRoot>> EnsureCreateDataAccessPolicy { get; }
    protected AsyncLazy<Action<TAggregateRoot>> EnsureUpdateDataAccessPolicy { get; }
    protected AsyncLazy<Action<TAggregateRoot>> EnsureDeleteDataAccessPolicy { get; }

    public new async Task<TAggregateRoot> GetByIdAsync(
        TAggregateRootId aggregateRootId,
        Includes<TAggregateRoot>? includes,
        CancellationToken cancellationToken)
    {
        var query = await BuildGetQueryAsync(
            (aggregateRoot) => aggregateRoot.Id!.Equals(aggregateRootId),
            includes).ConfigureAwait(false);
        var entity = await query.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
        if (entity is null)
        {
            await EnsureEntityExistsAsync(
                aggregateRootId,
                (aggregateRoot) => aggregateRoot.Id!.Equals(aggregateRootId),
                cancellationToken).ConfigureAwait(false);
        }
        return entity!;
    }

    public async Task<TAggregateRoot> GetByFilterAsync(
        Expression<Func<TAggregateRoot, bool>> filter,
        Includes<TAggregateRoot>? includes,
        CancellationToken cancellationToken)
    {
        var query = await BuildGetQueryAsync(filter, includes).ConfigureAwait(false);
        var entity = await query.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
        if (entity is null)
        {
            await EnsureEntityExistsAsync(default!, filter, cancellationToken).ConfigureAwait(false);
        }
        return entity!;
    }

    public async Task CreateAsync(TAggregateRoot aggregateRoot, CancellationToken cancellationToken)
    {
        EnsureArg.IsNotNull(aggregateRoot, nameof(aggregateRoot));
        var ensureCreateDataAccessPolicy = await EnsureCreateDataAccessPolicy.Task.ConfigureAwait(false);
        ensureCreateDataAccessPolicy(aggregateRoot);
        DbSet.Value.Add(aggregateRoot);
        await SaveChangesAsync(cancellationToken).ConfigureAwait(false);
    }

    public async Task UpdateAsync(TAggregateRoot aggregateRoot, CancellationToken cancellationToken)
    {
        EnsureArg.IsNotNull(aggregateRoot, nameof(aggregateRoot));
        var ensureUpdateDataAccessPolicy = await EnsureUpdateDataAccessPolicy.Task.ConfigureAwait(false);
        ensureUpdateDataAccessPolicy(aggregateRoot);
        DbSet.Value.Update(aggregateRoot);
        await SaveChangesAsync(cancellationToken).ConfigureAwait(false);
    }

    public async Task DeleteAsync(TAggregateRoot aggregateRoot, CancellationToken cancellationToken)
    {
        EnsureArg.IsNotNull(aggregateRoot, nameof(aggregateRoot));
        var ensureDeleteDataAccessPolicy = await EnsureDeleteDataAccessPolicy.Task.ConfigureAwait(false);
        ensureDeleteDataAccessPolicy(aggregateRoot);
        DbSet.Value.Remove(aggregateRoot);
        await SaveChangesAsync(cancellationToken).ConfigureAwait(false);
    }

    public Task DeleteAsync(TAggregateRootId id, CancellationToken cancellationToken)
    {
        var aggregate = DbSet.Value.Find(id);

        return aggregate == default
            ? Task.CompletedTask
            : DbSet.Value.DeleteByKeyAsync(cancellationToken, id);
    }

    protected virtual Task SaveChangesAsync(CancellationToken cancellationToken)
        => DbContext.SaveChangesAsync(cancellationToken);

    protected async Task<IQueryable<TAggregateRoot>> BuildGetQueryAsync(
        Expression<Func<TAggregateRoot, bool>> filter,
        Includes<TAggregateRoot>? includes)
    {
        var readDataAccessPolicy = await ReadDataAccessPolicy.Task.ConfigureAwait(false);
        var query = ApplyIncludes(DbSet.Value, includes);
        query = ApplyFiltering(query, filter);
        query = ApplyFiltering(query, readDataAccessPolicy?.CanBeAccessedCriteria);
        return query.DecompileAsync();
    }

    private static Action<TAggregateRoot> CompileDataAccessFunction(DataAccessPolicy<TAggregateRoot>? dataAccessPolicy)
    {
        var isDataAccessAllowed1 = dataAccessPolicy?.CanBeAccessedCriteria?.Compile();
        var isDataAccessAllowed2 = dataAccessPolicy?.CanBeAccessedFunction;
        if (isDataAccessAllowed1 is null
            && isDataAccessAllowed2 is null)
        {
            return (aggregateRoot) =>
            {
            };
        }
        if (isDataAccessAllowed1 is not null
            && isDataAccessAllowed2 is not null)
        {
            return (aggregateRoot) =>
            {
                if (isDataAccessAllowed1(aggregateRoot) && isDataAccessAllowed2(aggregateRoot))
                {
                    return;
                }
                throw new EntityAccessDeniedException<TAggregateRoot, TAggregateRootId>(aggregateRoot.Id);
            };
        }
        var isDataAccessAllowed = isDataAccessAllowed1
            ?? isDataAccessAllowed2;
        return (aggregateRoot) =>
        {
            if (isDataAccessAllowed!(aggregateRoot))
            {
                return;
            }
            throw new EntityAccessDeniedException<TAggregateRoot, TAggregateRootId>(aggregateRoot.Id);
        };
    }
}