DogsTailFarmer/crypto-ws-api

View on GitHub
crypto_ws_api/ws_session.py

Summary

Maintainability
A
0 mins
Test Coverage
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import gc
import sys
import time
import logging
import ujson as json
import hmac
import hashlib
import base64
import string
import random
import websockets.client

from websockets import ConnectionClosed

from enum import Enum
from crypto_ws_api import TIMEOUT, ID_LEN_LIMIT, DELAY

logger_ws = logger = logging.getLogger(__name__)
logger_ws.level = logging.INFO
sys.tracebacklimit = 0
ALPHABET = string.ascii_letters + string.digits
CONST_3 = "userDataStream.start"


def generate_signature(exchange, api_secret, data):
    if exchange == 'bitfinex':
        sig = hmac.new(api_secret.encode("utf-8"), data.encode("utf-8"), hashlib.sha384).hexdigest()
    elif exchange in ('huobi', 'okx'):
        sig = hmac.new(api_secret.encode("utf-8"), data.encode("utf-8"), hashlib.sha256).digest()
        sig = base64.b64encode(sig).decode()
    elif exchange == 'binance_ws':
        sig = hmac.new(api_secret.encode("ascii"), data.encode("ascii"), hashlib.sha256).hexdigest()
    elif exchange == 'bybit':
        sig = hmac.new(bytes(api_secret.encode("utf-8")), data.encode("utf-8"), hashlib.sha256).hexdigest()
    else:
        sig = hmac.new(api_secret.encode("utf-8"), data.encode("utf-8"), hashlib.sha256).hexdigest()
    return sig


# https://binance-docs.github.io/apidocs/websocket_api/en/#rate-limits
class RateLimitInterval(Enum):
    SECOND = 1
    MINUTE = 60
    HOUR = 3600
    DAY = 86400


class UserWSS:
    __slots__ = (
        "init",
        "method",
        "exchange",
        "endpoint",
        "_api_key",
        "_api_secret",
        "_passphrase",
        "_ws",
        "_listen_key",
        "_retry_after",
        "ws_id",
        "operational_status",
        "order_handling",
        "request_limit_reached",
        "in_event",
        "_response_pool",
        "tasks",
    )

    def __init__(self, method, ws_id, exchange, endpoint, api_key, api_secret, passphrase=None):
        self.init = True
        self.method = method
        self.exchange = exchange
        self.endpoint = endpoint
        #
        self._api_key = api_key
        self._api_secret = api_secret
        self._passphrase = passphrase
        self._ws = None
        self._listen_key = None
        self._response_pool = {}
        self._retry_after = int(time.time() * 1000) - 1
        self.ws_id = ws_id
        self.operational_status = None
        self.order_handling = False
        self.request_limit_reached = False
        self.in_event = asyncio.Event()
        self.tasks = set()

    def tasks_manage(self, coro, name=None):
        _t = asyncio.create_task(coro, name=name)
        self.tasks.add(_t)
        _t.add_done_callback(self.tasks.discard)

    async def _ws_listener(self):
        self.tasks_manage(self.ws_login())
        async for msg in self._ws:
            # logger.info(f"_ws_listener: msg: {self.ws_id}: {msg}")
            if isinstance(msg, str):
                res = await self._handle_msg(json.loads(msg))
                # logger.info(f"_ws_listener: res: {self.ws_id}: {res}")
                if res != 'pass':
                    if res is None:
                        self._response_pool[f"NoneResponse{self.ws_id}"] = None
                    elif self.exchange == 'binance':
                        self._response_pool[res.get('id')] = res.get('result')
                    elif self.exchange in ['okx', 'bitfinex']:
                        self._response_pool[res.get('id') or self.ws_id] = res.get('data') or res
                    self.in_event.set()
                await asyncio.sleep(0)
            else:
                logger.warning(f"UserWSS: {self.ws_id}: {msg}")
                await self.stop()

    async def start_wss(self):
        async for self._ws in websockets.client.connect(self.endpoint, logger=logger_ws):
            try:
                await self._ws_listener()
            except ConnectionClosed as ex:
                if ex.code == 4000:
                    logger.info(f"WSS closed for {self.ws_id}")
                    break
                else:
                    self.operational_status = False
                    [task.cancel() for task in self.tasks if not task.done()]
                    self.tasks.clear()
                    logger.warning(f"Restart WSS for {self.ws_id}")
                    continue
            except Exception as ex:
                logger.error(f"WSS start_wss() other exception: {ex}")

    async def ws_login(self):
        res = await self.request('userDataStream.start', _api_key=True)
        if res is None:
            logger.warning(f"UserWSS: Not 'logged in' for {self.ws_id}")
            raise ConnectionClosed(None, None)
        else:
            if self.exchange == 'binance':
                self._listen_key = res.get('listenKey')
                self.tasks_manage(self.heartbeat(), f"heartbeat-{self.ws_id}")
            else:
                self._listen_key = f"{int(time.time() * 1000)}{self.ws_id}"
            self.operational_status = True
            self.order_handling = True
            self.tasks_manage(self._keepalive(), f"keepalive-{self.ws_id}")
            logger.info(f"UserWSS: 'logged in' for {self.ws_id}")

    async def request(self, _method=None, _params=None, _api_key=False, _signed=False):
        """
        Construct and handling request/response to WS API endpoint, use a description of the methods on
        https://developers.binance.com/docs/binance-trading-api/websocket_api#request-format
        :return: result: {} or None if temporary Out-of-Service state
        """
        method = _method or self.method
        if self.request_limit_reached:
            logger.warning(f"UserWSS {self.ws_id}: request limit reached, try later")
            return None
        if method != 'userDataStream.start' and not self.operational_status:
            logger.warning("UserWSS temporary in Out-of-Service state")
            return None
        if method in ('order.place', 'order.cancelReplace', 'order') and not self.order_handling:
            logger.warning("UserWSS: exceeded order placement limit, try later")
            return None
        params = _params.copy() if _params else None
        r_id = f"{self.exchange}{method}{''.join(random.choices(ALPHABET, k=8))}"
        if self.exchange in ("okx", "bitfinex") and method == CONST_3:
            _id = self.ws_id
        else:
            _id = ''.join(e for e in r_id if e.isalnum())[-ID_LEN_LIMIT[self.exchange]:]
        await self._ws.send(
            json.dumps(self.compose_request(_id, _api_key, method, params, _signed))
        )
        await asyncio.sleep(0)
        try:
            res = await asyncio.wait_for(self._response_distributor(_id), timeout=TIMEOUT)
        except asyncio.exceptions.TimeoutError:
            logger.warning(f"UserWSS: get response timeout error: {self.ws_id}")
            await self.stop()
        except asyncio.CancelledError:
            pass  # Task cancellation should not be logged as an error
        else:
            # logger.info(f"request: {self.ws_id}: {res}")
            return res

    async def _response_distributor(self, _id):
        while self.operational_status is not None:
            await self.in_event.wait()
            self.in_event.clear()
            if _id in self._response_pool:
                return self._response_pool.pop(_id)
            elif f"NoneResponse{self.ws_id}" in self._response_pool:
                return self._response_pool.pop(f"NoneResponse{self.ws_id}", None)

    def compose_request(self, _id, api_key, method, params, signed):
        if self.exchange == "binance":
            return self._compose_binance_request(_id, api_key, method, params, signed)
        elif self.exchange == "okx":
            return self._compose_okx_request(_id, method, params)
        elif self.exchange == 'bitfinex':
            return self._compose_bitfinex_request(_id, method, params)
        else:
            raise ValueError(f"Unsupported exchange: {self.exchange}")

    def _compose_binance_request(self, _id, api_key, method, params, signed):
        req = {"id": _id, "method": method}
        params = params or {}
        if api_key:
            params["apiKey"] = self._api_key
        if signed:
            params["timestamp"] = int(time.time() * 1000)
            payload = '&'.join(f"{key}={value}" for key, value in sorted(params.items()))
            params["signature"] = generate_signature('binance_ws', self._api_secret, payload)
        if params:
            req["params"] = params
        return req

    def _compose_okx_request(self, _id, method, params):
        if method == CONST_3:
            ts = int(time.time())
            signature_payload = f"{ts}GET/users/self/verify"
            signature = generate_signature(self.exchange, self._api_secret, signature_payload)
            return {
                "op": 'login',
                "args": [
                    {"apiKey": self._api_key, "passphrase": self._passphrase, "timestamp": ts, "sign": signature}
                ]
            }
        else:
            return {"id": _id, "op": method, "args": params if isinstance(params, list) else [params]}

    def _compose_bitfinex_request(self, _id, method, params):
        if method == CONST_3:
            ts = int(time.time() * 1000)
            data = f"AUTH{ts}"
            return {
                'event': "auth",
                'apiKey': self._api_key,
                'authSig': generate_signature(self.exchange, self._api_secret, data),
                'authPayload': data,
                'authNonce': ts,
                'filter': ['trading']
            }
        else:
            if method == 'on':
                params.update({"meta": {"aff_code": "v_4az2nCP"}})
            return [0, method, _id, params]

    async def _keepalive(self, interval=10):
        while self.operational_status is not None:
            if self.request_limit_reached and (int(time.time() * 1000) - self._retry_after >= 0):
                self.request_limit_reached = False
                logger.info(f"UserWSS: request limit reached restored for {self.ws_id}")
            if not self.order_handling and (int(time.time() * 1000) - self._retry_after >= 0):
                self.order_handling = True
                logger.info(f"UserWSS order handling status restored for {self.ws_id}")
            await asyncio.sleep(interval)

    async def heartbeat(self, interval=60 * 30):
        params = {
            "listenKey": self._listen_key,
        }
        while self.operational_status is not None:
            await self.request(
                "userDataStream.ping",
                params,
                _api_key=True,
            )
            await asyncio.sleep(interval)

    async def stop(self):
        """
        Stop data stream
        """
        self.operational_status = None  # Not restart and break all loops
        self.order_handling = False
        self.init = True
        [task.cancel() for task in self.tasks if not task.done()]
        self.tasks.clear()
        if self._ws and not self._ws.closed:
            await self._ws.close(code=4000)
        gc.collect()
        logger.info(f"User WSS for {self.ws_id} stopped")

    async def _handle_msg(self, msg):
        if self.exchange == 'binance':
            self._handle_rate_limits(msg.pop('rateLimits', []))
            if msg.get('status') != 200:
                await self.binance_error_handle(msg)
                msg = None
            return msg
        elif self.exchange == 'okx':
            if msg.get('code') != '0':
                await self.okx_error_handle(msg)
                msg = None
            return msg
        elif self.exchange == 'bitfinex':
            return await self.bitfinex_error_handle(msg)

    # region BitfinexErrorHandle
    async def bitfinex_error_handle(self, msg):
        if isinstance(msg, dict):
            return await self._handle_dict_message(msg)
        elif isinstance(msg, list) and msg[1] == 'n' and msg[2][1] in ('on-req', 'oc-req', 'oc_multi-req'):
            return self._transform_list_message(msg)
        else:
            return 'pass'

    async def _handle_dict_message(self, msg):
        event = msg.get('event')
        if event == 'info':
            return await self._handle_info_event(msg)
        elif event == 'auth':
            return msg if msg.get('status') == "OK" else None
        elif msg.get('code'):
            return await self._handle_error_code(msg)
        return 'pass'

    async def _handle_info_event(self, msg):
        if not msg.get('platform', {}).get('status'):
            logger.warning(f"UserWSS Bitfinex platform in maintenance mode: {msg}")
            await self.stop()
        elif msg.get('version') != 2:
            logger.critical('Bitfinex WSS platform: version change detected')
        return 'pass'

    async def _handle_error_code(self, msg):
        code = msg.get('code')
        if code == 10305:
            logger.warning('UserWSS Bitfinex: Reached limit of open channels')
            self._retry_after = int((time.time() + TIMEOUT) * 1000)
            self.request_limit_reached = True
        else:
            logger.warning(f"Malformed request for {self.ws_id}: {msg}")
        return None

    @staticmethod
    def _transform_list_message(msg):
        return {
            "id": msg[2][2],
            "data": [
                msg[2][0],
                msg[2][1],
                None,
                None,
                [msg[2][4]] if msg[2][1] == 'on-req' else msg[2][4],
                None,
                msg[2][6],
                msg[2][7]
            ]
        }
    # endregion

    async def okx_error_handle(self, msg):
        if msg.get('code') == '1':
            logger.warning(f"Operation failed: {msg}")
        elif msg.get('code') == '63999':
            logger.warning(f"An issue occurred on exchange's side: {msg}")
        elif msg.get('code') == '60014':
            self._retry_after = int((time.time() + TIMEOUT) * 1000)
            self.request_limit_reached = True
        logger.warning(f"Malformed request: status: {msg}")

    async def binance_error_handle(self, msg):
        error_msg = msg.get('error')
        logger.error(f"Malformed request: status: {error_msg}")
        if msg.get('status') == 403:
            await self.stop()
        if msg.get('status') in (418, 429):
            self._retry_after = error_msg.get('data', {}).get('retryAfter', int((time.time() + TIMEOUT) * 1000))
            self.request_limit_reached = True

    def _handle_rate_limits(self, rate_limits: []):
        def retry_after():
            return (int(time.time() / interval) + 1) * interval * 1000
        for rl in rate_limits:
            if rl.get('limit') - rl.get('count') <= 0:
                interval = rl.get('intervalNum') * RateLimitInterval[rl.get('interval')].value
                self._retry_after = max(self._retry_after, retry_after())
                if rl.get('rateLimitType') == 'REQUEST_WEIGHT':
                    self.request_limit_reached = True
                elif rl.get('rateLimitType') == 'ORDERS':
                    self.order_handling = False


class UserWSSession:
    __slots__ = (
        "exchange",
        "endpoint",
        "_api_key",
        "_api_secret",
        "_passphrase",
        "user_wss",
        "tasks_wss",
    )

    def __init__(self, exchange, endpoint, api_key, api_secret, passphrase=None):
        if exchange not in ('binance', 'okx', 'bitfinex'):
            raise UserWarning(f"UserWSSession: exchange {exchange} not serviced")
        self.exchange = exchange
        self.endpoint = endpoint
        #
        self._api_key = api_key
        self._api_secret = api_secret
        self._passphrase = passphrase
        self.user_wss = {}
        self.tasks_wss = set()

    async def handle_request(
        self,
        trade_id: str,
        method: str,
        _params=None,
        _api_key=False,
        _signed=False,
    ):
        ws_id = f"{self.exchange}-{trade_id}-{method}"
        user_wss = self.user_wss.setdefault(
            ws_id,
            UserWSS(
                method,
                ws_id,
                self.exchange,
                self.endpoint,
                self._api_key,
                self._api_secret,
                self._passphrase
            )
        )

        if user_wss.init:
            user_wss.init = False
            user_wss.operational_status = False
            _t = asyncio.create_task(user_wss.start_wss())
            self.tasks_wss.add(_t)
            _t.add_done_callback(self.tasks_wss.discard)

        duration = 0
        while not (user_wss.operational_status and user_wss.order_handling):
            await asyncio.sleep(DELAY)
            if duration > TIMEOUT:
                return None
            duration += DELAY

        try:
            return await user_wss.request(_params=_params, _api_key=_api_key, _signed=_signed)
        except asyncio.CancelledError:
            pass  # Task cancellation should not be logged as an error
        except Exception as ex:
            logger.error(f"crypto_ws_api.ws_session.handle_request(): {ex}")
        return None

    async def stop(self):
        user_wss_copy = dict(self.user_wss)
        for ws in user_wss_copy.values():
            await ws.stop()
        self.user_wss.clear()
        [task.cancel() for task in self.tasks_wss if not task.done()]
        self.tasks_wss.clear()