fabiommendes/sidekick

View on GitHub
examples/experiments/function-calling.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [],
   "source": [
    "%load_ext cython"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [],
   "source": [
    "NOT_GIVEN = object()\n",
    "impl = add_lambda = lambda x, y: x + y\n",
    "\n",
    "# Currying techniques\n",
    "def curry2_with_exception(impl):\n",
    "    def func(*args):\n",
    "        try:\n",
    "            return impl(*args)\n",
    "        except TypeError as exc:\n",
    "            if len(args) == 1:\n",
    "                return lambda y: impl(x, y)\n",
    "            else:\n",
    "                raise\n",
    "    return func\n",
    "\n",
    "\n",
    "def curry2_with_exception_kw(impl):\n",
    "    def func(*args, **kwargs):\n",
    "        try:\n",
    "            return impl(*args, **kwargs)\n",
    "        except TypeError as exc:\n",
    "            if len(args) == 1:\n",
    "                return lambda y, **kw: impl(x, y, **kwargs, **kw)\n",
    "            else:\n",
    "                raise\n",
    "    return func\n",
    "\n",
    "\n",
    "def curry2_with_extra_args(impl):\n",
    "    def func(x, *args):\n",
    "        if args:\n",
    "            y, = args\n",
    "            return impl(x, y)\n",
    "        return lambda y: impl(x, y)\n",
    "    return func\n",
    "    \n",
    "    \n",
    "def curry2_with_not_given(impl):\n",
    "    def func(x, y=NOT_GIVEN):\n",
    "        if y is NOT_GIVEN:\n",
    "            return lambda y: impl(x, y)\n",
    "        return impl(x, y)\n",
    "    return func\n",
    "\n",
    "\n",
    "def curry2_with_args(impl):\n",
    "    def func(*args):\n",
    "        if len(args) == 2:\n",
    "            return impl(*args)\n",
    "        elif len(args) == 1:\n",
    "            x = args[0]\n",
    "            return lambda y: impl(x, y)\n",
    "        else:\n",
    "            raise TypeError('invalid number of args')\n",
    "    return func\n",
    "\n",
    "\n",
    "def curry(n, func):\n",
    "    def incomplete_factory(arity, used_args):\n",
    "        return lambda *args: (\n",
    "            func(*(used_args + args))\n",
    "            if len(used_args) + len(args) >= arity\n",
    "            else incomplete_factory(arity, used_args + args)\n",
    "        )\n",
    "    return incomplete_factory(n, ())\n",
    "\n",
    "\n",
    "def curry_kw(n, func):\n",
    "    def incomplete_factory(arity, used_args, used_kwargs):\n",
    "        return lambda *args, **kwargs: (\n",
    "            func(*(used_args + args), **used_kwargs, **kwargs)\n",
    "            if len(used_args) + len(args) >= arity\n",
    "            else incomplete_factory(arity, used_args + args, {**used_kwargs, **kwargs})\n",
    "        )\n",
    "    return incomplete_factory(n, (), {})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [],
   "source": [
    "class wrapped:\n",
    "    __slots__ = 'func',\n",
    "    def __init__(self, func):\n",
    "        self.func = func\n",
    "    \n",
    "    def __call__(self, *args, **kwargs):\n",
    "        return self.func(*args, **kwargs)\n",
    "    \n",
    "class wrapped2:\n",
    "    __slots__ = 'func',\n",
    "    def __init__(self, func):\n",
    "        self.func = func\n",
    "    \n",
    "    def __call__(self, x, y):\n",
    "        return self.func(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [],
   "source": [
    "%%cython\n",
    "\n",
    "cimport cython\n",
    "\n",
    "NOT_GIVEN = object()\n",
    "\n",
    "            \n",
    "@cython.freelist(8)\n",
    "cdef class cy:\n",
    "    cdef object func\n",
    "    cdef int arity\n",
    "    \n",
    "    def __init__(self, arity, func):\n",
    "        self.func = func\n",
    "        self.arity = arity\n",
    "    \n",
    "    def __call__(self, *args):\n",
    "        cdef int n = len(args)\n",
    "        if n >= 2:\n",
    "            return self.func(*args)\n",
    "        elif n == 1:\n",
    "            f = self.func \n",
    "            x = args[0]\n",
    "            return lambda y: f(x, y)\n",
    "        else:\n",
    "            raise TypeError('invalid number of arguments')\n",
    "            \n",
    "def cy_fn(int arity, impl):\n",
    "    def foo(*args, **kwargs):\n",
    "        cdef int n = len(args)\n",
    "        if n >= arity:\n",
    "            return impl(*args, **kwargs)\n",
    "        else:\n",
    "            return lambda *xs: impl(*(xs + args))\n",
    "    return foo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "hideCode": false,
    "hideOutput": true,
    "hidePrompt": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "124 ns ± 20.4 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
      "72.8 ns ± 13.2 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
      "202 ns ± 57.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n"
     ]
    }
   ],
   "source": [
    "from operator import add\n",
    "x = 1\n",
    "\n",
    "# Baseline\n",
    "%timeit -n 100000 -r 100 (1).__add__(2)\n",
    "%timeit -n 100000 -r 100 add(1, 2)\n",
    "%timeit -n 100000 -r 100 impl(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "259 ns ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "591 ns ± 56.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "553 ns ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "670 ns ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "504 ns ± 59.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "615 ns ± 121 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "209 ns ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "274 ns ± 42.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "467 ns ± 56.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "325 ns ± 39.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
     ]
    }
   ],
   "source": [
    "# Function based currying techniques for 2-arg functions\n",
    "f1 = curry2_with_extra_args(impl)\n",
    "f2 = curry2_with_not_given(impl)\n",
    "f3 = curry2_with_args(impl)\n",
    "f4 = curry2_with_exception(impl)\n",
    "f5 = curry2_with_exception_kw(impl)\n",
    "f6 = cy(2, impl)\n",
    "f7 = cy_fn(2, impl)\n",
    "f8 = wrapped(impl)\n",
    "f9 = wrapped2(impl)\n",
    "\n",
    "%timeit -n 100000 impl(1, 2)\n",
    "%timeit -n 100000 f1(1, 2)\n",
    "%timeit -n 100000 f2(1, 2)\n",
    "%timeit -n 100000 f3(1, 2)\n",
    "%timeit -n 100000 f4(1, 2)\n",
    "%timeit -n 100000 f5(1, 2)\n",
    "%timeit -n 100000 f6(1, 2)\n",
    "%timeit -n 100000 f7(1, 2)\n",
    "%timeit -n 100000 f8(1, 2)\n",
    "%timeit -n 100000 f9(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "834 ns ± 23.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n",
      "768 ns ± 192 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
     ]
    }
   ],
   "source": [
    "# Generic currying\n",
    "f1 = curry(2, impl)\n",
    "f2 = curry_kw(2, impl)\n",
    "\n",
    "%timeit -n 100000 f1(1, 2)\n",
    "%timeit -n 100000 f2(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'cytoolz'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-28-70cd7f04714c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Other libs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtoolz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcytoolz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuncy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mf1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoolz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mf2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcytoolz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cytoolz'"
     ]
    }
   ],
   "source": [
    "# Other libs\n",
    "import toolz, cytoolz, funcy\n",
    "\n",
    "f1 = toolz.curry(impl)\n",
    "f2 = cytoolz.curry(impl)\n",
    "f3 = funcy.curry(impl, 2)\n",
    "f4 = funcy.autocurry(impl, 2)\n",
    "\n",
    "%timeit -n 100000 f1(1, 2)\n",
    "%timeit -n 100000 f2(1, 2)\n",
    "%timeit -n 100000 f3(1)(2)\n",
    "%timeit -n 100000 f4(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 462,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "209 ns ± 11.2 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n",
      "213 ns ± 13.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)\n"
     ]
    }
   ],
   "source": [
    "# Cython accelerators\n",
    "f1 = cy(2, impl)\n",
    "f2 = cy_fn(2, impl)\n",
    "\n",
    "%timeit -n 100000 -r 100 f1(1, 2)\n",
    "%timeit -n 100000 -r 100 f2(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 463,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [],
   "source": [
    "def foo(f):\n",
    "    def foo1(x):\n",
    "        g = f(x)\n",
    "        def foo2(y):\n",
    "            return g(y)\n",
    "        return foo2\n",
    "    return foo1\n",
    "\n",
    "\n",
    "def foo2(f):\n",
    "    x = yield\n",
    "    g = f(x)\n",
    "    y = yield\n",
    "    return g(y)\n",
    "    \n",
    "foo(curry(2, add))(1)(2)\n",
    "\n",
    "\n",
    "it = foo2(curry(2, add))\n",
    "next(it)\n",
    "it.send(1)\n",
    "# it.send(2)\n",
    "# next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 476,
   "metadata": {
    "hideCode": false,
    "hidePrompt": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting\n",
      "{'args': (1, 2),\n",
      " 'func': <function <lambda> at 0x7f14aea0df28>,\n",
      " 'kwargs': {},\n",
      " 'spec': [1, 2, 1]}\n",
      "\n",
      "reduced\n",
      "{'args': (2,),\n",
      " 'argspec': [1, 2],\n",
      " 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
      " 'func': <function <lambda> at 0x7f14aea0df28>,\n",
      " 'kwargs': {},\n",
      " 'n': 1,\n",
      " 'n_args': 1,\n",
      " 'partial_args': (1,),\n",
      " 'spec': [1, 2, 1]}\n",
      "\n",
      "post_reduction\n",
      "{'args': (2,),\n",
      " 'argspec': [1, 2],\n",
      " 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
      " 'func': <function <lambda> at 0x7f14aea0df28>,\n",
      " 'kwargs': {},\n",
      " 'n': 2,\n",
      " 'n_args': 1,\n",
      " 'partial_args': (1,),\n",
      " 'spec': [1, 2, 1]}\n",
      "\n",
      "starting\n",
      "{'args': (3,),\n",
      " 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
      " 'kwargs': {},\n",
      " 'spec': [1, 2]}\n",
      "\n",
      "post_reduction\n",
      "{'args': (3,),\n",
      " 'argspec': [1, 1],\n",
      " 'call': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
      " 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,\n",
      " 'kwargs': {},\n",
      " 'n': 2,\n",
      " 'n_args': 1,\n",
      " 'spec': [1, 2]}\n",
      "\n",
      "starting\n",
      "{'args': (4,),\n",
      " 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
      " 'kwargs': {},\n",
      " 'spec': [1, 1]}\n",
      "\n",
      "reduced\n",
      "{'args': (),\n",
      " 'argspec': [1],\n",
      " 'call': <function <lambda>.<locals>.<lambda>.<locals>.<lambda> at 0x7f14ae9d9d90>,\n",
      " 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
      " 'kwargs': {},\n",
      " 'n': 1,\n",
      " 'n_args': 0,\n",
      " 'partial_args': (4,),\n",
      " 'spec': [1, 1]}\n",
      "\n",
      "post_reduction\n",
      "{'args': (),\n",
      " 'argspec': [1],\n",
      " 'call': <function <lambda>.<locals>.<lambda>.<locals>.<lambda> at 0x7f14ae9d9d90>,\n",
      " 'func': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),\n",
      " 'kwargs': {},\n",
      " 'n': 1,\n",
      " 'n_args': 0,\n",
      " 'partial_args': (4,),\n",
      " 'spec': [1, 1]}\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(<function __main__.curry.<locals>.incomplete_factory.<locals>.<lambda>(*args)>,)"
      ]
     },
     "execution_count": 476,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from functools import partial\n",
    "from pprint import pprint\n",
    "log = lambda x, y: print(x) or pprint(y) or print()\n",
    "\n",
    "c_add = lambda x: lambda y: lambda z: x + y + z\n",
    "c_add2 = lambda x: lambda y, z: lambda w: x + y + z + w\n",
    "\n",
    "\n",
    "def uncurry(spec, func):\n",
    "    if len(spec) == 1:\n",
    "        return curry(spec[0], func)\n",
    "    \n",
    "    spec = list(reversed(spec))\n",
    "    \n",
    "    def curried(*args, **kwargs):\n",
    "        log('starting', locals())\n",
    "        n_args = len(args)\n",
    "        argspec = spec.copy()\n",
    "        call = func\n",
    "        n = argspec.pop()\n",
    "            \n",
    "        if n > n_args:\n",
    "            argspec.append(n - n_args)\n",
    "            call = partial(call, *args, **kwargs)\n",
    "        else:\n",
    "            while n_args >= n:\n",
    "                partial_args = args[:n]\n",
    "                args = args[n:]\n",
    "                call = call(*partial_args, **kwargs)\n",
    "                n_args -= n\n",
    "                kwargs = {}\n",
    "                log('reduced', locals())\n",
    "                if argspec:\n",
    "                    n = argspec.pop()\n",
    "                else:\n",
    "                    break\n",
    "            else:\n",
    "                argspec.append(n)\n",
    "\n",
    "        log('post_reduction', locals())\n",
    "        \n",
    "        if argspec:\n",
    "            return uncurry(argspec[::-1], call)\n",
    "        else:\n",
    "            return call\n",
    "\n",
    "    return curried\n",
    "\n",
    "\n",
    "# (\n",
    "#     uncurry([1, 1, 1], c_add)(1, 2, 3),\n",
    "#     uncurry([1, 1, 1], c_add)(1)(2)(3), \n",
    "#     uncurry([1, 1, 1], c_add)(1)(2, 3),\n",
    "#     uncurry([1, 1, 1], c_add)(1, 2)(3),\n",
    "# )\n",
    "\n",
    "\n",
    "(\n",
    "#     uncurry([1, 2, 1], c_add2)(1, 2, 3, 4),\n",
    "#    uncurry([1, 2, 1], c_add2)(1)(2)(3)(4), \n",
    "#     uncurry([1, 2, 1], c_add2)(1)(2, 3)(4),\n",
    "     uncurry([1, 2, 1], c_add2)(1, 2)(3)(4),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 452,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.13 µs ± 366 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
      "3.55 µs ± 75.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
      "406 ns ± 10.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "r =  lambda x: lambda y: x + y\n",
    "cadd = uncurry([1, 1], r)\n",
    "%timeit -n 10000 cadd(1, 2)\n",
    "%timeit -n 10000 cadd(1)(2)\n",
    "%timeit -n 10000 r(1)(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_code_all_hidden": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}