Ananto30/zero

View on GitHub
zero/codegen/codegen.py

Summary

Maintainability
A
0 mins
Test Coverage
import inspect

# from pydantic import BaseModel


class CodeGen:
    def __init__(self, rpc_router, rpc_input_type_map, rpc_return_type_map):
        self._rpc_router = rpc_router
        self._rpc_input_type_map = rpc_input_type_map
        self._rpc_return_type_map = rpc_return_type_map
        self._typing_imports = set()

    def generate_code(self, host="localhost", port=5559):
        code = f"""# Generated by Zero
# import types as per needed

from zero import ZeroClient


zero_client = ZeroClient("{host}", {port})


class RpcClient:
    def __init__(self, zero_client: ZeroClient):
        self._zero_client = zero_client
"""

        for func_name in self._rpc_router:
            code += f"""
    {self.get_function_str(func_name)}
        return self._zero_client.call("{func_name}", {
            None if self._rpc_input_type_map[func_name] is None
        else self.get_function_input_param_name(func_name)
        })
"""
        # self.generate_data_classes()  TODO: next feature
        return code

    def get_imports(self):
        return f"from typing import {', '.join(i for i in self._typing_imports)}"

    def get_input_type_str(self, func_name: str):  # pragma: no cover
        if self._rpc_input_type_map[func_name] is None:
            return ""
        if self._rpc_input_type_map[func_name].__module__ == "typing":
            type_name = self._rpc_input_type_map[func_name]._name
            self._typing_imports.add(type_name)
            return ": " + type_name
        return ": " + self._rpc_input_type_map[func_name].__name__

    def get_return_type_str(self, func_name: str):  # pragma: no cover
        if self._rpc_return_type_map[func_name].__module__ == "typing":
            type_name = self._rpc_return_type_map[func_name]._name
            self._typing_imports.add(type_name)
            return type_name
        return self._rpc_return_type_map[func_name].__name__

    def get_function_str(self, func_name: str):
        func = self._rpc_router[func_name][0]
        func_lines = inspect.getsourcelines(func)[0]
        def_line = [line for line in func_lines if "def" in line][0]

        # put self after the first (
        def_line = def_line.replace(f"{func_name}(", f"{func_name}(self").replace(
            "async ", ""
        )

        # if there is input, add comma after self
        if self._rpc_input_type_map[func_name]:
            def_line = def_line.replace(f"{func_name}(self", f"{func_name}(self, ")

        return def_line.replace("\n", "")

    def get_function_input_param_name(self, func_name: str):
        func = self._rpc_router[func_name][0]
        func_lines = inspect.getsourcelines(func)[0]
        def_line = [line for line in func_lines if "def" in line][0]
        params = def_line.split("(")[1].split(")")[0]
        return params.split(":")[0].strip()

    # def generate_data_classes(self):
    #     code = ""
    #     for func_name in self._rpc_input_type_map:
    #         input_class = self._rpc_input_type_map[func_name]
    #         if input_class and is_pydantic(input_class):
    #             code += inspect.getsource(input_class)