KarrLab/de_sim

View on GitHub
de_sim/examples/sirs.py

Summary

Maintainability
A
50 mins
Test Coverage
A
100%
""" Example DE-Sim implementations of stochastic Susceptible, Infectious, or Recovered (SIR) epidemic models

:Author: Arthur Goldberg <Arthur.Goldberg@mssm.edu>
:Date: 2020-07-08
:Copyright: 2020, Karr Lab
:License: MIT
"""

import enum
import numpy

from de_sim.checkpoint import AccessCheckpoints
from de_sim.simulation_checkpoint_object import (AccessStateObjectInterface,
                                                 CheckpointSimulationObject)
import de_sim


class SusceptibleToInfectious(de_sim.EventMessage):
    "S -> I transition"


class InfectiousToRecovered(de_sim.EventMessage):
    "I -> R transition"


MESSAGE_TYPES = [SusceptibleToInfectious, InfectiousToRecovered]


class SIR(de_sim.SimulationObject):
    """ Implement a Susceptible, Infectious, or Recovered (SIR) epidemic model

    This example uses DE-Sim to implement a continuous-time Markov chain (CTMC) SIR
    epidemic model, as described in section 3 of Allen (2017).

    Allen, L.J., 2017. A primer on stochastic epidemic models: Formulation, numerical simulation, and analysis.
    Infectious Disease Modelling, 2(2), pp.128-142.

    Attributes:
        s (:obj:`int`): number of susceptible subjects
        i (:obj:`int`): number of infectious subjects
        N (:obj:`int`): total number of susceptible subjects, a constant
        beta (:obj:`float`): SIR beta parameter
        gamma (:obj:`float`): SIR gamma parameter
        recording_period (:obj:`float`): time step for recording state
        random_state (:obj:`numpy.random.RandomState`): a random state
        history (:obj:`list`): list of recorded states
    """
    def __init__(self, name, s, i, N, beta, gamma, recording_period):
        """ Initialize an SIR instance

        Args:
            name (:obj:`str`): the instance's name
            s (:obj:`int`): initial number of susceptible subjects, s(0)
            i (:obj:`int`): initial number of infectious subjects, i(0)
            N (:obj:`int`): total number of susceptible subjects, a constant
            beta (:obj:`float`): SIR beta parameter
            gamma (:obj:`float`): SIR gamma parameter
            recording_period (:obj:`float`): time step for recording state
            random_state (:obj:`numpy.random.RandomState`): random state
            history (:obj:`list`): list of recorded states
        """
        self.s = s
        self.i = i
        self.N = N
        self.beta = beta
        self.gamma = gamma
        self.recording_period = recording_period
        self.random_state = numpy.random.RandomState()
        self.history = []
        super().__init__(name)

    def init_before_run(self):
        """ Send the initial events, and record the initial state
        """
        self.schedule_next_event()

    def schedule_next_event(self):
        """ Schedule the next SIR event
        """
        rates = {'s_to_i': self.beta * self.s * self.i / self.N,
                 'i_to_r': self.gamma * self.i}
        lambda_val = rates['s_to_i'] + rates['i_to_r']
        if lambda_val == 0:
            return

        tau = self.random_state.exponential(1.0/lambda_val)
        prob_s_to_i = rates['s_to_i'] / lambda_val
        if self.random_state.random_sample() < prob_s_to_i:
            self.send_event(tau, self, SusceptibleToInfectious())
        else:
            self.send_event(tau, self, InfectiousToRecovered())

    def handle_s_to_i(self, event):
        """ Handle a susceptible to infectious event

        Args:
            event (:obj:`~de_sim.event.Event`): simulation event; not used
        """
        del event   # Avoid PyLint warning W0613, unused-argument
        self.s -= 1
        self.i += 1
        self.schedule_next_event()

    def handle_i_to_r(self, event):
        """ Handle an infectious to recovered event

        Args:
            event (:obj:`~de_sim.event.Event`): simulation event; not used
        """
        del event   # Avoid PyLint warning W0613, unused-argument
        self.i -= 1
        self.schedule_next_event()

    event_handlers = [(SusceptibleToInfectious, 'handle_s_to_i'),
                      (InfectiousToRecovered, 'handle_i_to_r')]

    # register the message types sent
    messages_sent = MESSAGE_TYPES


class StateTransitionType(enum.Enum):
    """ State transition types
    """
    s_to_i = 'Transition from Susceptible to Infectious'
    i_to_r = 'Transition from Infectious to Recovered'


### SIR epidemic model, version 2 ###
class TransitionMessage(de_sim.EventMessage):
    "Message for all model transitions"
    transition_type: StateTransitionType


MESSAGE_TYPES = [TransitionMessage]


class SIR2(SIR):
    """ Version 2 of an SIR epidemic model

    SIR2 is similar to SIR, but uses one event message type for both transitions, and a
    single message handler to process transition events.
    """
    def schedule_next_event(self):
        """ Schedule the next SIR event
        """
        rates = {'s_to_i': self.beta * self.s * self.i / self.N,
                 'i_to_r': self.gamma * self.i}
        lambda_val = rates['s_to_i'] + rates['i_to_r']
        if lambda_val == 0:
            return

        tau = self.random_state.exponential(1.0/lambda_val)
        prob_s_to_i = rates['s_to_i'] / lambda_val
        if self.random_state.random_sample() < prob_s_to_i:
            self.send_event(tau, self, TransitionMessage(StateTransitionType.s_to_i))
        else:
            self.send_event(tau, self, TransitionMessage(StateTransitionType.i_to_r))

    def handle_state_transition(self, event):
        """ Handle an infectious state transition

        Args:
            event (:obj:`~de_sim.event.Event`): simulation event that contains the type of transition
        """
        transition_type = event.message.transition_type
        if transition_type is StateTransitionType.s_to_i:
            self.s -= 1
            self.i += 1
        elif transition_type is StateTransitionType.i_to_r:
            self.i -= 1
        self.schedule_next_event()

    event_handlers = [(TransitionMessage, 'handle_state_transition')]

    # register the message types sent
    messages_sent = MESSAGE_TYPES


class AccessSIRObjectState(AccessStateObjectInterface):
    """ Get the state of an SIR object

    Attributes:
        sir (:obj:`obj`): an SIR object
        random_state (:obj:`numpy.random.RandomState`): a random state
    """

    def __init__(self, sir):
        self.sir = sir
        self.random_state = sir.random_state

    def get_checkpoint_state(self, time):
        """ Get the SIR object's state

        Args:
            time (:obj:`float`): current time; ignored
        """
        return dict(s=self.sir.s,
                    i=self.sir.i)

    def get_random_state(self):
        """ Get the SIR object's random state
        """
        return self.random_state.get_state()


class RunSIRs(object):

    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = checkpoint_dir

    def simulate(self, sir_class, max_time, **sir_args):
        """ Create and run an SIR simulation

        Args:
            sir_class (:obj:`type`): a type of SIR class to run
            max_time (:obj:`float`): simulation end time
            sir_args (:obj:`dict`): arguments for an SIR object
        """

        # create a simulator
        simulator = de_sim.Simulator()

        # create an SIR instance
        self.sir = sir = sir_class(**sir_args)
        simulator.add_object(sir)

        # create a checkpoint simulation object
        access_state_object = AccessSIRObjectState(sir)
        checkpointing_obj = CheckpointSimulationObject('checkpointing_obj', sir_args['recording_period'],
                                                       self.checkpoint_dir, access_state_object)
        simulator.add_object(checkpointing_obj)

        # initialize simulation, which sends the SIR instance an initial event message
        simulator.initialize()

        # run the simulation
        event_num = simulator.simulate(max_time).num_events
        print("Executed {} events.\n".format(event_num))

    def print_history(self):
        """ Print an SIR simulation's history
        """
        header = ['time', 's', 'i', 'r']
        print('\t'.join(header))

        access_checkpoints = AccessCheckpoints(self.checkpoint_dir)
        for checkpoint_time in access_checkpoints.list_checkpoints():
            chkpt = access_checkpoints.get_checkpoint(time=checkpoint_time)
            state = chkpt.state
            state_as_list = [checkpoint_time, state['s'], state['i'], self.sir.N - state['s'] - state['i']]
            state_as_list = [str(v) for v in state_as_list]
            print('\t'.join(state_as_list))

    def last_checkpoint(self):
        """ Get the last checkpoint of the last simulation run

        Returns:
            :obj:`~de_sim.checkpoint.Checkpoint`: the last checkpoint of the last simulation run
        """
        access_checkpoints = AccessCheckpoints(self.checkpoint_dir)
        last_checkpoint_time = access_checkpoints.list_checkpoints()[-1]
        return access_checkpoints.get_checkpoint(time=last_checkpoint_time)