zero/codegen/codegen.py
import inspect
# from pydantic import BaseModel
class CodeGen:
def __init__(self, rpc_router, rpc_input_type_map, rpc_return_type_map):
self._rpc_router = rpc_router
self._rpc_input_type_map = rpc_input_type_map
self._rpc_return_type_map = rpc_return_type_map
self._typing_imports = set()
def generate_code(self, host="localhost", port=5559):
code = f"""# Generated by Zero
# import types as per needed
from zero import ZeroClient
zero_client = ZeroClient("{host}", {port})
class RpcClient:
def __init__(self, zero_client: ZeroClient):
self._zero_client = zero_client
"""
for func_name in self._rpc_router:
code += f"""
{self.get_function_str(func_name)}
return self._zero_client.call("{func_name}", {
None if self._rpc_input_type_map[func_name] is None
else self.get_function_input_param_name(func_name)
})
"""
# self.generate_data_classes() TODO: next feature
return code
def get_imports(self):
return f"from typing import {', '.join(i for i in self._typing_imports)}"
def get_input_type_str(self, func_name: str): # pragma: no cover
if self._rpc_input_type_map[func_name] is None:
return ""
if self._rpc_input_type_map[func_name].__module__ == "typing":
type_name = self._rpc_input_type_map[func_name]._name
self._typing_imports.add(type_name)
return ": " + type_name
return ": " + self._rpc_input_type_map[func_name].__name__
def get_return_type_str(self, func_name: str): # pragma: no cover
if self._rpc_return_type_map[func_name].__module__ == "typing":
type_name = self._rpc_return_type_map[func_name]._name
self._typing_imports.add(type_name)
return type_name
return self._rpc_return_type_map[func_name].__name__
def get_function_str(self, func_name: str):
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
def_line = [line for line in func_lines if "def" in line][0]
# put self after the first (
def_line = def_line.replace(f"{func_name}(", f"{func_name}(self").replace(
"async ", ""
)
# if there is input, add comma after self
if self._rpc_input_type_map[func_name]:
def_line = def_line.replace(f"{func_name}(self", f"{func_name}(self, ")
return def_line.replace("\n", "")
def get_function_input_param_name(self, func_name: str):
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
def_line = [line for line in func_lines if "def" in line][0]
params = def_line.split("(")[1].split(")")[0]
return params.split(":")[0].strip()
# def generate_data_classes(self):
# code = ""
# for func_name in self._rpc_input_type_map:
# input_class = self._rpc_input_type_map[func_name]
# if input_class and is_pydantic(input_class):
# code += inspect.getsource(input_class)