ixxi-dante/nw2vec

View on GitHub
julia/feature-gradients.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m\u001b[1mActivating\u001b[22m\u001b[39m environment at `~/Code/Research/nw2vec/Project.toml`\n"
     ]
    }
   ],
   "source": [
    "using Pkg; Pkg.activate(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/LightGraphs/Xm08G.ji for LightGraphs [093fc24a-ae57-5d10-9952-331d41423f4d]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Distributions/xILW0.ji for Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/AbstractPlotting/6fydZ.ji for AbstractPlotting [537997a7-5e4e-5d89-9595-2241ea00577e]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Makie/iZ1Bl.ji for Makie [ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a]\n",
      "└ @ Base loading.jl:1240\n",
      "WARNING: using Makie.AbstractPlotting in module Main conflicts with an existing identifier.\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/CairoMakie/9mSey.ji for CairoMakie [13f3f980-e62b-5c42-98c6-ff1f3baf88f0]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/DataFrames/AR9oZ.ji for DataFrames [a93c6f00-e57d-5684-b7b6-d8193f3e46c0]\n",
      "└ @ Base loading.jl:1240\n",
      "┌ Info: Precompiling Combinatorics [861a8166-3701-5b0c-9a16-15d98fcdc6aa]\n",
      "└ @ Base loading.jl:1242\n"
     ]
    }
   ],
   "source": [
    "include(\"layers.jl\")\n",
    "include(\"utils.jl\")\n",
    "include(\"vae.jl\")\n",
    "using .Utils, .Layers, .VAE\n",
    "\n",
    "using Flux\n",
    "using LightGraphs\n",
    "using Colors\n",
    "using AbstractPlotting\n",
    "using Makie\n",
    "using CairoMakie\n",
    "using BSON: @load\n",
    "using NPZ\n",
    "using Random\n",
    "using StatsBase\n",
    "using Distributions\n",
    "using LinearAlgebra\n",
    "using DataFrames\n",
    "using Combinatorics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dataset (generic function with 1 method)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function dataset(args)\n",
    "    data = npzread(args[\"dataset\"])\n",
    "\n",
    "    features = convert(Array{Float32}, transpose(data[\"features\"]))\n",
    "    classes = transpose(data[\"labels\"])\n",
    "\n",
    "    # Make sure we have a non-weighted graph\n",
    "    @assert Set(data[\"adjdata\"]) == Set([1])\n",
    "\n",
    "    # Remove any diagonal elements in the matrix\n",
    "    rows = data[\"adjrow\"]\n",
    "    cols = data[\"adjcol\"]\n",
    "    nondiagindices = findall(rows .!= cols)\n",
    "    rows = rows[nondiagindices]\n",
    "    cols = cols[nondiagindices]\n",
    "    # Make sure indices start at 0\n",
    "    @assert minimum(rows) == minimum(cols) == 0\n",
    "\n",
    "    # Construct the graph\n",
    "    edges = LightGraphs.SimpleEdge.(1 .+ rows, 1 .+ cols)\n",
    "    g = SimpleGraphFromIterator(edges)\n",
    "\n",
    "    # Check sizes for sanity\n",
    "    @assert nv(g) == size(g, 1) == size(g, 2) == size(features, 2)\n",
    "\n",
    "    # Randomize to the level requested\n",
    "    nnodes = nv(g)\n",
    "    correlation = args[\"forced-correlation\"]\n",
    "    nshuffle = Int(round((1 - correlation) * nnodes))\n",
    "    idx = StatsBase.sample(1:nnodes, nshuffle, replace = false)\n",
    "    shuffledidx = shuffle(idx)\n",
    "    features[:, idx] = features[:, shuffledidx]\n",
    "    classes[:, idx] = classes[:, shuffledidx]\n",
    "\n",
    "    @assert eltype(features) == Float32\n",
    "    g, features, classes\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "make_dataset (generic function with 1 method)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function make_dataset()\n",
    "    g = SimpleGraph(9)\n",
    "    add_edge!(g, 1, 2)\n",
    "    add_edge!(g, 1, 3)\n",
    "    add_edge!(g, 1, 4)\n",
    "    add_edge!(g, 1, 5)\n",
    "    add_edge!(g, 5, 6)\n",
    "    add_edge!(g, 5, 7)\n",
    "    add_edge!(g, 5, 8)\n",
    "    add_edge!(g, 8, 9)\n",
    "    \n",
    "    features = vcat(ones(Float32, (1, 9)), Array(Diagonal(ones(Float32, 9))))\n",
    "    \n",
    "    g, features\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting model state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "plotweights (generic function with 1 method)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adims(a, dims) = [a[i, :] for i = dims]\n",
    "\n",
    "function plotstate(;enc, vae, x, refx, g, dims, colors)\n",
    "    @assert length(dims) in [2, 3]\n",
    "    embμ, emblogσ = enc(x)\n",
    "    logitÂ, unormF̂ = vae(x)\n",
    "    hbox(\n",
    "        vbox(\n",
    "            Scene(),\n",
    "            heatmap(σ.(logitÂ).data, colorrange = (0, 1)),\n",
    "            heatmap(1:size(x, 1), 1:size(x, 2), softmax(unormF̂).data, colorrange = (0, 1)),\n",
    "            sizes = [.45, .45, .1]\n",
    "        ),\n",
    "        vbox(\n",
    "            scatter(adims(embμ, dims)..., color = colors, markersize = Utils.markersize(embμ)),\n",
    "            heatmap(Array(adjacency_matrix(g)), colorrange = (0, 1)),\n",
    "            heatmap(1:size(x, 1), 1:size(x, 2), refx, colorrange = (0, 1)),\n",
    "            sizes = [.45, .45, .1]\n",
    "        ),\n",
    "    )\n",
    "end\n",
    "\n",
    "function plotweights(layers...)\n",
    "    theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)\n",
    "    vbox([hbox(heatmap(l.W.data), text(theme, repr(l))) for l in layers]...)\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dict{String,Any} with 11 entries:\n",
       "  \"diml1dec\"           => 32\n",
       "  \"sharedl1\"           => false\n",
       "  \"label-distribution\" => Bernoulli\n",
       "  \"diml1enc\"           => 32\n",
       "  \"bias\"               => false\n",
       "  \"overlap\"            => 8\n",
       "  \"initb\"              => nobias\n",
       "  \"decadjdeep\"         => true\n",
       "  \"dimxiadj\"           => 16\n",
       "  \"forced-correlation\" => 1.0\n",
       "  \"dimxifeat\"          => 16"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args = Dict(\n",
    "    #\"dataset\" => \"../data/twitter/git=679a9eb593-csv_to_npz-mt=5-tmw=3-w2v_dim=50-w2v_iter=10-cho=True-nclusters=10,20,50,80,100/dataset=retweetsrange-nclusters=10.npz\",\n",
    "    \"forced-correlation\" => 1.0, # default\n",
    "    \"label-distribution\" => VAE.label_distributions[\"bernoulli\"],\n",
    "    \"diml1enc\" => 32,\n",
    "    \"diml1dec\" => 32,\n",
    "    \"dimxiadj\" => 16,\n",
    "    \"dimxifeat\" => 16,\n",
    "    \"overlap\" => 8,\n",
    "    \"bias\" => false,\n",
    "    \"sharedl1\" => false,\n",
    "    \"decadjdeep\" => true,\n",
    "    \"initb\" => VAE.Layers.nobias\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the model and plot its state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#g, _features, _ = dataset(args)\n",
    "g, _features = make_dataset()\n",
    "labels = _features\n",
    "feature_size = size(_features, 1)\n",
    "label_size = feature_size\n",
    "fnormalise = Utils.normaliser(_features)\n",
    "features = fnormalise(_features);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Info: using unshared l1 encoder\n",
      "Info: using deep adjacency decoder\n",
      "Info: using boolean feature decoder\n"
     ]
    }
   ],
   "source": [
    "@load \"../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-weights.bson\" weights args\n",
    "\n",
    "enc, sampleξ, dec, paramsenc, paramsdec = VAE.make_vae(\n",
    "    g = g, feature_size = feature_size, label_size = label_size, args = args)\n",
    "vae(x) = dec(sampleξ(enc(x)...))\n",
    "\n",
    "paramsvae = Tracker.Params()\n",
    "push!(paramsvae, paramsenc..., paramsdec...)\n",
    "loadparams!(paramsvae, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GLMakie.Screen(...)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history = npzread(\"../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-history.npz\")\n",
    "\n",
    "theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)\n",
    "scene = vbox([hbox(lines(1:length(history[name]), history[name], color = color), text(theme, name))\n",
    "        for (name, color) in [\n",
    "                (\"total loss\", :blue),\n",
    "                (\"kl\", :red),\n",
    "                (\"reg\", :red),\n",
    "                (\"adj\", :green),\n",
    "                (\"feat\", :green),\n",
    "                #(\"ap\", :cyan),\n",
    "                #(\"auc\", :cyan)\n",
    "            ]]...)\n",
    "#Makie.save(\"training-history.png\", scene)\n",
    "display(scene)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GLMakie.Screen(...)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#communities = [c for c in 1:args[\"l\"] for i in 1:args[\"k\"]]\n",
    "#palette = distinguishable_colors(args[\"l\"])\n",
    "#colors = map(i -> getindex(palette, i), communities)\n",
    "\n",
    "display(\n",
    "    plotstate(enc = enc, vae = vae, x = _features, refx = labels,\n",
    "        g = g, dims = 1:3, colors = \"black\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distributions of gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ae (generic function with 1 method)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ae(x) = dec(enc(x)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "make_onehot (generic function with 1 method)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function make_onehot(coord, dims)\n",
    "    out = zeros(Float32, dims)\n",
    "    out[coord] = 1\n",
    "    out\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### w.r.t. features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.7637189] (tracked), getfield(Tracker, Symbol(\"##21#23\")){getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box((Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked),)), getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked)]), Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.7637189] (tracked))))"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Apred, back = Tracker.forward(x -> σ.(ae(x)[1]), features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pre-compute all shortest path distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9×9 Array{Int64,2}:\n",
       " 0  1  1  1  1  2  2  2  3\n",
       " 1  0  2  2  2  3  3  3  4\n",
       " 1  2  0  2  2  3  3  3  4\n",
       " 1  2  2  0  2  3  3  3  4\n",
       " 1  2  2  2  0  1  1  1  2\n",
       " 2  3  3  3  1  0  2  2  3\n",
       " 2  3  3  3  1  2  0  2  3\n",
       " 2  3  3  3  1  2  2  0  1\n",
       " 3  4  4  4  2  3  3  1  0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "distances = floyd_warshall_shortest_paths(g).dists"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all gradients w.r.t. involved neighbours for each link prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "nnodes = size(Apred)[1]\n",
    "@assert nnodes == size(Apred)[2]\n",
    "nfeatures = size(features)[1]\n",
    "nhops = 2\n",
    "onehot = zeros((nnodes, nnodes))\n",
    "coord = CartesianIndex(1, 1)\n",
    "\n",
    "df = DataFrame(;((Symbol(\"grad$i\"), Float32[]) for i in 1:nfeatures)...,\n",
    "    u = Int64[], v = Int64[], n = Int64[], dist_uv = Int64[], dist_uorv = Int64[])\n",
    "\n",
    "@showprogress for (u, v) in combinations(1:nnodes, 2)\n",
    "    #print(\"$u - $v: \")\n",
    "    \n",
    "    # Get gradients for the u-v prediction\n",
    "    global coord\n",
    "    onehot[coord] = 0\n",
    "    coord = CartesianIndex(u, v)\n",
    "    onehot[coord] = 1\n",
    "    @assert onehot == make_onehot(CartesianIndex(u, v), (nnodes, nnodes))\n",
    "    grads = Tracker.data(back(onehot)[1])\n",
    "\n",
    "    # Get the gradient for each 2-hop neigbour of u, v\n",
    "    dist_uv = distances[u, v]\n",
    "    neighbours_uv = collect(union(neighborhood(g, u, nhops), neighborhood(g, v, nhops)))\n",
    "    non_neighbours_uv = setdiff(1:nnodes, neighbours_uv)\n",
    "    \n",
    "    # All gradients for neighbours should be non-null\n",
    "    @assert all(sum(grads[:, neighbours_uv] .!= 0, dims = 1) .> 0)\n",
    "    # All gradients for non-neighbours should be null\n",
    "    @assert all(sum(grads[:, non_neighbours_uv] .== 0, dims = 1) .> 0)\n",
    "\n",
    "    #println(\"$(length(neighbours_uv)) neighbours to both\")\n",
    "    for n in neighbours_uv\n",
    "        push!(df, (grads[:, n]..., u, v, n, dist_uv, min(distances[u, n], distances[v, n])))\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table class=\"data-frame\"><thead><tr><th></th><th>grad1</th><th>grad2</th><th>grad3</th><th>grad4</th><th>grad5</th><th>grad6</th><th>grad7</th><th>grad8</th><th>grad9</th><th>grad10</th><th>u</th><th>v</th><th>n</th><th>dist_uv</th><th>dist_uorv</th></tr><tr><th></th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Int64</th><th>Int64</th><th>Int64</th><th>Int64</th><th>Int64</th></tr></thead><tbody><p>276 rows × 15 columns</p><tr><th>1</th><td>0.0645079</td><td>-0.021225</td><td>0.00150802</td><td>0.00318302</td><td>-0.0319299</td><td>0.00660334</td><td>0.00870657</td><td>0.0306246</td><td>-0.00391502</td><td>-0.0416326</td><td>1</td><td>2</td><td>1</td><td>1</td><td>0</td></tr><tr><th>2</th><td>0.0653441</td><td>-0.0255881</td><td>-0.0114106</td><td>0.0108165</td><td>-0.0316851</td><td>-0.00136565</td><td>-0.00199008</td><td>0.0298564</td><td>0.00248509</td><td>-0.0377283</td><td>1</td><td>2</td><td>2</td><td>1</td><td>0</td></tr><tr><th>3</th><td>0.036182</td><td>-0.0138953</td><td>0.00824023</td><td>-0.00308798</td><td>-0.0130371</td><td>0.0059515</td><td>0.00703448</td><td>0.0184967</td><td>-0.00519035</td><td>-0.0281207</td><td>1</td><td>2</td><td>3</td><td>1</td><td>1</td></tr><tr><th>4</th><td>0.0292288</td><td>-0.0099957</td><td>0.00241025</td><td>0.00671407</td><td>-0.00923298</td><td>0.00159172</td><td>0.00227589</td><td>0.0133913</td><td>-0.00464004</td><td>-0.0223713</td><td>1</td><td>2</td><td>4</td><td>1</td><td>1</td></tr><tr><th>5</th><td>0.0139868</td><td>-0.00648846</td><td>-0.00108487</td><td>0.00196581</td><td>-0.00564524</td><td>0.00110447</td><td>0.000616112</td><td>0.00652828</td><td>-0.00118362</td><td>-0.0109085</td><td>1</td><td>2</td><td>5</td><td>1</td><td>1</td></tr><tr><th>6</th><td>0.00515707</td><td>-0.00153295</td><td>-9.54022e-5</td><td>-0.00106444</td><td>-0.00479411</td><td>0.00258529</td><td>0.00279811</td><td>0.00244054</td><td>-0.000862602</td><td>-0.00403407</td><td>1</td><td>2</td><td>6</td><td>1</td><td>2</td></tr><tr><th>7</th><td>0.00515707</td><td>-0.00153295</td><td>-9.54022e-5</td><td>-0.00106444</td><td>-0.00479411</td><td>0.00258529</td><td>0.00279811</td><td>0.00244054</td><td>-0.000862602</td><td>-0.00403407</td><td>1</td><td>2</td><td>7</td><td>1</td><td>2</td></tr><tr><th>8</th><td>0.00421073</td><td>-0.00125165</td><td>-7.78955e-5</td><td>-0.000869113</td><td>-0.00391437</td><td>0.00211088</td><td>0.00228465</td><td>0.00199269</td><td>-0.000704312</td><td>-0.0032938</td><td>1</td><td>2</td><td>8</td><td>1</td><td>2</td></tr><tr><th>9</th><td>0.0807522</td><td>-0.0227874</td><td>0.0254424</td><td>-0.0167195</td><td>-0.0220161</td><td>0.0101796</td><td>0.0140722</td><td>0.0501589</td><td>-0.00739399</td><td>-0.0722438</td><td>1</td><td>3</td><td>1</td><td>1</td><td>0</td></tr><tr><th>10</th><td>0.0340761</td><td>-0.0103437</td><td>-0.00257617</td><td>0.00722817</td><td>-0.0177132</td><td>0.00453164</td><td>0.00870965</td><td>0.00973111</td><td>-0.00574675</td><td>-0.0162742</td><td>1</td><td>3</td><td>2</td><td>1</td><td>1</td></tr><tr><th>11</th><td>0.0940134</td><td>-0.0318718</td><td>0.0344218</td><td>-0.0327469</td><td>-0.00542446</td><td>0.00659211</td><td>0.00370851</td><td>0.0694634</td><td>-0.00152773</td><td>-0.0996689</td><td>1</td><td>3</td><td>3</td><td>1</td><td>0</td></tr><tr><th>12</th><td>0.0324342</td><td>-0.00973636</td><td>0.00403564</td><td>0.00701846</td><td>-0.0169275</td><td>0.00535893</td><td>0.00938935</td><td>0.0100804</td><td>-0.00863605</td><td>-0.0212517</td><td>1</td><td>3</td><td>4</td><td>1</td><td>1</td></tr><tr><th>13</th><td>0.0163616</td><td>-0.00657916</td><td>-0.0015163</td><td>0.00433485</td><td>-0.00743952</td><td>0.00136546</td><td>0.003353</td><td>0.00500078</td><td>-0.00370737</td><td>-0.00774432</td><td>1</td><td>3</td><td>5</td><td>1</td><td>1</td></tr><tr><th>14</th><td>0.00629887</td><td>-0.00162777</td><td>-0.000149417</td><td>0.00192418</td><td>-0.0060904</td><td>0.00131024</td><td>0.00368191</td><td>0.001949</td><td>-0.00250137</td><td>-0.000507388</td><td>1</td><td>3</td><td>6</td><td>1</td><td>2</td></tr><tr><th>15</th><td>0.00629887</td><td>-0.00162777</td><td>-0.000149417</td><td>0.00192418</td><td>-0.0060904</td><td>0.00131024</td><td>0.00368191</td><td>0.001949</td><td>-0.00250137</td><td>-0.000507388</td><td>1</td><td>3</td><td>7</td><td>1</td><td>2</td></tr><tr><th>16</th><td>0.00514301</td><td>-0.00132907</td><td>-0.000121998</td><td>0.00157108</td><td>-0.00497279</td><td>0.0010698</td><td>0.00300627</td><td>0.00159136</td><td>-0.00204236</td><td>-0.000414281</td><td>1</td><td>3</td><td>8</td><td>1</td><td>2</td></tr><tr><th>17</th><td>0.0617292</td><td>-0.0381997</td><td>0.0128525</td><td>0.015317</td><td>-0.00830853</td><td>7.7243e-5</td><td>-0.0030023</td><td>0.0352501</td><td>0.0032523</td><td>-0.0626972</td><td>1</td><td>4</td><td>1</td><td>1</td><td>0</td></tr><tr><th>18</th><td>0.0315697</td><td>-0.0214036</td><td>0.000929661</td><td>0.0126375</td><td>-0.0146739</td><td>0.00134248</td><td>0.00210057</td><td>0.0117977</td><td>-0.00231886</td><td>-0.0223714</td><td>1</td><td>4</td><td>2</td><td>1</td><td>1</td></tr><tr><th>19</th><td>0.0374139</td><td>-0.0274874</td><td>0.014762</td><td>0.000410167</td><td>-0.00619973</td><td>0.00214736</td><td>-1.34762e-6</td><td>0.0223506</td><td>3.35925e-5</td><td>-0.0405679</td><td>1</td><td>4</td><td>3</td><td>1</td><td>1</td></tr><tr><th>20</th><td>0.0606246</td><td>-0.0319523</td><td>0.00271798</td><td>0.0253772</td><td>-0.00479889</td><td>-0.00373943</td><td>-0.0057935</td><td>0.0348819</td><td>0.00547458</td><td>-0.0587476</td><td>1</td><td>4</td><td>4</td><td>1</td><td>0</td></tr><tr><th>21</th><td>0.0144691</td><td>-0.0119399</td><td>0.000725399</td><td>0.00435129</td><td>-0.00615814</td><td>0.00146693</td><td>0.000685575</td><td>0.00692591</td><td>0.000697119</td><td>-0.0146377</td><td>1</td><td>4</td><td>5</td><td>1</td><td>1</td></tr><tr><th>22</th><td>0.00458313</td><td>-0.005771</td><td>0.00140194</td><td>-0.000148894</td><td>-0.00231272</td><td>0.00167019</td><td>0.000371735</td><td>0.00286897</td><td>0.00138584</td><td>-0.00791148</td><td>1</td><td>4</td><td>6</td><td>1</td><td>2</td></tr><tr><th>23</th><td>0.00458313</td><td>-0.005771</td><td>0.00140194</td><td>-0.000148894</td><td>-0.00231272</td><td>0.00167019</td><td>0.000371735</td><td>0.00286897</td><td>0.00138584</td><td>-0.00791148</td><td>1</td><td>4</td><td>7</td><td>1</td><td>2</td></tr><tr><th>24</th><td>0.00374211</td><td>-0.00471201</td><td>0.00114468</td><td>-0.000121571</td><td>-0.00188833</td><td>0.00136371</td><td>0.000303521</td><td>0.0023425</td><td>0.00113153</td><td>-0.0064597</td><td>1</td><td>4</td><td>8</td><td>1</td><td>2</td></tr><tr><th>25</th><td>0.0135714</td><td>0.000587793</td><td>-0.000455864</td><td>-0.00280042</td><td>-0.0118951</td><td>0.00429923</td><td>0.00695602</td><td>0.0092625</td><td>1.17823e-5</td><td>-0.00800769</td><td>1</td><td>5</td><td>1</td><td>1</td><td>0</td></tr><tr><th>26</th><td>0.00842064</td><td>0.000716397</td><td>-0.0014722</td><td>-0.000474771</td><td>-0.00648818</td><td>0.0022389</td><td>0.00288162</td><td>0.00454633</td><td>0.000500505</td><td>-0.00369009</td><td>1</td><td>5</td><td>2</td><td>1</td><td>1</td></tr><tr><th>27</th><td>0.00979808</td><td>-1.31039e-5</td><td>0.000951256</td><td>-0.00217937</td><td>-0.00753009</td><td>0.00323605</td><td>0.00495289</td><td>0.0065921</td><td>-0.00113181</td><td>-0.00520114</td><td>1</td><td>5</td><td>3</td><td>1</td><td>1</td></tr><tr><th>28</th><td>0.00692593</td><td>-0.000400284</td><td>0.000315747</td><td>0.000504271</td><td>-0.00339045</td><td>0.00123777</td><td>0.00237482</td><td>0.00346345</td><td>-0.000968839</td><td>-0.00436599</td><td>1</td><td>5</td><td>4</td><td>1</td><td>1</td></tr><tr><th>29</th><td>0.00958598</td><td>-0.00149535</td><td>-0.0023221</td><td>-0.00599477</td><td>-0.00405036</td><td>-0.00198833</td><td>-0.000829992</td><td>0.0112333</td><td>0.0046394</td><td>-0.0148051</td><td>1</td><td>5</td><td>5</td><td>1</td><td>0</td></tr><tr><th>30</th><td>0.0060058</td><td>0.00156492</td><td>-0.00212839</td><td>-0.00666747</td><td>-0.00397689</td><td>-4.11231e-5</td><td>0.00101687</td><td>0.00900213</td><td>0.00315526</td><td>-0.0106931</td><td>1</td><td>5</td><td>6</td><td>1</td><td>1</td></tr><tr><th>&vellip;</th><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td></tr></tbody></table>"
      ],
      "text/latex": [
       "\\begin{tabular}{r|ccccccccccccccc}\n",
       "\t& grad1 & grad2 & grad3 & grad4 & grad5 & grad6 & grad7 & grad8 & grad9 & grad10 & u & v & n & dist\\_uv & dist\\_uorv\\\\\n",
       "\t\\hline\n",
       "\t& Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Int64 & Int64 & Int64 & Int64 & Int64\\\\\n",
       "\t\\hline\n",
       "\t1 & 0.0645079 & -0.021225 & 0.00150802 & 0.00318302 & -0.0319299 & 0.00660334 & 0.00870657 & 0.0306246 & -0.00391502 & -0.0416326 & 1 & 2 & 1 & 1 & 0 \\\\\n",
       "\t2 & 0.0653441 & -0.0255881 & -0.0114106 & 0.0108165 & -0.0316851 & -0.00136565 & -0.00199008 & 0.0298564 & 0.00248509 & -0.0377283 & 1 & 2 & 2 & 1 & 0 \\\\\n",
       "\t3 & 0.036182 & -0.0138953 & 0.00824023 & -0.00308798 & -0.0130371 & 0.0059515 & 0.00703448 & 0.0184967 & -0.00519035 & -0.0281207 & 1 & 2 & 3 & 1 & 1 \\\\\n",
       "\t4 & 0.0292288 & -0.0099957 & 0.00241025 & 0.00671407 & -0.00923298 & 0.00159172 & 0.00227589 & 0.0133913 & -0.00464004 & -0.0223713 & 1 & 2 & 4 & 1 & 1 \\\\\n",
       "\t5 & 0.0139868 & -0.00648846 & -0.00108487 & 0.00196581 & -0.00564524 & 0.00110447 & 0.000616112 & 0.00652828 & -0.00118362 & -0.0109085 & 1 & 2 & 5 & 1 & 1 \\\\\n",
       "\t6 & 0.00515707 & -0.00153295 & -9.54022e-5 & -0.00106444 & -0.00479411 & 0.00258529 & 0.00279811 & 0.00244054 & -0.000862602 & -0.00403407 & 1 & 2 & 6 & 1 & 2 \\\\\n",
       "\t7 & 0.00515707 & -0.00153295 & -9.54022e-5 & -0.00106444 & -0.00479411 & 0.00258529 & 0.00279811 & 0.00244054 & -0.000862602 & -0.00403407 & 1 & 2 & 7 & 1 & 2 \\\\\n",
       "\t8 & 0.00421073 & -0.00125165 & -7.78955e-5 & -0.000869113 & -0.00391437 & 0.00211088 & 0.00228465 & 0.00199269 & -0.000704312 & -0.0032938 & 1 & 2 & 8 & 1 & 2 \\\\\n",
       "\t9 & 0.0807522 & -0.0227874 & 0.0254424 & -0.0167195 & -0.0220161 & 0.0101796 & 0.0140722 & 0.0501589 & -0.00739399 & -0.0722438 & 1 & 3 & 1 & 1 & 0 \\\\\n",
       "\t10 & 0.0340761 & -0.0103437 & -0.00257617 & 0.00722817 & -0.0177132 & 0.00453164 & 0.00870965 & 0.00973111 & -0.00574675 & -0.0162742 & 1 & 3 & 2 & 1 & 1 \\\\\n",
       "\t11 & 0.0940134 & -0.0318718 & 0.0344218 & -0.0327469 & -0.00542446 & 0.00659211 & 0.00370851 & 0.0694634 & -0.00152773 & -0.0996689 & 1 & 3 & 3 & 1 & 0 \\\\\n",
       "\t12 & 0.0324342 & -0.00973636 & 0.00403564 & 0.00701846 & -0.0169275 & 0.00535893 & 0.00938935 & 0.0100804 & -0.00863605 & -0.0212517 & 1 & 3 & 4 & 1 & 1 \\\\\n",
       "\t13 & 0.0163616 & -0.00657916 & -0.0015163 & 0.00433485 & -0.00743952 & 0.00136546 & 0.003353 & 0.00500078 & -0.00370737 & -0.00774432 & 1 & 3 & 5 & 1 & 1 \\\\\n",
       "\t14 & 0.00629887 & -0.00162777 & -0.000149417 & 0.00192418 & -0.0060904 & 0.00131024 & 0.00368191 & 0.001949 & -0.00250137 & -0.000507388 & 1 & 3 & 6 & 1 & 2 \\\\\n",
       "\t15 & 0.00629887 & -0.00162777 & -0.000149417 & 0.00192418 & -0.0060904 & 0.00131024 & 0.00368191 & 0.001949 & -0.00250137 & -0.000507388 & 1 & 3 & 7 & 1 & 2 \\\\\n",
       "\t16 & 0.00514301 & -0.00132907 & -0.000121998 & 0.00157108 & -0.00497279 & 0.0010698 & 0.00300627 & 0.00159136 & -0.00204236 & -0.000414281 & 1 & 3 & 8 & 1 & 2 \\\\\n",
       "\t17 & 0.0617292 & -0.0381997 & 0.0128525 & 0.015317 & -0.00830853 & 7.7243e-5 & -0.0030023 & 0.0352501 & 0.0032523 & -0.0626972 & 1 & 4 & 1 & 1 & 0 \\\\\n",
       "\t18 & 0.0315697 & -0.0214036 & 0.000929661 & 0.0126375 & -0.0146739 & 0.00134248 & 0.00210057 & 0.0117977 & -0.00231886 & -0.0223714 & 1 & 4 & 2 & 1 & 1 \\\\\n",
       "\t19 & 0.0374139 & -0.0274874 & 0.014762 & 0.000410167 & -0.00619973 & 0.00214736 & -1.34762e-6 & 0.0223506 & 3.35925e-5 & -0.0405679 & 1 & 4 & 3 & 1 & 1 \\\\\n",
       "\t20 & 0.0606246 & -0.0319523 & 0.00271798 & 0.0253772 & -0.00479889 & -0.00373943 & -0.0057935 & 0.0348819 & 0.00547458 & -0.0587476 & 1 & 4 & 4 & 1 & 0 \\\\\n",
       "\t21 & 0.0144691 & -0.0119399 & 0.000725399 & 0.00435129 & -0.00615814 & 0.00146693 & 0.000685575 & 0.00692591 & 0.000697119 & -0.0146377 & 1 & 4 & 5 & 1 & 1 \\\\\n",
       "\t22 & 0.00458313 & -0.005771 & 0.00140194 & -0.000148894 & -0.00231272 & 0.00167019 & 0.000371735 & 0.00286897 & 0.00138584 & -0.00791148 & 1 & 4 & 6 & 1 & 2 \\\\\n",
       "\t23 & 0.00458313 & -0.005771 & 0.00140194 & -0.000148894 & -0.00231272 & 0.00167019 & 0.000371735 & 0.00286897 & 0.00138584 & -0.00791148 & 1 & 4 & 7 & 1 & 2 \\\\\n",
       "\t24 & 0.00374211 & -0.00471201 & 0.00114468 & -0.000121571 & -0.00188833 & 0.00136371 & 0.000303521 & 0.0023425 & 0.00113153 & -0.0064597 & 1 & 4 & 8 & 1 & 2 \\\\\n",
       "\t25 & 0.0135714 & 0.000587793 & -0.000455864 & -0.00280042 & -0.0118951 & 0.00429923 & 0.00695602 & 0.0092625 & 1.17823e-5 & -0.00800769 & 1 & 5 & 1 & 1 & 0 \\\\\n",
       "\t26 & 0.00842064 & 0.000716397 & -0.0014722 & -0.000474771 & -0.00648818 & 0.0022389 & 0.00288162 & 0.00454633 & 0.000500505 & -0.00369009 & 1 & 5 & 2 & 1 & 1 \\\\\n",
       "\t27 & 0.00979808 & -1.31039e-5 & 0.000951256 & -0.00217937 & -0.00753009 & 0.00323605 & 0.00495289 & 0.0065921 & -0.00113181 & -0.00520114 & 1 & 5 & 3 & 1 & 1 \\\\\n",
       "\t28 & 0.00692593 & -0.000400284 & 0.000315747 & 0.000504271 & -0.00339045 & 0.00123777 & 0.00237482 & 0.00346345 & -0.000968839 & -0.00436599 & 1 & 5 & 4 & 1 & 1 \\\\\n",
       "\t29 & 0.00958598 & -0.00149535 & -0.0023221 & -0.00599477 & -0.00405036 & -0.00198833 & -0.000829992 & 0.0112333 & 0.0046394 & -0.0148051 & 1 & 5 & 5 & 1 & 0 \\\\\n",
       "\t30 & 0.0060058 & 0.00156492 & -0.00212839 & -0.00666747 & -0.00397689 & -4.11231e-5 & 0.00101687 & 0.00900213 & 0.00315526 & -0.0106931 & 1 & 5 & 6 & 1 & 1 \\\\\n",
       "\t$\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/plain": [
       "276×15 DataFrame\n",
       "│ Row │ grad1      │ grad2       │ grad3       │ grad4        │ grad5       │ grad6       │ grad7       │ grad8       │ grad9        │ grad10      │ u     │ v     │ n     │ dist_uv │ dist_uorv │\n",
       "│     │ \u001b[90mFloat32\u001b[39m    │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m      │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mFloat32\u001b[39m      │ \u001b[90mFloat32\u001b[39m     │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m   │ \u001b[90mInt64\u001b[39m     │\n",
       "├─────┼────────────┼─────────────┼─────────────┼──────────────┼─────────────┼─────────────┼─────────────┼─────────────┼──────────────┼─────────────┼───────┼───────┼───────┼─────────┼───────────┤\n",
       "│ 1   │ 0.0645079  │ -0.021225   │ 0.00150802  │ 0.00318302   │ -0.0319299  │ 0.00660334  │ 0.00870657  │ 0.0306246   │ -0.00391502  │ -0.0416326  │ 1     │ 2     │ 1     │ 1       │ 0         │\n",
       "│ 2   │ 0.0653441  │ -0.0255881  │ -0.0114106  │ 0.0108165    │ -0.0316851  │ -0.00136565 │ -0.00199008 │ 0.0298564   │ 0.00248509   │ -0.0377283  │ 1     │ 2     │ 2     │ 1       │ 0         │\n",
       "│ 3   │ 0.036182   │ -0.0138953  │ 0.00824023  │ -0.00308798  │ -0.0130371  │ 0.0059515   │ 0.00703448  │ 0.0184967   │ -0.00519035  │ -0.0281207  │ 1     │ 2     │ 3     │ 1       │ 1         │\n",
       "│ 4   │ 0.0292288  │ -0.0099957  │ 0.00241025  │ 0.00671407   │ -0.00923298 │ 0.00159172  │ 0.00227589  │ 0.0133913   │ -0.00464004  │ -0.0223713  │ 1     │ 2     │ 4     │ 1       │ 1         │\n",
       "│ 5   │ 0.0139868  │ -0.00648846 │ -0.00108487 │ 0.00196581   │ -0.00564524 │ 0.00110447  │ 0.000616112 │ 0.00652828  │ -0.00118362  │ -0.0109085  │ 1     │ 2     │ 5     │ 1       │ 1         │\n",
       "│ 6   │ 0.00515707 │ -0.00153295 │ -9.54022e-5 │ -0.00106444  │ -0.00479411 │ 0.00258529  │ 0.00279811  │ 0.00244054  │ -0.000862602 │ -0.00403407 │ 1     │ 2     │ 6     │ 1       │ 2         │\n",
       "│ 7   │ 0.00515707 │ -0.00153295 │ -9.54022e-5 │ -0.00106444  │ -0.00479411 │ 0.00258529  │ 0.00279811  │ 0.00244054  │ -0.000862602 │ -0.00403407 │ 1     │ 2     │ 7     │ 1       │ 2         │\n",
       "│ 8   │ 0.00421073 │ -0.00125165 │ -7.78955e-5 │ -0.000869113 │ -0.00391437 │ 0.00211088  │ 0.00228465  │ 0.00199269  │ -0.000704312 │ -0.0032938  │ 1     │ 2     │ 8     │ 1       │ 2         │\n",
       "│ 9   │ 0.0807522  │ -0.0227874  │ 0.0254424   │ -0.0167195   │ -0.0220161  │ 0.0101796   │ 0.0140722   │ 0.0501589   │ -0.00739399  │ -0.0722438  │ 1     │ 3     │ 1     │ 1       │ 0         │\n",
       "│ 10  │ 0.0340761  │ -0.0103437  │ -0.00257617 │ 0.00722817   │ -0.0177132  │ 0.00453164  │ 0.00870965  │ 0.00973111  │ -0.00574675  │ -0.0162742  │ 1     │ 3     │ 2     │ 1       │ 1         │\n",
       "⋮\n",
       "│ 266 │ 0.0412439  │ 0.0124566   │ -0.0120742  │ 0.00936623   │ -0.0590286  │ 0.012509    │ 0.0284646   │ 0.0304896   │ -0.00870682  │ 0.00460341  │ 7     │ 9     │ 5     │ 3       │ 1         │\n",
       "│ 267 │ 0.00561691 │ 0.00853805  │ -0.00792131 │ 0.00419491   │ -0.0152246  │ -0.00259941 │ 0.00539298  │ 0.00800614  │ -0.003767    │ 0.0103741   │ 7     │ 9     │ 1     │ 3       │ 2         │\n",
       "│ 268 │ 0.00888112 │ 0.0134998   │ -0.0125247  │ 0.00663273   │ -0.0240722  │ -0.00411003 │ 0.00852706  │ 0.0126588   │ -0.00595615  │ 0.0164029   │ 7     │ 9     │ 6     │ 3       │ 2         │\n",
       "│ 269 │ 0.10261    │ -0.0519794  │ 0.0833485   │ -0.0359104   │ 0.0279448   │ 0.00737936  │ -0.00128965 │ -0.00660026 │ 0.0381341    │ -0.14625    │ 7     │ 9     │ 8     │ 3       │ 1         │\n",
       "│ 270 │ 0.11679    │ -0.0771613  │ 0.114605    │ -0.0506138   │ 0.0582975   │ 0.0131479   │ -0.0101065  │ -0.0207424  │ 0.0526607    │ -0.195522   │ 7     │ 9     │ 9     │ 3       │ 0         │\n",
       "│ 271 │ 0.0911527  │ -0.0259464  │ 0.0306541   │ -0.0086948   │ -0.140151   │ 0.0501676   │ 0.0964996   │ 0.026796    │ 0.015733     │ -0.0377764  │ 8     │ 9     │ 8     │ 1       │ 0         │\n",
       "│ 272 │ 0.0248515  │ -0.00836338 │ 0.00185364  │ -0.00484979  │ -0.054738   │ 0.0192284   │ 0.0367432   │ 0.00956858  │ 0.00852271   │ -0.00498861 │ 8     │ 9     │ 5     │ 1       │ 1         │\n",
       "│ 273 │ 0.105916   │ -0.0374046  │ 0.0401043   │ -0.0102666   │ -0.154377   │ 0.0593004   │ 0.109111    │ 0.0258074   │ 0.018702     │ -0.0492926  │ 8     │ 9     │ 9     │ 1       │ 0         │\n",
       "│ 274 │ 0.0036196  │ 0.00355871  │ -0.00161957 │ -0.000241777 │ -0.0109235  │ 0.00135475  │ 0.00574067  │ 0.0044341   │ 0.000358557  │ 0.00191391  │ 8     │ 9     │ 1     │ 1       │ 2         │\n",
       "│ 275 │ 0.0057231  │ 0.00562682  │ -0.00256077 │ -0.000382283 │ -0.0172716  │ 0.00214204  │ 0.00907679  │ 0.00701092  │ 0.000566929  │ 0.00302616  │ 8     │ 9     │ 6     │ 1       │ 2         │\n",
       "│ 276 │ 0.0057231  │ 0.00562682  │ -0.00256077 │ -0.000382283 │ -0.0172716  │ 0.00214204  │ 0.00907679  │ 0.00701092  │ 0.000566929  │ 0.00302616  │ 8     │ 9     │ 7     │ 1       │ 2         │"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ENV[\"COLUMNS\"] = 300\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO**\n",
    "\n",
    "Plot facets of that data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### w.r.t. embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.78653204 … 0.73746455 0.57887185; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.76371896] (tracked), getfield(Tracker, Symbol(\"##21#23\")){getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box(([1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.26693793304649993 -0.6054415825060152; -0.191996930972175 -0.3271772580544453 … -0.5457029432684755 -0.8966404562162545] (tracked),)), getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([[1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.26693793304649993 -0.6054415825060152; -0.191996930972175 -0.3271772580544453 … -0.5457029432684755 -0.8966404562162545] (tracked)]), Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.78653204 … 0.73746455 0.57887185; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.76371896] (tracked))))"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Apred_emb, back_emb = Tracker.forward(x -> σ.(dec(x)[1]), enc(features)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tracked 24×9 Array{Float64,2}:\n",
       "  0.0506319   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.10493   \n",
       " -0.0837624   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.234714  \n",
       "  0.0126053   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.028999  \n",
       " -0.0347072   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.223553  \n",
       "  0.111203    0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0901013 \n",
       "  0.0937184   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0224072 \n",
       " -0.024323    0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.119372  \n",
       "  0.0506258   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.102253  \n",
       " -0.00647627  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0671987 \n",
       "  0.0500753   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.00589014\n",
       " -0.00719932  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0790304 \n",
       "  0.00181941  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0755756 \n",
       " -0.0445624   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0485049 \n",
       "  0.00678587  0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.15691   \n",
       "  0.0739295   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0631873 \n",
       " -0.0275066   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0550068 \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       \n",
       "  0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0       "
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "back_emb(onehot(CartesianIndex(1, 9), (9, 9)))[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO**\n",
    "\n",
    "Precompute all mutual distances.\n",
    "\n",
    "Then:\n",
    "- for each couple of nodes u,v (linked or not):\n",
    "  - get the gradients w.r.t. embeddings\n",
    "  - for both u and v, store the gradient in pandas data along with distance between u and v\n",
    "\n",
    "Then look at the average gradient, faceted by distance between u and v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving the plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embμ, emblogσ = enc(features)\n",
    "logitÂ, unormF̂ = vae(features);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adjacency reconstruction\n",
    "scene = Scene(resolution = (15000, 15000))\n",
    "heatmap!(scene, σ.(logitÂ).data, colorrange = (0, 1))\n",
    "Makie.save(\"Apred.png\", scene);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adjacency reference\n",
    "scene = Scene(resolution = (15000, 15000))\n",
    "heatmap!(scene, Array(adjacency_matrix(g)), colorrange = (0, 1))\n",
    "Makie.save(\"Aref.png\", scene);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feature reconstruction\n",
    "scene = Scene(resolution = (150, 15000))\n",
    "heatmap!(scene, 1:size(_features, 1), 1:size(_features, 2), softmax(unormF̂).data, colorrange = (0, 1))\n",
    "Makie.save(\"Fpred.png\", scene);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feature reference\n",
    "scene = Scene(resolution = (150, 15000))\n",
    "heatmap(1:size(_features, 1), 1:size(_features, 2), _features, colorrange = (0, 1))\n",
    "Makie.save(\"Fref.png\", scene);"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": "5ba504aae1ef4fd98baab9111205e028",
   "lastKernelId": "a7b6f62a-f991-4853-8a1a-62f6b154f88e"
  },
  "kernelspec": {
   "display_name": "Julia 1.2.0",
   "language": "julia",
   "name": "julia-1.2"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.2.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}