Ananto30/zero

View on GitHub
zero/client_server/worker.py

Summary

Maintainability
A
2 hrs
Test Coverage
import asyncio
import inspect
import logging
import time
from typing import Optional

from zero import config
from zero.codegen.codegen import CodeGen
from zero.encoder.protocols import Encoder
from zero.error import SERVER_PROCESSING_ERROR
from zero.zero_mq.factory import get_worker


class _Worker:
    def __init__(
        self,
        rpc_router: dict,
        device_comm_channel: str,
        encoder: Encoder,
        rpc_input_type_map: dict,
        rpc_return_type_map: dict,
    ):
        self._rpc_router = rpc_router
        self._device_comm_channel = device_comm_channel
        self._encoder = encoder
        self._rpc_input_type_map = rpc_input_type_map
        self._rpc_return_type_map = rpc_return_type_map

        self._loop = asyncio.new_event_loop() or asyncio.get_event_loop()

        self.codegen = CodeGen(
            self._rpc_router,
            self._rpc_input_type_map,
            self._rpc_return_type_map,
        )

    def start_dealer_worker(self, worker_id):
        def process_message(data: bytes) -> Optional[bytes]:
            try:
                decoded = self._encoder.decode(data)
                req_id, func_name, msg = decoded
                response = self.handle_msg(func_name, msg)
                return self._encoder.encode([req_id, response])
            except (
                Exception
            ) as inner_exc:  # pragma: no cover pylint: disable=broad-except
                logging.exception(inner_exc)
                return self._encoder.encode(
                    ["", {"__zerror__server_exception": SERVER_PROCESSING_ERROR}]
                )

        worker = get_worker(config.ZEROMQ_PATTERN, worker_id)
        try:
            worker.listen(self._device_comm_channel, process_message)
        except KeyboardInterrupt:
            logging.warning(
                "Caught KeyboardInterrupt, terminating worker %d", worker_id
            )
        except Exception as exc:  # pylint: disable=broad-except
            logging.exception(exc)
        finally:
            logging.warning("Closing worker %d", worker_id)
            worker.close()

    def handle_msg(self, rpc, msg):
        if rpc == "get_rpc_contract":
            return self.generate_rpc_contract(msg)

        if rpc == "connect":
            return "connected"

        if rpc not in self._rpc_router:
            logging.error("Function `%s` not found!", rpc)
            return {"__zerror__function_not_found": f"Function `{rpc}` not found!"}

        func = self._rpc_router[rpc]
        ret = None

        try:
            # TODO: is this a bottleneck
            if inspect.iscoroutinefunction(func):
                # this is blocking
                ret = self._loop.run_until_complete(func(msg) if msg else func())
            else:
                ret = func(msg) if msg else func()

        except Exception as exc:  # pylint: disable=broad-except
            logging.exception(exc)
            ret = {"__zerror__server_exception": repr(exc)}

        return ret

    def generate_rpc_contract(self, msg):
        try:
            return self.codegen.generate_code(msg[0], msg[1])
        except Exception as exc:  # pylint: disable=broad-except
            logging.exception(exc)
            return {"__zerror__failed_to_generate_client_code": str(exc)}

    @classmethod
    def spawn_worker(
        cls,
        rpc_router: dict,
        device_comm_channel: str,
        encoder: Encoder,
        rpc_input_type_map: dict,
        rpc_return_type_map: dict,
        worker_id: int,
    ) -> None:
        """
        Spawn a worker process.

        A class method is used because the worker process is spawned using multiprocessing.Process.
        The class method is used to avoid pickling the class instance (which can lead to errors).
        """
        # give some time for the broker to start
        time.sleep(0.2)

        worker = _Worker(
            rpc_router,
            device_comm_channel,
            encoder,
            rpc_input_type_map,
            rpc_return_type_map,
        )
        worker.start_dealer_worker(worker_id)