csirtgadgets/verbose-robot

View on GitHub
cif/router.py

Summary

Maintainability
B
4 hrs
Test Coverage
#!/usr/bin/env python3

import ujson as json
import logging
import textwrap
import traceback
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
from time import sleep
import zmq
from zmq import POLLIN as Z_POLLIN
import os
from pprint import pprint

from cif.constants import ROUTER_ADDR, STORE_ADDR, HUNTER_ADDR,  \
    RUNTIME_PATH, ROUTER_STREAM_ENABLED, ROUTER_WEBHOOKS_ENABLED
from cifsdk.constants import CONFIG_PATH
from cifsdk.utils import setup_logging, setup_signals, setup_runtime_path, \
    settings
from cif.utils import get_argument_parser


import time
from cifsdk.msg import Msg
from cif.gatherer import Manager as GathererManager
from cif.streamer import Manager as StreamManager
from cif.webhooks import Manager as WebhooksManager
from cif.hunter import Manager as HunterManager
from cif.store import Manager as StoreManager


HUNTER_MIN_CONFIDENCE = 1
HUNTER_THREADS = os.getenv('CIF_HUNTER_THREADS', 0)
HUNTER_ADVANCED = os.getenv('CIF_HUNTER_ADVANCED', 0)
GATHERER_THREADS = os.getenv('CIF_GATHERER_THREADS', 2)
STORE_DEFAULT = 'sqlite'
STORE_PLUGINS = ['cif.store.dummy', 'cif.store.sqlite', 'cif.store.elasticsearch']

ZMQ_HWM = 1000000
ZMQ_SNDTIMEO = 5000
ZMQ_RCVTIMEO = 5000

FRONTEND_TIMEOUT = os.getenv('CIF_FRONTEND_TIMEOUT', 1)
BACKEND_TIMEOUT = os.getenv('CIF_BACKEND_TIMEOUT', 1)

HUNTER_TOKEN = os.getenv('CIF_HUNTER_TOKEN', None)

STORE_DEFAULT = os.getenv('CIF_STORE_STORE', STORE_DEFAULT)
STORE_NODES = os.getenv('CIF_STORE_NODES')

PIDFILE = os.getenv('CIF_ROUTER_PIDFILE', '%s/cif_router.pid' % RUNTIME_PATH)
CONFIG_PATH = os.environ.get('CIF_ROUTER_CONFIG_PATH', 'router.yml')
if not os.path.isfile(CONFIG_PATH):
    CONFIG_PATH = os.environ.get('CIF_ROUTER_CONFIG_PATH', os.path.join(os.path.expanduser('~'), 'router.yml'))

TRACE = os.getenv('CIF_ROUTER_TRACE')

logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)

if TRACE in [1, '1']:
    logger.setLevel(logging.DEBUG)


class Router(object):

    def __init__(self, listen=ROUTER_ADDR, test=False, **kwargs):

        self.settings = settings(CONFIG_PATH)

        self.context = zmq.Context()

        self.count = 0
        self.count_start = time.time()

        self.terminate = False
        self.test = test

        self.listen = listen

        self.hunter_token = None
        if self.settings and self.settings.get('hunter_token'):
            self.hunter_token = self.settings['hunter_token']

        self.frontend_s = self.context.socket(zmq.ROUTER)
        self.gatherers = None
        self.hunters = None
        self.streamer = None
        self.webhooks = None
        self.store = None

        self.kwargs = kwargs

    def _init_webhooks(self):
        if not ROUTER_WEBHOOKS_ENABLED:
            return False

        self.webhooks = WebhooksManager(self.context)
        self.webhooks.start()

    def _init_streamer(self):
        # CIF_ROUTER_STREAM_ENABLED=1|0
        if not ROUTER_STREAM_ENABLED:
            return False

        self.streamer = StreamManager(self.context)
        self.streamer.start()

    def _init_hunters(self, **kwargs):
        threads = kwargs.get('hunter_threads', HUNTER_THREADS)
        if int(threads) == 0:
            return False

        logger.info('launching hunters...')

        self.hunters = HunterManager(self.context, threads)
        self.hunters.start()

    def _init_gatherers(self, **kwargs):
        logger.info('launching gatherers...')

        threads = kwargs.get('gatherer_threads', GATHERER_THREADS)

        self.gatherers = GathererManager(self.context, threads)
        self.gatherers.start()

    def _init_store(self, **kwargs):
        logger.info('launching store...')

        store_address = kwargs.get('store_address', STORE_NODES)
        store_type = kwargs.get('store_type', STORE_DEFAULT)

        self.store = StoreManager(self.context)
        self.store.start(store_address=store_address, store_type=store_type)

        logger.info('Waiting for Store to initialize...')
        time.sleep(5)
        logger.info("Store Ready....")

    def _init_pollers(self):
        self.poller = zmq.Poller()
        self.poller_backend = zmq.Poller()

        # register the front end poller
        self.poller.register(self.frontend_s, Z_POLLIN)
        self.poller.register(self.store.socket, Z_POLLIN)
        self.poller.register(self.store.s_write, Z_POLLIN)
        self.poller.register(self.store.s_hunter_write, Z_POLLIN)

        # setup the backend..
        self.poller_backend.register(self.gatherers.sink_s, Z_POLLIN)

        if self.hunters:
            self.poller_backend.register(self.hunters.sink, Z_POLLIN)

    def _poll_frontend(self):
        items = dict(self.poller.poll(FRONTEND_TIMEOUT))

        if items.get(self.frontend_s) == Z_POLLIN:
            self.handle_message(self.frontend_s)

        if items.get(self.store.socket) == Z_POLLIN:
            Msg().recv(self.store.socket, relay=self.frontend_s)

        if items.get(self.store.s_write) == Z_POLLIN:
            Msg().recv(self.store.s_write, relay=self.frontend_s)

        if items.get(self.store.s_hunter_write) == Z_POLLIN:
            Msg().recv(self.store.s_hunter_write, relay=self.frontend_s)

    def _poll_backend(self):
        items = dict(self.poller_backend.poll(BACKEND_TIMEOUT))

        if items.get(self.gatherers.sink_s) == Z_POLLIN:
            self.handle_message_gatherer(self.gatherers.sink_s)

        if not self.hunters:
            return

        if items.get(self.hunters.sink) == Z_POLLIN:
            self.handle_message(self.hunters.sink)

    def _log_counter(self):
        self.count += 1
        if (self.count % 100) == 0:
            t = (time.time() - self.count_start)
            n = self.count / t
            logger.info('processing {} msgs per {} sec'.format(round(n, 2), round(t, 2)))
            self.count = 0
            self.count_start = time.time()

    def handle_message(self, s):
        id, token, mtype, data = Msg().recv(s)

        handler = self.handle_message_default
        if mtype in ['indicators_create', 'indicators_search']:
            handler = getattr(self, "handle_" + mtype)

        logger.debug(f"handling message: {mtype}")

        try:
            handler(id, mtype, token, data)
        except Exception as e:
            logger.error(e)

        self._log_counter()

    def handle_message_default(self, id, mtype, token, data='[]'):
        Msg(id=id, mtype=mtype, token=token, data=data).send(self.store.socket)

    def handle_message_gatherer(self, s):
        id, token, mtype, data = Msg().recv(s)

        sock = self.store.s_write
        if token == self.hunter_token:
            sock = self.store.s_hunter_write

        Msg(id=id, mtype=mtype, token=token, data=data).send(sock)

        if not self.hunters and not self.streamer and not self.webhooks:
            return

        data = json.loads(data)
        if isinstance(data, dict):
            data = [data]

        for d in data:
            s = json.dumps(d)

            if self.streamer:
                self.streamer.socket.send_multipart([s.encode('utf-8')])

            if self.webhooks:
                self.webhooks.socket.send_string(s)

            if self.hunters and int(d.get('confidence', 0)) \
                    >= HUNTER_MIN_CONFIDENCE:
                self.hunters.socket.send_string(s)

    def handle_indicators_search(self, id, mtype, token, data):
        self.handle_message_default(id, mtype, token, data)

        # TODO- issue here with un-authorized messages, may need to
        # re-think using store success/fail status
        if self.hunters:
            self.hunters.socket.send_string(data)

        if self.streamer:
            self.streamer.socket.send_string(data)

        if self.webhooks:
            self.webhooks.socket.send_string(data)

    def handle_indicators_create(self, id, mtype, token, data):
        Msg(id=id, mtype=mtype, token=token, data=data)\
            .send(self.gatherers.s)

    def start(self):
        self._init_store(**self.kwargs)
        self._init_gatherers(**self.kwargs)

        self._init_hunters(**self.kwargs)
        self._init_streamer()
        self._init_webhooks()

        logger.info('launching frontend...')
        self.frontend_s.set_hwm(ZMQ_HWM)
        self.frontend_s.bind(self.listen)

        self._init_pollers()

        logger.debug('starting loop')

        # we use this instead of a loop so we can make sure to get front end
        # queries as they come in that way hunters don't over burden the store,
        # think of it like QoS it's weighted so front end has a higher chance
        # of getting a faster response
        while not self.terminate:
            self._poll_frontend()

            self._poll_backend()

            if self.test:
                break

    def stop(self):
        self.terminate = True
        logger.debug('shutting down front end..')

        if self.frontend_s:
            self.frontend_s.close()
            sleep(0.5)

        # hunters come first
        for m in ['hunters', 'gatherers', 'streamer', 'webhooks',
                  'store']:
            if getattr(self, m):
                logger.debug(f"stopping {m}...")
                getattr(self, m).stop()
                sleep(0.5)


def main():
    p = get_argument_parser()
    p = ArgumentParser(
        description=textwrap.dedent('''\
        Env Variables:
            CIF_RUNTIME_PATH
            CIF_ROUTER_CONFIG_PATH
            CIF_ROUTER_ADDR
            CIF_HUNTER_ADDR
            CIF_HUNTER_TOKEN
            CIF_HUNTER_THREADS
            CIF_GATHERER_THREADS
            CIF_STORE_ADDR

        example usage:
            $ cif-router --listen 0.0.0.0 -d
        '''),
        formatter_class=RawDescriptionHelpFormatter,
        prog='cif-router',
        parents=[p]
    )

    p.add_argument('--config',
                   help='specify config path [default: %(default)s',
                   default=CONFIG_PATH)

    p.add_argument('--listen',
                   help='address to listen on [default: %(default)s]',
                   default=ROUTER_ADDR)

    p.add_argument('-G', '--gatherers', '--gatherer-threads',
                   help='specify number of gatherer threads to use '
                        '[default: %(default)s]',
                   default=GATHERER_THREADS)

    p.add_argument('--hunter', help='address hunters listen on on '
                                    '[default: %(default)s]',
                   default=HUNTER_ADDR)

    p.add_argument('--hunter-token', help='specify token for hunters to use '
                                          '[default: %(default)s]',
                   default=HUNTER_TOKEN)

    p.add_argument('-H', '--hunters', '--hunter-threads',
                   help='specify number of hunter threads to use '
                        '[default: %(default)s]',
                   default=HUNTER_THREADS)

    p.add_argument("--store-address",
                   help="specify the store address cif-router is listening on "
                        "[default: %(default)s]", default=STORE_ADDR)

    p.add_argument("--store",
                   help=f"specify a store type {', '.join(STORE_PLUGINS)} "
                   f"[default: %(default)s]",
                   default=STORE_DEFAULT)

    p.add_argument('--store-nodes',
                   help='specify storage nodes address [default: %(default)s]',
                   default=STORE_NODES)

    p.add_argument('--logging-ignore',
                   help='set logging to WARNING for specific modules')

    p.add_argument('--pidfile',
                   help='specify pidfile location')

    args = p.parse_args()
    setup_logging(args)

    if args.verbose:
        logger.setLevel(logging.INFO)

    if args.debug:
        logger.setLevel(logging.DEBUG)

    logger.info('loglevel is: {}'.format(logging.getLevelName(logger.getEffectiveLevel())))

    if args.logging_ignore:
        to_ignore = args.logging_ignore.split(',')

        for i in to_ignore:
            logging.getLogger(i).setLevel(logging.WARNING)

    setup_runtime_path(args.runtime_path)
    setup_signals(__name__)

    # http://stackoverflow.com/a/789383/7205341
    pid = str(os.getpid())
    logger.debug("pid: %s" % pid)

    if args.pidfile and os.path.isfile(args.pidfile):
        logger.critical("%s already exists, exiting" % args.pidfile)
        raise SystemExit

    r = Router(listen=args.listen, store_type=args.store,
               store_address=args.store_address,
               store_nodes=args.store_nodes, hunter_token=args.hunter_token,
               hunter_threads=args.hunters, gatherer_threads=args.gatherers)

    if args.pidfile:
        try:
            pidfile = open(args.pidfile, 'w')
            pidfile.write(pid)
            pidfile.close()

        except PermissionError as e:
            logger.critical('unable to create pid %s' % args.pidfile)
            raise SystemExit

    try:
        logger.info('starting router..')
        r.start()

    except KeyboardInterrupt:
        # todo - signal to threads to shut down and wait for them to finish
        logger.info('shutting down via SIGINT...')

    except SystemExit:
        logger.info('shutting down via SystemExit...')

    except Exception as e:
        print("TEST")
        logger.critical(e)
        traceback.print_exc()

    logger.info('stopping..')
    r.stop()

    logger.info('Shutting down')
    if args.pidfile and os.path.isfile(args.pidfile):
        os.unlink(args.pidfile)


if __name__ == "__main__":
    main()