eagerx/core/rx_pipelines.py
import numpy as np
import time
# RX IMPORTS
import rx
from rx import operators as ops
from rx.disposable import CompositeDisposable
from rx.scheduler import EventLoopScheduler, ThreadPoolScheduler
from rx.subject import ReplaySubject, Subject, BehaviorSubject
# EAGERX IMPORTS
import eagerx
from eagerx.core.constants import DEBUG
from eagerx.core.rx_operators import (
cb_ft,
spy,
trace_observable,
flag_dict,
switch_to_reset,
init_channels,
init_real_reset,
merge_dicts,
init_state_inputs_channel,
init_state_resets,
init_callback_pipeline,
switch_with_check_pipeline,
node_reset_flags,
filter_dict_on_key,
throttle_callback_trigger,
with_latest_from,
convert,
)
def init_node_pipeline(
ns,
rate_node,
node,
inputs,
outputs,
F,
SS_ho,
SS_CL_ho,
R,
RR,
E,
real_reset,
feedthrough,
state_inputs,
state_outputs,
targets,
cb_ft,
sync,
real_time_factor,
simulate_delays,
disposables,
event_scheduler=None,
):
# Node ticks
Rn = ReplaySubject() # Reset flag for the node (Nc=Ns and r_signal)
Nc = Subject() # Number completed callbacks (i.e. send Topics):
Ns = BehaviorSubject(0) # Number of started callbacks (i.e. number of planned Topics).
# Throttle the callback trigger
Nct = throttle_callback_trigger(rate_node, Nc, E, sync, real_time_factor, event_scheduler, node)
# Create input channels
zipped_inputs, zipped_input_flags = init_channels(
ns,
Nct,
rate_node,
inputs,
sync,
real_time_factor,
simulate_delays,
E,
node=node,
scheduler=event_scheduler,
)
# Create action channels
zipped_feedthrough, zipped_action_flags, d_rr = init_real_reset(
ns,
Nct,
rate_node,
RR,
real_reset,
feedthrough,
targets,
sync,
real_time_factor,
simulate_delays,
E,
event_scheduler,
node=node,
)
# Zip inputs & action channels
zipped_channels = rx.zip(zipped_inputs, zipped_feedthrough).pipe(
ops.share(), ops.observe_on(event_scheduler)
) # this is required, otherwise a block.
# New routine with merge
PR = R.pipe(
ops.observe_on(event_scheduler),
ops.map(lambda x: True),
ops.merge(zipped_channels),
switch_to_reset(),
ops.share(),
)
# Create reset signal
Rr, P = PR.pipe(ops.partition(lambda value: isinstance(value, bool)))
# Create accumulator: (acc)
d_Ns = P.pipe(ops.scan(lambda acc, x: acc + 1, 0)).subscribe(Ns)
# Create callback stream
input_stream = Ns.pipe(ops.skip(1), ops.zip(P), ops.share())
d_msg, output_stream = init_callback_pipeline(
ns,
node.callback_cb,
cb_ft,
input_stream,
real_reset,
targets,
state_outputs,
outputs,
event_scheduler,
node=node,
)
# Publish output msg as ROS topic and to subjects if single process
for o in outputs:
d = output_stream.pipe(
ops.filter(lambda x: x is not None),
ops.pluck(o["name"]),
ops.filter(lambda x: x is not None),
convert(o["space"], o["processor"], o["name"], "outputs", node, direction="out"),
ops.share(),
).subscribe(o["msg"])
# Add disposable
d_msg += [d]
# Publish output msg as ROS topic and to subjects if single process
Nc_obs = output_stream.pipe(ops.scan(lambda acc, x: acc + 1, 0))
Nc_empty = output_stream.pipe(ops.scan(lambda acc, x: acc + 1 if x is None else acc, 0), ops.start_with(0))
# Increase ticks
d_Nc = Nc_obs.subscribe(Nc, scheduler=event_scheduler)
d_Rn = Nc_obs.pipe(
ops.start_with(0), # added to simulated first zero from BS(0) of Nc
ops.zip(Nc_empty),
ops.map(lambda x: x[0] - x[1]),
ops.combine_latest(Ns, Rr, Nc_empty),
ops.filter(lambda value: value[0] == value[1] - value[3]),
ops.take(1),
ops.merge(rx.never()),
# spy('post-filter', node),
).subscribe(Rn)
# Create reset flags for the set_states
ss_flags = init_state_inputs_channel(ns, state_inputs, event_scheduler, node=node)
SS_ho.on_next(ss_flags.pipe(ops.take(1), ops.start_with(None)))
SS_CL_ho.on_next(ss_flags.pipe(ops.start_with(None)))
# Combine flags and subscribe F to all the flags zipped together
d_flag = (
rx.zip(zipped_input_flags, zipped_action_flags)
.pipe(ops.map(lambda x: merge_dicts(x[0], x[1])))
.subscribe(F, scheduler=event_scheduler)
)
# Dispose
disposables.clear()
d = CompositeDisposable([Rn, Nc, Ns, d_Rn, d_Nc, d_Ns, d_flag] + d_msg + d_rr)
disposables.add(d)
return {"Rn": Rn, "dispose": d}
def init_node(
ns,
rate_node,
node,
inputs,
outputs,
feedthrough=tuple(),
state_inputs=tuple(),
targets=tuple(),
):
# Initialize scheduler
event_scheduler = EventLoopScheduler()
eps_disp = CompositeDisposable()
reset_disp = CompositeDisposable(eps_disp)
# Gather reactive properties
sync = node.sync
real_time_factor = node.real_time_factor
simulate_delays = node.simulate_delays
# Track node I/O
node_inputs = []
node_outputs = []
# Prepare reset topic
R = Subject()
reset_input = dict(name="reset", address=ns + "/reset", dtype="int64", msg=R)
node_inputs.append(reset_input)
# Prepare real_reset topic
RR = Subject()
real_reset_input = dict(name="real_reset", address=ns + "/real_reset", dtype="int64", msg=RR)
node_inputs.append(real_reset_input)
# End reset
E = Subject()
end_reset = dict(name="end_reset", address=ns + "/end_reset", dtype="float64", msg=E)
node_inputs.append(end_reset)
# Prepare node reset topic
node_reset = dict(name=node.ns_name, dtype="bool")
node_reset["address"] = node.ns_name + "/end_reset"
node_reset["msg"] = Subject()
node_outputs.append(node_reset)
# Real reset checks
state_outputs = []
real_reset = len(feedthrough) > 0
assert not real_reset or (
real_reset and len(targets) > 0
), "Cannot initialize real reset node (%s). If len(feedthroughs) is provided, then len(targets) > 0. must hold."
# Prepare input topics
flags = []
for i in inputs:
# Subscribe to input topic
Ir = Subject()
i["msg"] = Ir
# Subscribe to input reset topic
Is = Subject()
i["reset"] = Is
# Prepare initial flag
flag = i["reset"].pipe(flag_dict(i["name"]), ops.first(), ops.merge(rx.never()))
flags.append(flag)
# Prepare output topics
for i in outputs:
# Prepare output topic
i["msg"] = Subject()
# Initialize reset topic
i["reset"] = Subject()
# Prepare action topics (used by RealResetNode)
for i in feedthrough:
# Subscribe to input topic
Ar = Subject()
i["msg"] = Ar
# Subscribe to input reset topic
As = Subject()
i["reset"] = As
# Prepare initial flag
flag = i["reset"].pipe(flag_dict(i["address"]), ops.first(), ops.merge(rx.never()))
flags.append(flag)
# Prepare state topics
for i in state_inputs:
# Initialize desired state message
S = Subject()
i["msg"] = S
D = Subject()
i["done"] = D
for i in targets:
# Initialize target state message
S = Subject()
i["msg"] = S
D = Subject()
i["done"] = D
# Initialize done flag for desired state
done_output = dict(name=i["name"], address=i["address"] + "/done", dtype="bool", msg=D)
state_outputs.append(done_output)
# Prepare initial flags
F_init = rx.zip(*flags).pipe(spy("zip", node), ops.map(lambda x: merge_dicts({}, x)))
# Reset flags
F = Subject()
f = rx.merge(F, F_init)
ss_flags = init_state_inputs_channel(ns, state_inputs, event_scheduler, node=node)
check_SS_CL, SS_CL, SS_CL_ho = switch_with_check_pipeline(init_ho=ss_flags)
latched_ss_flags = rx.combine_latest(SS_CL, ss_flags).pipe(ops.map(lambda x: x[1]), ops.take(1))
check_SS, SS, SS_ho = switch_with_check_pipeline(init_ho=latched_ss_flags)
# Node ticks
Rn_ho = BehaviorSubject(BehaviorSubject((0, 0, True)))
Rn = Rn_ho.pipe(ops.switch_latest(), ops.map(lambda x: x[0]))
# Create reset switch_latest
Rr = R.pipe(ops.map(lambda x: True)) # ops.observe_on(event_scheduler), # seemed to make ROS variant crash
# Prepare "ready-to-reset" signal (i.e. after node receives reset signal and stops sending any output msgs).
RrRn = rx.zip(Rr.pipe(spy("Rr", node)), Rn.pipe(spy("Rn", node))).pipe(
ops.map(lambda x: x[1]), spy("SEND_FLAGS", node), ops.share()
)
# Send output flags
for i in outputs:
d = RrRn.subscribe(i["reset"])
reset_disp.add(d)
# Reset node pipeline
reset_trigger = rx.zip(RrRn, f.pipe(spy("F", node)), SS.pipe(spy("SS", node))).pipe(
with_latest_from(SS_CL),
ops.map(lambda x: x[0][:-1] + (x[1],)),
spy("RENEW_PIPE", node),
ops.map(lambda x: x[-1]),
ops.share(),
) # x: SS_CL
reset_obs = reset_trigger.pipe(
ops.map(
lambda x: init_node_pipeline(
ns,
rate_node,
node,
inputs,
outputs,
F,
SS_ho,
SS_CL_ho,
R,
RR,
E,
real_reset,
feedthrough,
state_inputs,
state_outputs,
targets,
cb_ft,
sync,
real_time_factor,
simulate_delays,
eps_disp,
event_scheduler=event_scheduler,
)
),
trace_observable("init_node_pipeline", node),
ops.share(),
)
d = reset_obs.pipe(ops.pluck("Rn")).subscribe(Rn_ho)
reset_disp.add(d)
# Dispose old pipeline, run reset callback
reset_msg = reset_obs.pipe(
ops.pluck("dispose"),
ops.buffer_with_count(2, skip=1),
ops.start_with(None),
# zipped with Rn_ho so that Rn has switched before sending "reset topic"
ops.zip(reset_trigger, Rn_ho.pipe(ops.skip(1)), check_SS, check_SS_CL),
ops.map(lambda x: x[1]), # x[1]=Nc
ops.share(),
spy("RESET", node, log_level=DEBUG),
ops.map(lambda x: node.reset_cb(**x)),
trace_observable("cb_reset", node),
ops.share(),
)
# Send node reset message
d = reset_msg.pipe(ops.map(lambda x: True)).subscribe(node_reset["msg"])
reset_disp.add(d)
rx_objects = dict(
inputs=inputs,
outputs=outputs,
feedthrough=feedthrough,
state_inputs=state_inputs,
state_outputs=state_outputs,
targets=targets,
node_inputs=node_inputs,
node_outputs=node_outputs,
disposable=reset_disp,
)
return rx_objects
def init_engine_pipeline(
ns,
rate_node,
node,
zipped_channels,
outputs,
Nct_ho,
DF,
RRn_ho,
SS_ho,
SS_CL_ho,
state_inputs,
sync,
real_time_factor,
simulate_delays,
E,
disposables,
event_scheduler=None,
):
# Node ticks
RRn = Subject()
RRn_ho.on_next(RRn)
Nc = Subject() # Number completed callbacks (i.e. send Topics): initialized at zero to kick of chain reaction
Ns = BehaviorSubject(0) # Number of started callbacks (i.e. number of planned Topics).
# Throttle the callback trigger
Nct = throttle_callback_trigger(rate_node, Nc, E, sync, real_time_factor, event_scheduler, node)
Nct_ho.on_next(Nct)
# Create a tuple with None, to be consistent with feedthrough pipeline of init_node_pipeline
zipped_channels = zipped_channels.pipe(ops.map(lambda i: (i, None)))
# New routine with merge
PR = DF.pipe(
ops.filter(lambda x: all([df for df in x])),
spy("DF filtered", node),
ops.observe_on(event_scheduler),
ops.map(lambda x: True),
ops.merge(zipped_channels),
switch_to_reset(),
ops.share(),
)
# Create reset signal
RRr, P = PR.pipe(ops.partition(lambda value: isinstance(value, bool)))
# Create accumulator: (acc)
d_Ns = P.pipe(ops.scan(lambda acc, x: acc + 1, 0)).subscribe(Ns)
# Create callback stream
input_stream = Ns.pipe(ops.skip(1), ops.zip(P), ops.share())
d_msg, output_stream = init_callback_pipeline(
ns,
node.callback_cb,
cb_ft,
input_stream,
False,
tuple(),
tuple(),
outputs,
event_scheduler,
node,
)
# Publish output msg as ROS topic and to subjects if single process
for o in outputs:
d = output_stream.pipe(
ops.pluck(o["name"]),
ops.filter(lambda x: x is not None),
convert(o["space"], o["processor"], o["name"], "outputs", node, direction="out"),
ops.share(),
).subscribe(o["msg"])
# Add disposable
d_msg += [d]
# After outputs have been send, increase the completed callback counter
Nc_obs = output_stream.pipe(ops.scan(lambda acc, x: acc + 1, 0))
# Increase ticks
d_Nc = Nc_obs.subscribe(Nc, scheduler=event_scheduler)
d_RRn = Nc_obs.pipe(
ops.start_with(0), # added to simulated first zero from BS(0) of Nc
ops.combine_latest(Ns, RRr),
ops.filter(lambda value: value[0] == value[1]),
ops.take(1),
ops.merge(rx.never()),
).subscribe(RRn)
# Create reset flags for the set_states
ss_flags = init_state_inputs_channel(ns, state_inputs, event_scheduler, node=node)
SS_ho.on_next(ss_flags.pipe(ops.take(1), ops.start_with(None)))
SS_CL_ho.on_next(ss_flags.pipe(ops.start_with(None)))
# Dispose
disposables.clear()
d = CompositeDisposable([RRn, Nc, Ns, d_RRn, d_Nc, d_Ns] + d_msg)
disposables.add(d)
return {"dispose": d}
def init_engine(
ns,
rate_node,
node,
inputs_init,
outputs,
state_inputs,
engine_state_inputs,
node_names,
target_addresses,
message_broker,
):
###########################################################################
# Initialization ##########################################################
###########################################################################
# Prepare scheduler
tp_scheduler = ThreadPoolScheduler(max_workers=max(5, len(engine_state_inputs)))
event_scheduler = EventLoopScheduler()
eps_disp = CompositeDisposable()
reset_disp = CompositeDisposable(eps_disp)
# Gather reactive properties
sync = node.sync
real_time_factor = node.real_time_factor
simulate_delays = node.simulate_delays
# Prepare input topics
for i in inputs_init:
# Subscribe to input topic
Ir = Subject()
i["msg"] = Ir
# Subscribe to input reset topic
Is = Subject()
i["reset"] = Is
# Prepare output topics
for i in outputs:
# Prepare output topic
i["msg"] = Subject()
# Initialize reset topic
i["reset"] = Subject()
# Prepare state topics
for i in state_inputs:
# Initialize desired state message
S = Subject()
i["msg"] = S
D = Subject()
i["done"] = D
# Prepare engine state topics
for i in engine_state_inputs:
# Initialize desired state message
S = Subject()
i["msg"] = S
D = Subject()
i["done"] = D
ss_flags = init_state_inputs_channel(ns, state_inputs, event_scheduler, node=node)
check_SS_CL, SS_CL, SS_CL_ho = switch_with_check_pipeline(init_ho=ss_flags)
latched_ss_flags = rx.combine_latest(SS_CL, ss_flags).pipe(ops.map(lambda x: x[1]), ops.take(1))
check_SS, SS, SS_ho = switch_with_check_pipeline(init_ho=latched_ss_flags)
# Prepare target_addresses
df_inputs = []
dfs = []
for i in target_addresses:
address = "%s/%s" % (ns, i)
done = Subject()
df_inputs.append(dict(address=address, done=done))
dfs.append(done)
# Track node I/O
node_inputs = []
node_outputs = []
###########################################################################
# Registry: node flags ####################################################
###########################################################################
init_node_flags = []
for i in node_names:
if i == "env/supervisor":
# todo: skipping is a HACK!
# - node_flag:env/supervisor/end_reset is received in engine Subscriber
# - However, the message does not come through in node_reset_flags(...)
continue
nf = dict(name=i, address="%s/%s/end_reset" % (ns, i), msg=Subject(), dtype="bool")
node_inputs.append(nf)
init_node_flags.append(nf)
###########################################################################
# Registry: real reset states #############################################
###########################################################################
# Resettable real states (to dynamically add state/done flags to DF for RR)
RR = Subject()
RRr = RR.pipe(ops.map(lambda x: True))
# Switch to latest state reset done flags
DF = rx.combine_latest(RRr, *dfs)
###########################################################################
# Start reset #############################################################
###########################################################################
# Prepare start_reset input
SR = Subject()
start_reset_input = dict(name="start_reset", address=ns + "/start_reset", msg=SR, dtype="int64")
node_inputs.append(start_reset_input)
# Latch on '/rx/start_reset' event
# todo: do not dynamically initialize
rx_objects = SR.pipe(
spy("SR", node, log_level=DEBUG),
ops.map(
lambda x: dict(
inputs=list(inputs_init),
sp_nodes=[],
launch_nodes=[],
state_inputs=list(engine_state_inputs),
node_flags=init_node_flags,
)
),
ops.share(),
)
inputs = rx_objects.pipe(ops.pluck("inputs"))
simstate_inputs = rx_objects.pipe(ops.pluck("state_inputs"))
node_flags = rx_objects.pipe(ops.pluck("node_flags"))
# Prepare output for reactive proxy
RM = Subject()
# Zip initial input flags
check_F_init, F_init, F_init_ho = switch_with_check_pipeline()
F_init = F_init.pipe(ops.first(), ops.merge(rx.never()))
d = inputs.pipe(
ops.map(
lambda inputs: rx.zip(*[i["reset"].pipe(flag_dict(i["name"])) for i in inputs]).pipe(
ops.map(lambda x: merge_dicts({}, x)), ops.start_with(None)
)
)
).subscribe(F_init_ho)
reset_disp.add(d)
F = Subject()
f = rx.merge(F, F_init)
# Zip node flags
check_NF, NF, NF_ho = switch_with_check_pipeline()
d = node_flags.pipe(ops.map(lambda node_flags: node_reset_flags(ns, node_flags, node))).subscribe(NF_ho)
reset_disp.add(d)
# Dynamically initialize new state pipeline
ResetTrigger = Subject()
ss_flags = simstate_inputs.pipe(
ops.map(lambda s: init_state_resets(ns, s, ResetTrigger, event_scheduler, tp_scheduler, node)),
trace_observable("EngineState", node),
ops.share(),
)
check_simSS, simSS, simSS_ho = switch_with_check_pipeline()
d = ss_flags.pipe(ops.map(lambda obs: obs.pipe(ops.start_with(None)))).subscribe(simSS_ho)
reset_disp.add(d)
# Before starting real_reset procedure, wait for EngineState pipeline to be initialized.
# This, so that the first time, the engine states are run.
# Some engines/simulators might require that for setting initial state.
ER = Subject()
end_register = dict(name="end_register", address=ns + "/end_register", msg=ER, dtype="int64")
node_outputs.append(end_register)
# Zip switch checks to indicate end of '/rx/start_reset' procedure, and start of '/rx/real_reset'
# todo: If simstate_inputs is not added here, EngineState reset sometimes blocks (done flag not received).
d = (
rx.zip(check_F_init, simstate_inputs, check_NF, check_simSS)
.pipe(ops.map(lambda i: message_broker.connect_io()), ops.map(lambda i: 0))
.subscribe(ER)
)
reset_disp.add(d)
###########################################################################
# Real reset ##############################################################
###########################################################################
# Prepare real_reset output. Previously, RR was both an input and subscribed to?
real_reset_input = dict(name="real_reset", address=ns + "/real_reset", msg=RR, dtype="int64")
node_inputs.append(real_reset_input)
# Real reset routine. Cuts-off tick_callback when RRr is received, instead of Rr
check_RRn, RRn, RRn_ho = switch_with_check_pipeline(init_ho=BehaviorSubject((0, 0, True)))
pre_reset_trigger = rx.zip(RRn.pipe(spy("RRn", node)), RRr.pipe(spy("RRr", node)), SS.pipe(spy("SS", node))).pipe(
with_latest_from(SS_CL),
ops.map(lambda x: x[0][:-1] + (x[1],)),
ops.map(lambda x: (x, node.pre_reset_cb(**x[-1]))),
# Run pre-reset callback
spy("PRE-RESET", node, log_level=DEBUG),
trace_observable("cb_pre_reset", node),
ops.share(),
)
d = pre_reset_trigger.pipe(
ops.map(lambda x: x[0][0][0]), # x[0][0][0]=Nc
ops.share(),
).subscribe(RM)
reset_disp.add(d)
ss_cl = pre_reset_trigger.pipe(ops.map(lambda x: x[0][-1])) # x[0][-1]=ss_cl
###########################################################################
# Reset ###################################################################
###########################################################################
# Prepare reset output
R = Subject()
reset_output = dict(name="reset", address=ns + "/reset", msg=R, dtype="int64")
node_outputs.append(reset_output)
# Send reset message
d = RM.subscribe(R)
reset_disp.add(d)
Rr = R.pipe(ops.map(lambda x: True))
d = rx.zip(f.pipe(spy("F", node)), Rr.pipe(spy("Rr", node))).pipe(ops.share()).subscribe(ResetTrigger)
reset_disp.add(d)
# Send reset messages for all outputs (Only '/rx/engine/outputs/tick')
[reset_disp.add(RM.subscribe(o["reset"])) for o in outputs]
###########################################################################
# Reset: initialize episode pipeline ######################################
###########################################################################
# Prepare end_reset output
end_reset = dict(name="end_reset", address=ns + "/end_reset", msg=Subject(), dtype="float64")
node_outputs.append(end_reset)
# Dynamically initialize new input pipeline
check_Nct, Nct, Nct_ho = switch_with_check_pipeline()
inputs_flags = inputs.pipe(
ops.zip(ResetTrigger),
ops.map(lambda i: i[0]),
ops.map(
lambda inputs: init_channels(
ns,
Nct,
rate_node,
inputs,
sync,
real_time_factor,
simulate_delays,
end_reset["msg"],
event_scheduler,
node,
)
),
ops.share(),
)
# Switch to latest zipped inputs pipeline
check_z_inputs, z_inputs, z_inputs_ho = switch_with_check_pipeline()
d = inputs_flags.pipe(ops.map(lambda i: i[0].pipe(ops.start_with(None)))).subscribe(z_inputs_ho, scheduler=event_scheduler)
reset_disp.add(d)
# Switch to latest zipped flags pipeline
check_z_flags, z_flags, z_flags_ho = switch_with_check_pipeline()
d = inputs_flags.pipe(ops.map(lambda i: i[1].pipe(ops.start_with(None)))).subscribe(z_flags_ho, scheduler=event_scheduler)
reset_disp.add(d)
d = z_flags.subscribe(F)
reset_disp.add(d)
# Initialize rest of episode pipeline
pipeline_trigger = rx.zip(check_z_flags, check_z_inputs)
reset_obs = pipeline_trigger.pipe(
ops.map(
lambda x: init_engine_pipeline(
ns,
rate_node,
node,
z_inputs,
outputs,
Nct_ho,
DF,
RRn_ho,
SS_ho,
SS_CL_ho,
state_inputs,
sync,
real_time_factor,
simulate_delays,
end_reset["msg"],
eps_disp,
event_scheduler=event_scheduler,
)
),
trace_observable("init_engine_pipeline", node),
ops.share(),
)
###########################################################################
# End reset ###############################################################
###########################################################################
# Send '/end_reset' after reset has finished
d = reset_obs.pipe(
ops.pluck("dispose"),
ops.buffer_with_count(2, skip=1),
ops.start_with(None),
ops.zip(
ss_cl.pipe(spy("ER-ss_cl", node)),
reset_obs.pipe(spy("ER-obs", node)),
simSS.pipe(spy("ER-simSS", node)),
NF.pipe(spy("ER-NF", node)),
check_SS.pipe(spy("ER-ch_SS", node)),
check_SS_CL.pipe(spy("ER-ch_SS_CL", node)),
),
ops.map(lambda x: node.reset_cb(**x[1])),
spy("POST-RESET", node, log_level=DEBUG),
trace_observable("cb_post_reset", node),
ops.map(lambda x: np.array(time.monotonic_ns() / 1e9, dtype="float64")),
ops.share(),
).subscribe(end_reset["msg"])
reset_disp.add(d)
rx_objects = dict(
inputs=inputs_init,
outputs=outputs,
node_inputs=node_inputs,
node_outputs=node_outputs,
state_inputs=list(state_inputs) + df_inputs + list(engine_state_inputs),
disposable=reset_disp,
)
return rx_objects
def init_supervisor(ns, node, outputs=tuple(), state_outputs=tuple()):
# Initialize schedulers
tp_scheduler = ThreadPoolScheduler(max_workers=5)
reset_disp = CompositeDisposable()
# Prepare states
done_outputs = []
for s in state_outputs:
# Prepare done flag
s["done"] = Subject()
done_outputs.append(
dict(
name=s["name"],
address=s["address"] + "/done",
dtype="bool",
msg=s["done"],
)
)
# Prepare state message (IMPORTANT: after done flag, we modify address afterwards)
s["msg"] = Subject()
s["address"] += "/set"
###########################################################################
# Start reset #############################################################
###########################################################################
SR = Subject() # ---> Not a node output, but used in node.reset() to kickstart reset pipeline (send self.cum_registered).
start_reset = dict(name="start_reset", address=ns + "/start_reset", msg=Subject(), dtype="int64")
d = SR.subscribe(start_reset["msg"], scheduler=tp_scheduler)
reset_disp.add(d)
###########################################################################
# End register ############################################################
###########################################################################
ER = Subject()
end_register = dict(name="reset", address=ns + "/end_register", dtype="int64", msg=ER)
real_reset = dict(name="real_reset", address=ns + "/real_reset", msg=Subject(), dtype="int64")
d = ER.subscribe(real_reset["msg"])
reset_disp.add(d)
# Publish state msgs
# msgs = SR.pipe(ops.skip(1), ops.map(node._get_states), ops.share())
msgs = ER.pipe(ops.map(node._get_states), ops.share())
for s in state_outputs:
d = msgs.pipe(
ops.pluck(s["name"] + "/done"),
trace_observable("done", node),
ops.share(),
).subscribe(s["done"])
reset_disp.add(d)
d = msgs.pipe(
filter_dict_on_key(s["name"]),
ops.filter(lambda msg: msg is not None),
convert(s["space"], s["processor"], s["name"], "states", node, direction="out"),
ops.share(),
).subscribe(s["msg"])
reset_disp.add(d)
###########################################################################
# Reset ###################################################################
###########################################################################
R = Subject()
reset = dict(name="reset", address=ns + "/reset", dtype="int64", msg=R)
# Prepare node reset topic
node_reset = dict(
name=node.ns_name,
address=node.ns_name + "/end_reset",
dtype="bool",
msg=Subject(),
)
# Reset pipeline
# todo: HACK! The Engine currently does not wait for this message.
d = R.pipe(ops.map(lambda x: True)).subscribe(node_reset["msg"], scheduler=tp_scheduler)
reset_disp.add(d)
###########################################################################
# End reset ###############################################################
###########################################################################
# Define tick attributes
space = eagerx.Space(shape=(), dtype="int64")
dtype = space.to_dict()["dtype"]
tick = dict(name="tick", address=ns + "/engine/outputs/tick", msg=Subject(), dtype=dtype)
# end_reset = dict(name="end_reset", address=ns + "/end_reset", msg=Subject(), dtype="int64")
# d = (
# end_reset["msg"]
# .pipe(
# spy("RESET END", node, log_level=DEBUG),
# convert(space, None, "tick", "outputs", node, direction="out"),
# )
# .subscribe(tick["msg"])
# )
# reset_disp.add(d)
# Create node inputs & outputs
node_inputs = [reset, end_register]
# node_outputs = [register_object, register_node, start_reset, tick, node_reset, real_reset]
node_outputs = [start_reset, tick, node_reset, real_reset]
outputs = []
# Create return objects
# env_subjects = dict(register_object=REG_OBJECT, register_node=REG_NODE, start_reset=SR)
env_subjects = dict(start_reset=SR)
rx_objects = dict(
node_inputs=node_inputs,
node_outputs=node_outputs,
outputs=outputs,
state_outputs=state_outputs + tuple(done_outputs),
disposable=reset_disp,
)
return rx_objects, env_subjects