examples/experiments/function-calling.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"%load_ext cython"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"NOT_GIVEN = object()\n",
"impl = add_lambda = lambda x, y: x + y\n",
"\n",
"# Currying techniques\n",
"def curry2_with_exception(impl):\n",
" def func(*args):\n",
" try:\n",
" return impl(*args)\n",
" except TypeError as exc:\n",
" if len(args) == 1:\n",
" return lambda y: impl(x, y)\n",
" else:\n",
" raise\n",
" return func\n",
"\n",
"\n",
"def curry2_with_exception_kw(impl):\n",
" def func(*args, **kwargs):\n",
" try:\n",
" return impl(*args, **kwargs)\n",
" except TypeError as exc:\n",
" if len(args) == 1:\n",
" return lambda y, **kw: impl(x, y, **kwargs, **kw)\n",
" else:\n",
" raise\n",
" return func\n",
"\n",
"\n",
"def curry2_with_extra_args(impl):\n",
" def func(x, *args):\n",
" if args:\n",
" y, = args\n",
" return impl(x, y)\n",
" return lambda y: impl(x, y)\n",
" return func\n",
" \n",
" \n",
"def curry2_with_not_given(impl):\n",
" def func(x, y=NOT_GIVEN):\n",
" if y is NOT_GIVEN:\n",
" return lambda y: impl(x, y)\n",
" return impl(x, y)\n",
" return func\n",
"\n",
"\n",
"def curry2_with_args(impl):\n",
" def func(*args):\n",
" if len(args) == 2:\n",
" return impl(*args)\n",
" elif len(args) == 1:\n",
" x = args[0]\n",
" return lambda y: impl(x, y)\n",
" else:\n",
" raise TypeError('invalid number of args')\n",
" return func\n",
"\n",
"\n",
"def curry(n, func):\n",
" def incomplete_factory(arity, used_args):\n",
" return lambda *args: (\n",
" func(*(used_args + args))\n",
" if len(used_args) + len(args) >= arity\n",
" else incomplete_factory(arity, used_args + args)\n",
" )\n",
" return incomplete_factory(n, ())\n",
"\n",
"\n",
"def curry_kw(n, func):\n",
" def incomplete_factory(arity, used_args, used_kwargs):\n",
" return lambda *args, **kwargs: (\n",
" func(*(used_args + args), **used_kwargs, **kwargs)\n",
" if len(used_args) + len(args) >= arity\n",
" else incomplete_factory(arity, used_args + args, {**used_kwargs, **kwargs})\n",
" )\n",
" return incomplete_factory(n, (), {})"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"class wrapped:\n",
" __slots__ = 'func',\n",
" def __init__(self, func):\n",
" self.func = func\n",
" \n",
" def __call__(self, *args, **kwargs):\n",
" return self.func(*args, **kwargs)\n",
" \n",
"class wrapped2:\n",
" __slots__ = 'func',\n",
" def __init__(self, func):\n",
" self.func = func\n",
" \n",
" def __call__(self, x, y):\n",
" return self.func(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"%%cython\n",
"\n",
"cimport cython\n",
"\n",
"NOT_GIVEN = object()\n",
"\n",
" \n",
"@cython.freelist(8)\n",
"cdef class cy:\n",
" cdef object func\n",
" cdef int arity\n",
" \n",
" def __init__(self, arity, func):\n",
" self.func = func\n",
" self.arity = arity\n",
" \n",
" def __call__(self, *args):\n",
" cdef int n = len(args)\n",
" if n >= 2:\n",
" return self.func(*args)\n",
" elif n == 1:\n",
" f = self.func \n",
" x = args[0]\n",
" return lambda y: f(x, y)\n",
" else:\n",
" raise TypeError('invalid number of arguments')\n",
" \n",
"def cy_fn(int arity, impl):\n",
" def foo(*args, **kwargs):\n",
" cdef int n = len(args)\n",
" if n >= arity:\n",
" return impl(*args, **kwargs)\n",
" else:\n",
" return lambda *xs: impl(*(xs + args))\n",
" return foo"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"hideCode": false,
"hideOutput": true,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"124 ns ± 20.4 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
"72.8 ns ± 13.2 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
"202 ns ± 57.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n"
]
}
],
"source": [
"from operator import add\n",
"x = 1\n",
"\n",
"# Baseline\n",
"%timeit -n 100000 -r 100 (1).__add__(2)\n",
"%timeit -n 100000 -r 100 add(1, 2)\n",
"%timeit -n 100000 -r 100 impl(1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"259 ns ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"591 ns ± 56.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"553 ns ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"670 ns ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"504 ns ± 59.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"615 ns ± 121 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"209 ns ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"274 ns ± 42.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"467 ns ± 56.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"325 ns ± 39.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
]
}
],
"source": [
"# Function based currying techniques for 2-arg functions\n",
"f1 = curry2_with_extra_args(impl)\n",
"f2 = curry2_with_not_given(impl)\n",
"f3 = curry2_with_args(impl)\n",
"f4 = curry2_with_exception(impl)\n",
"f5 = curry2_with_exception_kw(impl)\n",
"f6 = cy(2, impl)\n",
"f7 = cy_fn(2, impl)\n",
"f8 = wrapped(impl)\n",
"f9 = wrapped2(impl)\n",
"\n",
"%timeit -n 100000 impl(1, 2)\n",
"%timeit -n 100000 f1(1, 2)\n",
"%timeit -n 100000 f2(1, 2)\n",
"%timeit -n 100000 f3(1, 2)\n",
"%timeit -n 100000 f4(1, 2)\n",
"%timeit -n 100000 f5(1, 2)\n",
"%timeit -n 100000 f6(1, 2)\n",
"%timeit -n 100000 f7(1, 2)\n",
"%timeit -n 100000 f8(1, 2)\n",
"%timeit -n 100000 f9(1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"834 ns ± 23.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
"768 ns ± 192 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
]
}
],
"source": [
"# Generic currying\n",
"f1 = curry(2, impl)\n",
"f2 = curry_kw(2, impl)\n",
"\n",
"%timeit -n 100000 f1(1, 2)\n",
"%timeit -n 100000 f2(1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'cytoolz'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-28-70cd7f04714c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Other libs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtoolz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcytoolz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuncy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mf1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoolz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mf2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcytoolz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cytoolz'"
]
}
],
"source": [
"# Other libs\n",
"import toolz, cytoolz, funcy\n",
"\n",
"f1 = toolz.curry(impl)\n",
"f2 = cytoolz.curry(impl)\n",
"f3 = funcy.curry(impl, 2)\n",
"f4 = funcy.autocurry(impl, 2)\n",
"\n",
"%timeit -n 100000 f1(1, 2)\n",
"%timeit -n 100000 f2(1, 2)\n",
"%timeit -n 100000 f3(1)(2)\n",
"%timeit -n 100000 f4(1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 462,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"209 ns ± 11.2 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
"213 ns ± 13.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n"
]
}
],
"source": [
"# Cython accelerators\n",
"f1 = cy(2, impl)\n",
"f2 = cy_fn(2, impl)\n",
"\n",
"%timeit -n 100000 -r 100 f1(1, 2)\n",
"%timeit -n 100000 -r 100 f2(1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 463,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [],
"source": [
"def foo(f):\n",
" def foo1(x):\n",
" g = f(x)\n",
" def foo2(y):\n",
" return g(y)\n",
" return foo2\n",
" return foo1\n",
"\n",
"\n",
"def foo2(f):\n",
" x = yield\n",
" g = f(x)\n",
" y = yield\n",
" return g(y)\n",
" \n",
"foo(curry(2, add))(1)(2)\n",
"\n",
"\n",
"it = foo2(curry(2, add))\n",
"next(it)\n",
"it.send(1)\n",
"# it.send(2)\n",
"# next(it)"
]
},
{
"cell_type": "code",
"execution_count": 476,
"metadata": {
"hideCode": false,
"hidePrompt": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"starting\n",
"{'args': (1, 2),\n",
" 'func': <function <lambda> at 0x7f14aea0df28>,\n",
" 'kwargs': {},\n",
" 'spec': [1, 2, 1]}\n",
"\n",
"reduced\n",
"{'args': (2,),\n",
" 'argspec': [1, 2],\n",
" 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
" 'func': <function <lambda> at 0x7f14aea0df28>,\n",
" 'kwargs': {},\n",
" 'n': 1,\n",
" 'n_args': 1,\n",
" 'partial_args': (1,),\n",
" 'spec': [1, 2, 1]}\n",
"\n",
"post_reduction\n",
"{'args': (2,),\n",
" 'argspec': [1, 2],\n",
" 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
" 'func': <function <lambda> at 0x7f14aea0df28>,\n",
" 'kwargs': {},\n",
" 'n': 2,\n",
" 'n_args': 1,\n",
" 'partial_args': (1,),\n",
" 'spec': [1, 2, 1]}\n",
"\n",
"starting\n",
"{'args': (3,),\n",
" 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
" 'kwargs': {},\n",
" 'spec': [1, 2]}\n",
"\n",
"post_reduction\n",
"{'args': (3,),\n",
" 'argspec': [1, 1],\n",
" 'call': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
" 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
" 'kwargs': {},\n",
" 'n': 2,\n",
" 'n_args': 1,\n",
" 'spec': [1, 2]}\n",
"\n",
"starting\n",
"{'args': (4,),\n",
" 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
" 'kwargs': {},\n",
" 'spec': [1, 1]}\n",
"\n",
"reduced\n",
"{'args': (),\n",
" 'argspec': [1],\n",
" 'call': <function <lambda>.<locals>.<lambda>.<locals>.<lambda> at 0x7f14ae9d9d90>,\n",
" 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
" 'kwargs': {},\n",
" 'n': 1,\n",
" 'n_args': 0,\n",
" 'partial_args': (4,),\n",
" 'spec': [1, 1]}\n",
"\n",
"post_reduction\n",
"{'args': (),\n",
" 'argspec': [1],\n",
" 'call': <function <lambda>.<locals>.<lambda>.<locals>.<lambda> at 0x7f14ae9d9d90>,\n",
" 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
" 'kwargs': {},\n",
" 'n': 1,\n",
" 'n_args': 0,\n",
" 'partial_args': (4,),\n",
" 'spec': [1, 1]}\n",
"\n"
]
},
{
"data": {
"text/plain": [
"(<function __main__.curry.<locals>.incomplete_factory.<locals>.<lambda>(*args)>,)"
]
},
"execution_count": 476,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from functools import partial\n",
"from pprint import pprint\n",
"log = lambda x, y: print(x) or pprint(y) or print()\n",
"\n",
"c_add = lambda x: lambda y: lambda z: x + y + z\n",
"c_add2 = lambda x: lambda y, z: lambda w: x + y + z + w\n",
"\n",
"\n",
"def uncurry(spec, func):\n",
" if len(spec) == 1:\n",
" return curry(spec[0], func)\n",
" \n",
" spec = list(reversed(spec))\n",
" \n",
" def curried(*args, **kwargs):\n",
" log('starting', locals())\n",
" n_args = len(args)\n",
" argspec = spec.copy()\n",
" call = func\n",
" n = argspec.pop()\n",
" \n",
" if n > n_args:\n",
" argspec.append(n - n_args)\n",
" call = partial(call, *args, **kwargs)\n",
" else:\n",
" while n_args >= n:\n",
" partial_args = args[:n]\n",
" args = args[n:]\n",
" call = call(*partial_args, **kwargs)\n",
" n_args -= n\n",
" kwargs = {}\n",
" log('reduced', locals())\n",
" if argspec:\n",
" n = argspec.pop()\n",
" else:\n",
" break\n",
" else:\n",
" argspec.append(n)\n",
"\n",
" log('post_reduction', locals())\n",
" \n",
" if argspec:\n",
" return uncurry(argspec[::-1], call)\n",
" else:\n",
" return call\n",
"\n",
" return curried\n",
"\n",
"\n",
"# (\n",
"# uncurry([1, 1, 1], c_add)(1, 2, 3),\n",
"# uncurry([1, 1, 1], c_add)(1)(2)(3), \n",
"# uncurry([1, 1, 1], c_add)(1)(2, 3),\n",
"# uncurry([1, 1, 1], c_add)(1, 2)(3),\n",
"# )\n",
"\n",
"\n",
"(\n",
"# uncurry([1, 2, 1], c_add2)(1, 2, 3, 4),\n",
"# uncurry([1, 2, 1], c_add2)(1)(2)(3)(4), \n",
"# uncurry([1, 2, 1], c_add2)(1)(2, 3)(4),\n",
" uncurry([1, 2, 1], c_add2)(1, 2)(3)(4),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 452,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.13 µs ± 366 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"3.55 µs ± 75.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"406 ns ± 10.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"r = lambda x: lambda y: x + y\n",
"cadd = uncurry([1, 1], r)\n",
"%timeit -n 10000 cadd(1, 2)\n",
"%timeit -n 10000 cadd(1)(2)\n",
"%timeit -n 10000 r(1)(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_code_all_hidden": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}