onebeyond/onebeyond-studio-core

View on GitHub
src/OneBeyond.Studio.Crosscuts/Expressions/ExpressionExtensions.cs

Summary

Maintainability
C
7 hrs
Test Coverage
using System;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using EnsureThat;

namespace OneBeyond.Studio.Crosscuts.Expressions;

/// <summary>
/// Extension methods for various types of <see cref="Expression"/>
/// </summary>
public static class ExpressionExtensions
{
    /// <summary>
    /// Replaces all occurences of <paramref name="source"/> paramter by <paramref name="target"/> parameter in the expression.
    /// </summary>
    /// <param name="this"></param>
    /// <param name="source"></param>
    /// <param name="target"></param>
    /// <returns></returns>
    public static TExpression ReplaceParameter<TExpression>(this TExpression @this, ParameterExpression source, ParameterExpression target)
        where TExpression : Expression
    {
        EnsureArg.IsNotNull(@this, nameof(@this));

        var parameterReplacer = new ExpressionReplacer(source, target);
        return (TExpression)parameterReplacer.Visit(@this);
    }

    public static TExpression ReplaceExpression<TExpression>(this TExpression @this, Expression source, Expression target)
        where TExpression : Expression
    {
        EnsureArg.IsNotNull(@this, nameof(@this));

        var expressionReplacer = new ExpressionReplacer(source, target);
        return (TExpression)expressionReplacer.Visit(@this);
    }

    [return: NotNullIfNotNull("operand1"), NotNullIfNotNull("operand2")]
    public static Expression<Func<T, bool>>? And<T>(
        this Expression<Func<T, bool>>? operand1,
        Expression<Func<T, bool>>? operand2)
    {
        if (operand1 is null)
        {
            return operand2;
        }
        if (operand2 is null)
        {
            return operand1;
        }
        operand2 = operand2.ReplaceParameter(operand2.Parameters[0], operand1.Parameters[0]);
        var body = Expression.AndAlso(operand1.Body, operand2.Body);
        return Expression.Lambda<Func<T, bool>>(body, operand1.Parameters);
    }

    [return: NotNullIfNotNull("operand1"), NotNullIfNotNull("operand2")]
    public static Expression<Func<T, bool>>? Or<T>(
        this Expression<Func<T, bool>>? operand1,
        Expression<Func<T, bool>>? operand2)
    {
        if (operand1 is null)
        {
            return operand2;
        }
        if (operand2 is null)
        {
            return operand1;
        }
        operand2 = operand2.ReplaceParameter(operand2.Parameters[0], operand1.Parameters[0]);
        var body = Expression.OrElse(operand1.Body, operand2.Body);
        return Expression.Lambda<Func<T, bool>>(body, operand1.Parameters);
    }

    /// <summary>
    /// Given an expression for a method that takes in a single parameter and
    /// returns a bool, converts the parameter type of the parameter
    /// from TSource to TTarget.
    /// </summary>
    /// <typeparam name="TSource"></typeparam>
    /// <typeparam name="TTarget"></typeparam>
    /// <param name="this"></param>
    /// <returns></returns>
    // Given an expression for a method that takes in a single parameter (and
    // returns a bool), this method converts the parameter type of the parameter
    // from TSource to TTarget.
    public static Expression<Func<TTarget, bool>> ConvertParameterType<TSource, TTarget>(this Expression<Func<TSource, bool>> @this)
    {
        var visitor = new ParameterTypeConverter<TSource, TTarget>();
        return (Expression<Func<TTarget, bool>>)visitor.Visit(@this);
    }

    private sealed class ParameterTypeConverter<TSource, TTarget> : ExpressionVisitor
    {
        private ReadOnlyCollection<ParameterExpression>? _parameters;

        protected override Expression VisitParameter(ParameterExpression node)
        {
            return _parameters?.FirstOrDefault(p => p.Name == node.Name)
              ?? (node.Type == typeof(TSource)
                    ? Expression.Parameter(typeof(TTarget), node.Name)
                    : node);
        }

        protected override Expression VisitLambda<T>(Expression<T> node)
        {
            _parameters = VisitAndConvert(node.Parameters, "VisitLambda");
            return Expression.Lambda(Visit(node.Body), _parameters);
        }
    }

    private sealed class ExpressionReplacer : ExpressionVisitor
    {
        private readonly Expression _source;
        private readonly Expression _target;

        public ExpressionReplacer(Expression source, Expression target)
        {
            EnsureArg.IsNotNull(source, nameof(source));
            EnsureArg.IsNotNull(target, nameof(target));
            EnsureArg.IsTrue(
                source.Type == target.Type,
                default,
                (options) => options.WithMessage($"Source type {source.Type.FullName} is equal to target type {target.Type.FullName}."));

            _source = source;
            _target = target;
        }

        [return: NotNullIfNotNull("node")]
        public override Expression? Visit(Expression? node)
        {
            return _source.Equals(node)
                ? _target
                : base.Visit(node);
        }
    }
}