ngtt/uplink/thread.py
import logging
from concurrent.futures import Future
import minijson
from satella.coding import wraps, for_argument, silence_excs, rethrow_as
from satella.coding.optionals import Optional
from satella.exceptions import Empty
from satella.time import ExponentialBackoff
from ..orders import Order
import typing as tp
import select
from satella.coding.concurrent import TerminableThread
from ..exceptions import DataStreamSyncFailed, ConnectionFailed
from ..protocol import NGTTHeaderType
from .connection import NGTTSocket
logger = logging.getLogger(__name__)
def must_be_connected_else_raise(fun):
@wraps(fun)
def outer(self, *args, **kwargs):
if not self.connected:
raise ConnectionFailed(True)
return fun(self, *args, **kwargs)
return outer
class NGTTConnection(TerminableThread):
"""
An interface to NGTT, also a thread maintaining connection in the background.
Note that instantiating this object is the same as calling start. You do not need to call
start on this object after you initialize it.
:param cert_file: path to file with certificate. This file should contain only the device
certificate, attaching entire certificate chain is not required.
:param key_file: path to private key
:param on_new_order: a callable taking only a single argument and returning nothing, the
callable to call when a new order appears. Note that you have to call
either :meth:`~ngtt.orders.Order.acknowledge` or :meth:`~ngtt.orders.Order.nack` for each
received order. Leave at default (None) if orders are not meant to be fetched.
:ivar connected (bool) is connection opened
"""
def __init__(self, cert_file: str, key_file: str,
on_new_order: tp.Optional[tp.Callable[[Order], None]] = None):
super().__init__(name='ngtt uplink')
self.on_new_order = on_new_order
self.cert_file = cert_file
self.stopped = False
self.key_file = key_file
self.current_connection = None
self.start()
def stop(self, wait_for_completion: bool = True):
"""
Stop this thread and the connection
:param wait_for_completion: whether to wait for thread to terminate
"""
if self.stopped:
return
self.terminate()
if wait_for_completion:
self.join()
self.stopped = True
def close(self):
"""
Alias for :meth:`~ngtt.uplink.NGTTConnection.stop`.
"""
self.stop()
if self.current_connection is not None:
self.current_connection.close()
self.current_connection = None
self.op_id_to_op = {}
@property
@silence_excs(AttributeError, returns=False)
def connected(self) -> bool:
"""Are we connected to target server?"""
return self.current_connection.connected
def connect(self):
if self.connected:
return
eb = ExponentialBackoff(1, 30, self.safe_sleep)
while not self.terminating and not self.connected:
try:
self.current_connection = NGTTSocket(self.cert_file, self.key_file)
self.current_connection.connect()
if self.on_new_order:
self.current_connection.send_frame(0, NGTTHeaderType.FETCH_ORDERS)
except ConnectionFailed as e:
logger.warning('Failure reconnecting', exc_info=e)
eb.failed()
eb.sleep()
if self.terminating:
return
def inner_loop(self):
self.current_connection.try_ping()
ccon = [self.current_connection]
rx, wx, ex = select.select(ccon,
ccon if self.current_connection.wants_write else [], [],
1)
if wx:
with rethrow_as(ConnectionResetError, ConnectionFailed):
self.current_connection.try_send()
if not rx:
return
frame = self.current_connection.recv_frame()
if frame is None:
return
if frame.packet_type == NGTTHeaderType.PING:
self.current_connection.got_ping()
elif frame.packet_type == NGTTHeaderType.ORDER:
try:
data = frame.real_data
except ValueError:
logger.error('Received invalid JSON over the wire')
raise ConnectionFailed(False, 'Got invalid JSON')
order = Order(data, frame.tid, self.current_connection)
self.on_new_order(order)
elif frame.packet_type in (
NGTTHeaderType.DATA_STREAM_REJECT, NGTTHeaderType.DATA_STREAM_CONFIRM):
tid = frame.tid
if tid in self.current_connection.futures:
self.current_connection.id_assigner.mark_as_free(tid)
# Assume it's a data stream running
fut = self.current_connection.futures.pop(tid)
if frame.packet_type == NGTTHeaderType.DATA_STREAM_CONFIRM:
fut.set_result(None)
elif frame.packet_type == NGTTHeaderType.DATA_STREAM_REJECT:
fut.set_exception(DataStreamSyncFailed())
else:
logger.info('This was an unknown confirmation')
def loop(self) -> None:
try:
self.connect()
except ConnectionFailed as e:
logger.warning('Failure during connect', exc_info=e)
return
if self.terminating:
return
try:
self.inner_loop()
except ConnectionFailed as e:
logger.warning('Connection failed', exc_info=e)
self.cleanup()
def cleanup(self):
Optional(self.current_connection).close()
self.current_connection = None