src/gyptis/complex.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Benjamin Vial
# This file is part of gyptis
# Version: 1.0.2
# License: MIT
# See the documentation at gyptis.gitlab.io
"""
Support for complex finite element forms.
This module provides a :class:`Complex` class and overrides some ``df`` functions
to easily deal with complex problems by spliting real and imaginary parts.
"""
from typing import Iterable
import numpy as np
import ufl
from . import dolfin as df
def _complexcheck(func):
"""Wrapper to check if arguments are complex"""
def wrapper(self, z):
if hasattr(z, "real") and hasattr(z, "imag"):
if not isinstance(z, Complex):
z = Complex(z.real, z.imag)
else:
z = Complex(z, 0)
return func(self, z)
return wrapper
def _complexify_linear(func):
def wrapper(z, *args, **kwargs):
if isinstance(z, Complex):
return Complex(func(z.real, *args, **kwargs), func(z.imag, *args, **kwargs))
else:
return func(z, *args, **kwargs)
return wrapper
def _complexify_bilinear(func):
def wrapper(a, b, *args, **kwargs):
if iscomplex(a) and iscomplex(b):
re = func(a.real, b.real, *args, **kwargs) - func(
a.imag, b.imag, *args, **kwargs
)
im = func(a.real, b.imag, *args, **kwargs) + func(
a.imag, b.real, *args, **kwargs
)
return Complex(re, im)
elif iscomplex(a) and not iscomplex(b):
re = func(a.real, b, *args, **kwargs)
im = func(a.imag, b)
return Complex(re, im, *args, **kwargs)
elif not iscomplex(a) and iscomplex(b):
re = func(a, b.real, *args, **kwargs)
im = func(a, b.imag)
return Complex(re, im, *args, **kwargs)
else:
return func(a, b, *args, **kwargs)
return wrapper
def _complexify_vector(func):
def wrapper(*args, **kwargs):
v = func(*args, **kwargs)
re, im = df.split(v)
return Complex(re, im)
return wrapper
def _complexify_vector_alt(func):
def wrapper(*args, **kwargs):
v = func(*args, **kwargs)
re, im = v.split() # (deepcopy=True)
return Complex(re, im)
return wrapper
class Complex:
"""A complex object.
Parameters
----------
real : type
Real part.
imag : type
Imaginary part (the default is 0.0).
Attributes
----------
real
imag
"""
def __init__(self, real, imag=0.0):
self.real = real
self.imag = imag
def __len__(self):
if hasattr(self.real, "__len__") and hasattr(self.imag, "__len__"):
if len(self.real) == len(self.imag):
return len(self.real)
else:
raise ValueError("real and imaginary parts should have the same length")
else:
return 0
def __iter__(self):
for i in range(len(self)):
yield Complex(self.real[i], self.imag[i])
def __getitem__(self, i):
return Complex(self.real[i], self.imag[i])
@_complexcheck
def __add__(self, other):
return Complex(self.real + other.real, self.imag + other.imag)
__radd__ = __add__
@_complexcheck
def __sub__(self, other):
return Complex(self.real - other.real, self.imag - other.imag)
__rsub__ = __sub__
@_complexcheck
def __mul__(self, other):
return Complex(
self.real * other.real - self.imag * other.imag,
self.imag * other.real + self.real * other.imag,
)
__rmul__ = __mul__
__array_ufunc__ = None
@_complexcheck
def __truediv__(self, other):
sr, si, tr, ti = self.real, self.imag, other.real, other.imag # short forms
r = tr**2 + ti**2
return Complex((sr * tr + si * ti) / r, (si * tr - sr * ti) / r)
@_complexcheck
def __rtruediv__(self, other):
sr, si, tr, ti = other.real, other.imag, self.real, self.imag # short forms
r = tr**2 + ti**2
return Complex((sr * tr + si * ti) / r, (si * tr - sr * ti) / r)
@property
def shape(self):
return self.real.shape
@property
def conj(self):
return Complex(self.real, -self.imag)
@property
def module(self):
"""Module of the complex number"""
return self.__abs__()
@property
def phase(self):
return self.__angle__()
def __abs__(self):
return df.sqrt(self.real**2 + self.imag**2)
def __neg__(self): # defines -c (c is Complex)
return Complex(-self.real, -self.imag)
@_complexcheck
def __eq__(self, other):
return self.real == other.real and self.imag == other.imag
@_complexcheck
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
return f"({self.real.__str__()} + {self.imag.__str__()}j)"
def __repr__(self):
return f"Complex({self.real.__repr__()}, {self.imag.__repr__()})"
def __pow__(self, power):
if iscomplex(power) and power.imag != 0:
raise NotImplementedError("complex exponent not implemented")
A, phi = self.polar()
return self.polar2cart(A**power, phi * power)
def __angle__(self):
x, y = self.real, self.imag
try:
return np.angle(x + 1j * y)
except Exception:
return df.conditional(
ufl.eq(self.__abs__(), 0),
0,
df.conditional(
ufl.eq(self.__abs__() + x, 0),
df.pi,
2 * df.atan(y / (self.__abs__() + x)),
),
)
def __call__(self, *args, **kwargs):
"Calls the complex function if base objects are callable"
return Complex(
self.real.__call__(*args, **kwargs),
self.imag.__call__(*args, **kwargs),
)
def tocomplex(self):
return self.real + 1j * self.imag
@staticmethod
def polar2cart(module, phase):
"""Polar to cartesian representation.
Parameters
----------
module : type
The module (positive).
phase : type
The polar angle.
Returns
-------
Complex
The complex number in cartesian representation.
"""
return module * Complex(df.cos(phase), df.sin(phase))
def polar(self):
"""Polar representation.
Returns
-------
tuple
Modulus and phase.
"""
return self.__abs__(), self.__angle__()
def to_complex(self):
return self.real + 1j * self.imag
def iscomplex(z):
"""Checks if object is complex.
Parameters
----------
z : type
Object.
Returns
-------
bool
True if z is complex, else False.
"""
return bool(hasattr(z, "real") and hasattr(z, "imag") and z.imag is not None)
class ComplexFunctionSpace(df.FunctionSpace):
"""Complex function space"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
element = super().ufl_element()
super().__init__(super().mesh(), element * element, **kwargs)
def _cplx_iter(f):
def wrapper(v, *args, **kwargs):
iterable = isinstance(v, Iterable)
cplx = any(iscomplex(v_) for v_ in v) if iterable else iscomplex(v)
if cplx:
if iterable:
v = np.array(v)
return Complex(f(v.real, *args, **kwargs), f(v.imag, *args, **kwargs))
else:
return f(v, *args, **kwargs)
return wrapper
def phasor(prop_cst, direction=0, **kwargs):
phasor_re = df.Expression(
f"cos(prop_cst*x[{direction}])", prop_cst=prop_cst, **kwargs
)
phasor_im = df.Expression(
f"sin(prop_cst*x[{direction}])", prop_cst=prop_cst, **kwargs
)
return Complex(phasor_re, phasor_im)
def phase_shift(phase, **kargs):
phasor_re = df.Expression("cos(phase)", phase=phase, **kargs)
phasor_im = df.Expression("sin(phase)", phase=phase, **kargs)
return Complex(phasor_re, phasor_im)
def phase_shift_constant(phase):
phasor_re = df.cos(phase)
phasor_im = df.sin(phase)
return Complex(phasor_re, phasor_im)
def vector(vect):
vsr = [v.real for v in vect]
vsi = [v.imag for v in vect]
return Complex(df.as_vector(vsr), df.as_vector(vsi))
def tensor(tens):
tsr = [[t.real for t in _t] for _t in tens]
tsi = [[t.imag for t in _t] for _t in tens]
return Complex(df.as_tensor(tsr), df.as_tensor(tsi))
interpolate = _complexify_linear(df.interpolate)
assemble = _complexify_linear(df.assemble)
grad = _complexify_linear(df.grad)
div = _complexify_linear(df.div)
curl = _complexify_linear(df.curl)
project = _complexify_linear(df.project)
sym = _complexify_linear(df.sym)
tr = _complexify_linear(df.tr)
Dx = _complexify_linear(df.Dx)
inner = _complexify_bilinear(df.inner)
dot = _complexify_bilinear(df.dot)
cross = _complexify_bilinear(df.cross)
as_tensor = _cplx_iter(df.as_tensor)
as_vector = _cplx_iter(df.as_vector)
Constant = _cplx_iter(df.Constant)
Function = _complexify_vector_alt(df.Function)
TrialFunction = _complexify_vector(df.TrialFunction)
TestFunction = _complexify_vector(df.TestFunction)
TrialFunctions = _complexify_vector(df.TrialFunctions)
TestFunctions = _complexify_vector(df.TestFunctions)
j = Complex(0, 1)
def _traverse(a):
if not isinstance(a, list):
yield a
else:
for e in a:
yield from _traverse(e)
def to_array(a):
values = [e.to_complex() for e in _traverse(a)]
return np.array(values).reshape(np.array(a).shape[:-1])