fabiommendes/django-bricks

View on GitHub
src/bricks/ni/converter.py

Summary

Maintainability
C
1 day
Test Coverage
import ast
import collections
import copy
import io
import sys

import astor
from lazyutils import lazy

from bricks.ni.to_ast import to_ast


class FunctionCallTransform:

    def __init__(self, to_name):
        self.to_name = to_name

    def visit(self, transpiler, node):
        node = copy.deepcopy(node)
        node.func.id = self.to_name
        transpiler.visit(node)


def get_default_transforms():
    return {
        'builtins.print': FunctionCallTransform('console.log')
    }


class Scope(collections.MutableMapping):
    """
    An hierarchical dictionary
    """

    def __init__(self, data=None):
        self._data = [{}]
        self._data[0].update(data or {})

    def __getitem__(self, key):
        for D in reversed(self._data):
            try:
                return D[key]
            except KeyError:
                pass
        raise KeyError(key)

    def __setitem__(self, key, value):
        self._data[-1][key] = value

    def __delitem__(self, key):
        raise KeyError('cannot delete keys on Scope dictionaries')

    def __iter__(self):
        used = set()
        for D in reversed(self._data):
            for key in D:
                if key not in used:
                    used.add(key)
                    yield key

    def __len__(self):
        return sum(1 for _ in self)

    def __repr__(self):
        return '%s(%r)' % (self.__class__.__name__, dict(self))

    def push_scope(self, data=None):
        """
        Create a new scope layer.

        Can pre-populate it with a dictionary of data.
        """

        self._data.append({})
        self._data[-1].update(data or {})

    def pop_scope(self):
        """
        Remove the last scope level and return the corresponding dictionary.
        """
        return self._data.pop()


class Transpiler:

    @lazy
    def file(self):
        raise RuntimeError('cannot be accessed outside the write method.')

    def __init__(self, node, dependencies=None, transforms=None):
        self.node = node
        self.indent_level = 0
        self.indent = ''
        self.scope = Scope()
        self.line_ended = False
        self.dependencies = set() if dependencies is None else dependencies
        self.transforms = transforms or get_default_transforms()

    def startline(self, st=''):
        self.file.write(self.indent + st)

    def endline(self, st=';'):
        self.write(st + '\n')
        self.line_ended = True

    def writeln(self, st=''):
        self.file.write(self.indent + st + '\n')

    def indent_up(self):
        self.indent_level += 1
        self.indent = '    ' * self.indent_level

    def indent_down(self):
        self.indent_level -= 1
        self.indent = '    ' * self.indent_level

    def error(self, msg):
        raise ValueError(msg)

    def write(self, value=''):
        self.file.write(value)

    def visit(self, node):
        if type(node) in (list, str, type(None), ast.Store, ast.Load):
            return

        if isinstance(node, (int, str)):
            self.write(str(node))
            return

        name = node.__class__.__name__
        try:
            walker = getattr(self, 'visit_' + name)
        except AttributeError:
            print()
            print('node:', name, file=sys.stderr)
            print('data:', node.__dict__, file=sys.stderr)
            astor.dump(node)
            raise NotImplementedError('node type not supported: %s' % name)
        walker(node)

    def local_var(self, name):
        return None

    def global_var(self, name):
        return None


class CClassTranspiler(Transpiler):
    pass


class CTranspiler(CClassTranspiler):
    pass


def extract_func_args(node):
    """
    Extract argument nodes from a Call element.
    """

    return node.args


class JsTranspiler(CClassTranspiler):

    def visit_body(self, seq, indent=True):
        if indent:
            self.indent_up()

        for obj in seq:
            self.line_ended = False
            self.visit(obj)
            if not self.line_ended:
                self.endline()

        if indent:
            self.indent_down()

    def visit_args(self, args):
        stop = len(args) - 1
        for i, elem in enumerate(args):
            self.visit(elem)
            if i != stop:
                self.write(', ')

    #
    # Function related
    #
    def visit_FunctionDef(self, node):
        # Consistency checks
        if node.decorator_list:
            self.error('function decorators are not supported in JS')

        # Write header
        self.scope.push_scope()
        self.write('function ' + node.name)
        self.visit(node.args)
        self.write(' {\n')

        # Body
        self.visit_body(node.body)

        # Finish
        self.scope.pop_scope()
        self.startline()
        self.endline('}')
        self.scope[node.name] = self.global_var(node.name)

    def visit_Lambda(self, node):
        self.write('function')
        self.visit(node.args)
        self.write('{return ')
        self.visit(node.body)
        self.write('}')

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            name = node.func.id
            if name in self.transforms:
                self.transforms[name].visit(self, node)
                return

        self.visit(node.func)
        self.write('(')
        stop = len(node.args) - 1
        for i, arg in enumerate(node.args):
            self.visit(arg)
            if i != stop:
                self.write(', ')
        self.write(')')

        # We add the function into dependencies set if it was not defined in
        # any previous scope
        if node.func.id not in self.scope:
            self.dependencies.add(node.func.id)

    def visit_arg(self, node):
        self.write(node.arg)

    def visit_arguments(self, node):
        if node.vararg or node.kwonlyargs or node.kw_defaults or \
                node.kwarg or node.defaults:
            self.error('only support simple Python function signatures.')
        self.write('(')
        self.visit_args(node.args)
        self.write(')')

    def visit_Return(self, node):
        self.startline('return ')
        self.visit(node.value)
        self.endline()

    #
    # Assignments
    #
    def visit_Assign(self, node):
        if len(node.targets) == 1:
            name = node.targets[0]
            if name.id not in self.scope:
                self.startline('var ')
                self.scope[name.id] = self.local_var(name.id)
            else:
                self.startline()
            self.visit(name)
            self.write(' = ')
            self.visit(node.value)
            self.endline()
        else:
            raise NotImplementedError

    def visit_AugAssign(self, node):
        self.startline()
        self.visit(node.target)
        self.write(' ')
        self.visit(astor.get_anyop(node.op))
        self.write('= ')
        self.visit(node.value)
        self.endline()

    #
    # Operators
    #
    def visit_BinOp(self, node):
        self.visit(node.left)
        self.write(' %s ' % astor.get_anyop(node.op))
        self.visit(node.right)

    def visit_Compare(self, node):
        self.visit(node.left)
        for op, right in zip(node.ops, node.comparators):
            self.write(' %s ' % astor.get_anyop(op))
            self.visit(right)

    #
    # Loops
    #
    def visit_While(self, node):
        if node.orelse:
            self.error('else clause of while node is not supported')
        self.startline('while (')
        self.visit(node.test)
        self.write(') {\n')
        self.visit_body(node.body)
        self.startline()
        self.endline('}')

    def visit_For(self, node):
        if node.iter.func.id == 'range':
            self.visit_For_range(node)
        else:
            raise NotImplementedError

    def visit_For_range(self, node):
        target = node.target.id
        if target not in self.scope:
            self.scope[target] = self.local_var(target)

        range_args = extract_func_args(node.iter)
        self.startline('for (%s = ' % target)
        if len(range_args) == 0:
            self.error('range has no arguments')
        elif len(range_args) == 1:
            start, end, inc = 0, range_args[0], '%s++' % target
        elif len(range_args) == 2:
            start, end = range_args[0]
            inc = '%s++' % target
        elif len(range_args) == 3:
            start, end, step = range_args
            inc = '%s += %s' % (target, step)
        else:
            self.error('range accepts up to 3 parameters.')

        self.visit(start)
        self.write('; %s < ' % target)
        self.visit(end)
        self.write('; %s)  {\n' % inc)
        self.visit_body(node.body)
        self.endline('}')

    #
    # Conditionals
    #
    def visit_If(self, node, start=True):
        if start:
            self.startline('if (')
        else:
            self.write('if (')
        self.visit(node.test)
        self.write(') {\n')
        self.visit_body(node.body)
        self.startline()
        self.endline('}')

        # Treat the else or else if clauses
        if node.orelse:
            if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
                self.startline('else ')
                self.visit_If(node.orelse[0], False)
            else:
                self.startline('else {\n')
                self.visit_body(node.orelse)
                self.startline()
                self.endline('}')

    #
    # Expressions
    #
    def visit_Expr(self, node):
        self.startline()
        self.visit(node.value)
        self.endline()

    def visit_Pass(self, node):
        self.startline('null')
        self.endline()

    #
    # Primitives
    #
    def visit_Name(self, node):
        self.write(node.id)

    def visit_Num(self, node):
        self.write(str(node.n))

    def visit_Str(self, node):
        self.write(repr(node.s))

    def visit_List(self, node):
        self.write('[')
        self.visit_args(node.elts)
        self.write(']')

    def visit_Dict(self, node):
        self.write('{')
        stop = len(node.keys) - 1
        for i, (k, v) in enumerate(zip(node.keys, node.values)):
            self.visit(k)
            self.write(': ')
            self.visit(v)
            if i != stop:
                self.write(', ')
        self.write('}')

    def visit_NameConstant(self, node):
        value = node.value
        if value is True:
            self.write('true')
        elif value is False:
            self.write('false')
        elif value is None:
            self.write('null')
        else:
            self.error('javascript has no %s constant' % value)

    #
    # Public API and dumpers
    #
    def transpile(self):
        """
        Convert Python AST to JS source.
        """

        file = io.StringIO()
        self.dump(file)
        return file.getvalue()

    def dump(self, file):
        """
        Dump python ast to file.
        """

        self.file = file or sys.stdout
        self.visit(self.node)


def jsdump(pyobj, file=None, **kwargs):
    """
    Dumps JS code in the given file object.

    If no file object is given, dumps to sys.stdout
    """

    transpiler = JsTranspiler(to_ast(pyobj), **kwargs)
    transpiler.dump(file or sys.stdout)


def jstranspile(pyobj, **kwargs):
    """
    Converts python object to Javascript source code.
    """

    transpiler = JsTranspiler(to_ast(pyobj), **kwargs)
    return transpiler.transpile()


def cdump(pyobj, file=None, **kwargs):
    """
    Dumps C code in the given file object.

    If no file object is given, dumps to sys.stdout
    """

    transpiler = CTranspiler(to_ast(pyobj), **kwargs)
    transpiler.dump(file or sys.stdout)


def ctranspile(pyobj, **kwargs):
    """
    Transpile python code or object to C source.
    """

    transpiler = CTranspiler(to_ast(pyobj), **kwargs)
    return transpiler.transpile()


def dump(obj, lang, file=None, **kwargs):
    """
    Dump content of object in file as source code for the given language.
    """

    if lang in ('js', 'javascript'):
        return jsdump(obj, file=file, **kwargs)
    elif lang in ('c'):
        return cdump(obj, file=file, **kwargs)
    else:
        raise ValueError('language not supported: %r' % lang)


def transpile(obj, lang, file=None, **kwargs):
    """
    Transpile python source/object into the choosen language.

    Supported languages are JS, C.
    """

    if lang in ('js', 'javascript'):
        return jstranspile(obj, file=file, **kwargs)
    elif lang in ('c'):
        return ctranspile(obj, file=file, **kwargs)
    else:
        raise ValueError('language not supported: %r' % lang)


class VarType:
    name = None


class NamedVarType(VarType):

    def __init__(self, name):
        self.name = name


#
# Still testing...!
#
from math import sqrt


def entropy(data: 'double[]', N: int) -> 'double':
    S = 0.0
    for i in range(N):
        x = data[i]
        if x == 0.0:
            return x
        else:
            S += - x * log(x)
    return S


jsdump(entropy)


def func(x, y):
    z = sqrt(x + 1e10)
    L = [1, 2, 3]
    x = {'sdfs': 2}
    z += x

    12
    while True:
        if x == 2:
            print(2)
        elif x == 3 == 3:
            print(4)
        else:
            x = lambda x: 2 ** 2
    return z + y


D = set()
jsdump(func, dependencies=D)
print(D)


def jsfunc(func):
    """
    Marks a Python function as a js-compatible one
    """

    if not hasattr(func, '__jssource__'):
        func.__jssource__ = jstranspile(func)

    return func