eagerx/core/rx_operators.py
# RX IMPORTS
import rx
from rx import Observable, typing, operators as ops
from rx.disposable import Disposable, SingleAssignmentDisposable, CompositeDisposable
from rx.subject import Subject, BehaviorSubject
from rx.internal.concurrency import synchronized
# EAGERX IMPORTS
import eagerx.utils.utils
from eagerx.core.constants import ( # noqa
SILENT,
DEBUG,
INFO,
ERROR,
WARN,
FATAL,
)
from eagerx.utils.utils import (
Info,
Msg,
Stamp,
)
# OTHER IMPORTS
import time
from math import floor
from collections import deque
from termcolor import colored
import datetime
import traceback
from os import getpid
from threading import current_thread, RLock
from typing import Callable, Optional, List, Any
import numpy as np
def cb_ft(cb_input, sync):
# Fill output msg with number of node ticks
output_msgs = dict()
for key, msg in cb_input.items():
if key not in ["node_tick", "t_n"]:
if len(msg.msgs) > 0:
output_msgs[key] = msg.msgs[-1]
else:
assert not sync, "Actions must always be fed through if we are running reactively."
output_msgs[key] = None
return output_msgs
def print_info(
node_name,
color,
id=None,
trace_type=None,
value=None,
date=None,
log_level=DEBUG,
):
msg = ""
if date:
msg += f"[{str(date)[:40].ljust(20)}]"
# Add process ID
msg += f"[{str(getpid())[:5].ljust(5)}]"
# Add thread ID
msg += f"[{current_thread().name.split('/')[-1][:15].ljust(15)}]"
msg += f"[{node_name.split('/')[-1][:12].ljust(12)}]"
if id:
msg += f"[{id.split('/')[-1][:12].ljust(12)}]"
msg += f" {trace_type}: {value}\n"
print(colored(msg, color), end="")
def spy(id: str, node, log_level: int = DEBUG, mapper: Callable = lambda msg: msg):
node_name = node.ns_name
color = node.color
effective_log_level = node.backend.log_level
def _spy(source):
def subscribe(observer, scheduler=None):
def on_next(value):
if node.log_level >= effective_log_level and log_level >= effective_log_level:
print_info(
node_name,
color,
id,
trace_type="",
value=str(mapper(value)),
log_level=log_level,
)
observer.on_next(value)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _spy
def trace_observable(
id: str,
node,
trace_next=False,
trace_next_payload=False,
trace_subscribe=False,
date=None,
): # pragma: no cover
node_name = node.ns_name
color = node.color
def _trace(source):
def on_subscribe(observer, scheduler):
def on_next(value):
if trace_next is True:
if trace_next_payload is True:
print_info(
node_name,
color,
id,
"on_next",
value,
date=date or datetime.datetime.now(),
log_level=DEBUG,
)
else:
print_info(
node_name,
color,
id,
"on_next",
"",
date=date or datetime.datetime.now(),
log_level=DEBUG,
)
observer.on_next(value)
def on_completed():
value = ""
print_info(
node_name,
color,
id,
"on_completed",
value,
date=date or datetime.datetime.now(),
log_level=DEBUG,
)
observer.on_completed()
def on_error(error):
if isinstance(error, Exception):
error_traceback = "%s, %s" % (
error,
traceback.print_tb(error.__traceback__),
)
print_info(
node_name,
color,
id,
"on_error",
error_traceback,
date=date or datetime.datetime.now(),
)
else:
print_info(
node_name,
color,
id,
"on_error",
error,
date=date or datetime.datetime.now(),
log_level=ERROR,
)
observer.on_error(error)
def dispose():
if trace_subscribe is True:
value = ""
print_info(
node_name,
color,
id,
"dispose",
value,
date=date or datetime.datetime.now(),
log_level=DEBUG,
)
disposable.dispose()
if trace_subscribe is True:
value = ""
print_info(
node_name,
color,
id,
"on_subscribe",
value,
date=date or datetime.datetime.now(),
log_level=DEBUG,
)
disposable = source.subscribe(
on_next=on_next,
on_error=on_error,
on_completed=on_completed,
)
return Disposable(dispose)
return rx.create(on_subscribe)
return _trace
def flag_dict(name):
def _init_flag_dict(source):
def subscribe(observer, scheduler=None):
def on_next(value):
flag_dict = {name: value}
observer.on_next(flag_dict)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _init_flag_dict
def filter_dict():
def _filter_dict(source):
def subscribe(observer, scheduler=None):
def on_next(value):
d = dict()
for node_name, flag in value.items():
if flag is True:
continue
d[node_name] = flag
observer.on_next(d)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _filter_dict
def switch_to_reset():
def _switch_to_reset(source):
def subscribe(observer, scheduler=None):
reset_mode = [False]
def on_next(value):
if isinstance(value, bool):
reset_mode[0] = True # if we receive a reset flag, turn on reset mode
observer.on_next(value)
else:
if not reset_mode[0]: # Check if we haven't previously received a reset message
observer.on_next(value)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _switch_to_reset
def create_msg_tuple(name: str, node_tick: int, msg: List[Any], stamp: List[Stamp], done: bool = None):
info = Info(name=name, node_tick=node_tick, t_in=stamp, done=done)
return Msg(info, msg)
def remap_state(name, sync, real_time_factor):
def _remap_state(source):
def subscribe(observer, scheduler=None):
start = time.time()
seq = [0]
def on_next(value):
node_tick = value[0][0]
msg = value[0][1]
done = value[1]
wc = time.time()
if sync:
sc = None
else:
sc = (wc - start) / real_time_factor
stamp = Stamp(seq[0], sc, wc)
res = create_msg_tuple(name, node_tick, [msg], [stamp], done=done)
seq[0] += 1
observer.on_next(res)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _remap_state
def remap_target(name, sync, real_time_factor):
def _remap_target(source):
def subscribe(observer, scheduler=None):
start = time.time()
seq = [0]
def on_next(value):
node_tick = value[0]
msg = value[1]
wc = time.time()
if sync:
sc = None
else:
sc = (wc - start) / real_time_factor
stamp = Stamp(seq[0], sc, wc)
res = create_msg_tuple(name, node_tick, [msg], [stamp])
seq[0] += 1
observer.on_next(res)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _remap_target
def filter_info_for_printing(info):
info_dict = dict()
if info.rate_in:
info_dict["rate_in"] = info.rate_in
info_dict["t_in"] = [t.sc for t in info.t_in]
if info.t_node:
info_dict["t_node"] = [t.sc for t in info.t_node]
if info.done:
info_dict["done"] = info.done
return info_dict
def remap_cb_input(mode=0):
def _remap_cb_input(value):
# mode=0 (info only), mode=1 (msgs only), mode=2 (all)
if mode == 2:
return value
if isinstance(value, tuple):
mapped_value = tuple([value[0].copy(), value[1].copy()])
for i in mapped_value:
for key, msg in i.items():
if key not in ["node_tick", "t_n"]:
if mode == 0:
i[key] = filter_info_for_printing(msg.info)
else:
i[key] = msg.msgs
else:
mapped_value = value.copy()
for key, msg in mapped_value.items():
if key not in ["node_tick", "t_n"]:
if mode == 0:
mapped_value[key] = filter_info_for_printing(msg.info)
else:
mapped_value[key] = msg.msgs
return mapped_value
return _remap_cb_input
def regroup_inputs(node, rate_node=1, is_input=True, perform_checks=True):
node_name = node.ns_name
color = node.color
def _regroup_inputs(source):
def subscribe(observer, scheduler=None):
def on_next(value):
# Regroups all inputs into a single dict
if is_input:
node_tick = value[0].info.node_tick
t_n = round(node_tick / rate_node, 12)
res = dict(node_tick=node_tick, t_n=t_n)
else:
res = dict()
for msg in value:
res[msg.info.name] = msg
# Perform checks
if perform_checks and is_input:
node_ticks = []
for msg in value:
node_ticks.append(msg.info.node_tick)
if len(node_ticks) > 0:
if not len(set(node_ticks)) == 1:
print_info(
node_name,
color,
"regroup_inputs",
trace_type="",
value="Not all node_ticks are the same: %s" % str(value),
log_level=ERROR,
)
# Send regrouped input
observer.on_next(res)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _regroup_inputs
def expected_inputs(N_out, rate_in, rate_out, delay, skip: bool):
# In case of skip=True and the output rate is an exact multiple of the input rate, we need to add a slight offset
eps = 1e-9
# Constants
offset = int(skip) * (int((rate_out - eps) / rate_in) if rate_out > rate_in else -1)
# Alternative numerically unstable implementation
# dt_out, dt_in = 1 / rate_out, 1 / rate_in
# t_prev = dt_out * (N_out - 1 + offset) - delay
# t = dt_out * (N_out + offset) - delay
# N_in_prev = int(t_prev / dt_in)
# N_in = int(t / dt_in)
# delta = N_in - N_in_prev
# j = (delta - 1 + int(not skip))
# T = dt_out * N_out - delay - j * dt_in
# correction = ceil(-T / dt_in)
# Numerically stable implementation
N_in_prev = int((rate_in * (N_out + offset - 1) - rate_out * rate_in * delay) // rate_out)
N_in = int((rate_in * (N_out + offset) - rate_out * rate_in * delay) // rate_out)
# Alternative (iterative) delay correction
# delta = N_in - N_in_prev
# for i in range(int(not skip), delta + int(not skip)):
# # Alternative numerically unstable implementation
# # t_in_delayed = dt_out * N_out - delay - i * dt_in
# # Numerically stable implementation
# t_in_delayed = (N_out * rate_in - delay * rate_out * rate_in - i * rate_out) / rate_in
# if t_in_delayed < 0:
# delta -= 1
# Alternative delay correction
delta = N_in - N_in_prev
j = delta - 1 + int(not skip)
# Numerically stable implementation
T = (rate_in * N_out - rate_out * rate_in * delay - rate_out * j) / (rate_out * rate_in)
correction = -floor(T * rate_in)
corrected = delta - min(delta, max(0, correction)) # limits as follows: 0 < correction < delta
# Overwrite t=0 dependencies, because there are none when skipping.
num_est = corrected if N_out > 0 else int(not skip)
return num_est
def generate_msgs(
source_Nc: Observable,
rate_node: float,
name: str,
rate_in: float,
params: dict,
sync: bool,
real_time_factor: float,
simulate_delays: bool,
node=None,
):
dt_i = 1 / rate_in
def _generate_msgs(source_msg: Observable):
window = params["window"]
skip = int(params["skip"])
def subscribe(observer: typing.Observer, scheduler: Optional[typing.Scheduler] = None) -> CompositeDisposable:
start = time.time()
msgs_queue: List = []
t_i_queue: List = []
num_queue: List = []
tick_queue: List = []
msgs_window = deque(maxlen=window)
t_i_window = deque(maxlen=window)
t_n_window = deque(maxlen=window)
lock = RLock()
@synchronized(lock)
def next(i):
if len(tick_queue) > 0:
if not sync or len(msgs_queue) >= num_queue[0]:
try:
tick = tick_queue.pop(0)
if sync:
# determine num_msgs
num_msgs = num_queue.pop(0)
msgs = msgs_queue[:num_msgs]
t_i = t_i_queue[:num_msgs]
msgs_queue[:] = msgs_queue[num_msgs:]
t_i_queue[:] = t_i_queue[num_msgs:]
else: # Empty complete buffer
msgs = msgs_queue.copy()
t_i = t_i_queue.copy()
msgs_queue[:] = []
t_i_queue[:] = []
except Exception as ex: # pylint: disable=broad-except
observer.on_error(ex)
return
# Determine t_n stamp
wc = time.time()
seq = tick
if sync:
sc = round(tick / rate_node, 12)
else:
sc = (wc - start) / real_time_factor
t_n = Stamp(seq, sc, wc)
if window > 0:
msgs_window.extend(msgs)
t_i_window.extend(t_i)
t_n_window.extend([t_n] * len(msgs))
wmsgs = list(msgs_window)
wt_i = list(t_i_window)
wt_n = list(t_n_window)
else:
wmsgs = msgs
wt_i = t_i
wt_n = [t_n] * len(msgs)
res = Msg(Info(name, tick, rate_in, wt_n, wt_i, None), wmsgs)
observer.on_next(res)
# Determine Nc logic
def on_next_Nc(x):
if sync:
# Calculate expected number of message to be received
delay = params["delay"] if simulate_delays else 0.0
num_msgs = expected_inputs(x, rate_in, rate_node, delay, bool(skip))
num_queue.append(num_msgs)
tick_queue.append(x)
next(x)
subscriptions = []
sad = SingleAssignmentDisposable()
sad.disposable = source_Nc.subscribe(on_next_Nc, observer.on_error, observer.on_completed, scheduler)
subscriptions.append(sad)
def on_next_msg(x):
msgs_queue.append(x[1])
wc = time.time()
seq = x[0]
if sync:
sc = round(x[0] * dt_i, 12)
else:
sc = (wc - start) / real_time_factor
t_i_queue.append(Stamp(seq, sc, wc))
next(x)
sad = SingleAssignmentDisposable()
if not sync and simulate_delays:
source_msg_delayed = source_msg.pipe(ops.delay(params["delay"] / real_time_factor))
else:
source_msg_delayed = source_msg
sad.disposable = source_msg_delayed.subscribe(on_next_msg, observer.on_error, observer.on_completed, scheduler)
subscriptions.append(sad)
return CompositeDisposable(subscriptions)
return rx.create(subscribe)
return _generate_msgs
def create_channel(
ns,
Nc,
rate_node,
inpt,
sync,
real_time_factor,
simulate_delays,
E,
scheduler,
is_feedthrough,
node,
):
if is_feedthrough:
name = inpt["feedthrough_to"]
else:
name = inpt["name"]
# Readable format
Is = inpt["reset"]
Ir = inpt["msg"].pipe(
convert(inpt["space"], inpt["processor"], name, "inputs", node, direction="in"),
# ops.combine_latest(E), # Throttle with end_reset
# ops.map(lambda x: x[0]), # Only pass through message
ops.observe_on(scheduler),
ops.scan(lambda acc, x: (acc[0] + 1, x), (-1, None)),
ops.share(),
)
# Get rate from rosparam server
rate_str = "%s/rate/%s" % (ns, inpt["address"][len(ns) + 1 :])
rate = eagerx.utils.utils.get_param_with_blocking(rate_str, node.backend)
# Create input channel
if real_time_factor == 0:
Nc = Nc.pipe(ops.observe_on(scheduler), ops.start_with(0))
else:
Nc = Nc.pipe(ops.observe_on(scheduler))
channel = Ir.pipe(
generate_msgs(
Nc,
rate_node,
name,
rate,
params=inpt,
sync=sync,
real_time_factor=real_time_factor,
simulate_delays=simulate_delays,
node=node,
),
ops.share(),
)
# Create reset flag
flag = Ir.pipe(
ops.map(lambda val: val[0] + 1),
ops.start_with(0),
ops.combine_latest(Is), # Depends on ROS reset msg type
ops.filter(lambda value: not sync or value[0] == value[1]),
ops.map(lambda x: {name: x[0]}),
)
return channel, flag
def init_channels(
ns,
Nc,
rate_node,
inputs,
sync,
real_time_factor,
simulate_delays,
E,
scheduler,
node,
is_feedthrough=False,
):
# Create channels
channels = []
flags = []
for i in inputs:
channel, flag = create_channel(
ns,
Nc,
rate_node,
i,
sync,
real_time_factor,
simulate_delays,
E,
scheduler,
is_feedthrough,
node,
)
channels.append(channel)
if is_feedthrough:
name = i["address"]
else:
name = i["name"]
flag_name = "flag [%s]" % name.split("/")[-1][:12].ljust(4)
# flag.pipe(spy(f'sub_{flag_name}', node)).subscribe(print)
flag = flag.pipe(spy(flag_name, node)) # , ops.take(1), ops.merge(rx.never()))
flags.append(flag)
zipped_flags = rx.zip(*flags).pipe(ops.map(lambda x: merge_dicts({}, x)))
zipped_channels = rx.zip(*channels).pipe(
ops.combine_latest(
E.pipe(ops.observe_on(scheduler))
), # Latch output on '/end_reset' --> Can only receive 1 each episode.
ops.map(lambda x: x[0]),
regroup_inputs(node, rate_node=rate_node),
ops.share(),
)
return zipped_channels, zipped_flags
def init_real_reset(
ns,
Nc,
rate_node,
RR,
real_reset,
feedthrough,
targets,
sync,
real_time_factor,
simulate_delays,
E,
scheduler,
node,
):
# Create real reset pipeline
dispose = []
if real_reset:
for i in feedthrough:
rate_str = "%s/rate/%s" % (ns, i["address"][len(ns) + 1 :])
rate_in = eagerx.utils.utils.get_param_with_blocking(rate_str, node.backend)
if not rate_in == rate_node:
raise ValueError(
"Rate of the reset node (%s) must be exactly the same as the feedthrough node rate (%s)."
% (rate_node, rate_in)
)
# Create zipped action channel
zipped_channels, zipped_flags = init_channels(
ns,
Nc,
rate_node,
feedthrough,
sync,
real_time_factor,
simulate_delays,
E,
scheduler,
node,
is_feedthrough=True,
)
# Create switch subject
target_signal = rx.zip(*[t["msg"] for t in targets])
RR_ho = BehaviorSubject(zipped_channels)
d_RR_ho = RR.pipe(
ops.combine_latest(target_signal), # make switch logic wait for all targets to be received.
ops.map(lambda x: Nc.pipe(ops.map(lambda x: None), ops.start_with(None))),
).subscribe(RR_ho)
rr_channel = RR_ho.pipe(ops.switch_latest())
# Add disposables
dispose += [RR_ho, d_RR_ho]
else:
# Create switch Subject
zipped_flags = rx.never().pipe(ops.start_with({}))
rr_channel = Nc.pipe(ops.map(lambda x: None), ops.start_with(None))
return rr_channel, zipped_flags, dispose
def init_target_channel(states, scheduler, node):
channels = []
for s in states:
c = s["msg"].pipe(
convert(s["space"], s["processor"], s["name"], "targets", node, direction="in"),
ops.share(),
ops.scan(lambda acc, x: (acc[0] + 1, x), (-1, None)),
remap_target(s["name"], node.sync, node.real_time_factor),
)
channels.append(c)
# HACK!: Why do we sometimes receive the targets twice? And is the first received the new target or the old one?
# In this way, we risk reusing a target twice.
return rx.zip(*channels).pipe(regroup_inputs(node, is_input=False), ops.take(1))
def merge_dicts(dict_1, dict_2):
if isinstance(dict_2, dict):
dict_2 = (dict_2,)
for d in dict_2:
dict_1.update(d)
return dict_1
def init_state_inputs_channel(ns, state_inputs, scheduler, node):
if len(state_inputs) > 0:
channels = []
for s in state_inputs:
d = s["done"].pipe(
ops.scan(lambda acc, x: x if x else acc, False),
)
c = s["msg"].pipe(
convert(s["space"], s["processor"], s["name"], "states", node, direction="in"),
ops.share(),
ops.scan(lambda acc, x: (acc[0] + 1, x), (-1, None)),
ops.start_with((-1, None)),
ops.combine_latest(d),
ops.filter(lambda x: x[0][0] >= 0 or x[1]),
remap_state(s["name"], node.sync, node.real_time_factor),
)
channels.append(c)
return rx.zip(*channels).pipe(regroup_inputs(node, is_input=False), ops.merge(rx.never()))
else:
return rx.never().pipe(ops.start_with(dict()))
def call_state_reset(state):
def _call_state_reset(source):
def subscribe(observer, scheduler=None):
def on_next(state_msg):
try:
state.reset(state=state_msg.msgs[0])
except Exception as e:
observer.on_error(e)
observer.on_next(state_msg)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _call_state_reset
def init_state_resets(ns, state_inputs, trigger, scheduler, tp_scheduler, node):
if len(state_inputs) > 0:
channels = []
for s in state_inputs:
d = s["done"].pipe(
ops.scan(lambda acc, x: x if x else acc, False),
)
c = s["msg"].pipe(
convert(s["space"], s["processor"], s["name"], "states", node, direction="in"),
ops.share(),
ops.scan(lambda acc, x: (acc[0] + 1, x), (-1, None)),
ops.start_with((-1, None)),
ops.combine_latest(d),
ops.filter(lambda x: x[0][0] >= 0 or x[1]),
remap_state(s["name"], node.sync, node.real_time_factor),
)
done, reset = trigger.pipe(
with_latest_from(c),
ops.take(1),
ops.merge(rx.never()),
ops.map(lambda x: x[1]),
ops.partition(lambda x: x.info.done),
)
reset = reset.pipe(ops.observe_on(tp_scheduler), call_state_reset(s["state"]), ops.observe_on(scheduler))
rs = rx.merge(
done.pipe(spy("done [%s]" % s["name"].split("/")[-1][:12].ljust(4), node)),
reset.pipe(spy("reset [%s]" % s["name"].split("/")[-1][:12].ljust(4), node)),
)
channels.append(rs)
return rx.zip(*channels).pipe(regroup_inputs(node, is_input=False), ops.merge(rx.never()))
else:
return rx.never().pipe(ops.start_with(dict()))
def init_callback_pipeline(
ns,
cb_tick,
cb_ft,
stream,
real_reset,
targets,
state_outputs,
outputs,
scheduler,
node,
):
d_msg = []
if real_reset:
target_stream = init_target_channel(targets, scheduler, node)
# Split stream into feedthrough (ft) and reset stream
reset_stream, ft_stream = stream.pipe(ops.partition(lambda x: x[1][1] is None))
# Either feedthrough action or run callback
ft_stream = ft_stream.pipe(
ops.map(lambda x: x[1][1]),
spy("CB_FT", node, log_level=DEBUG, mapper=remap_cb_input(mode=0)),
ops.map(lambda val: cb_ft(val, node.sync)),
ops.share(),
)
reset_stream = reset_stream.pipe(
ops.map(lambda x: x[1][0]),
ops.combine_latest(target_stream),
spy("CB_RESET", node, log_level=DEBUG, mapper=remap_cb_input(mode=0)),
ops.map(lambda val: cb_tick(**val[0], **val[1])),
ops.share(),
)
output_stream = rx.merge(reset_stream, ft_stream)
# Send done flags
for s in state_outputs:
d = reset_stream.pipe(
spy("reset", node, log_level=DEBUG),
ops.pluck(s["name"] + "/done"),
ops.share(),
).subscribe(s["msg"])
# Add disposable
d_msg += [d]
else:
output_stream = stream.pipe(
ops.filter(lambda x: x[1][1] is None),
ops.map(lambda x: x[1][0]),
spy("CB_TICK", node, log_level=DEBUG, mapper=remap_cb_input(mode=0)),
ops.map(lambda val: cb_tick(**val)),
ops.share(),
)
return d_msg, output_stream
def switch_with_check_pipeline(init_ho=None):
if init_ho is None:
stream_ho = Subject()
else:
stream_ho = BehaviorSubject(init_ho)
check, stream = stream_ho.pipe(ops.switch_latest(), ops.partition(lambda event: event is None))
return check, stream, stream_ho
def node_reset_flags(ns, node_flags, node):
flags = [
nf["msg"].pipe(
flag_dict(nf["name"]),
spy("has_reset", node),
ops.start_with({nf["name"]: False}),
)
for nf in node_flags
]
init_dict = dict()
for nf in node_flags:
init_dict[nf["name"]] = False
stream = rx.merge(*flags).pipe(
ops.scan(lambda acc, x: merge_dicts(acc, x), init_dict),
ops.filter(lambda x: any([value for key, value in x.items()])),
filter_dict(),
spy("awaiting", node, log_level=DEBUG),
ops.filter(lambda x: len(x) == 0),
ops.start_with(None),
)
return stream
def filter_dict_on_key(key):
def _filter_dict_on_key(source):
def subscribe(observer, scheduler=None):
def on_next(value):
if key in value:
observer.on_next(value[key])
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _filter_dict_on_key
def throttle_with_time(dt, node, rate_tol: float = 0.95, log_level: int = DEBUG):
time_fn = lambda: time.monotonic_ns() / 1e9 # noqa: E731
node_name = node.ns_name
color = node.color
effective_log_level = node.backend.log_level
log_time = 2 # [s]
def _throttle_with_time(source):
def subscribe(observer, scheduler=None):
# Timing
tic = [None]
cum_cbs = [0]
cum_delay = [0]
cum_sleep = [0]
# Logging
last_cum_cbs = [0]
last_cum_delay = [0]
last_cum_sleep = [0]
last_Nc = [0]
last_time = [None]
def on_next(value):
Nc, start = value
if Nc == 0: # Do not throttle before the first callback
# NOTE: This is the first time we receive a value
sleep_time = 0.0
else: # Determine sleep time
assert tic[0] is not None, "tic is None"
toc = time_fn()
dt_comp = toc - tic[0]
sleep_time = dt - dt_comp # if sleep_time > 0 then we are early, if sleep_time < 0 then we are late
# Throttle callback
if sleep_time > 0: # Sleep if we are early
time.sleep(sleep_time)
cum_sleep[0] += sleep_time
else: # If we are overdue, the proceeed
cum_delay[0] += -sleep_time
cum_cbs[0] += 1
tic[0] = start + Nc * dt
# node.backend.loginfo(colored(f"[{node_name}] Nc: {Nc} | toc: {toc} | sleep_time: {sleep_time} | tic[0]: {tic[0]}", color))
# Logging
curr = time_fn()
last_time[0] = last_time[0] if last_time[0] is not None else curr
if (curr - last_time[0]) > log_time:
# Calculate statistics since last logged instance
log_window = curr - last_time[0]
log_cbs = cum_cbs[0] - last_cum_cbs[0]
# log_delay = cum_delay[0] - last_cum_delay[0]
log_sleep = cum_sleep[0] - last_cum_sleep[0]
log_Nc = Nc - last_Nc[0]
# Log statistics if not keeping rate
Nc_expected = log_window / dt # [ticks]
rate_ratio = log_Nc / Nc_expected
cbs_ratio = log_cbs / log_Nc
sleep_ratio = log_sleep / log_window
# delay_ratio = log_delay / log_window
if rate_ratio < rate_tol and node.log_level >= effective_log_level and WARN >= effective_log_level:
print_str = f"Running at {rate_ratio*100:.2f}% of rate ({1/dt} Hz) | {sleep_ratio*100:.2f}% sleep | {100 - sleep_ratio*100:.2f}% computation | {cbs_ratio*100: .2f}% callbacks delayed |"
print_info(
node_name,
"red",
f"last {log_window:.2f} s",
trace_type="",
value=print_str,
log_level=INFO,
)
elif node.log_level >= effective_log_level and log_level >= effective_log_level:
print_str = f"Running at {rate_ratio*100:.2f}% of rate ({1/dt} Hz) | {sleep_ratio*100:.2f}% sleep | {100 - sleep_ratio*100:.2f}% computation | {cbs_ratio*100: .2f}% callbacks delayed |"
print_info(
node_name,
color,
f"last {log_window:.2f} s",
trace_type="",
value=print_str,
log_level=log_level,
)
# Set baseline statistics
last_time[0] = curr
last_cum_cbs[0] = cum_cbs[0]
last_cum_delay[0] = cum_delay[0]
last_cum_sleep[0] = cum_sleep[0]
last_Nc[0] = Nc
# Send tick for next callback
observer.on_next(Nc)
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _throttle_with_time
def throttle_callback_trigger(rate_node, Nc, E, sync, real_time_factor, scheduler, node):
# return Nc
if sync and real_time_factor == 0:
Nct = Nc
else:
assert (
real_time_factor > 0
), "The real_time_factor must be larger than zero when *not* running reactive (i.e. asychronous)."
wc_dt = 1 / (rate_node * real_time_factor)
Nct = Nc.pipe(
ops.scan(lambda acc, x: acc + 1, 0),
ops.start_with(0),
ops.combine_latest(E), # .pipe(spy("Nct_E", node, log_level=INFO))),
ops.observe_on(scheduler),
# spy("Nct_E_CBL", node, log_level=INFO),
throttle_with_time(wc_dt, node),
ops.share(),
)
return Nct
def with_latest_from(*sources: Observable):
def _with_latest_from(parent: Observable) -> Observable:
NO_VALUE = NotSet()
def subscribe(observer, scheduler=None):
def subscribe_all(parent, *children):
parent_queued = [None]
values = [NO_VALUE for _ in children]
def subscribe_child(i, child):
subscription = SingleAssignmentDisposable()
def on_next(value):
with parent.lock:
values[i] = value
if parent_queued[0] is not None and NO_VALUE not in values:
result = (parent_queued[0],) + tuple(values)
parent_queued[0] = None
observer.on_next(result)
subscription.disposable = child.subscribe_(on_next, observer.on_error, scheduler=scheduler)
return subscription
parent_subscription = SingleAssignmentDisposable()
def on_next(value):
with parent.lock:
if NO_VALUE not in values:
result = (value,) + tuple(values)
observer.on_next(result)
else:
parent_queued[0] = value
disp = parent.subscribe_(on_next, observer.on_error, observer.on_completed, scheduler)
parent_subscription.disposable = disp
children_subscription = [subscribe_child(i, child) for i, child in enumerate(children)]
return [parent_subscription] + children_subscription
return CompositeDisposable(subscribe_all(parent, *sources))
return Observable(subscribe)
return _with_latest_from
class NotSet:
"""Sentinel value."""
def __eq__(self, other):
return self is other
def __repr__(self):
return "NotSet"
def convert(space: eagerx.Space, processor, name, component, node, direction="out"):
OUTPUT = True if direction == "out" else False
INPUT = True if direction == "in" else False
space_checked = [False]
p_msg = f" (after processing with `{processor.__class__.__qualname__}`)" if processor else ""
assert isinstance(space, eagerx.Space), f"The space of '{name}' is not of type eagerx.Space."
def _convert(source):
def subscribe(observer, scheduler=None):
def on_next(recv):
if INPUT:
# Preprocess input message
if processor is not None:
recv = processor.convert(recv)
# Check if message complies with space (after conversion)
if not space_checked[0]:
space_checked[0] = True
if not space.contains(np.array(recv)):
shape_msg = f"(msg.shape={recv.shape} vs space.shape={space.shape})"
dtype_msg = f"(msg.dtype={recv.dtype} vs space.dtype={space.dtype})"
msg = (
f"[subscriber][{node.ns_name}][{component}][{name}]: Message{p_msg} does not match the defined space. "
f"Either a mismatch in expected shape {shape_msg}, dtype {dtype_msg}, and/or the value is out of bounds (low/high)."
)
node.backend.logwarn_once(msg, identifier=f"[{node.ns_name}][{name}]")
observer.on_next(recv)
elif OUTPUT:
# Convert python native types to numpy arrays.
if isinstance(recv, float):
recv = np.array(recv, dtype="float32")
elif isinstance(recv, bool):
recv = np.array(recv, dtype="bool")
elif isinstance(recv, int):
recv = np.array(recv, dtype="int64")
# Process message
if processor is not None:
recv = processor.convert(recv)
else:
recv = recv
if not space_checked[0]:
space_checked[0] = True
if not space.contains(np.array(recv)):
shape_msg = f"(msg.shape={recv.shape} vs space.shape={space.shape})"
dtype_msg = f"(msg.dtype={recv.dtype} vs space.dtype={space.dtype})"
msg = (
f"[publisher][{node.ns_name}][{component}][{name}]: Message{p_msg} does not match the defined space. "
f"Either a mismatch in expected shape {shape_msg}, dtype {dtype_msg}, and/or the value is out of bounds (low/high)."
)
node.backend.logwarn_once(msg, identifier=f"[{node.ns_name}][{name}]")
observer.on_next(recv)
else:
raise NotImplementedError(f"Direction not implemented: {direction}.")
return source.subscribe(on_next, observer.on_error, observer.on_completed, scheduler)
return rx.create(subscribe)
return _convert