pygae/galgebra

View on GitHub
galgebra/atoms.py

Summary

Maintainability
A
25 mins
Test Coverage
""" Sympy primitives for representing atoms of ga expressions """

from typing import Union

from sympy import Symbol, AtomicExpr, S, Basic, sympify, MatrixExpr
from sympy import Determinant as _Determinant
from sympy.core import numbers
from sympy.core.function import AppliedUndef, UndefinedFunction
from sympy.printing.pretty.stringpict import prettyForm, stringPict
from sympy.printing.pretty.pretty_symbology import U

__all__ = [
    'BasisVectorSymbol',
    'BasisBladeSymbol',
    'BasisBladeNoWedgeSymbol',
    'BasisBaseSymbol',
    'DotProductSymbol',
]


def _all_same(items):
    return all(x == items[0] for x in items)


class BasisVectorSymbol(Symbol):
    """ A symbol representing a basis vector """
    is_commutative = False

    def _latex(self, print_obj):
        try:
            return print_obj._print_Symbol(self, style="bold")
        except TypeError:
            # too old a sympy version for `style=`
            return r"\mathbf{{{}}}".format(print_obj._print_Symbol(self))


class _GradedSymbol(AtomicExpr):
    """ Base class for all graded symbols

    Constructing this from a single symbol returns that symbol itself.
    Constructing from no symbols returns the scalar `S.One`.
    This may change in future.
    """
    # the scalar isn't commutative, but __new__ ensures we do not ever create
    # this type of objects for scalars
    is_commutative = False

    def __new__(cls, *args: BasisVectorSymbol) -> Union[
        numbers.One,
        BasisVectorSymbol,
        "_GradedSymbol"
    ]:
        if len(args) == 0:
            return S.One
        elif len(args) == 1:
            return args[0]
        else:
            return super().__new__(cls, *args)


class _JoinedPrinterMixin(Basic):
    """ Helper class to print `Basic.args` joined by symbol.

    Subclasses must populate `_op_sym` and `_op_sym_latex`
    """

    def _sympystr(self, printer):
        return self._op_sympystr.join(
            printer._print(v)
            for v in self.args
        )

    def _pretty(self, printer):
        ret = []
        for i, v in enumerate(self.args):
            if i != 0:
                ret.append(self._op_pretty)
            ret.append(printer._print(v))
        return prettyForm(*stringPict.next(*ret))

    def _latex(self, printer):
        return self._op_latex.join(
            printer._print(v)
            for v in self.args
        )


class BasisBaseSymbol(_GradedSymbol, _JoinedPrinterMixin):
    r""" A basis base in a non-orthogonal algebra, such as :math:`e_1 e_2` """
    _op_sympystr = '*'
    _op_pretty = prettyForm('*')
    _op_latex = ''


class BasisBladeSymbol(_GradedSymbol, _JoinedPrinterMixin):
    r""" A basis blade such as :math:`e_1 \wedge e_2` """
    _op_sympystr = '^'
    _op_pretty = prettyForm('^')
    _op_latex = r'\wedge '


class BasisBladeNoWedgeSymbol(BasisBladeSymbol):
    r""" A basis blade with shortened rendering such as :math:`e_{12}` """
    def _split_name(self):
        sub_str = []
        root_str = []
        for basis_vec in self.args:
            split_lst = basis_vec.name.split('_')
            if len(split_lst) != 2:
                raise ValueError('Incompatible basis vector {} for wedgeless printing'.format(basis_vec))
            else:
                sub_str.append(split_lst[1])
                root_str.append(split_lst[0])
        if _all_same(root_str):
            return root_str[0], ''.join(sub_str)
        else:
            raise ValueError('No unique root symbol to use for wedgeless printing')

    def __common_printer(self, printer):
        # print as if we were a basis vector
        root, sub = self._split_name()
        return printer._print(BasisVectorSymbol("{}_{}".format(root, sub)))

    _sympystr = _pretty = _latex = __common_printer


class DotProductSymbol(AtomicExpr):
    """ A symbol used to represent a dot product, like :class:`sympy.DotProduct` """
    is_real = True

    def _sympystr(self, printer):
        a, b = self.args
        return "({}.{})".format(printer._print(a), printer._print(b))

    def _latex(self, printer):
        a, b = self.args
        return r"\left ({}\cdot {}\right ) ".format(printer._print(a), printer._print(b))

    def _pretty(self, printer):
        a, b = self.args
        pform = prettyForm(*stringPict.next(
            printer._print(a),
            printer._print(U('DOT OPERATOR')),
            printer._print(b),
        ))
        return prettyForm(*pform.parens())


class MatrixFunction(UndefinedFunction):
    """ Like a MatrixSymbol, but for functions. """
    def __new__(mcl, name, m, n, **kwargs):
        cls = super().__new__(mcl, name, (AppliedUndef, MatrixExpr), {}, **kwargs)
        cls.shape = sympify(n, strict=True), sympify(n, strict=True)
        return cls


# workaround until sympy/sympy#19354 is merged
if _Determinant.is_commutative is not True:
    class Determinant(_Determinant):
        is_commutative = True
else:
    Determinant = _Determinant