yuanx749/complex_systems

View on GitHub
mcs/pde.py

Summary

Maintainability
A
0 mins
Test Coverage
from .mcs import *


class PDE(MCS):
    """PDEs simulation.

    Override this class to customize.

    Attributes:
        max_step: The max step.
        dim: The number of variables.
        dt: The time step.
        dh: Spatial resolution.
        size: Size of grid.
        f: An `~numpy.ndarray` of shape (max_step, size, size, dim) representing the states.
        step: The current step.
    """

    def __init__(self, max_step: int, dim: int, dt: float, dh: float, size: int):
        super().__init__(max_step)
        self.dim = dim
        self.dt = dt
        self.dh = dh
        self.size = size
        self.f = np.zeros((max_step, size, size, dim))

    def initialize(self):
        """Sets up the initial conditions."""
        x = y = np.arange(0, self.dh * (self.size + 1), self.dh)
        self.xv, self.yv = np.meshgrid(x, y)
        np.random.seed(42)
        self.f[0, ..., 0] = 1 + np.random.uniform(-0.01, 0.01, (self.size, self.size))
        self.f[0, ..., 1] = 1 + np.random.uniform(-0.01, 0.01, (self.size, self.size))

    def update(self, *, F: Callable = None):
        r"""Updates the states in the next step.

        Args:
            F: A state transition function corresponding to :math:`\partial f/\partial t = F(f,...,x,y,t)`.
        """
        if F is None:
            F = self._identity
        config = self.f[self.step]
        self.step += 1
        self.f[self.step] = config + F(config) * self.dt

    @staticmethod
    def turing(a, b, c, d, h, k, Du, Dv, dh):
        r"""Reaction-diffusion equations:

        .. math::
            \partial u/\partial t = a(u-h) + b(v-k) + D_u \Delta u

            \partial v/\partial t = c(u-h) + d(v-k) + D_v \Delta v.
        """

        def dfdt(config):
            lap = PDE._laplacian(config, dh)
            u, v = np.moveaxis(config, -1, 0)
            delta_u, delta_v = np.moveaxis(lap, -1, 0)
            du = a * (u - h) + b * (v - k) + Du * delta_u
            dv = c * (u - h) + d * (v - k) + Dv * delta_v
            return np.stack([du, dv], axis=2)

        return dfdt

    @staticmethod
    def _laplacian(u, dh):
        assert u.ndim == 3
        u_r = np.roll(u, -1, axis=1)
        u_l = np.roll(u, 1, axis=1)
        u_u = np.roll(u, -1, axis=0)
        u_d = np.roll(u, 1, axis=0)
        return (u_r + u_l + u_u + u_d - 4 * u) / (dh**2)

    def visualize(self, *, step: int = -1, indices: List[int] = None):
        """Visualizes the states of the system using heatmap.

        Args:
            step: The step to plot.
            indices: A list of indices of the states to plot.
                If `None`, plot all states.
        Returns:
            A list of `matplotlib.figure.Figure` objects.
        """
        figs = []
        indices = np.arange(self.dim) if indices is None else indices
        for state in indices:
            fig, ax = plt.subplots()
            pcm = ax.pcolormesh(self.xv, self.yv, self.f[step, ..., state], vmin=0, vmax=2)
            ax.set_aspect("equal")
            fig.colorbar(pcm, ax=ax)
            figs.append(fig)
        return figs