eagerx/core/rx_message_broker.py
# IMPORT RX
from typing import Any, TYPE_CHECKING
import rx.disposable
# IMPORT OTHER
from rx import Observable, create
from rx.disposable import Disposable
from termcolor import cprint
import types
from functools import wraps
from threading import Condition
# IMPORT EAGERX
from eagerx.core.constants import DEBUG
if TYPE_CHECKING:
from eagerx.core.entities import Backend # noqa: F401
def thread_safe_wrapper(func, condition):
@wraps(func)
def wrapped(*args, **kwargs):
with condition:
return func(*args, **kwargs)
return wrapped
class RxMessageBroker(object):
def __init__(self, owner, backend: "Backend"):
self.owner = owner
self.backend = backend
# Determine log_level
self.effective_log_level = backend.log_level
# Ensure that we are not reading and writing at the same time.
self.cond = Condition()
# Structured as outputs[address][node_name] = {rx=Subject, node_name=node_name, source=RxOutput(...), etc..}
self.rx_connectable = dict()
# Structured as node_io[node_name][type][address] = {rx=Subject, disposable=rx_disposable, etc..}
self.node_io = dict()
self.disconnected = dict()
self.connected_bnd = dict()
self.connected_rx = dict()
# All publishers and subscribers (grouped to unregister when shutting down)
self._publishers = []
self.subscribers = []
self.disposables = []
# Every method is wrapped in a 'with Condition' block in order to be threadsafe
def __getattribute__(self, name):
attr = super(RxMessageBroker, self).__getattribute__(name)
if isinstance(attr, types.MethodType):
attr = thread_safe_wrapper(attr, self.cond)
return attr
def add_rx_objects(
self,
node_name,
node=None,
inputs=tuple(),
outputs=tuple(),
feedthrough=tuple(),
state_inputs=tuple(),
state_outputs=tuple(),
targets=tuple(),
node_inputs=tuple(),
node_outputs=tuple(),
disposable: rx.disposable.CompositeDisposable = None,
):
# Add disposable
if disposable:
self.disposables.append(disposable)
# Determine tick address
if node:
ns = node.ns
else:
ns = self.node_io[node_name]["node"].ns
tick_address = ns + "/engine/outputs/tick"
# Only add outputs that we would like to link with rx (i.e., skipping backend serialization)
for i in outputs:
if i["address"] == tick_address:
continue
assert i["address"] not in self.rx_connectable, (
"Non-unique output (%s). All output names must be unique." % i["address"]
)
self.rx_connectable[i["address"]] = dict(rx=i["msg"], source=i, node_name=node_name, rate=i["rate"])
# Register all I/O of node
if node_name not in self.node_io:
assert node is not None, 'No reference to Node "%s" was provided, during the first attempt to register it.'
# Prepare io dictionaries
self.node_io[node_name] = dict(
node=node,
inputs={},
outputs={},
feedthrough={},
state_inputs={},
state_outputs={},
targets={},
node_inputs={},
node_outputs={},
)
self.disconnected[node_name] = dict(inputs={}, feedthrough={}, state_inputs={}, targets={}, node_inputs={})
self.connected_bnd[node_name] = dict(inputs={}, feedthrough={}, state_inputs={}, targets={}, node_inputs={})
self.connected_rx[node_name] = dict(inputs={}, feedthrough={}, state_inputs={}, targets={}, node_inputs={})
n = dict(
inputs={},
outputs={},
feedthrough={},
state_inputs={},
state_outputs={},
targets={},
node_inputs={},
node_outputs={},
)
for i in inputs:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "inputs")
self._assert_already_registered(cname_address, n, "inputs")
n["inputs"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"processor": i["processor"],
"window": i["window"],
"status": "disconnected",
}
n["inputs"][cname_address + "/reset"] = {
"rx": i["reset"],
"disposable": None,
"source": i,
"dtype": "int64",
"status": "disconnected",
}
for i in outputs:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "outputs")
self._assert_already_registered(cname_address, n, "outputs")
n["outputs"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"rate": i["rate"],
"processor": i["processor"],
"status": "",
}
n["outputs"][cname_address + "/reset"] = {
"rx": i["reset"],
"disposable": None,
"source": i,
"dtype": "int64",
"status": "",
}
# Create publisher
i["msg_pub"] = self.backend.Publisher(i["address"], i["dtype"])
d = i["msg"].subscribe(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self.disposables.append(d)
self._publishers.append(i["msg_pub"])
i["reset_pub"] = self.backend.Publisher(i["address"] + "/reset", n["outputs"][cname_address + "/reset"]["dtype"])
d = i["reset"].subscribe(
on_next=i["reset_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self.disposables.append(d)
self._publishers.append(i["reset_pub"])
for i in feedthrough:
address = i["address"]
cname_address = f"{i['feedthrough_to']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "feedthrough")
self._assert_already_registered(cname_address, n, "feedthrough")
n["feedthrough"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"processor": i["processor"],
"window": i["window"],
"status": "disconnected",
}
n["feedthrough"][cname_address + "/reset"] = {
"rx": i["reset"],
"disposable": None,
"source": i,
"dtype": "int64",
"status": "disconnected",
}
for i in state_outputs:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "state_outputs")
self._assert_already_registered(cname_address, n, "state_outputs")
n["state_outputs"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"status": "",
}
if "processor" in i:
n["state_outputs"][cname_address]["processor"] = i["processor"]
# Create publisher
i["msg_pub"] = self.backend.Publisher(i["address"], i["dtype"])
d = i["msg"].subscribe(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self.disposables.append(d)
self._publishers.append(i["msg_pub"])
for i in state_inputs:
address = i["address"]
try:
cname_address = f"{i['name']}:{address}"
except KeyError:
cname_address = f"done_flag:{address}"
if "msg" in i: # Only true if sim state node (i.e. **not** for engine done flags)
self._assert_already_registered(cname_address + "/set", self.node_io[node_name], "state_inputs")
self._assert_already_registered(cname_address + "/set", n, "state_inputs")
n["state_inputs"][cname_address + "/set"] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"processor": i["processor"],
"status": "disconnected",
}
# Only true if **not** a real reset node (i.e., sim state & engine done flag)
if (cname_address + "/done") not in n["state_outputs"].keys():
self._assert_already_registered(cname_address + "/done", self.node_io[node_name], "state_inputs")
self._assert_already_registered(cname_address + "/done", n, "state_inputs")
n["state_inputs"][cname_address + "/done"] = {
"rx": i["done"],
"disposable": None,
"source": i,
"dtype": "bool",
"status": "disconnected",
}
for i in targets:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address + "/set", self.node_io[node_name], "targets")
self._assert_already_registered(cname_address + "/set", n, "targets")
n["targets"][cname_address + "/set"] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"processor": i["processor"],
"status": "disconnected",
}
for i in node_inputs:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "node_inputs")
self._assert_already_registered(cname_address, n, "node_inputs")
n["node_inputs"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"status": "disconnected",
}
for i in node_outputs:
address = i["address"]
cname_address = f"{i['name']}:{address}"
self._assert_already_registered(cname_address, self.node_io[node_name], "node_outputs")
self._assert_already_registered(cname_address, n, "node_outputs")
n["node_outputs"][cname_address] = {
"rx": i["msg"],
"disposable": None,
"source": i,
"dtype": i["dtype"],
"status": "",
}
# Create publisher: (latched: register, node_reset, start_reset, reset, real_reset)
i["msg_pub"] = self.backend.Publisher(i["address"], i["dtype"])
d = i["msg"].subscribe(
on_next=i["msg_pub"].publish,
on_error=lambda e: print("Error : {0}".format(e)),
)
self.disposables.append(d)
self._publishers.append(i["msg_pub"])
# Add new addresses to already registered I/Os
for key in n.keys():
self.node_io[node_name][key].update(n[key])
# Add new addresses to disconnected
for key in ("inputs", "feedthrough", "state_inputs", "targets", "node_inputs"):
self.disconnected[node_name][key].update(n[key].copy())
def print_io_status(self, node_names=None):
# Only print status for specific node
if node_names is None:
node_names = self.node_io.keys()
else:
if isinstance(node_names, str):
node_names = [node_names]
# Print status
for node_name in node_names:
cprint(
('OWNER "%s"' % self.owner).ljust(15, " ") + ('| OVERVIEW NODE "%s" ' % node_name).ljust(180, " "),
attrs=["bold", "underline"],
)
for key in (
"inputs",
"feedthrough",
"state_inputs",
"targets",
"node_inputs",
"outputs",
"state_outputs",
"node_outputs",
):
if len(self.node_io[node_name][key]) == 0:
continue
for cname_address in self.node_io[node_name][key].keys():
color = None
if key in (
"outputs",
"node_outputs",
"state_outputs",
):
color = "cyan"
else:
if cname_address in self.disconnected[node_name][key]:
color = "red"
if cname_address in self.connected_rx[node_name][key]:
assert color is None, f"Duplicate connection status for address ({cname_address})."
color = "green"
if cname_address in self.connected_bnd[node_name][key]:
assert color is None, f"Duplicate connection status for address ({cname_address})."
color = "blue"
assert (
color is not None
), "Address (cname_address) not found in self.(disconnected, connected_rx, connected_bnd)."
status = self.node_io[node_name][key][cname_address]["status"]
# Print status
entry = self.node_io[node_name][key][cname_address]
key_str = ("%s" % key).ljust(15, " ")
address_str = ("| %s " % cname_address).ljust(50, " ")
dtype_str = ("| %s " % entry["dtype"]).ljust(10, " ")
if "processor" in entry:
processor_str = ("| %s " % entry["processor"].__class__.__name__).ljust(23, " ")
else:
processor_str = ("| %s " % "").ljust(23, " ")
if "window" in entry:
window_str = ("| %s " % entry["window"]).ljust(8, " ")
else:
window_str = ("| %s " % "").ljust(8, " ")
if "rate" in entry:
rate_str = "|" + ("%s" % entry["rate"]).center(3, " ")
else:
rate_str = "|" + "".center(3, " ")
status_str = ("| %s" % status).ljust(60, " ")
log_msg = key_str + rate_str + address_str + dtype_str + processor_str + window_str + status_str
cprint(log_msg, color)
print(" ".center(140, " "))
def connect_io(self, print_status=True):
# If log_level is not high enough, overwrite print_status
if self.effective_log_level > DEBUG:
print_status = False
for node_name, node in self.disconnected.items():
# Skip if no disconnected addresses
num_disconnected = 0
for _key, addresses in node.items():
num_disconnected += len(addresses)
if num_disconnected == 0:
continue
# Else, initialize connection
print_status and cprint(
('OWNER "%s"' % self.owner).ljust(15, " ") + ('| CONNECTING NODE "%s" ' % node_name).ljust(180, " "),
attrs=["bold", "underline"],
)
for key, addresses in node.items():
for cname_address in list(addresses.keys()):
_, address = self._split_cname_address(cname_address)
entry = addresses[cname_address]
assert (
cname_address not in self.connected_rx[node_name][key]
), f"Address ({cname_address}) of this node ({node_name}) already connected via rx."
assert (
cname_address not in self.connected_bnd[node_name][key]
), f"Address ({cname_address}) of this node ({node_name}) already connected via backend."
if address in self.rx_connectable.keys():
color = "green"
status = "Rx".ljust(4, " ")
entry["rate"] = self.rx_connectable[address]["rate"]
rate_str = f"|{str(entry['rate']).center(3, ' ')}"
node_str = f'| {self.rx_connectable[address]["node_name"].ljust(40, " ")}'
dtype_str = f'| {self.rx_connectable[address]["source"]["dtype"]}'.ljust(12, " ")
processor_str = f'| {self.rx_connectable[address]["source"]["processor"].__class__.__name__}'.ljust(
12, " "
)
status += node_str + dtype_str + processor_str
self.connected_rx[node_name][key][cname_address] = entry
T = self.rx_connectable[address]["rx"]
else:
color = "blue"
status = f"{self.backend.BACKEND} |".ljust(5, " ")
rate_str = "|" + "".center(3, " ")
dtype = entry["dtype"]
self.connected_bnd[node_name][key][cname_address] = entry
T = from_topic(self.backend, dtype, address, node_name, self.subscribers)
# Subscribe and change status
entry["disposable"] = T.subscribe(entry["rx"])
self.disposables.append(entry["disposable"])
entry["status"] = status
# Print status
key_str = ("%s" % key).ljust(15, " ")
address_str = ("| %s " % cname_address).ljust(50, " ")
dtype_str = ("| %s " % entry["dtype"]).ljust(10, " ")
status_str = ("| Connected via %s" % status).ljust(60, " ")
if "processor" in entry:
processor_str = ("| %s " % entry["processor"].__class__.__name__).ljust(23, " ")
else:
processor_str = ("| %s " % "").ljust(23, " ")
if "window" in entry:
window_str = ("| %s " % entry["window"]).ljust(8, " ")
else:
window_str = ("| %s " % "").ljust(8, " ")
log_msg = key_str + rate_str + address_str + dtype_str + processor_str + window_str + status_str
print_status and cprint(log_msg, color)
# Remove address from disconnected
addresses.pop(cname_address)
print_status and print("".center(140, " "))
def _split_cname_address(self, cname_address):
res = cname_address.split(":")
if len(res) == 2:
cname, address = res
else:
cname, address = None, res[0]
return cname, address
def _assert_already_registered(self, name, d, component):
assert name not in d[component], f'Cannot re-register the same address ({name}) twice as "{component}".'
def shutdown(self):
self.backend.logdebug(f"[{self.owner}] RxMessageBroker.shutdown() called.")
[d.dispose() for d in self.disposables]
[pub.unregister() for pub in self._publishers]
[sub.unregister() for sub in self.subscribers]
def from_topic(bnd: "Backend", dtype: Any, address: str, node_name, subscribers: list) -> Observable:
def _subscribe(observer, scheduler=None) -> Disposable:
try:
def cb_from_topic(msg):
observer.on_next(msg)
sub = bnd.Subscriber(address, dtype, callback=cb_from_topic)
subscribers.append(sub)
except Exception as e:
bnd.logwarn("[%s]: %s" % (node_name, e))
raise e
return observer
return create(_subscribe)