KarrLab/wc_rules

View on GitHub
wc_rules/simulator/simulator.py

Summary

Maintainability
A
2 hrs
Test Coverage
A
93%
from ..matcher.core import build_rete_net_class
from .scheduler import NextReactionMethod, PeriodicWriteObservables,CoordinatedScheduler
from ..utils.collections import DictLike, LoggableDict, subdict
from ..utils.data import NestedDict
from ..modeling.model import AggregateModel
from collections import defaultdict, deque, ChainMap
from collections.abc import Sequence
from attrdict import AttrDict
from ..schema.actions import PrimaryAction, CompositeAction, RemoveNode, CollectReferences
from ..matcher.token import convert_action_to_tokens
import sys

def get_model_name(rule_name):
    return '.'.join(rule_name.split('.')[:-1])

class PrefixSubdict:

    def __init__(self,prefix,target):
        self.prefix = prefix
        self.target = target

    def __contains__(self,key):
        return f'{self.prefix}.{key}' in self.target

    def __getitem__(self,key):
        return self.target[f'{self.prefix}.{key}']

CURRENT_ACTION_RECORD = deque()

class ActionStack:

    def __init__(self,cache,match):
        self._dq  = deque()
        self.cache = cache
        self.match = match
        self.record = deque()

    def put(self,action):
        if isinstance(action,list):
            self._dq.extend(action)
        else:
            self._dq.append(action)
        return self

    def put_first(self,action):
        if isinstance(action,list):
            self._dq.extendleft(reversed(action))
        else:
            self._dq.appendleft(action)
        return self

    def do(self):
        global CURRENT_ACTION_RECORD
        CURRENT_ACTION_RECORD = deque()

        tokens = []
        while self._dq:
            x = self._dq.popleft()
            if isinstance(x,CollectReferences):
                x.execute(self.match,self.cache)
            if isinstance(x,PrimaryAction):
                CURRENT_ACTION_RECORD.append(x)
                if isinstance(x,RemoveNode):
                    tokens.extend(convert_action_to_tokens(x,self.cache))
                    x.execute(self.cache)
                else:    
                    x.execute(self.cache)
                    tokens.extend(convert_action_to_tokens(x,self.cache))
                self.record.append(x)
            elif isinstance(x,CompositeAction):
                self.put_first(x.expand())
        return tokens



class SimulationEngine:

    ReteNetClass = build_rete_net_class()
    
    def __init__(self,model=None,parameters=None):
        print('Importing model and parameters.')
        model = AggregateModel('model',models=[model])
        model.verify(parameters)
        self.variables = LoggableDict(NestedDict.flatten(parameters))

        self.rules = dict(model.iter_rules())
        self.observables = dict(model.iter_observables())
        self.cache = DictLike()
        self.compiled_rules = dict()

        print('Initializing rete net matching engine.')
        self.net = self.ReteNetClass() \
            .initialize_start() \
            .initialize_end()    \
            .initialize_rules(self.rules,self.variables) \
            .initialize_observables(self.observables)

        self.net.validate()
        print('Linking rules to matching engine.')
        for name,rule in self.rules.items():
            self.compile_rule(name,rule)

        for node in self.net.get_nodes(type='variable'):
            self.variables.set(node.core,node.state.cache.value)

    def get_updated_variables(self):
        # this needs to be removed as it is not used in the simulator
        # it is currently being used by many tests though
        modified = list(self.variables.modified)
        self.variables.flush()
        return modified

    def compile_rule(self,rule_name,rule):
        compilation = AttrDict()
        compilation.original_object = rule
        compilation.reactants = {k:self.net.get_node(core=v).state.cache for k,v in rule.reactants.items()}
        compilation.helpers = {k:self.net.get_node(core=v).state.cache for k,v in rule.helpers.items()}
        compilation.parameters = PrefixSubdict(get_model_name(rule_name),self.variables)
        compilation.action_manager = rule.get_action_manager()
        compilation.factories = rule.factories
        self.compiled_rules[rule_name] = compilation
        return self

    def update(self,variables):
        for variable in variables:
            value = self.net.get_node(core=variable).state.cache.value
            self.variables.set(variable,value)
        return self

    def load(self,obj):
        if isinstance(obj,Sequence):
            for x in obj:
                self.load(x)
            return self
        ax = ActionStack(self.cache,{})
        for action in obj.generate_actions():
            tokens = ax.put(action).do()
            self.net.process_tokens(tokens)
        variables = self.net.get_updated_variables()
        self.update(variables)
        return self

    def fire(self,rule_name):
        rule = self.compiled_rules[rule_name]
        match = AttrDict()
        for name,pcache in rule.reactants.items():
            match[name] = AttrDict(pcache.sample())
        ax = ActionStack(self.cache,match)
        for action in rule.action_manager.exec(ax.match,rule.parameters,rule.reactants,rule.helpers,rule.factories):
            tokens = ax.put(action).do()
            self.net.process_tokens(tokens)
        variables = self.net.get_updated_variables()
        self.update(variables)
        return self
        
    def simulate(self,start=0.0,end=1.0,period=1,write_location=None):
        global CURRENT_ACTION_RECORD
        print(f'Simulating from time={start} to time={end}.')
        time = start
        sch = CoordinatedScheduler([
            NextReactionMethod(),
            PeriodicWriteObservables(
                start=start,
                period=period,
                write_location=write_location
            )])
        sch.update(start,self.variables)
        while True:
            event,time = sch.pop()
            if event=='write_observables':
                print(f'Current time={time}\tNumber of objects={len(self.cache)}')
                self.write_observables(time,write_location)
                if time >= end:
                    break
            elif event in self.rules:
                try:
                    self.fire(event)
                except Exception as error:
                    print(f'While firing {event}, an error occurred.')
                    print(f'The following actions were being implemented: ')
                    for act in CURRENT_ACTION_RECORD:
                        print(act)
                    print(f'Simulation failed. Exiting.' )
                    raise error.with_traceback(sys.exc_info()[2])
                    
                sch.update(time,subdict(self.variables,self.variables.modified))
            self.variables.flush()
        print(f'Simulation complete. Exiting.')
        return self

    def write_observables(self,time,write_location):
        variables = [x.core for x in self.net.get_nodes(type='variable') if x.data.subtype=='recompute']
        write_location.append({'time':time,'observables':subdict(self.variables,variables)})
        return self