examples/experiments/if-explode.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"# Convert the pattern that appears in ADT methods:\n",
"#\n",
"# if self.is_state1:\n",
"# ...\n",
"# elif self.is_state2:\n",
"# ...\n",
"# else:\n",
"# ...\n",
"# \n",
"# Into n implementations and distributed among different sub-classes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"import inspect\n",
"import ast\n",
"import types\n",
"from textwrap import dedent\n",
"from sidekick import Union, opt, record\n",
"\n",
"def example(foo, x, *args, y=None, **kwargs):\n",
" if foo.is_just:\n",
" first_expr()\n",
" second_expr()\n",
" elif foo.is_nothing:\n",
" some_expr()\n",
" else:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class Maybe(Union):\n",
" Just = opt(object)\n",
" Nothing = opt()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def extract_case(test):\n",
" \"\"\"\n",
" Extracts (self variable name, case name) from test AST.\n",
" \n",
" AST represents code of the form: ``this.is_case``\n",
" \"\"\"\n",
" if not isinstance(test, ast.Attribute):\n",
" raise ValueError('case is not a simple attribute access')\n",
"\n",
" # Extract attr\n",
" attr = test.attr\n",
" if not attr.startswith('is_'):\n",
" raise ValueError(f'not a valid case attribute name: {attr}')\n",
" case = attr[3:]\n",
"\n",
" # Extract self name\n",
" self_expr = test.value\n",
" if not isinstance(self_expr, ast.Name):\n",
" name = self_expr.__class__.__name__\n",
" raise ValueError(f'invalid self expression: {name}')\n",
" self_name = self_expr.id\n",
" return (self_name, case)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def extract_method(body, func, py_func):\n",
" \"\"\"\n",
" Extracts a method from a block of instructions, the original function\n",
" expression and the original python function implementation.\n",
" \"\"\"\n",
" func_ast = ast.FunctionDef(\n",
" decorator_list=[], \n",
" name=func.name, args=func.args, body=body, \n",
" lineno=func.lineno, col_offset=func.col_offset,\n",
" )\n",
" mod = ast.Module(body=[func_ast])\n",
" skip = py_func.__code__.co_firstlineno - 1\n",
" mod = ast.Module(body=[func_ast])\n",
" ast.increment_lineno(mod, skip)\n",
" code = compile(mod, py_func.__code__.co_filename, 'exec', optimize=1)\n",
" ns = {}\n",
" exec(code, py_func.__globals__, ns)\n",
" return ns[func.name]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def methods_map(function):\n",
" \"\"\"\n",
" Convert case function into a mapping from cases to implementations.\n",
" \"\"\"\n",
" methods = {}\n",
" self_vars = set()\n",
"\n",
" tree = ast.parse(dedent(inspect.getsource(function)))\n",
" func = tree.body[0]\n",
" body = func.body\n",
"\n",
" # Check if body consists of a single If statement\n",
" if not (len(body) == 1 and isinstance(body[0], ast.If)):\n",
" raise ValueError('function body is not a single if statement')\n",
"\n",
" if_block = body[0]\n",
"\n",
" while if_block:\n",
" self, case = extract_case(if_block.test)\n",
" self_vars.add(self)\n",
" method = extract_method(if_block.body, func, function)\n",
" method.__name__ = f'{method.__name__}[{case}]'\n",
" method.__qualname__ = f'{case}.{method.__qualname__}'\n",
" methods[case] = method\n",
"\n",
" orelse = if_block.orelse\n",
" if len(orelse) == 1 and isinstance(orelse[0], ast.If):\n",
" if_block = orelse[0]\n",
" elif not orelse:\n",
" break\n",
" else:\n",
" method = extract_method(orelse, func, function)\n",
" method.__name__ = f'{method.__name__}[else]'\n",
" method.__qualname__ = f'else.{method.__qualname__}'\n",
" methods['else'] = method\n",
" break\n",
"\n",
" if len(self_vars) != 1:\n",
" raise ValueError(f'inconsistent self variables: {self_vars}')\n",
"\n",
" return methods"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def update_union(adt, attr, force=False):\n",
" func = getattr(adt, attr)\n",
" methods = methods_map(func)\n",
" generic = methods.pop('else', None)\n",
" cases = dict(adt._meta.cases)\n",
" case_names = {x.lower(): x for x in cases}\n",
"\n",
" method_map = {case_names[case]: f for case, f in methods.items()}\n",
" case_classes = set(cases.values())\n",
"\n",
" for name, func in method_map.items():\n",
" mcs = cases[name]\n",
" case_classes.remove(mcs)\n",
" if attr not in mcs.__dict__ or force:\n",
" setattr(mcs, attr, method_map[name])\n",
"\n",
" if generic:\n",
" for mcs in case_classes:\n",
" if attr not in mcs.__dict__ or force:\n",
" setattr(mcs, attr, generic)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def case_method(mcs, name=None, force=False):\n",
" # Decorating a method\n",
" if isinstance(mcs, types.FunctionType):\n",
" mcs.is_case_method = True\n",
" return mcs\n",
" \n",
" def decorator(func):\n",
" attr = name or func.__name__\n",
" if hasattr(mcs, attr) and not force:\n",
" raise TypeError(f'{mcs.__name__} already has method {attr}')\n",
" setattr(mcs, attr, func)\n",
" update_union(mcs, attr)\n",
" return func\n",
" \n",
" return decorator\n",
"\n",
"\n",
"def optimize_case_methods(mcs):\n",
" for attr in dir(mcs):\n",
" value = getattr(mcs, attr, None)\n",
" if isinstance(value, types.FunctionType) and getattr(value, 'is_case_method', False):\n",
" update_union(mcs, attr)\n",
" return mcs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"@case_method(Maybe, force=True)\n",
"def foo(self):\n",
" if self.is_just:\n",
" return self.value + 2\n",
" else:\n",
" return 2\n",
" \n",
"@case_method(Maybe, force=True)\n",
"def bar(self, x):\n",
" if self.is_nothing:\n",
" return x\n",
" elif self.is_just:\n",
" return 42\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class Slist1(Union):\n",
" Cons = opt(head=object, tail=object)\n",
" Nil = opt()\n",
" \n",
" def size(self):\n",
" if self.is_cons:\n",
" return 1 + self.tail.size()\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"@optimize_case_methods\n",
"class Slist2(Union):\n",
" Cons = opt(head=object, tail=object)\n",
" Nil = opt()\n",
" \n",
" @case_method\n",
" def size(self):\n",
" if self.is_cons:\n",
" return 1 + self.tail.size()\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class Slist3(Union):\n",
" class Cons(this):\n",
" head: object\n",
" tail: object \n",
" \n",
" def size(self):\n",
" return 1 + self.tail.size()\n",
" \n",
" class Nil(this):\n",
" def size(self):\n",
" return 0\n",
" \n",
" def generic(self):\n",
" x = 0\n",
" while self.is_cons:\n",
" x += 1\n",
" self = self.tail\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"@optimize_case_methods\n",
"class Slist4(Union):\n",
" class Cons(this):\n",
" head: object\n",
" tail: object\n",
" \n",
" class Nil(this):\n",
" pass\n",
" \n",
" @case_method\n",
" def size(self):\n",
" if self.is_cons:\n",
" return 1 + self.tail.size()\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class Slist5:\n",
" __slots__ = ()\n",
" \n",
"class Cons(Slist5):\n",
" __slots__ = ('head', 'tail')\n",
" \n",
" def __init__(self, head, tail):\n",
" self.head = head\n",
" self.tail = tail\n",
" \n",
" def size(self):\n",
" return 1 + self.tail.size()\n",
"\n",
"class Nil(Slist5):\n",
" __slots__ = ()\n",
"\n",
" def size(self):\n",
" return 0\n",
" \n",
"Slist5.Cons = Cons\n",
"Slist5.Nil = Nil()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class Slist6:\n",
" __slots__ = ()\n",
" is_cons = False\n",
" is_nil = False\n",
" \n",
" def size(self):\n",
" if self.is_cons:\n",
" return 1 + self.tail.size()\n",
" else:\n",
" return 0\n",
" \n",
"class Cons(Slist6):\n",
" __slots__ = ('head', 'tail')\n",
" is_cons = True\n",
" \n",
" def __init__(self, head, tail):\n",
" self.head = head\n",
" self.tail = tail\n",
" \n",
"class Nil(Slist6):\n",
" __slots__ = ()\n",
" is_nil = True\n",
"\n",
"Slist6.Cons = Cons\n",
"Slist6.Nil = Nil()\n",
"Slist6._meta = record(cases={'Cons': Cons, 'Nil': Nil})\n",
"\n",
"@case_method(Slist6)\n",
"def size2(self):\n",
" if self.is_cons:\n",
" return 1 + self.tail.size2()\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"hideCode": false,
"hideOutput": true,
"hidePrompt": false
},
"outputs": [],
"source": [
"def to_list(mcs, data):\n",
" return mcs.Cons(data[0], to_list(mcs, data[1:])) if data else mcs.Nil"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"n = 5\n",
"L1 = to_list(Slist1, list(range(n)))\n",
"L2 = to_list(Slist2, list(range(n)))\n",
"L3 = to_list(Slist3, list(range(n)))\n",
"L4 = to_list(Slist4, list(range(n)))\n",
"L5 = to_list(Slist5, list(range(n)))\n",
"L6 = to_list(Slist6, list(range(n)))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.12 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"1.87 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"2.02 µs ± 282 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"1.82 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"907 ns ± 57.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"1.12 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"879 ns ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
]
}
],
"source": [
"%timeit -n 100000 L1.size()\n",
"%timeit -n 100000 L2.size()\n",
"%timeit -n 100000 L3.size()\n",
"%timeit -n 100000 L4.size()\n",
"%timeit -n 100000 L5.size()\n",
"%timeit -n 100000 L6.size()\n",
"%timeit -n 100000 L6.size2()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"227 ns ± 8.83 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"257 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"255 ns ± 5.75 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"260 ns ± 22.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"56.1 ns ± 2.23 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"63.4 ns ± 3.23 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
]
}
],
"source": [
"%timeit -n 1000000 L1.head\n",
"%timeit -n 1000000 L2.head\n",
"%timeit -n 1000000 L3.head\n",
"%timeit -n 1000000 L4.head\n",
"%timeit -n 1000000 L5.head\n",
"%timeit -n 1000000 L6.head"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"236 ns ± 25.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"240 ns ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"261 ns ± 11.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"248 ns ± 14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"56.5 ns ± 5.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"57.9 ns ± 5.77 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
]
}
],
"source": [
"%timeit -n 1000000 L1.tail\n",
"%timeit -n 1000000 L2.tail\n",
"%timeit -n 1000000 L3.tail\n",
"%timeit -n 1000000 L4.tail\n",
"%timeit -n 1000000 L5.tail\n",
"%timeit -n 1000000 L6.tail"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"73.1 ns ± 12.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"75.1 ns ± 7.48 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"72.4 ns ± 3.61 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"73.7 ns ± 8.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"68.1 ns ± 7.82 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"71.8 ns ± 8.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
]
}
],
"source": [
"%timeit -n 1000000 L1.size\n",
"%timeit -n 1000000 L2.size\n",
"%timeit -n 1000000 L3.size\n",
"%timeit -n 1000000 L4.size\n",
"%timeit -n 1000000 L5.size\n",
"%timeit -n 1000000 L6.size"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 12 0 LOAD_CONST 1 (1)\n",
" 2 LOAD_FAST 0 (self)\n",
" 4 LOAD_ATTR 0 (tail)\n",
" 6 LOAD_METHOD 1 (size)\n",
" 8 CALL_METHOD 0\n",
" 10 BINARY_ADD\n",
" 12 RETURN_VALUE\n",
"\n",
" 31 0 LOAD_CONST 1 (1)\n",
" 2 LOAD_FAST 0 (self)\n",
" 4 LOAD_ATTR 0 (tail)\n",
" 6 LOAD_METHOD 1 (size2)\n",
" 8 CALL_METHOD 0\n",
" 10 BINARY_ADD\n",
" 12 RETURN_VALUE\n"
]
}
],
"source": [
"import dis\n",
"dis.dis(Slist5.Cons.size)\n",
"\n",
"print()\n",
"dis.dis(Slist6.Cons.size2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"f2 = Slist2.Nil.size\n",
"f3 = Slist3.Nil.size\n",
"c2 = f2.__code__\n",
"c3 = f3.__code__"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"d2 = {attr: getattr(c2, attr) for attr in dir(c2) if attr.startswith('co_')}\n",
"d3 = {attr: getattr(c3, attr) for attr in dir(c3) if attr.startswith('co_')}"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'co_filename': ('<ipython-input-11-b95d254fa8b0>',\n",
" '<ipython-input-12-027eeb8dd036>'),\n",
" 'co_firstlineno': (6, 10),\n",
" 'co_lnotab': (b'\\x00\\x05', b'\\x00\\x01')}"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"{k1: (v1, v2) for (k1, v1), (k2, v2) in zip(d2.items(), d3.items()) if v2 != v1}"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 1 0 LOAD_CONST 1 (1)\n",
" 2 LOAD_FAST 0 (self)\n",
" 4 LOAD_ATTR 0 (tail)\n",
" 6 LOAD_METHOD 1 (size)\n",
" 8 CALL_METHOD 0\n",
" 10 BINARY_ADD\n",
" 12 RETURN_VALUE\n"
]
}
],
"source": [
"f = lambda self: 1 + self.tail.size()\n",
"Slist3.Cons.size = f \n",
"dis.dis(f)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"ns2 = Slist2.Cons.__dict__\n",
"ns3 = Slist3.Cons.__dict__\n",
"ns4 = Slist4.Cons.__dict__"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"73.2 ns ± 6.69 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"74.8 ns ± 3.08 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"105 ns ± 21 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"\n",
"75.1 ns ± 6.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"71.6 ns ± 2.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n",
"80 ns ± 8.92 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
]
}
],
"source": [
"key = 'size'\n",
"%timeit -n 1000000 ns2[key]\n",
"%timeit -n 1000000 ns3[key]\n",
"%timeit -n 1000000 ns4[key]\n",
"\n",
"print()\n",
"key = 'head'\n",
"%timeit -n 1000000 ns2[key]\n",
"%timeit -n 1000000 ns3[key]\n",
"%timeit -n 1000000 ns4[key]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'n2' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-29-18229c12e6af>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mn2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mns3\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'n2' is not defined"
]
}
],
"source": [
"n2.keys(), ns3.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"exec?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"f = lambda x: x\n",
"x = 42.0\n",
"%timeit f(...)\n",
"%timeit ... if x.real else ...\n",
"%timeit ..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_code_all_hidden": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}