eager-dev/eagerx

View on GitHub
eagerx/backends/single_process.py

Summary

Maintainability
A
3 hrs
Test Coverage
A
92%
import copy
import numpy as np
import typing
import time
import threading
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import cpu_count
from eagerx.utils.utils import Header
from eagerx.core.pubsub import Publisher, Subscriber, ShutdownService
import eagerx
from eagerx.core.constants import (
    WARN,
    BackendException,
    Unspecified,
)


# A singleton that is used to check if an argument was specified.
_unspecified = Unspecified()


def merge(a: typing.Dict, b: typing.Dict, path=None):
    """merges b into a"""
    if path is None:
        path = []
    for key in b:
        if key in a:
            if isinstance(a[key], dict) and isinstance(b[key], dict):
                merge(a[key], b[key], path + [str(key)])
            elif a[key] == b[key]:
                pass  # same leaf value
            else:
                a[key] = b[key]
        else:
            a[key] = b[key]
    return a


def split(a: typing.Any):
    if isinstance(a, dict):
        for key in list(a):
            value = a.pop(key)
            value = split(value)
            keys = [k for k in key.split("/") if len(k) > 0]
            for kk in reversed(keys[1:]):
                value = {kk: value}
            if keys[0] not in a:
                a[keys[0]] = value
            else:
                merge(a[keys[0]], value)
    return a


class SingleProcess(eagerx.Backend):

    BACKEND = "SINGLE_PROCESS"
    DISTRIBUTED_SUPPORT = False
    MULTIPROCESSING_SUPPORT = False
    COLAB_SUPPORT = True

    MIN_THREADS = 10

    @classmethod
    def make(cls, log_level=WARN) -> eagerx.specs.BackendSpec:
        spec = cls.get_specification()
        spec.config.log_level = log_level if isinstance(log_level, str) else eagerx.get_log_level()
        return spec

    def initialize(self):
        self._backend = self
        self._pserver = dict()
        self._topics = dict()
        self._cond = threading.Condition()
        self._tpool = ThreadPoolExecutor(max_workers=max(self.MIN_THREADS, cpu_count()))

    def Publisher(self, address: str, dtype: str):
        return _Publisher(self._backend, self._tpool, self._topics, self._cond, address, dtype)

    def Subscriber(self, address: str, dtype: str, callback, header: bool = False, callback_args=tuple()):
        return _Subscriber(
            self._backend, self._topics, self._cond, address, dtype, callback, header, callback_args=callback_args
        )

    def register_environment(self, name: str, force_start: bool, fn: typing.Callable):
        return _ShutdownService()

    def delete_param(self, param: str, level: int = 1) -> None:
        try:
            keys = [k for k in param.split("/") if len(k) > 0]
            val = self._pserver
            for key in keys[:-1]:
                val = val[key]
            val.pop(keys[-1])
            self.loginfo(f'Parameters under namespace "{param}" deleted.')
        except KeyError as e:
            if level == 0:
                raise BackendException(e)
            elif level == 1:
                self.logwarn(e)
            else:
                pass

    def upload_params(self, ns: str, values: typing.Dict, verbose: bool = False) -> None:
        values = copy.deepcopy(values)
        ns = [k for k in ns.split("/") if len(k) > 0]
        ns_values = split(values)
        for k in ns:
            ns_values = {k: ns_values}
        merge(self._pserver, ns_values)

    def get_param(self, name: str, default: typing.Any = _unspecified):
        try:
            keys = [k for k in name.split("/") if len(k) > 0]
            val = self._pserver
            for key in keys:
                val = val[key]
            return val
        except KeyError as e:
            if not isinstance(default, Unspecified):
                return default
            else:
                raise BackendException(e)

    def spin(self):
        raise NotImplementedError(f"Not implemented, because backend '{self.BACKEND}' does not support multiprocessing.")

    def shutdown(self) -> None:
        if not self._has_shutdown:
            self.logdebug("Backend.shutdown() called.")
            self._has_shutdown = True
            self._tpool.shutdown(wait=True)


class _Publisher(Publisher):
    def __init__(
        self,
        backend: eagerx.Backend,
        tpool: ThreadPoolExecutor,
        topics: typing.Dict,
        cond: threading.Condition,
        address: str,
        dtype: str,
    ):
        super().__init__(backend)
        self._tpool = tpool
        self._cond = cond
        self._topics = topics
        with self._cond:
            if address not in topics:
                self._topic = dict(pubs=0, subs=[], latched=None, dtype=dtype)
                topics[address] = self._topic
            else:
                self._topic = topics[address]
                assert self._topic["dtype"] == dtype, f"Dtypes do not match for topic {address}."

            # Increase publisher count
            self._topic["pubs"] += 1

            self._address = address
            self._dtype = dtype
            self._name = f"{self._address}"
            self._unregistered = False

    def _publish(self, msg: typing.Union[float, bool, int, str, np.ndarray, np.number], header: Header) -> None:
        if not self._unregistered:
            # todo: check if dtype(msg) == self._dtype?
            # Convert python native types to numpy arrays.
            if isinstance(msg, float):
                msg = np.array(msg, dtype="float32")
            elif isinstance(msg, int) and not isinstance(msg, bool):
                msg = np.array(msg, dtype="int64")

            # Check if message complies with space
            if not isinstance(msg, (np.ndarray, np.number, str, bool)):
                self._backend.logerr(f"[publisher][{self._name}]: type(recv)={type(msg)}")
                time.sleep(10000000)

            # with self._cond:  # todo: needed?
            for cb in self._topic["subs"]:
                self._tpool.submit(cb, msg, header)
            self._topic["latched"] = msg, header

    def unregister(self) -> None:
        if not self._unregistered:
            with self._cond:
                self._unregistered = True
                assert self._topic["pubs"] > 0, "According to the counter, there should be no publishers left for this topic."
                self._topic["pubs"] -= 1

                # If no other subscribers or publishers, remove topic.
                if len(self._topic["subs"]) == 0 and self._topic["pubs"] == 0:
                    self._topics.pop(self._address)


class _Subscriber(Subscriber):
    def __init__(
        self,
        backend: eagerx.Backend,
        topics: typing.Dict,
        cond: threading.Condition,
        address: str,
        dtype: str,
        callback,
        header: bool,
        callback_args=tuple(),
    ):
        super().__init__(backend, header)
        self._cond = cond
        self._topics = topics
        with self._cond:
            if address not in topics:
                self._topic = dict(pubs=0, subs=[], latched=None, dtype=dtype)
                topics[address] = self._topic
                latched = None
            else:
                self._topic = topics[address]
                assert self._topic["dtype"] == dtype, f"Dtypes do not match for topic {address}."
                latched = self._topic["latched"]

            self._unregistered = False
            self._topic["subs"].append(self.callback)
            self._address = address
            self._dtype = dtype
            self._cb_wrapped = callback
            self._cb_args = callback_args
            self._name = f"{self._address}"

        if latched is not None:
            self._backend.logdebug(f"LATCHED: {self._address}")
            self.callback(*latched)  # todo: inside cond?

    def callback(self, msg, header):
        # todo: pass header to callback in publisher
        # todo: pass header to wrapped callback if self._header
        if not self._unregistered:
            # todo: check if dtype(msg) == self._dtype?
            if not isinstance(msg, (np.ndarray, np.number, str, bool)):
                self._backend.logerr(f"[subscriber][{self._name}]: type(recv)={type(msg)}")
                time.sleep(10000000)
            self._cb_wrapped(msg, header, *self._cb_args) if self._header else self._cb_wrapped(msg, *self._cb_args)

    def unregister(self) -> None:
        if not self._unregistered:
            with self._cond:
                self._unregistered = True

                self._topic["subs"] = [cb for cb in self._topic["subs"] if not id(cb) == id(self.callback)]

                # If no other subscribers or publishers, remove topic.
                if len(self._topic["subs"]) == 0 and self._topic["pubs"] == 0:
                    self._topics.pop(self._address)


class _ShutdownService(ShutdownService):
    def __init__(self):
        pass

    def unregister(self):
        pass