j84reginato/my-eval

View on GitHub
src/Solving/Differentiator.php

Summary

Maintainability
D
1 day
Test Coverage
<?php

declare(strict_types=1);

namespace MyEval\Solving;

use MyEval\Exceptions\DivisionByZeroException;
use MyEval\Exceptions\ExponentialException;
use MyEval\Exceptions\NullOperandException;
use MyEval\Exceptions\SyntaxErrorException;
use MyEval\Exceptions\UnexpectedOperatorException;
use MyEval\Exceptions\UnknownFunctionException;
use MyEval\Exceptions\UnknownOperatorException;
use MyEval\Parsing\Nodes\Node;
use MyEval\Parsing\Nodes\Operand\ConstantNode;
use MyEval\Parsing\Nodes\Operand\FloatNode;
use MyEval\Parsing\Nodes\Operand\IntegerNode;
use MyEval\Parsing\Nodes\Operand\RationalNode;
use MyEval\Parsing\Nodes\Operand\VariableNode;
use MyEval\Parsing\Nodes\Operator\FunctionNode;
use MyEval\Parsing\Nodes\Operator\InfixExpressionNode;
use MyEval\Parsing\Operations\OperationBuilder;

/**
 * Differentiate an abstract syntax tree (AST).
 *
 * Implementation of a Visitor, transforming an AST into another AST representing the derivative of the original AST.
 *
 * The class implements differentiation rules for all arithmetic operators as well as every elementary function
 * recognized by StdMathLexer and Parser, handling for example the product rule and the chain rule correctly.
 *
 * To keep the resulting AST reasonably simple, a number of simplification rules are built in.
 *
 * ## Example:
 *
 * ~~~{.php}
 * use MyEval\StdMathEval;
 * use MyEval\Solving\Differentiator;
 *
 * $parser = new StdMathEval();
 * $ast = $parser->parse('exp(2x)+xy');
 * $ddx = new Differentiator('x');     // Create a d/dx operator
 * $df = $ast->accept($ddx);           // $df now contains the AST of '2exp(2x)+y'
 *
 * ~~~
 *
 * or more complex use:
 *
 * ~~~{.php}
 * use MyEval\Lexing\StdMathLexer;
 * use MyEval\Parsing\Parser;
 * use MyEval\Solving\Differentiator;
 *
 * // Tokenize
 * $lexer = new StdMathLexer();
 * $tokens = $lexer->tokenize('exp(2x)+xy');
 *
 * // Parse
 * $parser = new Parser();
 * $ast = $parser->parse($tokens);
 *
 * // Differentiate
 * $differentiator = new Differentiator('x');  // Create a d/dx operator
 * $value = $ast->accept($differentiator);     // $df now contains the AST of '2exp(2x)+y'
 * ~~~
 */
class Differentiator implements Visitor
{
    /**
     * Create a Differentiator.
     *
     * @param string           $variable    Variable that we differentiate with respect to.
     * @param OperationBuilder $nodeFactory Used for building the resulting AST.
     */
    public function __construct(
        private string $variable,
        private OperationBuilder $nodeFactory = new OperationBuilder()
    ) {
    }

    /**
     * Differentiate an IntegerNode.
     *
     * Create a NumberNode representing '0'. (The derivative of a constant is indentically 0).
     *
     * @param IntegerNode $node AST to be differentiated.
     *
     * @return IntegerNode
     */
    public function visitIntegerNode(IntegerNode $node): IntegerNode
    {
        return new IntegerNode(0);
    }

    /**
     * Differentiate a RationalNode.
     *
     * Create a NumberNode representing '0'. (The derivative of a constant is indentically 0).
     *
     * @param RationalNode $node AST to be differentiated.
     *
     * @return IntegerNode
     **/
    public function visitRationalNode(RationalNode $node): IntegerNode
    {
        return new IntegerNode(0);
    }

    /**
     * Differentiate a NumberNode.
     *
     * Create a NumberNode representing '0'. (The derivative of a constant is indentically 0).
     *
     * @param FloatNode $node AST to be differentiated.
     *
     * @return IntegerNode
     */
    public function visitNumberNode(FloatNode $node): IntegerNode
    {
        return new IntegerNode(0);
    }

    /**
     * Differentiate a BooleanNode.
     *
     * @return void
     * @throws SyntaxErrorException
     */
    public function visitBooleanNode(): void
    {
        throw new SyntaxErrorException();
    }

    /**
     * Differentiate a VariableNode.
     *
     * Create a NumberNode representing '0' or '1' depending on the differetiation variable.
     *
     * @param VariableNode $node AST to be differentiated.
     *
     * @return IntegerNode
     */
    public function visitVariableNode(VariableNode $node): IntegerNode
    {
        if ($node->value === $this->variable) {
            return new IntegerNode(1);
        }

        return new IntegerNode(0);
    }

    /**
     * Differentiate a ConstantNode.
     *
     * Create a NumberNode representing '0'. (The derivative of a constant is identically 0).
     *
     * @param ConstantNode $node AST to be differentiated.
     *
     * @return IntegerNode|ConstantNode
     */
    public function visitConstantNode(ConstantNode $node): IntegerNode|ConstantNode
    {
        if ($node->value === 'NAN') {
            return $node;
        }

        return new IntegerNode(0);
    }

    /**
     * Differentiate an InfixExpressionNode.
     *
     * Using the usual rules for differentiating, create an ExpressionNode giving an AST correctly representing the
     * derivative `(x op y)'` where `op` is one of `+`, `-`, `*`, `/` or `^`.
     *
     * ## Differentiation rules:
     *
     * - \\( (f+g)' = f' + g' \\)
     * - \\( (f-g) ' = f' - g' \\)
     * - \\( (-f)' = -f' \\)
     * - \\( (f*g)' = f'g + f g' \\)
     * - \\( (f/g)' = (f' g - f g')/g^2 \\)
     * - \\( (f^g)' = f^g  (g' \\log(f) + gf'/f) \\) with a simpler expression when g is a NumberNode.
     *
     * @param InfixExpressionNode $node AST to be differentiated.
     *
     * @return Node
     * @throws DivisionByZeroException
     * @throws NullOperandException
     * @throws UnexpectedOperatorException
     * @throws UnknownOperatorException If the operator is something other than `+`, `-`, `*`, `/` or `^`.
     * @throws ExponentialException
     */
    public function visitInfixExpressionNode(InfixExpressionNode $node): Node
    {
        $left     = $node->getLeft();
        $operator = $node->operator;
        $right    = $node->getRight();

        if ($left === null || ($right === null && $operator !== '-')) {
            throw new NullOperandException();
        }

        // Perform the right operation based on the operator
        switch ($operator) {
            case '+':
                return $this->nodeFactory->addition($left->accept($this), $right->accept($this));

            case '-':
                return $this->nodeFactory->subtraction($left->accept($this), $right?->accept($this));

            // Product rule (fg)' = fg' + f'g
            case '*':
                return $this->nodeFactory->addition(
                    $this->nodeFactory->multiplication($node->getLeft(), $right->accept($this)),
                    $this->nodeFactory->multiplication($left->accept($this), $node->getRight())
                );

            // Quotient rule (f/g)' = (f'g - fg')/g^2
            case '/':
                $term1       = $this->nodeFactory->multiplication($left->accept($this), $node->getRight());
                $term2       = $this->nodeFactory->multiplication($node->getLeft(), $right->accept($this));
                $numerator   = $this->nodeFactory->subtraction($term1, $term2);
                $denominator = $this->nodeFactory->exponentiation($node->getRight(), new IntegerNode(2));

                return $this->nodeFactory->division($numerator, $denominator);

            // f^g = exp(g log(f)), so (f^g)' = f^g (g'log(f) + g/f)
            case '^':
                $base     = $left;
                $exponent = $right;

                if ($exponent instanceof IntegerNode) {
                    $fpow = $this->nodeFactory->exponentiation($base, new IntegerNode($exponent->value - 1));

                    return $this->nodeFactory->multiplication(
                        new IntegerNode($exponent->value),
                        $this->nodeFactory->multiplication($fpow, $base->accept($this))
                    );
                }

                if ($exponent instanceof FloatNode) {
                    $fpow = $this->nodeFactory->exponentiation($base, new FloatNode($exponent->value - 1));

                    return $this->nodeFactory->multiplication(
                        new FloatNode($exponent->value),
                        $this->nodeFactory->multiplication($fpow, $base->accept($this))
                    );
                }

                if ($exponent instanceof RationalNode) {
                    $power = new RationalNode(
                        $exponent->getNumerator() - $exponent->getDenominator(),
                        $exponent->getDenominator()
                    );
                    $fpow  = $this->nodeFactory->exponentiation($base, $power);

                    return $this->nodeFactory->multiplication(
                        $exponent,
                        $this->nodeFactory->multiplication($fpow, $base->accept($this))
                    );
                }

                if ($base instanceof ConstantNode && $base->value === 'e') {
                    return $this->nodeFactory->multiplication($right->accept($this), $node);
                }

                $term1   = $this->nodeFactory->multiplication($right->accept($this), new FunctionNode('ln', $base));
                $term2   = $this->nodeFactory->division(
                    $this->nodeFactory->multiplication($exponent, $base->accept($this)),
                    $base
                );
                $factor2 = $this->nodeFactory->addition($term1, $term2);

                return $this->nodeFactory->multiplication($node, $factor2);

            default:
                throw new UnknownOperatorException($operator);
        }
    }

    /**
     * Evaluate a TernaryNode.
     *
     * @return void
     * @throws SyntaxErrorException
     */
    public function visitTernaryNode(): void
    {
        throw new SyntaxErrorException();
    }

    /**
     * Differentiate a FunctionNode.
     *
     * Create an ExpressionNode giving an AST representing the derivative `f'` where `f` is an elementary function.
     *
     * ## Differentiation rules:
     *
     *  \\( \\sin(f(x))' = f'(x)  \\cos(f(x)) \\)
     *  \\( \\cos(f(x))' = -f'(x)  \\sin(f(x)) \\)
     *  \\( \\tan(f(x))' = f'(x) (1 + \\tan(f(x))^2 \\)
     *  \\( \\operatorname{cot}(f(x))' = f'(x) (-1 - \\operatorname{cot}(f(x))^2 \\)
     *  \\( \\arcsin(f(x))' = f'(x) / \\sqrt{1-f(x)^2} \\)
     *  \\( \\arccos(f(x))' = -f'(x) / \\sqrt{1-f(x)^2} \\)
     *  \\( \\arctan(f(x))' = f'(x) / (1+f(x)^2) \\)
     *  \\( \\operatorname{arccot}(f(x))' = -f'(x) / (1+f(x)^2) \\)
     *  \\( \\exp(f(x))' = f'(x) \\exp(f(x)) \\)
     *  \\( \\log(f(x))' = f'(x) / f(x) \\)
     *  \\( \\ln(f(x))' = f'(x) / (\\log(10) * f(x)) \\)
     *  \\( \\sqrt{f(x)}' = f'(x) / (2 \\sqrt{f(x)} \\)
     *  \\( \\sinh(f(x))' = f'(x) \\cosh(f(x)) \\)
     *  \\( \\cosh(f(x))' = f'(x) \\sinh(f(x)) \\)
     *  \\( \\tanh(f(x))' = f'(x) (1-\\tanh(f(x))^2) \\)
     *  \\( \\operatorname{coth}(f(x))' = f'(x) (1-\\operatorname{coth}(f(x))^2) \\)
     *  \\( \\operatorname{arsinh}(f(x))' = f'(x) / \\sqrt{f(x)^2+1} \\)
     *  \\( \\operatorname{arcosh}(f(x))' = f'(x) / \\sqrt{f(x)^2-1} \\)
     *  \\( \\operatorname{artanh}(f(x))' = f'(x) (1-f(x)^2) \\)
     *  \\( \\operatorname{arcoth}(f(x))' = f'(x) (1-f(x)^2) \\)
     *
     * @param FunctionNode $node AST to be differentiated.
     *
     * @return Node
     * @throws DivisionByZeroException If results a division by zero.
     * @throws NullOperandException If operand is null.
     * @throws UnknownFunctionException If the function name doesn't match one of the above.
     * @throws UnknownOperatorException If the operator name is unknow.
     * @throws UnexpectedOperatorException
     * @throws ExponentialException
     */
    public function visitFunctionNode(FunctionNode $node): Node
    {
        if ($node->operand === null) {
            throw new NullOperandException();
        }

        $inner = $node->operand->accept($this);
        $arg   = $node->operand;

        switch ($node->operator) {
            case 'sin':
                $df = new FunctionNode('cos', $arg);
                break;

            case 'cos':
                $sin = new FunctionNode('sin', $arg);
                $df  = $this->nodeFactory->unaryMinus($sin);
                break;

            case 'tan':
                $tansquare = $this->nodeFactory->exponentiation($node, new IntegerNode(2));
                $df        = $this->nodeFactory->addition(new IntegerNode(1), $tansquare);
                break;

            case 'cot':
                $unaryMinus = $this->nodeFactory->unaryMinus(new IntegerNode(1));
                $cotsquare  = $this->nodeFactory->exponentiation($node, new IntegerNode(2));
                $df         = $this->nodeFactory->subtraction($unaryMinus, $cotsquare);
                break;

            case 'arcsin':
                $exp   = $this->nodeFactory->exponentiation($arg, new IntegerNode(2));
                $denom = new FunctionNode('sqrt', $this->nodeFactory->subtraction(new IntegerNode(1), $exp));
                return $this->nodeFactory->division($inner, $denom);

            case 'arccos':
                $exp   = $this->nodeFactory->exponentiation($arg, new IntegerNode(2));
                $denom = new FunctionNode('sqrt', $this->nodeFactory->subtraction(new IntegerNode(1), $exp));
                return $this->nodeFactory->division($this->nodeFactory->unaryMinus($inner), $denom);

            case 'arctan':
                $denom = $this->nodeFactory->addition(
                    new IntegerNode(1),
                    $this->nodeFactory->exponentiation($arg, new IntegerNode(2))
                );
                return $this->nodeFactory->division($inner, $denom);

            case 'arccot':
                $denom = $this->nodeFactory->addition(
                    new IntegerNode(1),
                    $this->nodeFactory->exponentiation($arg, new IntegerNode(2))
                );
                $df    = $this->nodeFactory->unaryMinus($this->nodeFactory->division(new IntegerNode(1), $denom));
                break;

            case 'exp':
                $df = new FunctionNode('exp', $arg);
                break;

            case 'ln':
            case 'log':
                return $this->nodeFactory->division($inner, $arg);

            case 'lg':
                $denominator = $this->nodeFactory->multiplication(new FunctionNode('ln', new IntegerNode(10)), $arg);
                return $this->nodeFactory->division($inner, $denominator);

            case 'sqrt':
                $denom = $this->nodeFactory->multiplication(new IntegerNode(2), $node);
                return $this->nodeFactory->division($inner, $denom);

            case 'sinh':
                $df = new FunctionNode('cosh', $arg);
                break;

            case 'cosh':
                $df = new FunctionNode('sinh', $arg);
                break;

            case 'tanh':
                $tanhsquare = $this->nodeFactory->exponentiation(new FunctionNode('tanh', $arg), new IntegerNode(2));
                $df         = $this->nodeFactory->subtraction(new IntegerNode(1), $tanhsquare);
                break;

            case 'coth':
                $cothsquare = $this->nodeFactory->exponentiation(new FunctionNode('coth', $arg), new IntegerNode(2));
                $df         = $this->nodeFactory->subtraction(new IntegerNode(1), $cothsquare);
                break;

            case 'arsinh':
                $exp  = $this->nodeFactory->exponentiation($arg, new IntegerNode(2));
                $temp = $this->nodeFactory->addition($exp, new IntegerNode(1));
                return $this->nodeFactory->division($inner, new FunctionNode('sqrt', $temp));

            case 'arcosh':
                $exp  = $this->nodeFactory->exponentiation($arg, new IntegerNode(2));
                $temp = $this->nodeFactory->subtraction($exp, new IntegerNode(1));
                return $this->nodeFactory->division($inner, new FunctionNode('sqrt', $temp));

            case 'artanh':
            case 'arcoth':
                $exp         = $this->nodeFactory->exponentiation($arg, new IntegerNode(2));
                $denominator = $this->nodeFactory->subtraction(new IntegerNode(1), $exp);
                return $this->nodeFactory->division($inner, $denominator);

            case 'abs':
                $df = new FunctionNode('sgn', $arg);
                break;

            default:
                throw new UnknownFunctionException($node->operator);
        }

        return $this->nodeFactory->multiplication($inner, $df);
    }
}